Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes #3939: vertexai gemini has a different URL, we should support that in the vertexai procedure #3947

Merged
merged 5 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/vertexai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,47 @@

NOTE: You need to create a Google Cloud project in your account an enable the Vertex.AI services. As an access-token you can run `gcloud auth print-access-token`. Using these services will incurr costs on your Google Cloud account.


All the following procedures can have the following APOC config, i.e. in `apoc.conf` or via docker env variable
.Apoc configuration
|===
|key | description | default
| apoc.ml.vertexai.url | the OpenAI endpoint base url | the `endpoint` configuration parameter value
|===

Moreover, they can have the following configuration keys, as the last parameter.

.Common configuration parameter
[opts=header]
|===
| key | description | default
| endpoint | analogous to `apoc.ml.vertexai.url` APOC config | https://\{region\}-aiplatform.googleapis.com/v1/projects/\{project\}/locations/\{region\}/publishers/google/models/\{model\}:\{resource\}
| headers | to add or edit the HTTP default header |
`{``Content-Type``: "application/json", ``Accept``: "application/json", ``Authorization``: "Bearer " + <$accessToken, i.e. 2nd parameter> }`
| model | The Vertex AI model | depends on the procedure
| region | The Vertex AI region | us-central1
| resource | The Vertex AI resource (see below) | depends on the procedure
| temperature, maxOutputTokens, maxDecodeSteps, topP, topK | Optional parameter which can be passed into the HTTP request. Depend on the API used |
{temperature: 0.3, maxOutputTokens: 256, maxDecodeSteps: 200, topP: 0.8, topK: 40}
|===

We can define the `endpoint` configuration as a full URL, e.g. `https://us-central1-aiplatform.googleapis.com/v1/projects/myVertexAIProject/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent`,
or define it via parameters that will then be replaced by the other configurations.

For example, if we define no `endpoint` config.,
the default one `https://\{region\}-aiplatform.googleapis.com/v1/projects/\{project\}/locations/\{region\}/publishers/google/models/\{model\}:\{resource\}` will be valued, where:

- `\{model\}` will be model defined by the `model` configuration
- `\{region\}` defined by `region` configuration
- `\{project\}` defined by the 3rd parameter (`project`)
- `\{resource\}` defined by `resource` configuration

Or else, we can define an `endpoint` as `https://us-central1-aiplatform.googleapis.com/v1/projects/\{project\}/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent`,
and in this case we just substitute `\{project\}` with the 3rd parameter.


Let's see some example.

== Generate Embeddings API

