diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs index f48d1ef45..f70a3c830 100644 --- a/LLama.Unittest/LLamaEmbedderTests.cs +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -1,6 +1,7 @@ using LLama.Common; using LLama.Extensions; using LLama.Native; +using Microsoft.Extensions.AI; using Xunit.Abstractions; namespace LLama.Unittest; @@ -41,6 +42,27 @@ private async Task CompareEmbeddings(string modelPath) var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization(); Assert.DoesNotContain(float.NaN, spoon); + var generator = (IEmbeddingGenerator>)embedder; + Assert.NotNull(generator.Metadata); + Assert.Equal(nameof(LLamaEmbedder), generator.Metadata.ProviderName); + Assert.NotNull(generator.Metadata.ModelId); + Assert.NotEmpty(generator.Metadata.ModelId); + Assert.Same(embedder, generator.GetService()); + Assert.Same(generator, generator.GetService>>()); + Assert.Null(generator.GetService()); + + var embeddings = await generator.GenerateAsync( + [ + "The cat is cute", + "The kitten is cute", + "The spoon is not real" + ]); + Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001)); + Assert.True(embeddings.Usage?.InputTokenCount is 19 or 20); + Assert.True(embeddings.Usage?.TotalTokenCount is 19 or 20); + _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); diff --git a/LLama/Extensions/LLamaExecutorExtensions.cs b/LLama/Extensions/LLamaExecutorExtensions.cs new file mode 100644 index 000000000..7fe6cc871 --- /dev/null +++ b/LLama/Extensions/LLamaExecutorExtensions.cs @@ -0,0 +1,169 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.CompilerServices; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using LLama.Common; +using LLama.Sampling; +using Microsoft.Extensions.AI; + +namespace LLama.Abstractions; + +/// +/// Extension methods to the interface. +/// +public static class LLamaExecutorExtensions +{ + /// Gets an instance for the specified . + /// The executor. + /// The to use to transform an input list messages into a prompt. + /// The to use to transform the output into text. + /// An instance for the provided . + /// is null. + public static IChatClient AsChatClient( + this ILLamaExecutor executor, + IHistoryTransform? historyTransform = null, + ITextStreamTransform? outputTransform = null) => + new LLamaExecutorChatClient( + executor ?? throw new ArgumentNullException(nameof(executor)), + historyTransform, + outputTransform); + + private sealed class LLamaExecutorChatClient( + ILLamaExecutor executor, + IHistoryTransform? historyTransform = null, + ITextStreamTransform? outputTransform = null) : IChatClient + { + private static readonly InferenceParams s_defaultParams = new(); + private static readonly DefaultSamplingPipeline s_defaultPipeline = new(); + private static readonly string[] s_antiPrompts = ["User:", "Assistant:", "System:"]; + [ThreadStatic] + private static Random? t_random; + + private readonly ILLamaExecutor _executor = executor; + private readonly IHistoryTransform _historyTransform = historyTransform ?? new AppendAssistantHistoryTransform(); + private readonly ITextStreamTransform _outputTransform = outputTransform ?? + new LLamaTransforms.KeywordTextOutputStreamTransform(s_antiPrompts); + + /// + public ChatClientMetadata Metadata { get; } = new(nameof(LLamaExecutorChatClient)); + + /// + public void Dispose() { } + + /// + public TService? GetService(object? key = null) where TService : class => + typeof(TService) == typeof(ILLamaExecutor) ? (TService)_executor : + this as TService; + + /// + public async Task CompleteAsync( + IList chatMessages, ChatOptions? options = null, CancellationToken cancellationToken = default) + { + var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken); + + StringBuilder text = new(); + await foreach (var token in _outputTransform.TransformAsync(result)) + { + text.Append(token); + } + + return new(new ChatMessage(ChatRole.Assistant, text.ToString())) + { + CreatedAt = DateTime.UtcNow, + }; + } + + /// + public async IAsyncEnumerable CompleteStreamingAsync( + IList chatMessages, ChatOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + var result = _executor.InferAsync(CreatePrompt(chatMessages), CreateInferenceParams(options), cancellationToken); + + await foreach (var token in _outputTransform.TransformAsync(result)) + { + yield return new() + { + CreatedAt = DateTime.UtcNow, + Role = ChatRole.Assistant, + Text = token, + }; + } + } + + /// Format the chat messages into a string prompt. + private string CreatePrompt(IList messages) + { + if (messages is null) + { + throw new ArgumentNullException(nameof(messages)); + } + + ChatHistory history = new(); + + if (_executor is not StatefulExecutorBase seb || + seb.GetStateData() is InteractiveExecutor.InteractiveExecutorState { IsPromptRun: true }) + { + foreach (var message in messages) + { + history.AddMessage( + message.Role == ChatRole.System ? AuthorRole.System : + message.Role == ChatRole.Assistant ? AuthorRole.Assistant : + AuthorRole.User, + string.Concat(message.Contents.OfType())); + } + } + else + { + // Stateless executor with IsPromptRun = false: use only the last message. + history.AddMessage(AuthorRole.User, string.Concat(messages.LastOrDefault()?.Contents.OfType() ?? [])); + } + + return _historyTransform.HistoryToText(history); + } + + /// Convert the chat options to inference parameters. + private static InferenceParams? CreateInferenceParams(ChatOptions? options) + { + List antiPrompts = new(s_antiPrompts); + if (options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.AntiPrompts), out IReadOnlyList? anti) is true) + { + antiPrompts.AddRange(anti); + } + + return new() + { + AntiPrompts = antiPrompts, + TokensKeep = options?.AdditionalProperties?.TryGetValue(nameof(InferenceParams.TokensKeep), out int tk) is true ? tk : s_defaultParams.TokensKeep, + MaxTokens = options?.MaxOutputTokens ?? 256, // arbitrary upper limit + SamplingPipeline = new DefaultSamplingPipeline() + { + AlphaFrequency = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaFrequency), out float af) is true ? af : s_defaultPipeline.AlphaFrequency, + AlphaPresence = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.AlphaPresence), out float ap) is true ? ap : s_defaultPipeline.AlphaPresence, + PenalizeEOS = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeEOS), out bool eos) is true ? eos : s_defaultPipeline.PenalizeEOS, + PenalizeNewline = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.PenalizeNewline), out bool pnl) is true ? pnl : s_defaultPipeline.PenalizeNewline, + RepeatPenalty = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenalty), out float rp) is true ? rp : s_defaultPipeline.RepeatPenalty, + RepeatPenaltyCount = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.RepeatPenaltyCount), out int rpc) is true ? rpc : s_defaultPipeline.RepeatPenaltyCount, + Grammar = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Grammar), out Grammar? g) is true ? g : s_defaultPipeline.Grammar, + MinKeep = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinKeep), out int mk) is true ? mk : s_defaultPipeline.MinKeep, + MinP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.MinP), out float mp) is true ? mp : s_defaultPipeline.MinP, + Seed = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.Seed), out uint seed) is true ? seed : (uint)(t_random ??= new()).Next(), + TailFreeZ = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TailFreeZ), out float tfz) is true ? tfz : s_defaultPipeline.TailFreeZ, + Temperature = options?.Temperature ?? 0, + TopP = options?.TopP ?? 0, + TopK = options?.TopK ?? s_defaultPipeline.TopK, + TypicalP = options?.AdditionalProperties?.TryGetValue(nameof(DefaultSamplingPipeline.TypicalP), out float tp) is true ? tp : s_defaultPipeline.TypicalP, + }, + }; + } + + /// A default transform that appends "Assistant: " to the end. + private sealed class AppendAssistantHistoryTransform : LLamaTransforms.DefaultHistoryTransform + { + public override string HistoryToText(ChatHistory history) => + $"{base.HistoryToText(history)}{AuthorRole.Assistant}: "; + } + } +} \ No newline at end of file diff --git a/LLama/Extensions/SpanNormalizationExtensions.cs b/LLama/Extensions/SpanNormalizationExtensions.cs index 8ed827b64..42eaaf163 100644 --- a/LLama/Extensions/SpanNormalizationExtensions.cs +++ b/LLama/Extensions/SpanNormalizationExtensions.cs @@ -81,6 +81,18 @@ public static Span EuclideanNormalization(this Span vector) return vector; } + /// + /// Creates a new array containing an L2 normalization of the input vector. + /// + /// + /// The same span + public static float[] EuclideanNormalization(this ReadOnlySpan vector) + { + var result = new float[vector.Length]; + TensorPrimitives.Divide(vector, TensorPrimitives.Norm(vector), result); + return result; + } + /// /// In-place apply p-normalization. https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm /// diff --git a/LLama/LLamaEmbedder.EmbeddingGenerator.cs b/LLama/LLamaEmbedder.EmbeddingGenerator.cs new file mode 100644 index 000000000..c404d7b3e --- /dev/null +++ b/LLama/LLamaEmbedder.EmbeddingGenerator.cs @@ -0,0 +1,54 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using LLama.Native; +using Microsoft.Extensions.AI; + +namespace LLama; + +public partial class LLamaEmbedder + : IEmbeddingGenerator> +{ + private EmbeddingGeneratorMetadata? _metadata; + + /// + EmbeddingGeneratorMetadata IEmbeddingGenerator>.Metadata => + _metadata ??= new( + nameof(LLamaEmbedder), + modelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null, + dimensions: EmbeddingSize); + + /// + TService? IEmbeddingGenerator>.GetService(object? key) where TService : class => + typeof(TService) == typeof(LLamaContext) ? (TService)(object)Context : + this as TService; + + /// + async Task>> IEmbeddingGenerator>.GenerateAsync(IEnumerable values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken) + { + if (Context.NativeHandle.PoolingType == LLamaPoolingType.None) + { + throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}."); + } + + GeneratedEmbeddings> results = new() + { + Usage = new() { InputTokenCount = 0 }, + }; + + foreach (var value in values) + { + var (embeddings, tokenCount) = await GetEmbeddingsWithTokenCount(value, cancellationToken).ConfigureAwait(false); + Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding when pooling is enabled."); + + results.Usage.InputTokenCount += tokenCount; + results.Add(new Embedding(embeddings[0]) { CreatedAt = DateTime.UtcNow }); + } + + results.Usage.TotalTokenCount = results.Usage.InputTokenCount; + + return results; + } +} \ No newline at end of file diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index ed6240359..9584eee4f 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -12,7 +12,7 @@ namespace LLama; /// /// Generate high dimensional embedding vectors from text /// -public sealed class LLamaEmbedder +public sealed partial class LLamaEmbedder : IDisposable { /// @@ -58,7 +58,10 @@ public void Dispose() /// /// /// - public async Task> GetEmbeddings(string input, CancellationToken cancellationToken = default) + public async Task> GetEmbeddings(string input, CancellationToken cancellationToken = default) => + (await GetEmbeddingsWithTokenCount(input, cancellationToken).ConfigureAwait(false)).Embeddings; + + private async Task<(IReadOnlyList Embeddings, int Tokens)> GetEmbeddingsWithTokenCount(string input, CancellationToken cancellationToken = default) { // Add all of the tokens to the batch var tokens = Context.Tokenize(input); @@ -113,6 +116,6 @@ public async Task> GetEmbeddings(string input, Cancellati Context.NativeHandle.KvCacheClear(); - return results; + return (results, tokens.Length); } } \ No newline at end of file diff --git a/LLama/LLamaSharp.csproj b/LLama/LLamaSharp.csproj index a38174b13..b7dbf5bcd 100644 --- a/LLama/LLamaSharp.csproj +++ b/LLama/LLamaSharp.csproj @@ -49,6 +49,8 @@ + +