Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net: Add shared integration tests for checking vector search scores #10144

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.VectorData;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.Memory;

/// <summary>
/// Base class for common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
/// <typeparam name="TKey">The type of key to use with the record collection.</typeparam>
public abstract class BaseVectorStoreRecordCollectionTests<TKey>
where TKey : notnull
{
protected abstract TKey Key1 { get; }
protected abstract TKey Key2 { get; }
protected abstract TKey Key3 { get; }
protected abstract TKey Key4 { get; }

protected abstract HashSet<string> GetSupportedDistanceFunctions();

protected abstract IVectorStoreRecordCollection<TKey, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition);

protected virtual int DelayAfterIndexCreateInMilliseconds { get; } = 0;

protected virtual int DelayAfterUploadInMilliseconds { get; } = 0;

[Theory]
[InlineData(DistanceFunction.CosineDistance, 0, 2)]
[InlineData(DistanceFunction.CosineSimilarity, 1, -1)]
[InlineData(DistanceFunction.DotProductSimilarity, 1, -1)]
[InlineData(DistanceFunction.EuclideanDistance, 0, 2)]
[InlineData(DistanceFunction.EuclideanSquaredDistance, 0, 4)]
[InlineData(DistanceFunction.Hamming, 0, 1)]
[InlineData(DistanceFunction.ManhattanDistance, 0, 2)]
public async Task VectorSearchShouldReturnExpectedScoresAsync(string distanceFunction, double expectedExactMatchScore, double expectedOppositeScore)
{
// Don't test unsupported distance functions.
var supportedDistanceFunctions = this.GetSupportedDistanceFunctions();
if (!supportedDistanceFunctions.Contains(distanceFunction))
{
return;
}

// Arrange
var definition = CreateKeyWithVectorRecordDefinition(4, distanceFunction);
var sut = this.GetTargetRecordCollection<KeyWithVectorRecord<TKey>>(
$"scorebydistancefunction{distanceFunction}",
definition);

await sut.CreateCollectionIfNotExistsAsync();
await Task.Delay(this.DelayAfterIndexCreateInMilliseconds);

// Create two vectors that are opposite to each other and records that use these.
var baseVector = new ReadOnlyMemory<float>([1, 0, 0, 0]);
var oppositeVector = new ReadOnlyMemory<float>([-1, 0, 0, 0]);

var baseRecord = new KeyWithVectorRecord<TKey>
{
Key = this.Key1,
Vector = baseVector,
};

var oppositeRecord = new KeyWithVectorRecord<TKey>
{
Key = this.Key2,
Vector = oppositeVector,
};

await sut.UpsertBatchAsync([baseRecord, oppositeRecord]).ToListAsync();
await Task.Delay(this.DelayAfterUploadInMilliseconds);

// Act
var searchResult = await sut.VectorizedSearchAsync(baseVector);

// Assert
var results = await searchResult.Results.ToListAsync();
Assert.Equal(2, results.Count);

Assert.Equal(this.Key1, results[0].Record.Key);
Assert.Equal(expectedExactMatchScore, results[0].Score);

Assert.Equal(this.Key2, results[1].Record.Key);
Assert.Equal(expectedOppositeScore, results[1].Score);

// Cleanup
await sut.DeleteCollectionAsync();
}

private static VectorStoreRecordDefinition CreateKeyWithVectorRecordDefinition(int vectorDimensions, string distanceFunction)
{
var definition = new VectorStoreRecordDefinition
{
Properties =
[
new VectorStoreRecordKeyProperty("Key", typeof(TKey)),
new VectorStoreRecordVectorProperty("Vector", typeof(ReadOnlyMemory<float>)) { Dimensions = vectorDimensions, DistanceFunction = distanceFunction },
],
};

return definition;
}

private class KeyWithVectorRecord<TRecordKey>

Check failure on line 108 in dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, ubuntu-latest, Release, true, integration)

Type 'KeyWithVectorRecord' can be sealed because it has no subtypes in its containing assembly and is not externally visible (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1852)

Check failure on line 108 in dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, ubuntu-latest, Release, true, integration)

Type 'KeyWithVectorRecord' can be sealed because it has no subtypes in its containing assembly and is not externally visible (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1852)

Check failure on line 108 in dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Debug)

Type 'KeyWithVectorRecord' can be sealed because it has no subtypes in its containing assembly and is not externally visible (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1852)

Check failure on line 108 in dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Debug)

Type 'KeyWithVectorRecord' can be sealed because it has no subtypes in its containing assembly and is not externally visible (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1852)

Check failure on line 108 in dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Release)

Type 'KeyWithVectorRecord' can be sealed because it has no subtypes in its containing assembly and is not externally visible (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1852)

Check failure on line 108 in dotnet/src/IntegrationTests/Connectors/Memory/BaseVectorStoreRecordCollectionTests.cs

View workflow job for this annotation

GitHub Actions / dotnet-build-and-test (8.0, windows-latest, Release)

