diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index ca2e45b89..93c3e74ab 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -564,6 +564,23 @@ public Task DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo { return Task.Run(() => Decode(batch), cancellationToken); } + + /// + /// + /// + /// + /// + /// + /// A tuple, containing the decode result, the number of tokens that have not been decoded yet and the total number of tokens that have been decoded. + public Task<(DecodeResult, int, int)> DecodeAsync(List tokens, LLamaSeqId id, LLamaBatch batch, int n_past) + { + return Task.Run(() => + { + var past = n_past; + var res = NativeHandle.Decode(tokens, id, batch, ref past); + return (res.Item1, res.Item2, past); + }); + } #endregion /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index d6e24530f..d8a5c530d 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -177,7 +177,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } /// - protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { var batch = new LLamaBatch(); @@ -194,7 +194,9 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta TryReuseMatchingPrefix(); - var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); + var (result, _, pastTokensCount) = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = pastTokensCount; + if (result != DecodeResult.Ok) throw new LLamaDecodeError(result); @@ -259,7 +261,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } } - return Task.CompletedTask; + return; } /// /// The descriptor of the state of the instruct executor. diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index c4893e5b9..f4a4ca965 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -222,7 +222,7 @@ private Task PreprocessLlava(string text, InferStateArgs args, bool addBos = tru } /// - protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) + protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args) { var batch = new LLamaBatch(); @@ -250,11 +250,13 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta // Changes to support Multi-Modal LLMs. // - (DecodeResult, int) header, end, result; + (DecodeResult, int, int) header, end, result; if (IsMultiModal && _EmbedImagePosition > 0) { // Tokens previous to the images - header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); + header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = header.Item3; + if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1); // Images @@ -262,7 +264,8 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount); // Post-image Tokens - end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount); + end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = end.Item3; _EmbedImagePosition = -1; _imageEmbedHandles.Clear(); @@ -270,7 +273,9 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } else { - result = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount); + result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount); + _pastTokensCount = result.Item3; + if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1); } @@ -346,7 +351,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta } } - return Task.CompletedTask; + return; } /// diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 433d9cd16..ca868b77d 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -96,7 +96,9 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams // Evaluate the prompt, in chunks smaller than the max batch size var n_past = 0; - var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past); + var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past); + n_past = past; + if (r != DecodeResult.Ok) throw new LLamaDecodeError(r);