Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Microsoft.Extensions.AI support for IChatClient / IEmbeddingGenerator #964

Merged
merged 2 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.AI;
using Xunit.Abstractions;

namespace LLama.Unittest;
Expand Down Expand Up @@ -41,6 +42,27 @@ private async Task CompareEmbeddings(string modelPath)
var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, spoon);

var generator = (IEmbeddingGenerator<string, Embedding<float>>)embedder;
Assert.NotNull(generator.Metadata);
Assert.Equal(nameof(LLamaEmbedder), generator.Metadata.ProviderName);
Assert.NotNull(generator.Metadata.ModelId);
Assert.NotEmpty(generator.Metadata.ModelId);
Assert.Same(embedder, generator.GetService<LLamaEmbedder>());
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Null(generator.GetService<string>());

var embeddings = await generator.GenerateAsync(
[
"The cat is cute",
"The kitten is cute",
"The spoon is not real"
]);
Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.True(embeddings.Usage?.InputTokenCount is 19 or 20);
Assert.True(embeddings.Usage?.TotalTokenCount is 19 or 20);

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
Expand Down
169 changes: 169 additions & 0 deletions LLama/Extensions/LLamaExecutorExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using LLama.Common;
using LLama.Sampling;
using Microsoft.Extensions.AI;

namespace LLama.Abstractions;

/// <summary>
/// Extension methods to the <see cref="LLamaExecutorExtensions" /> interface.
/// </summary>
public static class LLamaExecutorExtensions
{
/// <summary>Gets an <see cref="IChatClient"/> instance for the specified <see cref="ILLamaExecutor"/>.</summary>
/// <param name="executor">The executor.</param>
/// <param name="historyTransform">The <see cref="IHistoryTransform"/> to use to transform an input list messages into a prompt.</param>
/// <param name="outputTransform">The <see cref="ITextStreamTransform"/> to use to transform the output into text.</param>
/// <returns>An <see cref="IChatClient"/> instance for the provided <see cref="ILLamaExecutor" />.</returns>
/// <exception cref="ArgumentNullException"><paramref name="executor"/> is null.</exception>
public static IChatClient AsChatClient(
this ILLamaExecutor executor,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null) =>
new LLamaExecutorChatClient(
executor ?? throw new ArgumentNullException(nameof(executor)),
historyTransform,
outputTransform);

private sealed class LLamaExecutorChatClient(
ILLamaExecutor executor,
IHistoryTransform? historyTransform = null,
ITextStreamTransform? outputTransform = null) : IChatClient
{
private static readonly InferenceParams s_defaultParams = new();
private static readonly DefaultSamplingPipeline s_defaultPipeline = new();
private static readonly string[] s_antiPrompts = ["User:", "Assistant:", "System:"];
[ThreadStatic]
private static Random? t_random;

private readonly ILLamaExecutor _executor = executor;
private readonly IHistoryTransform _historyTransform = historyTransform ?? new AppendAssistantHistoryTransform();
private readonly ITextStreamTransform _outputTransform = outputTransform ??
new LLamaTransforms.KeywordTextOutputStreamTransform(s_antiPrompts);

/// <inheritdoc/>
public ChatClientMetadata Metadata { get; } = new(nameof(LLamaExecutorChatClient));

/// <inheritdoc/>
public void Dispose() { }

/// <inheritdoc/>
public TService? GetService<TService>(object? key = null) where TService : class =>
typeof(TService) == typeof(ILLamaExecutor) ? (TService)_executor :
this as TService;

/// <inheritdoc/>
public async Task<ChatCompletion> CompleteAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default)
{
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);

StringBuilder text = new();
await foreach (var token in _outputTransform.TransformAsync(result))
{
text.Append(token);
}

return new(new ChatMessage(ChatRole.Assistant, text.ToString()))
{
CreatedAt = DateTime.UtcNow,
};
}

/// <inheritdoc/>
public async IAsyncEnumerable<StreamingChatCompletionUpdate> CompleteStreamingAsync(
IList<ChatMessage> chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken);

await foreach (var token in _outputTransform.TransformAsync(result))
{
yield return new()
{
CreatedAt = DateTime.UtcNow,
Role = ChatRole.Assistant,
Text = token,
};
}
}

/// <summary>Format the chat messages into a string prompt.</summary>
private string CreatePrompt(IList<ChatMessage> messages)
{
if (messages is null)
{
throw new ArgumentNullException(nameof(messages));
}

ChatHistory history = new();

if (_executor is not StatefulExecutorBase seb ||
seb.GetStateData() is InteractiveExecutor.InteractiveExecutorState { IsPromptRun: true })
{
foreach (var message in messages)
{
history.AddMessage(
message.Role == ChatRole.System ? AuthorRole.System :
message.Role == ChatRole.Assistant ? AuthorRole.Assistant :
AuthorRole.User,
string.Concat(message.Contents.OfType<TextContent>()));
}
}
else
{
// Stateless executor with IsPromptRun = false: use only the last message.
history.AddMessage(AuthorRole.User, string.Concat(messages.LastOrDefault()?.Contents.OfType<TextContent>() ?? []));
}

return _historyTransform.HistoryToText(history);
}

