diff --git a/dotnet/src/SemanticKernel.AotTests/JsonSerializerContexts/CustomResultJsonSerializerContext.cs b/dotnet/src/SemanticKernel.AotTests/JsonSerializerContexts/CustomResultJsonSerializerContext.cs new file mode 100644 index 000000000000..c5a0d599864c --- /dev/null +++ b/dotnet/src/SemanticKernel.AotTests/JsonSerializerContexts/CustomResultJsonSerializerContext.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json.Serialization; +using Microsoft.SemanticKernel.Data; +using SemanticKernel.AotTests.Plugins; + +namespace SemanticKernel.AotTests.JsonSerializerContexts; + +[JsonSerializable(typeof(CustomResult))] +[JsonSerializable(typeof(int))] +[JsonSerializable(typeof(KernelSearchResults))] +[JsonSerializable(typeof(KernelSearchResults))] +[JsonSerializable(typeof(KernelSearchResults))] +internal sealed partial class CustomResultJsonSerializerContext : JsonSerializerContext +{ +} diff --git a/dotnet/src/SemanticKernel.AotTests/Plugins/CustomResult.cs b/dotnet/src/SemanticKernel.AotTests/Plugins/CustomResult.cs new file mode 100644 index 000000000000..4ab3d1045218 --- /dev/null +++ b/dotnet/src/SemanticKernel.AotTests/Plugins/CustomResult.cs @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace SemanticKernel.AotTests.Plugins; +internal sealed class CustomResult +{ + public string Value { get; set; } + + public CustomResult(string value) + { + this.Value = value; + } +} diff --git a/dotnet/src/SemanticKernel.AotTests/Program.cs b/dotnet/src/SemanticKernel.AotTests/Program.cs index 49fae7449dd4..a9fa29b9a2a3 100644 --- a/dotnet/src/SemanticKernel.AotTests/Program.cs +++ b/dotnet/src/SemanticKernel.AotTests/Program.cs @@ -59,6 +59,11 @@ private static async Task Main(string[] args) // Tests for text search VectorStoreTextSearchTests.GetTextSearchResultsAsync, + VectorStoreTextSearchTests.AddVectorStoreTextSearch, + + TextSearchExtensionsTests.CreateWithSearch, + TextSearchExtensionsTests.CreateWithGetTextSearchResults, + TextSearchExtensionsTests.CreateWithGetSearchResults, ]; private static async Task RunUnitTestsAsync(IEnumerable> functionsToRun) diff --git a/dotnet/src/SemanticKernel.AotTests/SemanticKernel.AotTests.csproj b/dotnet/src/SemanticKernel.AotTests/SemanticKernel.AotTests.csproj index 6b54614b9ca5..9da3b544ac88 100644 --- a/dotnet/src/SemanticKernel.AotTests/SemanticKernel.AotTests.csproj +++ b/dotnet/src/SemanticKernel.AotTests/SemanticKernel.AotTests.csproj @@ -18,6 +18,7 @@ + diff --git a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockTextSearch.cs b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockTextSearch.cs new file mode 100644 index 000000000000..72aa218239f9 --- /dev/null +++ b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/MockTextSearch.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.SemanticKernel.Data; + +namespace SemanticKernel.AotTests.UnitTests.Search; + +internal sealed class MockTextSearch : ITextSearch +{ + private readonly KernelSearchResults? _objectResults; + private readonly KernelSearchResults? _textSearchResults; + private readonly KernelSearchResults? _stringResults; + + public MockTextSearch(KernelSearchResults? objectResults) + { + this._objectResults = objectResults; + } + + public MockTextSearch(KernelSearchResults? textSearchResults) + { + this._textSearchResults = textSearchResults; + } + + public MockTextSearch(KernelSearchResults? stringResults) + { + this._stringResults = stringResults; + } + + public Task> GetSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + return Task.FromResult(this._objectResults!); + } + + public Task> GetTextSearchResultsAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + return Task.FromResult(this._textSearchResults!); + } + + public Task> SearchAsync(string query, TextSearchOptions? searchOptions = null, CancellationToken cancellationToken = default) + { + return Task.FromResult(this._stringResults!); + } +} diff --git a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/TextSearchExtensionsTests.cs b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/TextSearchExtensionsTests.cs new file mode 100644 index 000000000000..603b95a939f5 --- /dev/null +++ b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/TextSearchExtensionsTests.cs @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text.Json; +using Microsoft.SemanticKernel; +using Microsoft.SemanticKernel.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using SemanticKernel.AotTests.JsonSerializerContexts; +using SemanticKernel.AotTests.Plugins; + +namespace SemanticKernel.AotTests.UnitTests.Search; + +internal sealed class TextSearchExtensionsTests +{ + private static readonly JsonSerializerOptions s_jsonSerializerOptions = new() + { + TypeInfoResolverChain = { CustomResultJsonSerializerContext.Default } + }; + + public static async Task CreateWithSearch() + { + // Arrange + var testData = new List { "test-value" }; + KernelSearchResults results = new(testData.ToAsyncEnumerable()); + ITextSearch textSearch = new MockTextSearch(results); + + // Act + var plugin = textSearch.CreateWithSearch("SearchPlugin", s_jsonSerializerOptions); + + // Assert + await AssertSearchFunctionSchemaAndInvocationResult(plugin["Search"], testData[0]); + } + + public static async Task CreateWithGetTextSearchResults() + { + // Arrange + var testData = new List { new("test-value") }; + KernelSearchResults results = new(testData.ToAsyncEnumerable()); + ITextSearch textSearch = new MockTextSearch(results); + + // Act + var plugin = textSearch.CreateWithGetTextSearchResults("SearchPlugin", s_jsonSerializerOptions); + + // Assert + await AssertSearchFunctionSchemaAndInvocationResult(plugin["GetTextSearchResults"], testData[0]); + } + + public static async Task CreateWithGetSearchResults() + { + // Arrange + var testData = new List { new("test-value") }; + KernelSearchResults results = new(testData.ToAsyncEnumerable()); + ITextSearch textSearch = new MockTextSearch(results); + + // Act + var plugin = textSearch.CreateWithGetSearchResults("SearchPlugin", s_jsonSerializerOptions); + + // Assert + await AssertSearchFunctionSchemaAndInvocationResult(plugin["GetSearchResults"], testData[0]); + } + + #region assert + internal static async Task AssertSearchFunctionSchemaAndInvocationResult(KernelFunction function, T expectedResult) + { + // Assert input parameter schema + AssertSearchFunctionMetadata(function.Metadata); + + // Assert the function result + FunctionResult functionResult = await function.InvokeAsync(new(), new() { ["query"] = "Mock Query" }); + + var result = functionResult.GetValue>()!; + Assert.AreEqual(1, result.Count); + Assert.AreEqual(expectedResult, result[0]); + } + + internal static void AssertSearchFunctionMetadata(KernelFunctionMetadata metadata) + { + // Assert input parameter schema + Assert.AreEqual(3, metadata.Parameters.Count); + Assert.AreEqual("{\"description\":\"What to search for\",\"type\":\"string\"}", metadata.Parameters[0].Schema!.ToString()); + Assert.AreEqual("{\"description\":\"Number of results (default value: 2)\",\"type\":\"integer\"}", metadata.Parameters[1].Schema!.ToString()); + Assert.AreEqual("{\"description\":\"Number of results to skip (default value: 0)\",\"type\":\"integer\"}", metadata.Parameters[2].Schema!.ToString()); + + // Assert return type schema + var type = typeof(T).Name; + var expectedSchema = type switch + { + "String" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"string\"}}},\"required\":[\"Results\"]}", + "TextSearchResult" => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}", + _ => "{\"type\":\"object\",\"properties\":{\"TotalCount\":{\"type\":[\"integer\",\"null\"],\"default\":null},\"Metadata\":{\"type\":[\"object\",\"null\"],\"default\":null},\"Results\":{\"type\":\"array\",\"items\":{\"type\":\"object\",\"properties\":{\"Name\":{\"type\":[\"string\",\"null\"]},\"Link\":{\"type\":[\"string\",\"null\"]},\"Value\":{\"type\":\"string\"}},\"required\":[\"Value\"]}}},\"required\":[\"Results\"]}", + }; + Assert.AreEqual(expectedSchema, metadata.ReturnParameter.Schema!.ToString()); + } + #endregion +} diff --git a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs index 01fb449d455b..eee8ae4db55e 100644 --- a/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs +++ b/dotnet/src/SemanticKernel.AotTests/UnitTests/Search/VectorStoreTextSearchTests.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft. All rights reserved. +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.VectorData; +using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Data; using Microsoft.VisualStudio.TestTools.UnitTesting; @@ -35,6 +37,39 @@ public static async Task GetTextSearchResultsAsync() Assert.AreEqual("test-link", results[0].Link); } + public static async Task AddVectorStoreTextSearch() + { + // Arrange + var testData = new List> + { + new(new DataModel { Key = "test-name", Text = "test-result", Link = "test-link" }, 0.5) + }; + var vectorizableTextSearch = new MockVectorizableTextSearch(testData); + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton>(vectorizableTextSearch); + + // Act + serviceCollection.AddVectorStoreTextSearch(); + var textSearch = serviceCollection.BuildServiceProvider().GetService>(); + Assert.IsNotNull(textSearch); + + // Assert + KernelSearchResults searchResults = await textSearch.GetTextSearchResultsAsync("query"); + + List results = []; + + await foreach (TextSearchResult result in searchResults.Results) + { + results.Add(result); + } + + // Assert + Assert.AreEqual(1, results.Count); + Assert.AreEqual("test-name", results[0].Name); + Assert.AreEqual("test-result", results[0].Value); + Assert.AreEqual("test-link", results[0].Link); + } + private sealed class DataModel { [TextSearchResultName] diff --git a/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchExtensions.cs b/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchExtensions.cs index 5487ad7fac3f..55e2df49e046 100644 --- a/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchExtensions.cs +++ b/dotnet/src/SemanticKernel.Core/Data/TextSearch/TextSearchExtensions.cs @@ -519,9 +519,9 @@ private static KernelFunctionFromMethodOptions DefaultGetSearchResultsMethodOpti private static IEnumerable CreateDefaultKernelParameterMetadata(JsonSerializerOptions jsonSerializerOptions) { return [ - new KernelParameterMetadata("query", jsonSerializerOptions) { Description = "What to search for", IsRequired = true }, - new KernelParameterMetadata("count", jsonSerializerOptions) { Description = "Number of results", IsRequired = false, DefaultValue = 2 }, - new KernelParameterMetadata("skip", jsonSerializerOptions) { Description = "Number of results to skip", IsRequired = false, DefaultValue = 0 }, + new KernelParameterMetadata("query", jsonSerializerOptions) { Description = "What to search for", ParameterType = typeof(string), IsRequired = true }, + new KernelParameterMetadata("count", jsonSerializerOptions) { Description = "Number of results", ParameterType = typeof(int), IsRequired = false, DefaultValue = 2 }, + new KernelParameterMetadata("skip", jsonSerializerOptions) { Description = "Number of results to skip", ParameterType = typeof(int), IsRequired = false, DefaultValue = 0 }, ]; } @@ -530,9 +530,9 @@ private static IEnumerable CreateDefaultKernelParameter private static IEnumerable GetDefaultKernelParameterMetadata() { return s_kernelParameterMetadata ??= [ - new KernelParameterMetadata("query") { Description = "What to search for", IsRequired = true }, - new KernelParameterMetadata("count") { Description = "Number of results", IsRequired = false, DefaultValue = 2 }, - new KernelParameterMetadata("skip") { Description = "Number of results to skip", IsRequired = false, DefaultValue = 0 }, + new KernelParameterMetadata("query") { Description = "What to search for", ParameterType = typeof(string), IsRequired = true }, + new KernelParameterMetadata("count") { Description = "Number of results", ParameterType = typeof(int), IsRequired = false, DefaultValue = 2 }, + new KernelParameterMetadata("skip") { Description = "Number of results to skip", ParameterType = typeof(int), IsRequired = false, DefaultValue = 0 }, ]; }