Type 'KeyWithVectorRecord' can be sealed because it has no subtypes in its containing assembly and is not externally visible (https://learn.microsoft.com/dotnet/fundamentals/code-analysis/quality-rules/ca1852)
{
public required TRecordKey Key { get; set; }

public ReadOnlyMemory<float> Vector { get; set; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.InMemory;

namespace SemanticKernel.IntegrationTests.Connectors.Memory.InMemory;

/// <summary>
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
public class CommonInMemoryVectorStoreRecordCollectionTests() : BaseVectorStoreRecordCollectionTests<string>
{
protected override string Key1 => "1";
protected override string Key2 => "2";
protected override string Key3 => "3";
protected override string Key4 => "4";

protected override IVectorStoreRecordCollection<string, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition)
{
return new InMemoryVectorStoreRecordCollection<string, TRecord>(recordCollectionName, new()
{
VectorStoreRecordDefinition = vectorStoreRecordDefinition
});
}

protected override HashSet<string> GetSupportedDistanceFunctions()
{
return [DistanceFunction.CosineDistance, DistanceFunction.CosineSimilarity, DistanceFunction.DotProductSimilarity, DistanceFunction.EuclideanDistance];
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.Postgres;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.Memory.Postgres;

/// <summary>
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
/// <param name="fixture">Postres setup and teardown.</param>
[Collection("PostgresVectorStoreCollection")]
public class CommonPostgresVectorStoreRecordCollectionTests(PostgresVectorStoreFixture fixture) : BaseVectorStoreRecordCollectionTests<string>
{
protected override string Key1 => "1";
protected override string Key2 => "2";
protected override string Key3 => "3";
protected override string Key4 => "4";

protected override IVectorStoreRecordCollection<string, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition)
{
return new PostgresVectorStoreRecordCollection<string, TRecord>(fixture.DataSource!, recordCollectionName, new()
{
VectorStoreRecordDefinition = vectorStoreRecordDefinition
});
}

protected override HashSet<string> GetSupportedDistanceFunctions()
{
return [DistanceFunction.CosineDistance, DistanceFunction.CosineSimilarity, DistanceFunction.DotProductSimilarity, DistanceFunction.EuclideanDistance, DistanceFunction.ManhattanDistance];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,22 @@ public PostgresVectorStoreFixture()
/// <summary>
/// Holds the Npgsql data source to use for tests.
/// </summary>
private NpgsqlDataSource? _dataSource;
public NpgsqlDataSource? DataSource { get; private set; }

private string _connectionString = null!;
private string _databaseName = null!;

/// <summary>
/// Gets a vector store to use for tests.
/// </summary>
public IVectorStore VectorStore => new PostgresVectorStore(this._dataSource!);
public IVectorStore VectorStore => new PostgresVectorStore(this.DataSource!);

/// <summary>
/// Get a database connection
/// </summary>
public NpgsqlConnection GetConnection()
{
return this._dataSource!.OpenConnection();
return this.DataSource!.OpenConnection();
}

public IVectorStoreRecordCollection<TKey, TRecord> GetCollection<TKey, TRecord>(
Expand Down Expand Up @@ -81,7 +81,7 @@ public async Task InitializeAsync()
NpgsqlDataSourceBuilder dataSourceBuilder = new(connectionStringBuilder.ToString());
dataSourceBuilder.UseVector();

this._dataSource = dataSourceBuilder.Build();
this.DataSource = dataSourceBuilder.Build();

// Wait for the postgres container to be ready and create the test database using the initial data source.
var initialDataSource = NpgsqlDataSource.Create(this._connectionString);
Expand Down Expand Up @@ -124,7 +124,7 @@ public async Task InitializeAsync()

private async Task CreateTableAsync()
{
NpgsqlConnection connection = await this._dataSource!.OpenConnectionAsync().ConfigureAwait(false);
NpgsqlConnection connection = await this.DataSource!.OpenConnectionAsync().ConfigureAwait(false);

await using (connection)
{
Expand All @@ -150,9 +150,9 @@ DescriptionEmbedding VECTOR(4) NOT NULL,
/// <returns>An async task.</returns>
public async Task DisposeAsync()
{
if (this._dataSource != null)
if (this.DataSource != null)
{
this._dataSource.Dispose();
this.DataSource.Dispose();
}

await this.DropDatabaseAsync();
Expand Down Expand Up @@ -218,7 +218,7 @@ private async Task CreateDatabaseAsync(NpgsqlDataSource initialDataSource)
await command.ExecuteNonQueryAsync();
}

await using (NpgsqlConnection conn = await this._dataSource!.OpenConnectionAsync())
await using (NpgsqlConnection conn = await this.DataSource!.OpenConnectionAsync())
{
await using (NpgsqlCommand command = new("CREATE EXTENSION vector", conn))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;
using Microsoft.Extensions.VectorData;
using Microsoft.SemanticKernel.Connectors.Qdrant;
using Xunit;

namespace SemanticKernel.IntegrationTests.Connectors.Memory.Qdrant;

/// <summary>
/// Inherits common integration tests that should pass for any <see cref="IVectorStoreRecordCollection{TKey, TRecord}"/>.
/// </summary>
/// <param name="fixture">Qdrant setup and teardown.</param>
[Collection("QdrantVectorStoreCollection")]
public class CommonQdrantVectorStoreRecordCollectionTests(QdrantVectorStoreFixture fixture) : BaseVectorStoreRecordCollectionTests<ulong>
{
protected override ulong Key1 => 1;
protected override ulong Key2 => 2;
protected override ulong Key3 => 3;
protected override ulong Key4 => 4;

protected override IVectorStoreRecordCollection<ulong, TRecord> GetTargetRecordCollection<TRecord>(string recordCollectionName, VectorStoreRecordDefinition? vectorStoreRecordDefinition)
{
return new QdrantVectorStoreRecordCollection<TRecord>(fixture.QdrantClient, recordCollectionName, new()
{
HasNamedVectors = true,
VectorStoreRecordDefinition = vectorStoreRecordDefinition
});
}

protected override HashSet<string> GetSupportedDistanceFunctions()
{
return [DistanceFunction.CosineSimilarity, DistanceFunction.EuclideanDistance, DistanceFunction.ManhattanDistance];
}
}
Loading