Skip to content

Commit

Permalink
Add unit tests for Text Search AOT enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
markwallace-microsoft committed Jan 9, 2025
1 parent 2a5e51b commit 938c6c2
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json.Serialization;
using Microsoft.SemanticKernel.Data;
using SemanticKernel.AotTests.Plugins;

namespace SemanticKernel.AotTests.JsonSerializerContexts;

[JsonSerializable(typeof(CustomResult))]
[JsonSerializable(typeof(int))]
[JsonSerializable(typeof(KernelSearchResults<string>))]
[JsonSerializable(typeof(KernelSearchResults<TextSearchResult>))]
[JsonSerializable(typeof(KernelSearchResults<object>))]
internal sealed partial class CustomResultJsonSerializerContext : JsonSerializerContext
{
}
12 changes: 12 additions & 0 deletions dotnet/src/SemanticKernel.AotTests/Plugins/CustomResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft. All rights reserved.

namespace SemanticKernel.AotTests.Plugins;
internal sealed class CustomResult
{
public string Value { get; set; }

public CustomResult(string value)
{
this.Value = value;
}
}
5 changes: 5 additions & 0 deletions dotnet/src/SemanticKernel.AotTests/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ private static async Task<int> Main(string[] args)

// Tests for text search
VectorStoreTextSearchTests.GetTextSearchResultsAsync,
VectorStoreTextSearchTests.AddVectorStoreTextSearch,

TextSearchExtensionsTests.CreateWithSearch,
TextSearchExtensionsTests.CreateWithGetTextSearchResults,
TextSearchExtensionsTests.CreateWithGetSearchResults,
];

private static async Task<bool> RunUnitTestsAsync(IEnumerable<Func<Task>> functionsToRun)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
<PackageReference Include="Microsoft.Extensions.Configuration" />
<PackageReference Include="Microsoft.Extensions.Configuration.UserSecrets" />
<PackageReference Include="MSTest.TestFramework" />
<PackageReference Include="System.Linq.Async" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.SemanticKernel.Data;

namespace SemanticKernel.AotTests.UnitTests.Search;

internal sealed class MockTextSearch : ITextSearch
{
private readonly KernelSearchResults<object>? _objectResults;
private readonly KernelSearchResults<TextSearchResult>? _textSearchResults;
private readonly KernelSearchResults<string>? _stringResults;

public MockTextSearch(KernelSearchResults<object>? objectResults)
{
this._objectResults = objectResults;
}

public MockTextSearch(KernelSearchResults<TextSearchResult>? textSearchResults)
{
this._textSearchResults = textSearchResults;
}

public MockTextSearch(KernelSearchResults<string>? stringResults)
{
this._stringResults = stringResults;
}

public Task<KernelSearchResults<object>> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default)
{
return Task.FromResult(this._objectResults!);
}

public Task<KernelSearchResults<TextSearchResult>> GetTextSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default)
{
return Task.FromResult(this._textSearchResults!);
}

public Task<KernelSearchResults<string>> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default)
{
return Task.FromResult(this._stringResults!);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Text.Json;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using SemanticKernel.AotTests.JsonSerializerContexts;
using SemanticKernel.AotTests.Plugins;

namespace SemanticKernel.AotTests.UnitTests.Search;

internal sealed class TextSearchExtensionsTests
{
private static readonly JsonSerializerOptions s_jsonSerializerOptions = new()
{
TypeInfoResolverChain = { CustomResultJsonSerializerContext.Default }
};

public static async Task CreateWithSearch()
{
// Arrange
var testData = new List<string> { "test-value" };
KernelSearchResults<string> results = new(testData.ToAsyncEnumerable());
ITextSearch textSearch = new MockTextSearch(results);

// Act
var plugin = textSearch.CreateWithSearch("SearchPlugin", s_jsonSerializerOptions);

// Assert
await AssertSearchFunctionSchemaAndInvocationResult<string>(plugin["Search"], testData[0]);
}

public static async Task CreateWithGetTextSearchResults()
{
// Arrange
var testData = new List<TextSearchResult> { new("test-value") };
KernelSearchResults<TextSearchResult> results = new(testData.ToAsyncEnumerable());
ITextSearch textSearch = new MockTextSearch(results);

// Act
var plugin = textSearch.CreateWithGetTextSearchResults("SearchPlugin", s_jsonSerializerOptions);

// Assert
await AssertSearchFunctionSchemaAndInvocationResult<TextSearchResult>(plugin["GetTextSearchResults"], testData[0]);
}

public static async Task CreateWithGetSearchResults()
{
// Arrange
var testData = new List<CustomResult> { new("test-value") };
KernelSearchResults<object> results = new(testData.ToAsyncEnumerable());
ITextSearch textSearch = new MockTextSearch(results);

// Act
var plugin = textSearch.CreateWithGetSearchResults("SearchPlugin", s_jsonSerializerOptions);

// Assert
await AssertSearchFunctionSchemaAndInvocationResult<object>(plugin["GetSearchResults"], testData[0]);
}

#region assert
internal static async Task AssertSearchFunctionSchemaAndInvocationResult<T>(KernelFunction function, T expectedResult)
{
// Assert input parameter schema
AssertSearchFunctionMetadata<T>(function.Metadata);

// Assert the function result
FunctionResult functionResult = await function.InvokeAsync(new(), new() { ["query"] = "Mock Query" });

var result = functionResult.GetValue<List<T>>()!;
Assert.AreEqual(1, result.Count);
Assert.AreEqual(expectedResult, result[0]);
}

internal static void AssertSearchFunctionMetadata<T>(KernelFunctionMetadata metadata)
{
// Assert input parameter schema
Assert.AreEqual(3, metadata.Parameters.Count);
Assert.AreEqual("{\"description\":\"What to search for\",\"type\":\"string\"}", metadata.Parameters[0].Schema!.ToString());
Assert.AreEqual("{\"description\":\"Number of results (default value: 2)\",\"type\":\"integer\"}", metadata.Parameters[1].Schema!.ToString());
Assert.AreEqual("{\"description\":\"Number of results to skip (default value: 0)\",\"type\":\"integer\"}", metadata.Parameters[2].Schema!.ToString());

// Assert return type schema
var type = typeof(T).Name;
var expectedSchema = type switch
{
"String" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}},\"required\":[\"Results\"]}",
"TextSearchResult" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}",
_ => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}",
};
Assert.AreEqual(expectedSchema, metadata.ReturnParameter.Schema!.ToString());
}
#endregion
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Data;
using Microsoft.VisualStudio.TestTools.UnitTesting;

