Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async implementation of LLamaExecutors #834

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,35 @@ public Task<DecodeResult> DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
{
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
/// </summary>
/// <param name="tokens"></param>
/// <param name="id"></param>
/// <param name="batch"></param>
/// <param name="n_past"></param>
/// <returns></returns>
public (DecodeResult, int) Decode(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, ref int n_past)
martindevans marked this conversation as resolved.
Show resolved Hide resolved
{
return NativeHandle.Decode(tokens, id, batch, ref n_past);
}

/// <summary>
/// </summary>
/// <param name="tokens"></param>
/// <param name="id"></param>
/// <param name="batch"></param>
/// <param name="n_past"></param>
/// <returns></returns>
martindevans marked this conversation as resolved.
Show resolved Hide resolved
public Task<(DecodeResult, int, int)> DecodeAsync(List<LLamaToken> tokens, LLamaSeqId id, LLamaBatch batch, int n_past)
{
return Task.Run(() =>
{
var past = n_past;
var res = NativeHandle.Decode(tokens, id, batch, ref past);
return (res.Item1, res.Item2, past);
});
}
#endregion

/// <inheritdoc />
Expand Down
8 changes: 5 additions & 3 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -147,11 +147,11 @@
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
Expand All @@ -177,7 +177,7 @@
}

/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

Expand All @@ -194,7 +194,9 @@

TryReuseMatchingPrefix();

var (result, _) = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
var (result, _, pastTokensCount) = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = pastTokensCount;

if (result != DecodeResult.Ok)
throw new LLamaDecodeError(result);

Expand All @@ -215,7 +217,7 @@
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 220 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

LLamaToken id;
Expand Down Expand Up @@ -259,7 +261,7 @@
}
}

return Task.CompletedTask;
return;
}
/// <summary>
/// The descriptor of the state of the instruct executor.
Expand All @@ -275,7 +277,7 @@
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public LLamaToken[] InputPrefixTokens { get; set; }

Check warning on line 280 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Non-nullable property 'InputPrefixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
/// <summary>
/// Instruction suffix tokens.
/// </summary>
Expand Down
17 changes: 11 additions & 6 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 101 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -159,7 +159,7 @@
{
foreach (var image in Images)
{
_imageEmbedHandles.Add(SafeLlavaImageEmbedHandle.CreateFromMemory(ClipModel.NativeHandle, Context, image));

Check warning on line 162 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Dereference of a possibly null reference.
}

int imageIndex = text.IndexOf("<image>");
Expand Down Expand Up @@ -196,11 +196,11 @@
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 199 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 203 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
Expand All @@ -222,7 +222,7 @@
}

/// <inheritdoc />
protected override Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
var batch = new LLamaBatch();

Expand Down Expand Up @@ -250,27 +250,32 @@

// Changes to support Multi-Modal LLMs.
//
(DecodeResult, int) header, end, result;
(DecodeResult, int, int) header, end, result;
if (IsMultiModal && _EmbedImagePosition > 0)
{
// Tokens previous to the images
header = Context.NativeHandle.Decode(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
header = await Context.DecodeAsync(_embeds.GetRange(0, _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = header.Item3;

if (header.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(header.Item1);

// Images
foreach( var image in _imageEmbedHandles )
ClipModel.EvalImageEmbed(Context, image, ref _pastTokensCount);

// Post-image Tokens
end = Context.NativeHandle.Decode(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, ref _pastTokensCount);
end = await Context.DecodeAsync(_embeds.GetRange(_EmbedImagePosition, _embeds.Count - _EmbedImagePosition), LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = end.Item3;

_EmbedImagePosition = -1;
_imageEmbedHandles.Clear();
Images.Clear();
}
else
{
result = Context.NativeHandle.Decode(_embeds, LLamaSeqId.Zero, batch, ref _pastTokensCount);
result = await Context.DecodeAsync(_embeds, LLamaSeqId.Zero, batch, _pastTokensCount);
_pastTokensCount = result.Item3;

if (result.Item1 != DecodeResult.Ok) throw new LLamaDecodeError(result.Item1);
}

Expand Down Expand Up @@ -346,7 +351,7 @@
}
}

return Task.CompletedTask;
return;
}

/// <summary>
Expand Down
4 changes: 3 additions & 1 deletion LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams

// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var (r, _) = Context.NativeHandle.Decode(tokens, LLamaSeqId.Zero, _batch, ref n_past);
var (r, _, past) = await Context.DecodeAsync(tokens, LLamaSeqId.Zero, _batch, n_past);
n_past = past;

if (r != DecodeResult.Ok)
throw new LLamaDecodeError(r);

Expand Down
Loading