This procedure `apoc.ml.vertexai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 768 element vector.
Expand Down Expand Up @@ -128,3 +169,156 @@ yield value
|name | description
| value | result entry from Vertex.AI (containing candidates(author, content), safetyAttributes(categories, scores, blocked), recitationResults(recitationAction, recitations))
|===


== Streaming API

This procedure `apoc.ml.vertexai.stream` takes a list of maps of contents exchanges between assistant and user (with optional system context), and will return the next message in the flow.

By default, it uses the https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/overview[Gemini AI APIs].


[source,cypher]
----
CALL apoc.ml.vertexai.stream([{role: "user", parts: [{text: "translate book in italian"}]}], '<accessToken>', '<projectID>')
----

.Results
[opts="header"]
|===
| value
| `{finishReason:"STOP", safetyRatings:[{probability:"NEGLIGIBLE", category:"HARM_CATEGORY_HARASSMENT"}, {probability:"NEGLIGIBLE", category:"HARM_CATEGORY_HATE_SPEECH"}, {probability:"NEGLIGIBLE", category:"HARM_CATEGORY_SEXUALLY_EXPLICIT"}, {probability:"NEGLIGIBLE", category:"HARM_CATEGORY_DANGEROUS_CONTENT"}], content:{role:"model", parts:[{text:"Libro"}]}}`
|===


We can adjust the parameter, for example `temperature`

[source,cypher]
----
CALL apoc.ml.vertexai.stream([{role: "user", parts: [{text: "translate book in italian"}]}], '<accessToken>', '<projectID>',
{temperature: 0})
----

which corresponds to the following Http body request, where `maxOutputTokens`, `topP` and `topK` have the default values specified above (`Common configuration parameter`):
----
{
"contents": [
{
"role": "user",
"parts": [
{
"text": "translate book in italian"
}
]
}
],
"generation_config": {
"temperature": 0,
"maxOutputTokens": 256,
"topP": 0.8,
"topK": 40
}
}
----


== Custom API

Using this procedure we can potentially invoke any API available with vertex AI.

To permit maximum flexibility, in this case the first parameter is not manipulated and exactly matches the body of the HTTP request,
and the return type is `ANY`.


.Gemini Pro Vision example
[source,cypher]
----
CALL apoc.ml.vertexai.custom({
contents: [
{
role: "user",
parts: [
{text: "What is this?"},
{inlineData: {
mimeType: "image/png",
data: '<base64Image>'}
}
]
}
]
},
"<accessToken>",
"<projectId>",
{model: 'gemini-pro-vision'}
)
----

.Results
[opts="header"]
|===
| value
| `[{usageMetadata: {promptTokenCount: 262, totalTokenCount: 272, candidatesTokenCount: 10}, candidates: [{content: {role: "model", parts: [{text: " This is a photo of a book..."}]}, finishReason: "STOP", safetyRatings: [{category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability│
│: "NEGLIGIBLE"}]}]}]`
|===




[source,cypher]
----
CALL apoc.ml.vertexai.custom({contents: {role: "user", parts: [{text: "translate book in italian"}]}},
"<accessToken>",
"<projectId>",
{endpoint: "https://us-central1-aiplatform.googleapis.com/v1/projects/{project}/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent"}
)
----

.Results
[opts="header"]
|===
| value
| `[{usageMetadata: {promptTokenCount: 4, totalTokenCount: 5, candidatesTokenCount: 1}, candidates: [{content: {role: "model", parts: [{text: "libro"}]}, finishReason: "STOP", safetyRatings: [{category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE"}]}]}]`
|===



[source,cypher]
----
CALL apoc.ml.vertexai.custom({contents: {role: "user", parts: [{text: "translate book in italian"}]}},
"<accessToken>",
null,
{endpoint: "https://us-central1-aiplatform.googleapis.com/v1/projects/vertex-project-413513/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent"}
)
----

.Results
[opts="header"]
|===
| value
| `[{usageMetadata: {promptTokenCount: 4, totalTokenCount: 5, candidatesTokenCount: 1}, candidates: [{content: {role: "model", parts: [{text: "libro"}]}, finishReason: "STOP", safetyRatings: [{category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE"}]}]}]`
|===


Moreover, we can use with other Google API with endpoints that don't start with `https://<region>-aiplatform.googleapis.com`,
for example we can use the https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize[Text-to-Speech API]:

[source,cypher]
----
CALL apoc.ml.vertexai.custom(
{
input:{
text:'just a test'
},
voice:{
languageCode:'en-US',
name:'en-US-Studio-O'
},
audioConfig:{
audioEncoding:'LINEAR16',
speakingRate:1
}
},
"<accessToken>",
"<projectId>",
{endpoint: "https://texttospeech.googleapis.com/v1/text:synthesize"})
----
1 change: 1 addition & 0 deletions extended/src/main/java/apoc/ExtendedApocConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class ExtendedApocConfig extends LifecycleAdapter
public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url";
public static final String APOC_ML_OPENAI_TYPE = "apoc.ml.openai.type";
public static final String APOC_ML_OPENAI_AZURE_VERSION = "apoc.ml.azure.api.version";
public static final String APOC_ML_VERTEXAI_URL = "apoc.ml.vertexai.url";
public static final String APOC_ML_WATSON_PROJECT_ID = "apoc.ml.watson.project.id";
public static final String APOC_ML_WATSON_URL = "apoc.ml.watson.url";
public static final String APOC_AWS_KEY_ID = "apoc.aws.key.id";
Expand Down
71 changes: 48 additions & 23 deletions extended/src/main/java/apoc/ml/VertexAI.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package apoc.ml;

import apoc.ApocConfig;
import apoc.Extended;
import apoc.result.MapResult;
import apoc.result.ObjectResult;
import apoc.util.JsonUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.procedure.Context;
import org.neo4j.procedure.Description;
Expand All @@ -13,6 +16,7 @@

import java.net.MalformedURLException;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -25,10 +29,8 @@ public class VertexAI {
@Context
public URLAccessChecker urlAccessChecker;

// "https://${region}-aiplatform.googleapis.com/v1/projects/${project}/locations/${region}/publishers/google/models/${model}:predict"
private static final String BASE_URL = "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict";
public static final String APOC_ML_VERTEXAI_URL = "apoc.ml.vertexai.url";
public static final String DEFAULT_REGION = "us-central1";
@Context
public ApocConfig apocConfig;

public static class EmbeddingResult {
public final long index;
Expand All @@ -42,27 +44,27 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {
}
}

private static Stream<Object> executeRequest(String accessToken, String project, Map<String, Object> configuration, String defaultModel, Object inputs, String jsonPath, Collection<String> retainConfigKeys, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
private Stream<Object> executeRequest(String accessToken, String project, Map<String, Object> configuration, String defaultModel, Object inputs, Collection<String> retainConfigKeys, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException {
return executeRequest(accessToken, project, configuration, defaultModel, inputs, retainConfigKeys, urlAccessChecker, VertexAIHandler.Type.PREDICT);
}

private Stream<Object> executeRequest(String accessToken, String project, Map<String, Object> configuration, String defaultModel, Object inputs, Collection<String> retainConfigKeys, URLAccessChecker urlAccessChecker,
VertexAIHandler.Type vertexAIHandlerType) throws JsonProcessingException {
if (accessToken == null || accessToken.isBlank())
throw new IllegalArgumentException("Access Token must not be empty");
if (project == null || project.isBlank())
throw new IllegalArgumentException("Project must not be empty");
String urlTemplate = System.getProperty(APOC_ML_VERTEXAI_URL, BASE_URL);

String model = configuration.getOrDefault("model", defaultModel).toString();
String region = configuration.getOrDefault("region", DEFAULT_REGION).toString();
String endpoint = String.format(urlTemplate, region, project, region, model);

Map<String, Object> headers = Map.of(
"Content-Type", "application/json",
"Accept", "application/json",
"Authorization", "Bearer " + accessToken
);
Map<String, Object> headers = (Map<String, Object>) configuration.getOrDefault("headers", new HashMap<>());
headers.putIfAbsent("Content-Type", "application/json");
headers.putIfAbsent("Accept", "application/json");
headers.putIfAbsent("Authorization", "Bearer " + accessToken);

Map<String, Object> data = Map.of("instances", inputs, "parameters", getParameters(configuration, retainConfigKeys));
VertexAIHandler vertexAIHandler = vertexAIHandlerType.get();
Map<String, Object> data = vertexAIHandler.getBody(inputs, configuration, retainConfigKeys);
String payload = new ObjectMapper().writeValueAsString(data);

return JsonUtil.loadJson(endpoint, headers, payload, jsonPath, true, List.of(), urlAccessChecker);
String url = vertexAIHandler.getFullUrl(configuration, apocConfig, defaultModel, project);
String jsonPath = vertexAIHandler.getJsonPath();
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker);
}

@Procedure("apoc.ml.vertexai.embedding")
Expand Down Expand Up @@ -93,7 +95,7 @@ public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @
}
*/
Object inputs = texts.stream().map(text -> Map.of("content", text)).toList();
Stream<Object> resultStream = executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, "$.predictions", List.of(), urlAccessChecker);
Stream<Object> resultStream = executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, List.of(), urlAccessChecker);
AtomicInteger ai = new AtomicInteger();
return resultStream
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
Expand Down Expand Up @@ -153,13 +155,13 @@ public Stream<MapResult> completion(@Name("prompt") String prompt, @Name("access
*/
Object input = List.of(Map.of("prompt",prompt));
var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
var resultStream = executeRequest(accessToken, project, configuration, "text-bison", input, "$.predictions", parameterKeys, urlAccessChecker);
var resultStream = executeRequest(accessToken, project, configuration, "text-bison", input, parameterKeys, urlAccessChecker);
return resultStream
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
.map(v -> (Map<String, Object>) v).map(MapResult::new);
}

private static Map<String, Object> getParameters(Map<String, Object> config, Collection<String> retainKeys) {
public static Map<String, Object> getParameters(Map<String, Object> config, Collection<String> retainKeys) {
/*
"temperature": TEMPERATURE,
"maxOutputTokens": MAX_OUTPUT_TOKENS,
Expand Down Expand Up @@ -192,7 +194,7 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Strin
) throws Exception {
Object inputs = List.of(Map.of("context",context, "examples",examples, "messages", messages));
var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");
return executeRequest(accessToken, project, configuration, "chat-bison", inputs, "$.predictions", parameterKeys, urlAccessChecker)
return executeRequest(accessToken, project, configuration, "chat-bison", inputs, parameterKeys, urlAccessChecker)
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
.map(v -> (Map<String, Object>) v).map(MapResult::new);
// POST https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/chat-bison:predict
Expand Down Expand Up @@ -271,4 +273,27 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Strin
}
*/
}

@Procedure("apoc.ml.vertexai.stream")
@Description("apoc.ml.vertexai.stream(contents, accessToken, project, configuration) - prompts the streaming API")
public Stream<MapResult> stream(@Name("messages") List<Map<String, String>> contents,
@Name("accessToken") String accessToken,
@Name("project") String project,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens");

return executeRequest(accessToken, project, configuration, "gemini-pro", contents, parameterKeys, urlAccessChecker, VertexAIHandler.Type.STREAM)
.flatMap(v -> ((List<Map<String, Object>>) v).stream())
.map(MapResult::new);
}

@Procedure("apoc.ml.vertexai.custom")
@Description("apoc.ml.vertexai.custom(contents, accessToken, project, configuration) - prompts a customizable API")
public Stream<ObjectResult> custom(@Name(value = "body") Map<String, Object> body,
@Name("accessToken") String accessToken,
@Name("project") String project,
@Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
return executeRequest(accessToken, project, configuration, "gemini-pro", body, Collections.emptyList(), urlAccessChecker, VertexAIHandler.Type.CUSTOM)
.map(ObjectResult::new);
}
}
Loading
Loading