/// <summary>Convert the chat options to inference parameters.</summary>
private static InferenceParams? CreateInferenceParams(ChatOptions? options)
{
List<string> antiPrompts = new(s_antiPrompts);
if (options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.AntiPrompts), out IReadOnlyList<string>? anti) is true)
{
antiPrompts.AddRange(anti);
}

return new()
{
AntiPrompts = antiPrompts,
TokensKeep = options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.TokensKeep), out int tk) is true ? tk : s_defaultParams.TokensKeep,
MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit
SamplingPipeline = new DefaultSamplingPipeline()
{
AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency,
AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence,
PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS,
PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline,
RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty,
RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount,
Grammar = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Grammar), out Grammar? g) is true ? g : s_defaultPipeline.Grammar,
MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep,
MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP,
Seed = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Seed), out uint seed) is true ? seed : (uint)(t_random ??= new()).Next(),
TailFreeZ = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TailFreeZ), out float tfz) is true ? tfz : s_defaultPipeline.TailFreeZ,
Temperature = options?.Temperature ?? 0,
TopP = options?.TopP ?? 0,
TopK = options?.TopK ?? s_defaultPipeline.TopK,
TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP,
},
};
}

/// <summary>A default transform that appends "Assistant: " to the end.</summary>
private sealed class AppendAssistantHistoryTransform : LLamaTransforms.DefaultHistoryTransform
{
public override string HistoryToText(ChatHistory history) =>
$"{base.HistoryToText(history)}{AuthorRole.Assistant}: ";
}
}
}
12 changes: 12 additions & 0 deletions LLama/Extensions/SpanNormalizationExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ public static Span<float> EuclideanNormalization(this Span<float> vector)
return vector;
}

/// <summary>
/// Creates a new array containing an L2 normalization of the input vector.
/// </summary>
/// <param name="vector"></param>
/// <returns>The same span</returns>
public static float[] EuclideanNormalization(this ReadOnlySpan<float> vector)
{
var result = new float[vector.Length];
TensorPrimitives.Divide(vector, TensorPrimitives.Norm(vector), result);
return result;
}

/// <summary>
/// <b>In-place</b> apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
/// <list type="bullet">
Expand Down
54 changes: 54 additions & 0 deletions LLama/LLamaEmbedder.EmbeddingGenerator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using Microsoft.Extensions.AI;

namespace LLama;

public partial class LLamaEmbedder
: IEmbeddingGenerator<string, Embedding<float>>
{
private EmbeddingGeneratorMetadata? _metadata;

/// <inheritdoc />
EmbeddingGeneratorMetadata IEmbeddingGenerator<string, Embedding<float>>.Metadata =>
_metadata ??= new(
nameof(LLamaEmbedder),
modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
dimensions: EmbeddingSize);

/// <inheritdoc />
TService? IEmbeddingGenerator<string, Embedding<float>>.GetService<TService>(object? key) where TService : class =>
typeof(TService) == typeof(LLamaContext) ? (TService)(object)Context :
this as TService;

/// <inheritdoc />
async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)
{
if (Context.NativeHandle.PoolingType == LLamaPoolingType.None)
{
throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}.");
}

GeneratedEmbeddings<Embedding<float>> results = new()
{
Usage = new() { InputTokenCount = 0 },
};

foreach (var value in values)
{
var (embeddings, tokenCount) = await GetEmbeddingsWithTokenCount(value, cancellationToken).ConfigureAwait(false);
Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding when pooling is enabled.");

results.Usage.InputTokenCount += tokenCount;
results.Add(new Embedding<float>(embeddings[0]) { CreatedAt = DateTime.UtcNow });
}

results.Usage.TotalTokenCount = results.Usage.InputTokenCount;

return results;
}
}
9 changes: 6 additions & 3 deletions LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace LLama;
/// <summary>
/// Generate high dimensional embedding vectors from text
/// </summary>
public sealed class LLamaEmbedder
public sealed partial class LLamaEmbedder
: IDisposable
{
/// <summary>
Expand Down Expand Up @@ -58,7 +58,10 @@ public void Dispose()
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
/// <exception cref="NotSupportedException"></exception>
public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, CancellationToken cancellationToken = default)
public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, CancellationToken cancellationToken = default) =>
(await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings;

private async Task<(IReadOnlyList<float[]> Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default)
{
// Add all of the tokens to the batch
var tokens = Context.Tokenize(input);
Expand Down Expand Up @@ -113,6 +116,6 @@ public async Task<IReadOnlyList<float[]>> GetEmbeddings(string input, Cancellati

Context.NativeHandle.KvCacheClear();

return results;
return (results, tokens.Length);
}
}
2 changes: 2 additions & 0 deletions LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" Version="8.0.0" />
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="8.0.1" />
<PackageReference Include="System.Numerics.Tensors" Version="8.0.0" />
</ItemGroup>
Expand Down
Loading