Skip to content

Commit

Permalink
Fix potential race condition in GetResultOrRunClassInitialize
Browse files Browse the repository at this point in the history
  • Loading branch information
Youssef1313 committed Jan 8, 2025
1 parent 1db1afd commit 124e797
Showing 1 changed file with 54 additions and 36 deletions.
90 changes: 54 additions & 36 deletions src/Adapter/MSTest.TestAdapter/Execution/TestClassInfo.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for full license information.

using Microsoft.VisualStudio.TestPlatform.MSTest.TestAdapter.Extensions;
Expand Down Expand Up @@ -385,52 +385,70 @@ internal UnitTestResult GetResultOrRunClassInitialize(ITestContext testContext,
return clonedInitializeResult;
}

DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available.");

// For optimization purposes, return right away if there is nothing to execute.
// For STA, this avoids starting a thread when we know it will do nothing.
// But we still return early even not STA.
if (ClassInitializeMethod is null && BaseClassInitMethods.Count == 0)
// At this point, maybe class initialize was executed by another thread such
// that TryGetClonedCachedClassInitializeResult would return non-null.
// Now, we need to check again, but under a lock.
// Note that we are duplicating the logic above.
// We could keep the logic in lock only and not duplicate, but we don't want to pay
// the lock cost unnecessarily for a common case.
lock (_testClassExecuteSyncObject)
{
IsClassInitializeExecuted = true;
return _classInitializeResult = new(ObjectModelUnitTestOutcome.Passed, null);
}
clonedInitializeResult = TryGetClonedCachedClassInitializeResult();

bool isSTATestClass = AttributeComparer.IsDerived<STATestClassAttribute>(ClassAttribute);
bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows);
if (isSTATestClass
&& isWindowsOS
&& Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
{
UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete");
Thread entryPointThread = new(() => result = DoRun())
// Optimization: If we already ran before and know the result, return it.
if (clonedInitializeResult is not null)
{
Name = "MSTest STATestClass ClassInitialize",
};
DebugEx.Assert(IsClassInitializeExecuted, "Class initialize result should be available if and only if class initialize was executed");
return clonedInitializeResult;
}

entryPointThread.SetApartmentState(ApartmentState.STA);
entryPointThread.Start();
DebugEx.Assert(!IsClassInitializeExecuted, "If class initialize was executed, we should have been in the previous if were we have a result available.");

try
// For optimization purposes, return right away if there is nothing to execute.
// For STA, this avoids starting a thread when we know it will do nothing.
// But we still return early even not STA.
if (ClassInitializeMethod is null && BaseClassInitMethods.Count == 0)
{
entryPointThread.Join();
return result;
IsClassInitializeExecuted = true;
return _classInitializeResult = new(ObjectModelUnitTestOutcome.Passed, null);
}
catch (Exception ex)

bool isSTATestClass = AttributeComparer.IsDerived<STATestClassAttribute>(ClassAttribute);
bool isWindowsOS = RuntimeInformation.IsOSPlatform(OSPlatform.Windows);
if (isSTATestClass
&& isWindowsOS
&& Thread.CurrentThread.GetApartmentState() != ApartmentState.STA)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString());
return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation()));
UnitTestResult result = new(ObjectModelUnitTestOutcome.Error, "MSTest STATestClass ClassInitialize didn't complete");
Thread entryPointThread = new(() => result = DoRun())
{
Name = "MSTest STATestClass ClassInitialize",
};

entryPointThread.SetApartmentState(ApartmentState.STA);
entryPointThread.Start();

try
{
entryPointThread.Join();
return result;
}
catch (Exception ex)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogError(ex.ToString());
return new UnitTestResult(new TestFailedException(ObjectModelUnitTestOutcome.Error, ex.TryGetMessage(), ex.TryGetStackTraceInformation()));
}
}
}
else
{
// If the requested apartment state is STA and the OS is not Windows, then warn the user.
if (!isWindowsOS && isSTATestClass)
else
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning);
}
// If the requested apartment state is STA and the OS is not Windows, then warn the user.
if (!isWindowsOS && isSTATestClass)
{
PlatformServiceProvider.Instance.AdapterTraceLogger.LogWarning(Resource.STAIsOnlySupportedOnWindowsWarning);
}

return DoRun();
return DoRun();
}
}

// Local functions
Expand Down

0 comments on commit 124e797

Please sign in to comment.