diff --git a/docs/asciidoc/modules/ROOT/pages/ml/vertexai.adoc b/docs/asciidoc/modules/ROOT/pages/ml/vertexai.adoc index efd3a6a6d7..4e8ede14bf 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/vertexai.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/vertexai.adoc @@ -4,6 +4,47 @@ NOTE: You need to create a Google Cloud project in your account an enable the Vertex.AI services. As an access-token you can run `gcloud auth print-access-token`. Using these services will incurr costs on your Google Cloud account. + +All the following procedures can have the following APOC config, i.e. in `apoc.conf` or via docker env variable +.Apoc configuration +|=== +|key | description | default +| apoc.ml.vertexai.url | the OpenAI endpoint base url | the `endpoint` configuration parameter value +|=== + +Moreover, they can have the following configuration keys, as the last parameter. + +.Common configuration parameter +[opts=header] +|=== +| key | description | default +| endpoint | analogous to `apoc.ml.vertexai.url` APOC config | https://\{region\}-aiplatform.googleapis.com/v1/projects/\{project\}/locations/\{region\}/publishers/google/models/\{model\}:\{resource\} +| headers | to add or edit the HTTP default header | + `{``Content-Type``: "application/json", ``Accept``: "application/json", ``Authorization``: "Bearer " + <$accessToken, i.e. 2nd parameter> }` +| model | The Vertex AI model | depends on the procedure +| region | The Vertex AI region | us-central1 +| resource | The Vertex AI resource (see below) | depends on the procedure +| temperature, maxOutputTokens, maxDecodeSteps, topP, topK | Optional parameter which can be passed into the HTTP request. Depend on the API used | + {temperature: 0.3, maxOutputTokens: 256, maxDecodeSteps: 200, topP: 0.8, topK: 40} +|=== + +We can define the `endpoint` configuration as a full URL, e.g. `https://us-central1-aiplatform.googleapis.com/v1/projects/myVertexAIProject/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent`, +or define it via parameters that will then be replaced by the other configurations. + +For example, if we define no `endpoint` config., +the default one `https://\{region\}-aiplatform.googleapis.com/v1/projects/\{project\}/locations/\{region\}/publishers/google/models/\{model\}:\{resource\}` will be valued, where: + +- `\{model\}` will be model defined by the `model` configuration +- `\{region\}` defined by `region` configuration +- `\{project\}` defined by the 3rd parameter (`project`) +- `\{resource\}` defined by `resource` configuration + +Or else, we can define an `endpoint` as `https://us-central1-aiplatform.googleapis.com/v1/projects/\{project\}/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent`, +and in this case we just substitute `\{project\}` with the 3rd parameter. + + +Let's see some example. + == Generate Embeddings API This procedure `apoc.ml.vertexai.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 768 element vector. @@ -128,3 +169,156 @@ yield value |name | description | value | result entry from Vertex.AI (containing candidates(author, content), safetyAttributes(categories, scores, blocked), recitationResults(recitationAction, recitations)) |=== + + +== Streaming API + +This procedure `apoc.ml.vertexai.stream` takes a list of maps of contents exchanges between assistant and user (with optional system context), and will return the next message in the flow. + +By default, it uses the https://cloud.google.com/vertex-ai/docs/generative-ai/multimodal/overview[Gemini AI APIs]. + + +[source,cypher] +---- +CALL apoc.ml.vertexai.stream([{role: "user", parts: [{text: "translate book in italian"}]}], '', '') +---- + +.Results +[opts="header"] +|=== +| value +| `{finishReason:"STOP", safetyRatings:[{probability:"NEGLIGIBLE", category:"HARM_CATEGORY_HARASSMENT"}, {probability:"NEGLIGIBLE", category:"HARM_CATEGORY_HATE_SPEECH"}, {probability:"NEGLIGIBLE", category:"HARM_CATEGORY_SEXUALLY_EXPLICIT"}, {probability:"NEGLIGIBLE", category:"HARM_CATEGORY_DANGEROUS_CONTENT"}], content:{role:"model", parts:[{text:"Libro"}]}}` +|=== + + +We can adjust the parameter, for example `temperature` + +[source,cypher] +---- +CALL apoc.ml.vertexai.stream([{role: "user", parts: [{text: "translate book in italian"}]}], '', '', + {temperature: 0}) +---- + +which corresponds to the following Http body request, where `maxOutputTokens`, `topP` and `topK` have the default values specified above (`Common configuration parameter`): +---- +{ + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "translate book in italian" + } + ] + } + ], + "generation_config": { + "temperature": 0, + "maxOutputTokens": 256, + "topP": 0.8, + "topK": 40 + } +} +---- + + +== Custom API + +Using this procedure we can potentially invoke any API available with vertex AI. + +To permit maximum flexibility, in this case the first parameter is not manipulated and exactly matches the body of the HTTP request, +and the return type is `ANY`. + + +.Gemini Pro Vision example +[source,cypher] +---- +CALL apoc.ml.vertexai.custom({ + contents: [ + { + role: "user", + parts: [ + {text: "What is this?"}, + {inlineData: { + mimeType: "image/png", + data: ''} + } + ] + } + ] +}, +"", +"", +{model: 'gemini-pro-vision'} +) +---- + +.Results +[opts="header"] +|=== +| value +| `[{usageMetadata: {promptTokenCount: 262, totalTokenCount: 272, candidatesTokenCount: 10}, candidates: [{content: {role: "model", parts: [{text: " This is a photo of a book..."}]}, finishReason: "STOP", safetyRatings: [{category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability│ +│: "NEGLIGIBLE"}]}]}]` +|=== + + + + +[source,cypher] +---- +CALL apoc.ml.vertexai.custom({contents: {role: "user", parts: [{text: "translate book in italian"}]}}, + "", + "", + {endpoint: "https://us-central1-aiplatform.googleapis.com/v1/projects/{project}/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent"} +) +---- + +.Results +[opts="header"] +|=== +| value +| `[{usageMetadata: {promptTokenCount: 4, totalTokenCount: 5, candidatesTokenCount: 1}, candidates: [{content: {role: "model", parts: [{text: "libro"}]}, finishReason: "STOP", safetyRatings: [{category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE"}]}]}]` +|=== + + + +[source,cypher] +---- +CALL apoc.ml.vertexai.custom({contents: {role: "user", parts: [{text: "translate book in italian"}]}}, + "", + null, + {endpoint: "https://us-central1-aiplatform.googleapis.com/v1/projects/vertex-project-413513/locations/us-central1/publishers/google/models/gemini-pro-vision:streamGenerateContent"} +) +---- + +.Results +[opts="header"] +|=== +| value +| `[{usageMetadata: {promptTokenCount: 4, totalTokenCount: 5, candidatesTokenCount: 1}, candidates: [{content: {role: "model", parts: [{text: "libro"}]}, finishReason: "STOP", safetyRatings: [{category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE"}, {category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE"}]}]}]` +|=== + + +Moreover, we can use with other Google API with endpoints that don't start with `https://-aiplatform.googleapis.com`, +for example we can use the https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize[Text-to-Speech API]: + +[source,cypher] +---- +CALL apoc.ml.vertexai.custom( +{ + input:{ + text:'just a test' + }, + voice:{ + languageCode:'en-US', + name:'en-US-Studio-O' + }, + audioConfig:{ + audioEncoding:'LINEAR16', + speakingRate:1 + } +}, +"", +"", +{endpoint: "https://texttospeech.googleapis.com/v1/text:synthesize"}) +---- \ No newline at end of file diff --git a/extended/src/main/java/apoc/ExtendedApocConfig.java b/extended/src/main/java/apoc/ExtendedApocConfig.java index 3e30fb4e78..953d79a49d 100644 --- a/extended/src/main/java/apoc/ExtendedApocConfig.java +++ b/extended/src/main/java/apoc/ExtendedApocConfig.java @@ -33,6 +33,7 @@ public class ExtendedApocConfig extends LifecycleAdapter public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url"; public static final String APOC_ML_OPENAI_TYPE = "apoc.ml.openai.type"; public static final String APOC_ML_OPENAI_AZURE_VERSION = "apoc.ml.azure.api.version"; + public static final String APOC_ML_VERTEXAI_URL = "apoc.ml.vertexai.url"; public static final String APOC_ML_WATSON_PROJECT_ID = "apoc.ml.watson.project.id"; public static final String APOC_ML_WATSON_URL = "apoc.ml.watson.url"; public static final String APOC_AWS_KEY_ID = "apoc.aws.key.id"; diff --git a/extended/src/main/java/apoc/ml/VertexAI.java b/extended/src/main/java/apoc/ml/VertexAI.java index 413d2fdf41..c2c50fbbe4 100644 --- a/extended/src/main/java/apoc/ml/VertexAI.java +++ b/extended/src/main/java/apoc/ml/VertexAI.java @@ -1,10 +1,13 @@ package apoc.ml; +import apoc.ApocConfig; import apoc.Extended; import apoc.result.MapResult; +import apoc.result.ObjectResult; import apoc.util.JsonUtil; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; + import org.neo4j.graphdb.security.URLAccessChecker; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; @@ -13,6 +16,7 @@ import java.net.MalformedURLException; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -25,10 +29,8 @@ public class VertexAI { @Context public URLAccessChecker urlAccessChecker; - // "https://${region}-aiplatform.googleapis.com/v1/projects/${project}/locations/${region}/publishers/google/models/${model}:predict" - private static final String BASE_URL = "https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict"; - public static final String APOC_ML_VERTEXAI_URL = "apoc.ml.vertexai.url"; - public static final String DEFAULT_REGION = "us-central1"; + @Context + public ApocConfig apocConfig; public static class EmbeddingResult { public final long index; @@ -42,27 +44,27 @@ public EmbeddingResult(long index, String text, List embedding) { } } - private static Stream executeRequest(String accessToken, String project, Map configuration, String defaultModel, Object inputs, String jsonPath, Collection retainConfigKeys, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException { + private Stream executeRequest(String accessToken, String project, Map configuration, String defaultModel, Object inputs, Collection retainConfigKeys, URLAccessChecker urlAccessChecker) throws JsonProcessingException, MalformedURLException { + return executeRequest(accessToken, project, configuration, defaultModel, inputs, retainConfigKeys, urlAccessChecker, VertexAIHandler.Type.PREDICT); + } + + private Stream executeRequest(String accessToken, String project, Map configuration, String defaultModel, Object inputs, Collection retainConfigKeys, URLAccessChecker urlAccessChecker, + VertexAIHandler.Type vertexAIHandlerType) throws JsonProcessingException { if (accessToken == null || accessToken.isBlank()) throw new IllegalArgumentException("Access Token must not be empty"); - if (project == null || project.isBlank()) - throw new IllegalArgumentException("Project must not be empty"); - String urlTemplate = System.getProperty(APOC_ML_VERTEXAI_URL, BASE_URL); - - String model = configuration.getOrDefault("model", defaultModel).toString(); - String region = configuration.getOrDefault("region", DEFAULT_REGION).toString(); - String endpoint = String.format(urlTemplate, region, project, region, model); - Map headers = Map.of( - "Content-Type", "application/json", - "Accept", "application/json", - "Authorization", "Bearer " + accessToken - ); + Map headers = (Map) configuration.getOrDefault("headers", new HashMap<>()); + headers.putIfAbsent("Content-Type", "application/json"); + headers.putIfAbsent("Accept", "application/json"); + headers.putIfAbsent("Authorization", "Bearer " + accessToken); - Map data = Map.of("instances", inputs, "parameters", getParameters(configuration, retainConfigKeys)); + VertexAIHandler vertexAIHandler = vertexAIHandlerType.get(); + Map data = vertexAIHandler.getBody(inputs, configuration, retainConfigKeys); String payload = new ObjectMapper().writeValueAsString(data); - return JsonUtil.loadJson(endpoint, headers, payload, jsonPath, true, List.of(), urlAccessChecker); + String url = vertexAIHandler.getFullUrl(configuration, apocConfig, defaultModel, project); + String jsonPath = vertexAIHandler.getJsonPath(); + return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of(), urlAccessChecker); } @Procedure("apoc.ml.vertexai.embedding") @@ -93,7 +95,7 @@ public Stream getEmbedding(@Name("texts") List texts, @ } */ Object inputs = texts.stream().map(text -> Map.of("content", text)).toList(); - Stream resultStream = executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, "$.predictions", List.of(), urlAccessChecker); + Stream resultStream = executeRequest(accessToken, project, configuration, "textembedding-gecko", inputs, List.of(), urlAccessChecker); AtomicInteger ai = new AtomicInteger(); return resultStream .flatMap(v -> ((List>) v).stream()) @@ -153,13 +155,13 @@ public Stream completion(@Name("prompt") String prompt, @Name("access */ Object input = List.of(Map.of("prompt",prompt)); var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens"); - var resultStream = executeRequest(accessToken, project, configuration, "text-bison", input, "$.predictions", parameterKeys, urlAccessChecker); + var resultStream = executeRequest(accessToken, project, configuration, "text-bison", input, parameterKeys, urlAccessChecker); return resultStream .flatMap(v -> ((List>) v).stream()) .map(v -> (Map) v).map(MapResult::new); } - private static Map getParameters(Map config, Collection retainKeys) { + public static Map getParameters(Map config, Collection retainKeys) { /* "temperature": TEMPERATURE, "maxOutputTokens": MAX_OUTPUT_TOKENS, @@ -192,7 +194,7 @@ public Stream chatCompletion(@Name("messages") List ((List>) v).stream()) .map(v -> (Map) v).map(MapResult::new); // POST https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/chat-bison:predict @@ -271,4 +273,27 @@ public Stream chatCompletion(@Name("messages") List stream(@Name("messages") List> contents, + @Name("accessToken") String accessToken, + @Name("project") String project, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + var parameterKeys = List.of("temperature", "topK", "topP", "maxOutputTokens"); + + return executeRequest(accessToken, project, configuration, "gemini-pro", contents, parameterKeys, urlAccessChecker, VertexAIHandler.Type.STREAM) + .flatMap(v -> ((List>) v).stream()) + .map(MapResult::new); + } + + @Procedure("apoc.ml.vertexai.custom") + @Description("apoc.ml.vertexai.custom(contents, accessToken, project, configuration) - prompts a customizable API") + public Stream custom(@Name(value = "body") Map body, + @Name("accessToken") String accessToken, + @Name("project") String project, + @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { + return executeRequest(accessToken, project, configuration, "gemini-pro", body, Collections.emptyList(), urlAccessChecker, VertexAIHandler.Type.CUSTOM) + .map(ObjectResult::new); + } } \ No newline at end of file diff --git a/extended/src/main/java/apoc/ml/VertexAIHandler.java b/extended/src/main/java/apoc/ml/VertexAIHandler.java new file mode 100644 index 0000000000..0e1314279c --- /dev/null +++ b/extended/src/main/java/apoc/ml/VertexAIHandler.java @@ -0,0 +1,122 @@ +package apoc.ml; + + +import apoc.ApocConfig; + +import java.util.Collection; +import java.util.Map; +import java.util.Objects; + +import static apoc.ExtendedApocConfig.APOC_ML_VERTEXAI_URL; +import static apoc.ml.VertexAI.getParameters; +import static org.apache.commons.lang3.StringUtils.isBlank; + +public abstract class VertexAIHandler { + public static final String ENDPOINT_CONF_KEY = "endpoint"; + public static final String MODEL_CONF_KEY = "model"; + public static final String RESOURCE_CONF_KEY = "resource"; + + public static final String STREAM_RESOURCE = "streamGenerateContent"; + public static final String PREDICT_RESOURCE = "predict"; + + private static final String DEFAULT_BASE_URL = "https://{region}-aiplatform.googleapis.com/v1/projects/{project}/locations/{region}/publishers/google/models/{model}:{resource}"; + public static final String DEFAULT_REGION = "us-central1"; + + public abstract String getDefaultResource(); + + public abstract Map getBody(Object inputs, Map configuration, Collection retainKeys); + + public abstract String getJsonPath(); + + public String getFullUrl(Map configuration, ApocConfig apocConfig, String defaultModel, String project) { + String model = configuration.getOrDefault(MODEL_CONF_KEY, defaultModel).toString(); + String region = configuration.getOrDefault("region", DEFAULT_REGION).toString(); + String resource = configuration.getOrDefault(RESOURCE_CONF_KEY, getDefaultResource()).toString(); + project = Objects.toString(project, ""); + String endpoint = getUrlTemplate(configuration, apocConfig); + + if (isBlank(endpoint) && isBlank(project)) { + throw new IllegalArgumentException("Either project parameter or endpoint config. must not be empty"); + } + return endpoint.replace("{region}", region) + .replace("{project}", project) + .replace("{model}", model) + .replace("{resource}", resource); + } + + private String getUrlTemplate(Map procConfig, ApocConfig apocConfig) { + return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY, + apocConfig.getString(APOC_ML_VERTEXAI_URL, System.getProperty(APOC_ML_VERTEXAI_URL, DEFAULT_BASE_URL))); + } + + enum Type { + PREDICT(new Predict()), + STREAM(new Stream()), + CUSTOM(new Custom()); + + private final VertexAIHandler handler; + Type(VertexAIHandler handler) { + this.handler = handler; + } + + public VertexAIHandler get() { + return handler; + } + } + + private static class Predict extends VertexAIHandler { + + @Override + public String getDefaultResource() { + return PREDICT_RESOURCE; + } + + @Override + public Map getBody(Object inputs, Map configuration, Collection retainKeys) { + return Map.of("instances", inputs, "parameters", getParameters(configuration, retainKeys)); + } + + @Override + public String getJsonPath() { + return "$.predictions"; + } + } + + private static class Stream extends VertexAIHandler { + + @Override + public String getDefaultResource() { + return STREAM_RESOURCE; + } + + @Override + public Map getBody(Object inputs, Map configuration, Collection retainKeys) { + return Map.of("contents", inputs, "generation_config", getParameters(configuration, retainKeys)); + } + + @Override + public String getJsonPath() { + return "$[0].candidates"; + } + } + + private static class Custom extends VertexAIHandler { + + @Override + public String getDefaultResource() { + return STREAM_RESOURCE; + } + + @Override + public Map getBody(Object inputs, Map configuration, Collection retainKeys) { + return (Map) inputs; + } + + @Override + public String getJsonPath() { + return null; + } + } +} + + \ No newline at end of file diff --git a/extended/src/main/resources/extended.txt b/extended/src/main/resources/extended.txt index e22c91371c..586fedecfb 100644 --- a/extended/src/main/resources/extended.txt +++ b/extended/src/main/resources/extended.txt @@ -109,7 +109,9 @@ apoc.ml.sagemaker.custom apoc.ml.sagemaker.embedding apoc.ml.vertexai.chat apoc.ml.vertexai.completion +apoc.ml.vertexai.custom apoc.ml.vertexai.embedding +apoc.ml.vertexai.stream apoc.ml.watson.chat apoc.ml.watson.completion apoc.model.jdbc diff --git a/extended/src/test/java/apoc/ml/VertexAIIT.java b/extended/src/test/java/apoc/ml/VertexAIIT.java index e70f49fd61..103ccddd05 100644 --- a/extended/src/test/java/apoc/ml/VertexAIIT.java +++ b/extended/src/test/java/apoc/ml/VertexAIIT.java @@ -1,6 +1,7 @@ package apoc.ml; import apoc.util.TestUtil; +import org.apache.commons.io.FileUtils; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -8,16 +9,33 @@ import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; +import java.io.File; +import java.io.IOException; +import java.util.Base64; +import java.util.HashMap; import java.util.List; import java.util.Map; +import static apoc.ml.VertexAIHandler.ENDPOINT_CONF_KEY; +import static apoc.ml.VertexAIHandler.MODEL_CONF_KEY; +import static apoc.ml.VertexAIHandler.PREDICT_RESOURCE; +import static apoc.ml.VertexAIHandler.RESOURCE_CONF_KEY; +import static apoc.ml.VertexAIHandler.STREAM_RESOURCE; import static apoc.util.TestUtil.testCall; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; public class VertexAIIT { private String vertexAiKey; private String vertexAiProject; + + private final List> streamContents = List.of( + Map.of("role", "user", + "parts", List.of(Map.of("text", "translate book in italian")) + ) + ); @Rule public DbmsRule db = new ImpermanentDbmsRule(); @@ -51,52 +69,144 @@ public void getEmbedding() { @Test public void completion() { testCall(db, "CALL apoc.ml.vertexai.completion('What color is the sky? Answer in one word: ', $apiKey, $project)", parameters,(row) -> { - System.out.println("row = " + row); - var result = (Map)row.get("value"); - // {value={safetyAttributes={blocked=false, scores=[0.1], categories=[Sexual]}, recitationResult={recitations=[], recitationAction=NO_ACTION}, content=blue}} - assertEquals(true, result.containsKey("safetyAttributes")); - var safetyAttributes = (Map) result.get("safetyAttributes"); - assertEquals(false, safetyAttributes.get("blocked")); - assertEquals(true, safetyAttributes.containsKey("categories")); - - String text = (String) result.get("content"); - assertEquals(true, text != null && !text.isBlank()); - assertEquals(true, text.toLowerCase().contains("blue")); - - assertEquals(true, result.containsKey("recitationResult")); - assertEquals("NO_ACTION", ((Map)result.get("recitationResult")).get("recitationAction")); + assertCorrectResponse(row, "blue"); }); } @Test public void chatCompletion() { testCall(db, """ -CALL apoc.ml.vertexai.chat([ -{author:"user", content:"What planet do timelords live on?"} -], $apiKey, $project, {temperature:0}, -"Fictional universe of Doctor Who. Only answer with a single word!", -[{input:{content:"What planet do humans live on?"}, output:{content:"Earth"}}]) -""", parameters, (row) -> { - System.out.println("row = " + row); - // {value={candidates=[{author=1, content=Gallifrey.}], safetyAttributes={blocked=false, scores=[0.1, 0.1, 0.1], categories=[Religion & Belief, Sexual, Toxic]}, recitationResults=[{recitations=[], recitationAction=NO_ACTION}]}} - var result = (Map)row.get("value"); + CALL apoc.ml.vertexai.chat([ + {author:"user", content:"What planet do timelords live on?"} + ], $apiKey, $project, {temperature:0}, + "Fictional universe of Doctor Who. Only answer with a single word!", + [{input:{content:"What planet do humans live on?"}, output:{content:"Earth"}}])""", + parameters, + (row) -> assertCorrectResponse(row, "gallifrey")); + } - assertEquals(true, result.containsKey("safetyAttributes")); - var safetyAttributes = (Map) result.get("safetyAttributes"); - assertEquals(false, safetyAttributes.get("blocked")); - assertEquals(true, safetyAttributes.containsKey("categories")); - assertEquals(3, ((List)safetyAttributes.get("categories")).size()); + @Test + public void stream() { + HashMap params = new HashMap<>(parameters); + params.put("contents", streamContents); + testCall(db, "CALL apoc.ml.vertexai.stream($contents,$apiKey, $project)", + params, (row) -> { + assertCorrectResponse(row, "libro"); + }); + } + + @Test + public void customWithCompleteString() { + HashMap params = new HashMap<>(parameters); + params.put("contents", streamContents); + String endpoint = "https://us-central1-aiplatform.googleapis.com/v1/projects/" + vertexAiProject + "/locations/us-central1/publishers/google/models/gemini-pro-vision:" + STREAM_RESOURCE; + params.put(ENDPOINT_CONF_KEY, endpoint); + testCall(db, " CALL apoc.ml.vertexai.custom({contents: $contents}, $apiKey, null, {endpoint: $endpoint})", + params, + (row) -> assertCorrectResponse(row, "libro")); + } + + @Test + public void customWithStringFormat() { + HashMap params = new HashMap<>(parameters); + params.put("contents", streamContents); + String endpoint = "https://us-central1-aiplatform.googleapis.com/v1/projects/{project}/locations/us-central1/publishers/google/models/gemini-pro-vision:" + STREAM_RESOURCE; + params.put(ENDPOINT_CONF_KEY, endpoint); + testCall(db, "CALL apoc.ml.vertexai.custom({contents: $contents}, $apiKey, $project, {endpoint: $endpoint})", + params, + (row) -> assertCorrectResponse(row, "libro")); + } - assertEquals(true, result.containsKey("recitationResults")); - assertEquals("NO_ACTION", ((List)result.get("recitationResults")).get(0).get("recitationAction")); + @Test + public void customWithGeminiVisionMultiType() throws IOException { + String path = Thread.currentThread().getContextClassLoader().getResource("tarallo.png").getPath(); + + byte[] fileContent = FileUtils.readFileToByteArray(new File(path)); + String base64Image = Base64.getEncoder().encodeToString(fileContent); + + List> parts = List.of( + Map.of("text", "What is this?"), + Map.of("inlineData", Map.of( + "mimeType", "image/png", "data", base64Image)) + ); + List> contents = List.of( + Map.of("role", "user", "parts", parts) + ); + Map params = new HashMap<>(parameters); + params.put("contents", contents); + params.put("conf", Map.of(MODEL_CONF_KEY, "gemini-pro-vision")); - var candidates = (List>)result.get("candidates"); - var author = candidates.get(0).get("author"); - assertEquals("1", author); + testCall(db, """ + CALL apoc.ml.vertexai.custom({contents: $contents}, + $apiKey, + $project, + $conf)""", + params, + (row) -> assertCorrectResponse(row, "tarall")); + } - var text = (String)candidates.get(0).get("content"); - assertEquals(true, text != null && !text.isBlank()); - assertEquals(true, text.toLowerCase().contains("gallifrey")); - }); + @Test + public void customWithSuffix() { + HashMap params = new HashMap<>(parameters); + params.put("contents", streamContents); + + testCall(db, "CALL apoc.ml.vertexai.custom({contents: $contents}, $apiKey, $project)", + params, + (row) -> assertCorrectResponse(row, "libro")); + } + + @Test + public void customWithCodeBison() { + Map params = new HashMap<>(parameters); + params.put("conf", Map.of(MODEL_CONF_KEY, "codechat-bison", RESOURCE_CONF_KEY, PREDICT_RESOURCE)); + + testCall(db, """ + CALL apoc.ml.vertexai.custom({instances: + [{messages: [{author: "user", content: "Who are you?"}]}] + }, + $apiKey, $project, $conf)""", + params, + (row) -> assertCorrectResponse(row, "language model")); + } + + @Test + public void customWithChatCompletion() { + Map params = new HashMap<>(parameters); + params.put("conf", Map.of(MODEL_CONF_KEY, "chat-bison", RESOURCE_CONF_KEY, PREDICT_RESOURCE)); + + testCall(db, """ + CALL apoc.ml.vertexai.custom({instances: + [{messages: [{author: "user", content: "What planet do human live on?"}]}] + }, + $apiKey, $project, $conf)""", + params, + (row) -> assertCorrectResponse(row, "earth")); + } + + @Test + public void customWithWrongHeader() { + Map headers = Map.of("Content-Type", "invalid", + "Authorization", "invalid"); + + try { + testCall(db, """ + CALL apoc.ml.vertexai.custom( + { + contents: $contents + }, $apiKey, $project, {headers: $headers}) + """, Map.of("apiKey", vertexAiKey, + "project", vertexAiProject, + "headers", headers, + "contents", streamContents), (row) -> fail("Should fail due to 401 response")); + } catch (RuntimeException e) { + String errMsg = e.getMessage(); + assertTrue(errMsg.contains("Server returned HTTP response code: 401"), "Current err. message is:" + errMsg); + } + } + + private void assertCorrectResponse(Map row, String expected) { + String stringRow = row.toString(); + assertTrue(stringRow.toLowerCase().contains(expected), + "Actual result is: " + stringRow); } } \ No newline at end of file diff --git a/extended/src/test/java/apoc/ml/VertexAITest.java b/extended/src/test/java/apoc/ml/VertexAITest.java index 205e668154..c64fb5ef41 100644 --- a/extended/src/test/java/apoc/ml/VertexAITest.java +++ b/extended/src/test/java/apoc/ml/VertexAITest.java @@ -1,5 +1,6 @@ package apoc.ml; +import apoc.ExtendedApocConfig; import apoc.util.TestUtil; import org.apache.commons.io.FileUtils; import org.junit.BeforeClass; @@ -51,9 +52,9 @@ public static void startServer() throws URISyntaxException { mockServer = startClientAndServer(1080); var path = Paths.get(getUrlFileName(VERTEX_MOCK_FOLDER + EMBEDDINGS).toURI()).getParent().toUri(); - // %2$s will be substituted by project parameter, - // see String.format(urlTemplate, region, project, region, model) in VertexAi.executeRequest method - System.setProperty(VertexAI.APOC_ML_VERTEXAI_URL, path + "%2$s"); + // {project} will be substituted by project parameter, + // see getFullUrl method in VertexAIHandler.java + System.setProperty(ExtendedApocConfig.APOC_ML_VERTEXAI_URL, path + "{project}"); Stream.of(EMBEDDINGS, COMPLETION, CHAT_COMPLETION) .forEach(VertexAITest::setRequestResponse); diff --git a/extended/src/test/resources/tarallo.png b/extended/src/test/resources/tarallo.png new file mode 100644 index 0000000000..80a8274af1 Binary files /dev/null and b/extended/src/test/resources/tarallo.png differ