Skip to content

Commit

Permalink
Add integration test case for listing model adapters
Browse files Browse the repository at this point in the history
  • Loading branch information
nstogner committed Nov 13, 2024
1 parent 71e1551 commit 2f81826
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 35 deletions.
32 changes: 22 additions & 10 deletions internal/openaiserver/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ func (h *Handler) getModels(w http.ResponseWriter, r *http.Request) {
}
}

models := make([]Model, len(k8sModels))
for i, k8sModel := range k8sModels {
model := Model{}
model.FromK8sModel(&k8sModel)
models[i] = model
models := make([]Model, 0)
for _, k8sModel := range k8sModels {
models = append(models, k8sModelToOpenAIModels(k8sModel)...)
}

// Wrapper struct to match the desired output format
Expand Down Expand Up @@ -90,10 +88,24 @@ type Model struct {
Features []kubeaiv1.ModelFeature `json:"features,omitempty"`
}

func (m *Model) FromK8sModel(model *kubeaiv1.Model) {
m.ID = model.Name
m.Created = model.CreationTimestamp.Unix()
func k8sModelToOpenAIModels(k8sM kubeaiv1.Model) []Model {
models := make([]Model, 1+len(k8sM.Spec.Adapters))
models[0] = constructOpenAIModel(k8sM, "")
for i, adapter := range k8sM.Spec.Adapters {
models[i+1] = constructOpenAIModel(k8sM, adapter.ID)
}
return models
}

func constructOpenAIModel(k8sM kubeaiv1.Model, adapter string) Model {
m := Model{}
m.ID = k8sM.Name
if adapter != "" {
m.ID += "/" + adapter
}
m.Created = k8sM.CreationTimestamp.Unix()
m.Object = "model"
m.OwnedBy = model.Spec.Owner
m.Features = model.Spec.Features
m.OwnedBy = k8sM.Spec.Owner
m.Features = k8sM.Spec.Features
return m
}
25 changes: 25 additions & 0 deletions test/integration/adapter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package integration

import (
"testing"

"github.com/stretchr/testify/require"
v1 "github.com/substratusai/kubeai/api/v1"
)

func TestAdapters(t *testing.T) {
sysCfg := baseSysCfg(t)
initTest(t, sysCfg)
m := modelForTest(t)
m.Spec.Adapters = []v1.Adapter{
{ID: "adapter1", URL: "hf://test-repo/test-adapter"},
{ID: "adapter2", URL: "hf://test-repo/test-adapter"},
}
require.NoError(t, testK8sClient.Create(testCtx, m))

requireOpenAIModelList(t, []string{modelLabelSelectorForTest(t)}, 3, []string{
m.Name,
m.Name + "/adapter1",
m.Name + "/adapter2",
}, "Model list should contain the model and its adapters")
}
8 changes: 1 addition & 7 deletions test/integration/selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,7 @@ func TestSelector(t *testing.T) {
for name, c := range listTestCases {
t.Run("list "+name, func(t *testing.T) {
t.Parallel()
list := sendOpenAIListModelsRequest(t, c.selectorHeaders, http.StatusOK, name)
require.Len(t, list, c.expLen)
ids := make([]string, len(list))
for i, m := range list {
ids[i] = m.ID
}
require.ElementsMatch(t, c.expModels, ids)
requireOpenAIModelList(t, c.selectorHeaders, c.expLen, c.expModels, name)
})
}
}
53 changes: 35 additions & 18 deletions test/integration/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ import (
"sigs.k8s.io/yaml"
)

func modelLabelSelectorForTest(t *testing.T) string {
return fmt.Sprintf("test-case-name=%s", strings.ToLower(t.Name()))
}

func modelForTest(t *testing.T) *v1.Model {
return &v1.Model{
ObjectMeta: metav1.ObjectMeta{
Expand Down Expand Up @@ -169,27 +173,40 @@ func sendOpenAIInferenceRequest(t *testing.T, modelName string, selectorHeaders
}
}

func sendOpenAIListModelsRequest(t *testing.T, selectorHeaders []string, expCode int, msg string) []openaiserver.Model {
t.Helper()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8000/openai/v1/models", nil)
require.NoError(t, err, msg)
for _, selector := range selectorHeaders {
req.Header.Add("X-Label-Selector", selector)
}
func requireOpenAIModelList(t *testing.T, selectorHeaders []string, expLen int, expIDs []string, msg string) {
require.EventuallyWithT(t, func(t *assert.CollectT) {
//t.Helper()
req, err := http.NewRequest(http.MethodGet, "http://localhost:8000/openai/v1/models", nil)
if !assert.NoError(t, err, msg) {
return
}
for _, selector := range selectorHeaders {
req.Header.Add("X-Label-Selector", selector)
}

res, err := testHTTPClient.Do(req)
require.NoError(t, err, msg)
require.Equal(t, expCode, res.StatusCode, msg)
defer res.Body.Close()
res, err := testHTTPClient.Do(req)
if !assert.NoError(t, err, msg) {
return
}
if !assert.Equal(t, http.StatusOK, res.StatusCode, msg) {
return
}
defer res.Body.Close()

var respBody struct {
Data []openaiserver.Model `json:"data"`
}
if err := json.NewDecoder(res.Body).Decode(&respBody); err != nil {
require.NoError(t, err, msg)
}
var respBody struct {
Data []openaiserver.Model `json:"data"`
}
if !assert.NoError(t, json.NewDecoder(res.Body).Decode(&respBody), msg) {
return
}

ids := make([]string, len(respBody.Data))
for i, m := range respBody.Data {
ids[i] = m.ID
}

return respBody.Data
assert.ElementsMatch(t, expIDs, ids)
}, 5*time.Second, time.Second/10, msg)
}

func closeChannels(c chan struct{}, n int) {
Expand Down

0 comments on commit 2f81826

Please sign in to comment.