Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backend based airls from cp to dp #1193

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apim-apk-agent/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ require (
github.com/pelletier/go-toml v1.9.5
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.9.0
github.com/wso2/apk/common-go-libs v0.0.0-20240920041902-85449a1c0150
github.com/wso2/apk/common-go-libs v0.0.0-20240923143402-ff7fdb0366f9
google.golang.org/grpc v1.62.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v2 v2.4.0
Expand Down
4 changes: 2 additions & 2 deletions apim-apk-agent/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ github.com/vektah/gqlparser v1.3.1 h1:8b0IcD3qZKWJQHSzynbDlrtP3IxVydZ2DZepCGofqf
github.com/vektah/gqlparser v1.3.1/go.mod h1:bkVf0FX+Stjg/MHnm8mEyubuaArhNEqfQhF+OTiAL74=
github.com/wso2/apk/adapter v0.0.0-20240408123538-86a74d977eee h1:g0ivVkzybfcEkB0vBGTAXTUuMZpsF3zOTVtAgmW851s=
github.com/wso2/apk/adapter v0.0.0-20240408123538-86a74d977eee/go.mod h1:xYS5auF/YxnyRykw7NBSn/YR2FHD4hTeyav4Nhec8d0=
github.com/wso2/apk/common-go-libs v0.0.0-20240920041902-85449a1c0150 h1:X3OezAh2UOxmQIRxsAua87nNqmoIGXx1yfQIvc4a+G4=
github.com/wso2/apk/common-go-libs v0.0.0-20240920041902-85449a1c0150/go.mod h1:SbZVA1jeiVG9dqk9fGcY/bB0JgEaQgtXqFAlxAfN0Lk=
github.com/wso2/apk/common-go-libs v0.0.0-20240923143402-ff7fdb0366f9 h1:MwQqG+/ODDIfLfc3xNMYk6jM+hB2ttjwZnaDBeiMOJI=
github.com/wso2/apk/common-go-libs v0.0.0-20240923143402-ff7fdb0366f9/go.mod h1:SbZVA1jeiVG9dqk9fGcY/bB0JgEaQgtXqFAlxAfN0Lk=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
Expand Down
4 changes: 2 additions & 2 deletions apim-apk-agent/internal/eventhub/dataloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"strconv"
"time"

dpv1alpha2 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha2"
dpv1alpha3 "github.com/wso2/apk/common-go-libs/apis/dp/v1alpha3"
"github.com/wso2/product-apim-tooling/apim-apk-agent/config"
internalk8sClient "github.com/wso2/product-apim-tooling/apim-apk-agent/internal/k8sClient"
logger "github.com/wso2/product-apim-tooling/apim-apk-agent/internal/loggers"
Expand Down Expand Up @@ -240,7 +240,7 @@ func FetchAPIsOnStartUp(conf *config.Config, k8sClient client.Client) {
if err != nil {
logger.LoggerEventhub.Errorf("Error occurred while fetching APIs from control plane %v", err)
}
removeApis := make([]dpv1alpha2.API, 0)
removeApis := make([]dpv1alpha3.API, 0)
for _, k8sAPI := range k8sAPIS {
found := false
if apis != nil {
Expand Down
74 changes: 67 additions & 7 deletions apim-apk-agent/internal/k8sClient/k8s_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ import (
)

// DeployAPICR applies the given API struct to the Kubernetes cluster.
func DeployAPICR(api *dpv1alpha2.API, k8sClient client.Client) {
crAPI := &dpv1alpha2.API{}
func DeployAPICR(api *dpv1alpha3.API, k8sClient client.Client) {
crAPI := &dpv1alpha3.API{}
if err := k8sClient.Get(context.Background(), client.ObjectKey{Namespace: api.ObjectMeta.Namespace, Name: api.Name}, crAPI); err != nil {
if !k8error.IsNotFound(err) {
loggers.LoggerK8sClient.Error("Unable to get API CR: " + err.Error())
Expand All @@ -66,7 +66,7 @@ func DeployAPICR(api *dpv1alpha2.API, k8sClient client.Client) {
}

// UndeployK8sAPICR removes the API Custom Resource from the Kubernetes cluster based on API ID label.
func UndeployK8sAPICR(k8sClient client.Client, k8sAPI dpv1alpha2.API) error {
func UndeployK8sAPICR(k8sClient client.Client, k8sAPI dpv1alpha3.API) error {
err := k8sClient.Delete(context.Background(), &k8sAPI, &client.DeleteOptions{})
if err != nil {
loggers.LoggerK8sClient.Errorf("Unable to delete API CR: %v", err)
Expand All @@ -82,7 +82,7 @@ func UndeployAPICR(apiID string, k8sClient client.Client) {
if errReadConfig != nil {
loggers.LoggerK8sClient.Errorf("Error reading configurations: %v", errReadConfig)
}
apiList := &dpv1alpha2.APIList{}
apiList := &dpv1alpha3.APIList{}
err := k8sClient.List(context.Background(), apiList, &client.ListOptions{Namespace: conf.DataPlane.Namespace, LabelSelector: labels.SelectorFromSet(map[string]string{"apiUUID": apiID})})
// Retrieve all API CRs from the Kubernetes cluster
if err != nil {
Expand Down Expand Up @@ -429,6 +429,66 @@ func DeploySubscriptionRateLimitPolicyCR(policy eventhubTypes.SubscriptionPolicy

}

// DeployAIRateLimitPolicyCR applies the given AIRateLimitPolicies struct to the Kubernetes cluster.
func DeployAIRateLimitPolicyCR(policy eventhubTypes.SubscriptionPolicy, k8sClient client.Client) {
conf, _ := config.ReadConfigs()
tokenCount := &dpv1alpha3.TokenCount{}
requestCount := &dpv1alpha3.RequestCount{}
if policy.DefaultLimit.AiAPIQuota.PromptTokenCount != nil &&
policy.DefaultLimit.AiAPIQuota.CompletionTokenCount != nil &&
policy.DefaultLimit.AiAPIQuota.TotalTokenCount != nil {
tokenCount = &dpv1alpha3.TokenCount{
Unit: policy.DefaultLimit.AiAPIQuota.TimeUnit,
RequestTokenCount: uint32(*policy.DefaultLimit.AiAPIQuota.PromptTokenCount),
ResponseTokenCount: uint32(*policy.DefaultLimit.AiAPIQuota.CompletionTokenCount),
TotalTokenCount: uint32(*policy.DefaultLimit.AiAPIQuota.TotalTokenCount),
}
} else {
tokenCount = nil
}
if policy.DefaultLimit.AiAPIQuota.RequestCount != nil {
requestCount = &dpv1alpha3.RequestCount{
RequestsPerUnit: uint32(*policy.DefaultLimit.AiAPIQuota.RequestCount),
Unit: policy.DefaultLimit.AiAPIQuota.TimeUnit,
}
} else {
requestCount = nil
}

crRateLimitPolicies := dpv1alpha3.AIRateLimitPolicy{
ObjectMeta: metav1.ObjectMeta{Name: policy.Name,
Namespace: conf.DataPlane.Namespace,
},
Spec: dpv1alpha3.AIRateLimitPolicySpec{
Override: &dpv1alpha3.AIRateLimit{
Organization: policy.TenantDomain,
TokenCount: tokenCount,
RequestCount: requestCount,
},
TargetRef: gwapiv1b1.PolicyTargetReference{Group: constants.GatewayGroup, Kind: "Subscription", Name: "default"},
},
}
crRateLimitPolicyFetched := &dpv1alpha3.AIRateLimitPolicy{}
if err := k8sClient.Get(context.Background(), client.ObjectKey{Namespace: crRateLimitPolicies.ObjectMeta.Namespace, Name: crRateLimitPolicies.Name}, crRateLimitPolicyFetched); err != nil {
if !k8error.IsNotFound(err) {
loggers.LoggerK8sClient.Error("Unable to get AiratelimitPolicy CR: " + err.Error())
}
if err := k8sClient.Create(context.Background(), &crRateLimitPolicies); err != nil {
loggers.LoggerK8sClient.Error("Unable to create AIRateLimitPolicies CR: " + err.Error())
} else {
loggers.LoggerK8sClient.Info("AIRateLimitPolicies CR created: " + crRateLimitPolicies.Name)
}
} else {
crRateLimitPolicyFetched.Spec = crRateLimitPolicies.Spec
crRateLimitPolicyFetched.ObjectMeta.Labels = crRateLimitPolicies.ObjectMeta.Labels
if err := k8sClient.Update(context.Background(), crRateLimitPolicyFetched); err != nil {
loggers.LoggerK8sClient.Error("Unable to update AiRatelimitPolicy CR: " + err.Error())
} else {
loggers.LoggerK8sClient.Info("AiRatelimitPolicy CR updated: " + crRateLimitPolicyFetched.Name)
}
}
}

// DeployBackendCR applies the given Backends struct to the Kubernetes cluster.
func DeployBackendCR(backends *dpv1alpha2.Backend, k8sClient client.Client) {
crBackends := &dpv1alpha2.Backend{}
Expand Down Expand Up @@ -625,10 +685,10 @@ func getSha1Value(input string) string {
}

// RetrieveAllAPISFromK8s retrieves all the API CRs from the Kubernetes cluster
func RetrieveAllAPISFromK8s(k8sClient client.Client, nextToken string) ([]dpv1alpha2.API, string, error) {
func RetrieveAllAPISFromK8s(k8sClient client.Client, nextToken string) ([]dpv1alpha3.API, string, error) {
conf, _ := config.ReadConfigs()
apiList := dpv1alpha2.APIList{}
resolvedAPIList := make([]dpv1alpha2.API, 0)
apiList := dpv1alpha3.APIList{}
resolvedAPIList := make([]dpv1alpha3.API, 0)
var err error
if nextToken == "" {
err = k8sClient.List(context.Background(), &apiList, &client.ListOptions{Namespace: conf.DataPlane.Namespace})
Expand Down
87 changes: 59 additions & 28 deletions apim-apk-agent/internal/synchronizer/ratelimit_policy_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ func FetchRateLimitPoliciesOnEvent(ratelimitName string, organization string, c

// FetchSubscriptionRateLimitPoliciesOnEvent fetches the policies from the control plane on the start up and notification event updates
func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organization string, c client.Client) {
logger.LoggerSynchronizer.Info("Fetching RateLimit Policies from Control Plane.")
logger.LoggerSynchronizer.Info("Fetching Subscription RateLimit Policies from Control Plane.")

// Read configurations and derive the eventHub details
conf, errReadConfig := config.ReadConfigs()
Expand All @@ -189,7 +189,7 @@ func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organizatio
}
}

logger.LoggerSynchronizer.Infof("Fetching RateLimit Policies from the URL %v: ", ehURL)
logger.LoggerSynchronizer.Infof("Fetching Subscription RateLimit Policies from the URL %v: ", ehURL)

ehUname := ehConfigs.Username
ehPass := ehConfigs.Password
Expand All @@ -201,19 +201,9 @@ func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organizatio
// Create a HTTP request
req, err := http.NewRequest("GET", ehURL, nil)
if err != nil {
logger.LoggerSynchronizer.Errorf("Error while creating http request for RateLimit Policies Endpoint : %v", err)
logger.LoggerSynchronizer.Errorf("Error while creating http request for Subscription RateLimit Policies Endpoint : %v", err)
}

var queryParamMap map[string]string

if queryParamMap != nil && len(queryParamMap) > 0 {
q := req.URL.Query()
// Making necessary query parameters for the request
for queryParamKey, queryParamValue := range queryParamMap {
q.Add(queryParamKey, queryParamValue)
}
req.URL.RawQuery = q.Encode()
}
// Setting authorization header
req.Header.Set(sync.Authorization, basicAuth)

Expand All @@ -231,45 +221,74 @@ func FetchSubscriptionRateLimitPoliciesOnEvent(ratelimitName string, organizatio
var errorMsg string
if err != nil {
errorMsg = "Error occurred while calling the REST API: " + policiesEndpoint
go retryRLPFetchData(conf, errorMsg, err, c)
go retrySubscriptionRLPFetchData(conf, errorMsg, err, c)
return
}
responseBytes, err := ioutil.ReadAll(resp.Body)
logger.LoggerSynchronizer.Debugf("Response String received for Policies: %v", string(responseBytes))

if err != nil {
errorMsg = "Error occurred while reading the response received for: " + policiesEndpoint
go retryRLPFetchData(conf, errorMsg, err, c)
go retrySubscriptionRLPFetchData(conf, errorMsg, err, c)
return
}

if resp.StatusCode == http.StatusOK {
var rateLimitPolicyList eventhubTypes.SubscriptionPolicyList
err := json.Unmarshal(responseBytes, &rateLimitPolicyList)
if err != nil {
logger.LoggerSynchronizer.Errorf("Error occurred while unmarshelling RateLimit Policies event data %v", err)
logger.LoggerSynchronizer.Errorf("Error occurred while unmarshelling Subscription RateLimit Policies event data %v", err)
return
}
logger.LoggerSynchronizer.Debugf("Policies received: %v", rateLimitPolicyList.List)
var rateLimitPolicies []eventhubTypes.SubscriptionPolicy = rateLimitPolicyList.List
for _, policy := range rateLimitPolicies {
if policy.DefaultLimit.RequestCount.TimeUnit == "min" {
policy.DefaultLimit.RequestCount.TimeUnit = "Minute"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "hours" {
policy.DefaultLimit.RequestCount.TimeUnit = "Hour"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "days" {
policy.DefaultLimit.RequestCount.TimeUnit = "Day"
if policy.QuotaType == "aiApiQuota" {
if policy.DefaultLimit.AiAPIQuota != nil {
switch policy.DefaultLimit.AiAPIQuota.TimeUnit {
case "min":
policy.DefaultLimit.AiAPIQuota.TimeUnit = "Minute"
case "hours":
policy.DefaultLimit.AiAPIQuota.TimeUnit = "Hour"
case "days":
policy.DefaultLimit.AiAPIQuota.TimeUnit = "Day"
default:
logger.LoggerSynchronizer.Errorf("Unsupported timeunit %s", policy.DefaultLimit.AiAPIQuota.TimeUnit)
continue
}
if policy.DefaultLimit.AiAPIQuota.PromptTokenCount == nil && policy.DefaultLimit.AiAPIQuota.TotalTokenCount != nil {
policy.DefaultLimit.AiAPIQuota.PromptTokenCount = policy.DefaultLimit.AiAPIQuota.TotalTokenCount
}
if policy.DefaultLimit.AiAPIQuota.CompletionTokenCount == nil && policy.DefaultLimit.AiAPIQuota.TotalTokenCount != nil {
policy.DefaultLimit.AiAPIQuota.CompletionTokenCount = policy.DefaultLimit.AiAPIQuota.TotalTokenCount
}
if policy.DefaultLimit.AiAPIQuota.TotalTokenCount == nil && policy.DefaultLimit.AiAPIQuota.PromptTokenCount != nil && policy.DefaultLimit.AiAPIQuota.CompletionTokenCount != nil {
total := *policy.DefaultLimit.AiAPIQuota.PromptTokenCount + *policy.DefaultLimit.AiAPIQuota.CompletionTokenCount
policy.DefaultLimit.AiAPIQuota.TotalTokenCount = &total
}
managementserver.AddSubscriptionPolicy(policy)
k8sclient.DeployAIRateLimitPolicyCR(policy, c)
} else {
logger.LoggerSynchronizer.Errorf("AIQuota type response recieved but no data found. %+v", policy.DefaultLimit)
}
} else {
if policy.DefaultLimit.RequestCount.TimeUnit == "min" {
policy.DefaultLimit.RequestCount.TimeUnit = "Minute"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "hours" {
policy.DefaultLimit.RequestCount.TimeUnit = "Hour"
} else if policy.DefaultLimit.RequestCount.TimeUnit == "days" {
policy.DefaultLimit.RequestCount.TimeUnit = "Day"
}
managementserver.AddSubscriptionPolicy(policy)
logger.LoggerSynchronizer.Infof("RateLimit Policy added to internal map: %v", policy)
// Update the exisitng rate limit policies with current policy
k8sclient.DeploySubscriptionRateLimitPolicyCR(policy, c)
}
managementserver.AddSubscriptionPolicy(policy)
logger.LoggerSynchronizer.Infof("RateLimit Policy added to internal map: %v", policy)
// Update the exisitng rate limit policies with current policy
k8sclient.DeploySubscriptionRateLimitPolicyCR(policy, c)

}
} else {
errorMsg = "Failed to fetch data! " + policiesEndpoint + " responded with " +
strconv.Itoa(resp.StatusCode)
go retryRLPFetchData(conf, errorMsg, err, c)
go retrySubscriptionRLPFetchData(conf, errorMsg, err, c)
}
}

Expand All @@ -284,3 +303,15 @@ func retryRLPFetchData(conf *config.Config, errorMessage string, err error, c cl
return
}
}

func retrySubscriptionRLPFetchData(conf *config.Config, errorMessage string, err error, c client.Client) {
logger.LoggerSynchronizer.Debugf("Time Duration for retrying: %v",
conf.ControlPlane.RetryInterval*time.Second)
time.Sleep(conf.ControlPlane.RetryInterval * time.Second)
FetchSubscriptionRateLimitPoliciesOnEvent("", "", c)
retryAttempt++
if retryAttempt >= retryCount {
logger.LoggerSynchronizer.Errorf(errorMessage, err)
return
}
}
13 changes: 12 additions & 1 deletion apim-apk-agent/pkg/eventhub/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,8 @@ type ConditionGroup struct {

// DefaultLimit represents the default limit within the response.
type DefaultLimit struct {
QuotaType string `json:"quotaType"`
AiAPIQuota *AiAPIQuota `json:"aiApiQuota"`
QuotaType string `json:"quotaType"`
RequestCount struct {
TimeUnit string `json:"timeUnit"`
UnitTime int `json:"unitTime"`
Expand All @@ -206,6 +207,16 @@ type DefaultLimit struct {
EventCount interface{} `json:"eventCount"`
}

// AiAPIQuota contains the AI ratelimit configurations
type AiAPIQuota struct {
CompletionTokenCount *int `json:"completionTokenCount"`
PromptTokenCount *int `json:"promptTokenCount"`
RequestCount *int `json:"requestCount"`
TimeUnit string `json:"timeUnit"`
TotalTokenCount *int `json:"totalTokenCount"`
UnitTime int `json:"unitTime"`
}

// Scope for struct Scope
type Scope struct {
Name string `json:"name"`
Expand Down
2 changes: 1 addition & 1 deletion apim-apk-agent/pkg/synchronizer/apis_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func SendRequestToControlPlane(req *http.Request, apiID *string, gwLabels []stri
if apiID != nil {
logger.LoggerSync.Debugf("Sending the control plane request for the API: %q", *apiID)
} else {
logger.LoggerSync.Debug("Sending the control plane request")
logger.LoggerSync.Debugf("Sending the control plane request, url: %s", req.URL.String())
}
resp, err := client.Do(req)

Expand Down
41 changes: 41 additions & 0 deletions apim-apk-agent/pkg/transformer/api_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ type APIMApi struct {
APIThrottlingPolicy string `yaml:"apiThrottlingPolicy"`
APIPolicies APIMOperationPolicies `yaml:"apiPolicies"`
AIConfiguration APIMAIConfiguration `yaml:"aiConfiguration"`
MaxTps *MaxTps `yaml:"maxTps"`
}

// APIMAIConfiguration holds the configuration details for AI providers
Expand All @@ -192,6 +193,46 @@ type APIYaml struct {
Data APIMApi `json:"data"`
}

// MaxTps represents the maximum transactions per second (TPS) settings for both
// production and sandbox environments. It also includes an optional configuration
// for token-based throttling.
//
// Fields:
// - Production: Maximum TPS for the production environment.
// - ProductionTimeUnit: The time unit for the production TPS limit (e.g., seconds, minutes).
// - Sandbox: Maximum TPS for the sandbox environment.
// - SandboxTimeUnit: The time unit for the sandbox TPS limit.
// - TokenBasedThrottlingConfiguration: Configuration for token-based throttling.
type MaxTps struct {
Production *int `yaml:"production"`
ProductionTimeUnit *string `yaml:"productionTimeUnit"`
Sandbox *int `yaml:"sandbox"`
SandboxTimeUnit *string `yaml:"sandboxTimeUnit"`
TokenBasedThrottlingConfiguration *TokenBasedThrottlingConfig `yaml:"tokenBasedThrottlingConfiguration"`
}

// TokenBasedThrottlingConfig defines the token-based throttling limits for
// both production and sandbox environments. Token-based throttling places
// a limit on the number of prompt and completion tokens that can be used.
//
// Fields:
// - ProductionMaxPromptTokenCount: Maximum number of prompt tokens for production.
// - ProductionMaxCompletionTokenCount: Maximum number of completion tokens for production.
// - ProductionMaxTotalTokenCount: Maximum total token count (prompt + completion) for production.
// - SandboxMaxPromptTokenCount: Maximum number of prompt tokens for sandbox.
// - SandboxMaxCompletionTokenCount: Maximum number of completion tokens for sandbox.
// - SandboxMaxTotalTokenCount: Maximum total token count (prompt + completion) for sandbox.
// - IsTokenBasedThrottlingEnabled: Flag to enable or disable token-based throttling.
type TokenBasedThrottlingConfig struct {
ProductionMaxPromptTokenCount *int `yaml:"productionMaxPromptTokenCount"`
ProductionMaxCompletionTokenCount *int `yaml:"productionMaxCompletionTokenCount"`
ProductionMaxTotalTokenCount *int `yaml:"productionMaxTotalTokenCount"`
SandboxMaxPromptTokenCount *int `yaml:"sandboxMaxPromptTokenCount"`
SandboxMaxCompletionTokenCount *int `yaml:"sandboxMaxCompletionTokenCount"`
SandboxMaxTotalTokenCount *int `yaml:"sandboxMaxTotalTokenCount"`
IsTokenBasedThrottlingEnabled *bool `yaml:"isTokenBasedThrottlingEnabled"`
}

// APIArtifact represents the artifact details of an API, including api details, environment configuration,
// Swagger definition, deployment descriptor, and revision ID extracted from the API Project Zip.
type APIArtifact struct {
Expand Down
Loading
Loading