diff --git a/plugins/wasm-go/extensions/ai-quota/main.go b/plugins/wasm-go/extensions/ai-quota/main.go index 8d6e57dc45..2facd912bc 100644 --- a/plugins/wasm-go/extensions/ai-quota/main.go +++ b/plugins/wasm-go/extensions/ai-quota/main.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "encoding/json" "errors" "fmt" @@ -215,35 +216,51 @@ func onHttpStreamingResponseBody(ctx wrapper.HttpContext, config QuotaConfig, da if chatMode == ChatModeNone || chatMode == ChatModeAdmin { return data } + var inputToken, outputToken int64 + var consumer string + if inputToken, outputToken, ok := getUsage(data); ok { + ctx.SetContext("input_token", inputToken) + ctx.SetContext("output_token", outputToken) + } + // chat completion mode if !endOfStream { return data } - inputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.input_token"}) - if err != nil { - return data - } - outputTokenStr, err := proxywasm.GetProperty([]string{"filter_state", "wasm.output_token"}) - if err != nil { - return data - } - inputToken, err := strconv.Atoi(string(inputTokenStr)) - if err != nil { - return data - } - outputToken, err := strconv.Atoi(string(outputTokenStr)) - if err != nil { + + if ctx.GetContext("input_token") == nil || ctx.GetContext("output_token") == nil || ctx.GetContext("consumer") == nil { return data } - consumer, ok := ctx.GetContext("consumer").(string) - if ok { - totalToken := int(inputToken + outputToken) - log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken) - config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil) - } + + inputToken = ctx.GetContext("input_token").(int64) + outputToken = ctx.GetContext("output_token").(int64) + consumer = ctx.GetContext("consumer").(string) + totalToken := int(inputToken + outputToken) + log.Debugf("update consumer:%s, totalToken:%d", consumer, totalToken) + config.redisClient.DecrBy(config.RedisKeyPrefix+consumer, totalToken, nil) return data } +func getUsage(data []byte) (inputTokenUsage int64, outputTokenUsage int64, ok bool) { + chunks := bytes.Split(bytes.TrimSpace(data), []byte("\n\n")) + for _, chunk := range chunks { + // the feature strings are used to identify the usage data, like: + // {"model":"gpt2","usage":{"prompt_tokens":1,"completion_tokens":1}} + if !bytes.Contains(chunk, []byte("prompt_tokens")) || !bytes.Contains(chunk, []byte("completion_tokens")) { + continue + } + inputTokenObj := gjson.GetBytes(chunk, "usage.prompt_tokens") + outputTokenObj := gjson.GetBytes(chunk, "usage.completion_tokens") + if inputTokenObj.Exists() && outputTokenObj.Exists() { + inputTokenUsage = inputTokenObj.Int() + outputTokenUsage = outputTokenObj.Int() + ok = true + return + } + } + return +} + func deniedNoKeyAuthData() types.Action { util.SendResponse(http.StatusUnauthorized, "ai-quota.no_key", "text/plain", "Request denied by ai quota check. No Key Authentication information found.") return types.ActionContinue