Skip to content

Commit

Permalink
feat: add embeddings examples
Browse files Browse the repository at this point in the history
  • Loading branch information
sgomez committed May 17, 2024
1 parent 430dc15 commit feed2db
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 8 deletions.
9 changes: 9 additions & 0 deletions examples/NOTICE
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
This folder includes adapted examples from https://github.com/vercel/ai, which is licensed under the Apache License, Version 2.0.

You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

You may obtain the original code at

https://github.com/vercel/ai
6 changes: 0 additions & 6 deletions examples/README.md

This file was deleted.

21 changes: 21 additions & 0 deletions examples/ai-core/src/complex/semantic-router/cosine-similarity.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
export function cosineSimilarity(a: number[], b: number[]) {
if (a.length !== b.length) {
throw new Error(
`Vectors must have the same length (a: ${a.length}, b: ${b.length})`,
)
}

return dotProduct(a, b) / (magnitude(a) * magnitude(b))
}

function dotProduct(a: number[], b: number[]) {
return a.reduce(
(accumulator: number, value: number, index: number) =>
accumulator + value * b[index]!,
0,
)
}

function magnitude(a: number[]) {
return Math.sqrt(dotProduct(a, a))
}
54 changes: 54 additions & 0 deletions examples/ai-core/src/complex/semantic-router/main.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#! /usr/bin/env -S pnpm tsx

import { ollama } from 'ollama-ai-provider'
import { OllamaEmbeddingModelId } from 'ollama-ai-provider/src/ollama-embedding-settings'

import { buildProgram } from '../../tools/command'
import { SemanticRouter } from './semantic-router'

async function main(model: OllamaEmbeddingModelId) {
const router = new SemanticRouter({
embeddingModel: ollama.embedding(model),
routes: [
{
name: 'sports' as const,
values: [
"who's your favorite football team?",
'The World Cup is the most exciting event.',
'I enjoy running marathons on weekends.',
],
},
{
name: 'music' as const,
values: [
"what's your favorite genre of music?",
'Classical music helps me concentrate.',
'I recently attended a jazz festival.',
],
},
],
similarityThreshold: 0.2,
})

// topic is strongly typed
const topic = await router.route(
'Many consider Michael Jordan the greatest basketball player ever.',
)

switch (topic) {
case 'sports': {
console.log('sports')
break
}
case 'music': {
console.log('music')
break
}
case null: {
console.log('no topic found')
break
}
}
}

buildProgram('nomic-embed-text', main).catch(console.error)
96 changes: 96 additions & 0 deletions examples/ai-core/src/complex/semantic-router/semantic-router.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { embed, Embedding, EmbeddingModel, embedMany } from 'ai'

import { cosineSimilarity } from './cosine-similarity'

export interface Route<NAME extends string> {
name: NAME
values: string[]
}

/**
* Routes values based on their distance to the values from a set of clusters.
* When the distance is below a certain threshold, the value is classified as belonging to the route,
* and the route name is returned. Otherwise, the value is classified as null.
*/
export class SemanticRouter<ROUTES extends Array<Route<string>>> {
readonly routes: ROUTES
readonly embeddingModel: EmbeddingModel<string>
readonly similarityThreshold: number

private routeValues:
| Array<{ embedding: Embedding; routeName: string; routeValue: string }>
| undefined

constructor({
embeddingModel,
routes,
similarityThreshold,
}: {
embeddingModel: EmbeddingModel<string>
routes: ROUTES
similarityThreshold: number
}) {
this.routes = routes
this.embeddingModel = embeddingModel
this.similarityThreshold = similarityThreshold
}

private async getRouteValues(): Promise<
Array<{ embedding: Embedding; routeName: string; routeValue: string }>
> {
if (this.routeValues !== undefined) {
return this.routeValues
}

this.routeValues = []

for (const route of this.routes) {
const { embeddings } = await embedMany({
model: this.embeddingModel,
values: route.values,
})

for (const [index, embedding] of embeddings.entries()) {
this.routeValues.push({
embedding: embedding,
routeName: route.name,
routeValue: route.values[index],
})
}
}

return this.routeValues
}

async route(value: string) {
const { embedding } = await embed({ model: this.embeddingModel, value })
const routeValues = await this.getRouteValues()

const allMatches: Array<{
routeName: string
routeValue: string
similarity: number
}> = []

for (const routeValue of routeValues) {
const similarity = cosineSimilarity(embedding, routeValue.embedding)

if (similarity >= this.similarityThreshold) {
allMatches.push({
routeName: routeValue.routeName,
routeValue: routeValue.routeValue,
similarity,
})
}
}

// sort (highest similarity first)
allMatches.sort((a, b) => b.similarity - a.similarity)

return allMatches.length > 0
? (allMatches[0].routeName as unknown as RouteNames<ROUTES>)
: null
}
}

type RouteNames<ROUTES> = ROUTES extends Array<Route<infer NAME>> ? NAME : never
2 changes: 1 addition & 1 deletion examples/ai-core/src/embed-many/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ async function main(model: OllamaEmbeddingModelId) {
console.log(embeddings)
}

buildProgram('all-minilm', main).catch(console.error)
buildProgram('nomic-embed-text', main).catch(console.error)
2 changes: 1 addition & 1 deletion examples/ai-core/src/embed/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ async function main(model: OllamaEmbeddingModelId) {
console.log(embedding)
}

buildProgram('all-minilm', main).catch(console.error)
buildProgram('nomic-embed-text', main).catch(console.error)

0 comments on commit feed2db

Please sign in to comment.