Skip to content

Commit

Permalink
Add support for model version in Vertex AI (huggingface#1496)
Browse files Browse the repository at this point in the history
* Add support for model version in Vertex AI

* Update src/lib/server/endpoints/google/endpointVertex.ts

Co-authored-by: goupilew <[email protected]>

* fix: optional chaining on extraBody

---------

Co-authored-by: goupilew <[email protected]>
Co-authored-by: Nathan Sarrazin <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2024
1 parent 379ab35 commit 79b1875
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
5 changes: 3 additions & 2 deletions docs/source/configuration/models/providers/google.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ MODELS=`[
"type": "vertex",
"project": "abc-xyz",
"location": "europe-west3",
"model": "gemini-1.5-pro-preview-0409", // model-name

"extraBody": {
"model_version": "gemini-1.5-pro-002",
},
// Optional
"safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
"apiEndpoint": "", // alternative api endpoint url,
Expand Down
5 changes: 3 additions & 2 deletions src/lib/server/endpoints/google/endpointVertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export const endpointVertexParametersSchema = z.object({
model: z.any(), // allow optional and validate against emptiness
type: z.literal("vertex"),
location: z.string().default("europe-west1"),
extraBody: z.object({ model_version: z.string() }).optional(),
project: z.string(),
apiEndpoint: z.string().optional(),
safetyThreshold: z
Expand Down Expand Up @@ -49,7 +50,7 @@ export const endpointVertexParametersSchema = z.object({
});

export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
const { project, location, model, apiEndpoint, safetyThreshold, tools, multimodal } =
const { project, location, model, apiEndpoint, safetyThreshold, tools, multimodal, extraBody } =
endpointVertexParametersSchema.parse(input);

const vertex_ai = new VertexAI({
Expand All @@ -64,7 +65,7 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
const hasFiles = messages.some((message) => message.files && message.files.length > 0);

const generativeModel = vertex_ai.getGenerativeModel({
model: model.id ?? model.name,
model: extraBody?.model_version ?? model.id ?? model.name,
safetySettings: safetyThreshold
? [
{
Expand Down

0 comments on commit 79b1875

Please sign in to comment.