Expand Down Expand Up @@ -35,6 +37,39 @@ public static async Task GetTextSearchResultsAsync()
Assert.AreEqual("test-link", results[0].Link);
}

public static async Task AddVectorStoreTextSearch()
{
// Arrange
var testData = new List<VectorSearchResult<DataModel>>
{
new(new DataModel { Key = "test-name", Text = "test-result", Link = "test-link" }, 0.5)
};
var vectorizableTextSearch = new MockVectorizableTextSearch<DataModel>(testData);
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IVectorizableTextSearch<DataModel>>(vectorizableTextSearch);

// Act
serviceCollection.AddVectorStoreTextSearch<DataModel>();
var textSearch = serviceCollection.BuildServiceProvider().GetService<VectorStoreTextSearch<DataModel>>();
Assert.IsNotNull(textSearch);

// Assert
KernelSearchResults<TextSearchResult> searchResults = await textSearch.GetTextSearchResultsAsync("query");

List<TextSearchResult> results = [];

await foreach (TextSearchResult result in searchResults.Results)
{
results.Add(result);
}

// Assert
Assert.AreEqual(1, results.Count);
Assert.AreEqual("test-name", results[0].Name);
Assert.AreEqual("test-result", results[0].Value);
Assert.AreEqual("test-link", results[0].Link);
}

private sealed class DataModel
{
[TextSearchResultName]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,9 @@ private static KernelFunctionFromMethodOptions DefaultGetSearchResultsMethodOpti
private static IEnumerable<KernelParameterMetadata> CreateDefaultKernelParameterMetadata(JsonSerializerOptions jsonSerializerOptions)
{
return [
new KernelParameterMetadata("query", jsonSerializerOptions) { Description = "What to search for", IsRequired = true },
new KernelParameterMetadata("count", jsonSerializerOptions) { Description = "Number of results", IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip", jsonSerializerOptions) { Description = "Number of results to skip", IsRequired = false, DefaultValue = 0 },
new KernelParameterMetadata("query", jsonSerializerOptions) { Description = "What to search for", ParameterType = typeof(string), IsRequired = true },
new KernelParameterMetadata("count", jsonSerializerOptions) { Description = "Number of results", ParameterType = typeof(int), IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip", jsonSerializerOptions) { Description = "Number of results to skip", ParameterType = typeof(int), IsRequired = false, DefaultValue = 0 },
];
}

Expand All @@ -530,9 +530,9 @@ private static IEnumerable<KernelParameterMetadata> CreateDefaultKernelParameter
private static IEnumerable<KernelParameterMetadata> GetDefaultKernelParameterMetadata()
{
return s_kernelParameterMetadata ??= [
new KernelParameterMetadata("query") { Description = "What to search for", IsRequired = true },
new KernelParameterMetadata("count") { Description = "Number of results", IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip") { Description = "Number of results to skip", IsRequired = false, DefaultValue = 0 },
new KernelParameterMetadata("query") { Description = "What to search for", ParameterType = typeof(string), IsRequired = true },
new KernelParameterMetadata("count") { Description = "Number of results", ParameterType = typeof(int), IsRequired = false, DefaultValue = 2 },
new KernelParameterMetadata("skip") { Description = "Number of results to skip", ParameterType = typeof(int), IsRequired = false, DefaultValue = 0 },
];
}

Expand Down

0 comments on commit 938c6c2

Please sign in to comment.