-
-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
182 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
This folder includes adapted examples from https://github.com/vercel/ai, which is licensed under the Apache License, Version 2.0. | ||
|
||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
You may obtain the original code at | ||
|
||
https://github.com/vercel/ai |
This file was deleted.
Oops, something went wrong.
21 changes: 21 additions & 0 deletions
21
examples/ai-core/src/complex/semantic-router/cosine-similarity.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
export function cosineSimilarity(a: number[], b: number[]) { | ||
if (a.length !== b.length) { | ||
throw new Error( | ||
`Vectors must have the same length (a: ${a.length}, b: ${b.length})`, | ||
) | ||
} | ||
|
||
return dotProduct(a, b) / (magnitude(a) * magnitude(b)) | ||
} | ||
|
||
function dotProduct(a: number[], b: number[]) { | ||
return a.reduce( | ||
(accumulator: number, value: number, index: number) => | ||
accumulator + value * b[index]!, | ||
0, | ||
) | ||
} | ||
|
||
function magnitude(a: number[]) { | ||
return Math.sqrt(dotProduct(a, a)) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#! /usr/bin/env -S pnpm tsx | ||
|
||
import { ollama } from 'ollama-ai-provider' | ||
import { OllamaEmbeddingModelId } from 'ollama-ai-provider/src/ollama-embedding-settings' | ||
|
||
import { buildProgram } from '../../tools/command' | ||
import { SemanticRouter } from './semantic-router' | ||
|
||
async function main(model: OllamaEmbeddingModelId) { | ||
const router = new SemanticRouter({ | ||
embeddingModel: ollama.embedding(model), | ||
routes: [ | ||
{ | ||
name: 'sports' as const, | ||
values: [ | ||
"who's your favorite football team?", | ||
'The World Cup is the most exciting event.', | ||
'I enjoy running marathons on weekends.', | ||
], | ||
}, | ||
{ | ||
name: 'music' as const, | ||
values: [ | ||
"what's your favorite genre of music?", | ||
'Classical music helps me concentrate.', | ||
'I recently attended a jazz festival.', | ||
], | ||
}, | ||
], | ||
similarityThreshold: 0.2, | ||
}) | ||
|
||
// topic is strongly typed | ||
const topic = await router.route( | ||
'Many consider Michael Jordan the greatest basketball player ever.', | ||
) | ||
|
||
switch (topic) { | ||
case 'sports': { | ||
console.log('sports') | ||
break | ||
} | ||
case 'music': { | ||
console.log('music') | ||
break | ||
} | ||
case null: { | ||
console.log('no topic found') | ||
break | ||
} | ||
} | ||
} | ||
|
||
buildProgram('nomic-embed-text', main).catch(console.error) |
96 changes: 96 additions & 0 deletions
96
examples/ai-core/src/complex/semantic-router/semantic-router.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import { embed, Embedding, EmbeddingModel, embedMany } from 'ai' | ||
|
||
import { cosineSimilarity } from './cosine-similarity' | ||
|
||
export interface Route<NAME extends string> { | ||
name: NAME | ||
values: string[] | ||
} | ||
|
||
/** | ||
* Routes values based on their distance to the values from a set of clusters. | ||
* When the distance is below a certain threshold, the value is classified as belonging to the route, | ||
* and the route name is returned. Otherwise, the value is classified as null. | ||
*/ | ||
export class SemanticRouter<ROUTES extends Array<Route<string>>> { | ||
readonly routes: ROUTES | ||
readonly embeddingModel: EmbeddingModel<string> | ||
readonly similarityThreshold: number | ||
|
||
private routeValues: | ||
| Array<{ embedding: Embedding; routeName: string; routeValue: string }> | ||
| undefined | ||
|
||
constructor({ | ||
embeddingModel, | ||
routes, | ||
similarityThreshold, | ||
}: { | ||
embeddingModel: EmbeddingModel<string> | ||
routes: ROUTES | ||
similarityThreshold: number | ||
}) { | ||
this.routes = routes | ||
this.embeddingModel = embeddingModel | ||
this.similarityThreshold = similarityThreshold | ||
} | ||
|
||
private async getRouteValues(): Promise< | ||
Array<{ embedding: Embedding; routeName: string; routeValue: string }> | ||
> { | ||
if (this.routeValues !== undefined) { | ||
return this.routeValues | ||
} | ||
|
||
this.routeValues = [] | ||
|
||
for (const route of this.routes) { | ||
const { embeddings } = await embedMany({ | ||
model: this.embeddingModel, | ||
values: route.values, | ||
}) | ||
|
||
for (const [index, embedding] of embeddings.entries()) { | ||
this.routeValues.push({ | ||
embedding: embedding, | ||
routeName: route.name, | ||
routeValue: route.values[index], | ||
}) | ||
} | ||
} | ||
|
||
return this.routeValues | ||
} | ||
|
||
async route(value: string) { | ||
const { embedding } = await embed({ model: this.embeddingModel, value }) | ||
const routeValues = await this.getRouteValues() | ||
|
||
const allMatches: Array<{ | ||
routeName: string | ||
routeValue: string | ||
similarity: number | ||
}> = [] | ||
|
||
for (const routeValue of routeValues) { | ||
const similarity = cosineSimilarity(embedding, routeValue.embedding) | ||
|
||
if (similarity >= this.similarityThreshold) { | ||
allMatches.push({ | ||
routeName: routeValue.routeName, | ||
routeValue: routeValue.routeValue, | ||
similarity, | ||
}) | ||
} | ||
} | ||
|
||
// sort (highest similarity first) | ||
allMatches.sort((a, b) => b.similarity - a.similarity) | ||
|
||
return allMatches.length > 0 | ||
? (allMatches[0].routeName as unknown as RouteNames<ROUTES>) | ||
: null | ||
} | ||
} | ||
|
||
type RouteNames<ROUTES> = ROUTES extends Array<Route<infer NAME>> ? NAME : never |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters