diff --git a/mira/dkg/client.py b/mira/dkg/client.py index 43b1b87d3..c1c9d4299 100644 --- a/mira/dkg/client.py +++ b/mira/dkg/client.py @@ -550,7 +550,7 @@ def get_grounder_terms(self, prefix: str) -> List["gilda.term.Term"]: query = dedent( f"""\ MATCH (n:{prefix}) - WHERE NOT n.obsolete and EXISTS(n.name) + WHERE NOT n.obsolete and n.name IS NOT NULL RETURN n.id, n.name, n.synonyms """ ) @@ -564,7 +564,7 @@ def get_grounder_terms(self, prefix: str) -> List["gilda.term.Term"]: def get_lexical(self) -> List[Entity]: """Get Lexical information for all entities.""" - query = f"MATCH (n) WHERE NOT n.obsolete and EXISTS(n.name) RETURN n" + query = f"MATCH (n) WHERE NOT n.obsolete and n.name IS NOT NULL RETURN n" return [Entity.from_data(n) for n, in self.query_tx(query) or []] def get_grounder(self, prefix: Union[str, List[str]]) -> "gilda.grounder.Grounder": @@ -660,7 +660,7 @@ def _search(self, query: str) -> List[Entity]: f"""\ MATCH (n) WHERE - EXISTS(n.name) + n.name IS NOT NULL AND ( replace(replace(toLower(n.name), '-', ''), '_', '') CONTAINS '{query_lower}' OR any( diff --git a/mira/dkg/templates/home.html b/mira/dkg/templates/home.html index ee860bbc3..b9f218783 100644 --- a/mira/dkg/templates/home.html +++ b/mira/dkg/templates/home.html @@ -27,6 +27,14 @@

Grounding

to see grounding results such as /api/ground/vaccine

+ {% if llm_ui %} +

LLM UI

+

+ The LLM UI provides a user interface to extract an ODE model from an image + of a system of ODEs. Click here + to access the LLM UI. +

+ {% endif %}

Summary

diff --git a/mira/dkg/ui.py b/mira/dkg/ui.py index 873b86790..04961a4b2 100644 --- a/mira/dkg/ui.py +++ b/mira/dkg/ui.py @@ -1,6 +1,5 @@ -import random - -from flask import Blueprint, render_template +from flask import Blueprint, render_template, current_app +from mira.sources.sympy_ode.proxies import OPEN_AI_CLIENT from .proxies import client, grounder @@ -14,11 +13,13 @@ def home(): """Render the home page.""" node_counter = client.get_node_counter() node_total = sum(node_counter.values()) + llm_ui = OPEN_AI_CLIENT in current_app.extensions return render_template( "home.html", number_terms=len(grounder.entries), node_counter=node_counter, node_total=node_total, + llm_ui=llm_ui, ) diff --git a/mira/dkg/wsgi.py b/mira/dkg/wsgi.py index 8fcccc03a..cb29ed102 100644 --- a/mira/dkg/wsgi.py +++ b/mira/dkg/wsgi.py @@ -108,6 +108,19 @@ def startup_event(): for curie, *parts in reader } + # If the OpenAI API key is set, enable the LLM UI + if api_key := os.environ.get("OPENAI_API_KEY"): + from mira.openai import OpenAIClient + from mira.sources.sympy_ode.llm_ui import llm_ui_blueprint + from mira.sources.sympy_ode.proxies import OPEN_AI_CLIENT + openai_client = OpenAIClient(api_key) + flask_app.extensions[OPEN_AI_CLIENT] = openai_client + flask_app.register_blueprint(llm_ui_blueprint) + else: + logger.warning( + "OpenAI API key not found in environment, LLM capabilities will be disabled" + ) + # Set MIRA_NEO4J_URL in the environment # to point this somewhere specific client = Neo4jClient() diff --git a/mira/openai/__init__.py b/mira/openai/__init__.py new file mode 100644 index 000000000..3dd784eed --- /dev/null +++ b/mira/openai/__init__.py @@ -0,0 +1,11 @@ +try: + import openai + from .client import OpenAIClient +except ImportError as ierr: + if 'openai' in str(ierr): + raise ImportError( + "The openai python package is needed to use the mira openai module is not " + "installed. Run `pip install openai` to install it." + ) from ierr + else: + raise ierr diff --git a/mira/openai/client.py b/mira/openai/client.py new file mode 100644 index 000000000..12154b7ce --- /dev/null +++ b/mira/openai/client.py @@ -0,0 +1,124 @@ +import base64 +from typing import Literal + +from openai import OpenAI + + +ImageFmts = Literal["jpeg", "jpg", "png", "webp", "gif"] + + +class OpenAIClient: + + def __init__(self, api_key: str = None): + self.client = OpenAI(api_key=api_key) + + def run_chat_completion_with_image( + self, + message: str, + base64_image: str, + model: str = "gpt-4o-mini", + image_format: ImageFmts = "jpeg", + max_tokens: int = 2048, + ): + """Run the OpenAI chat completion with an image + + Parameters + ---------- + message : + The prompt to send for chat completion together with the image + base64_image : + The image data as a base64 string + model : + The model to use. The default is the gpt-4o-mini model. + image_format : + The format of the image. The default is "jpeg". Currently supports + "jpeg", "jpg", "png", "webp", "gif". GIF images cannot be animated. + max_tokens : + The maximum number of tokens to generate for chat completion. One + token is roughly one word in plain text, however it can be more per + word in some cases. The default is 150. + + Returns + ------- + : + The response from OpenAI as a string. + """ + response = self.client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": message, + }, + { + "type": "image_url", + "image_url": { + # Supports PNG, JPEG, WEBP, non-animated GIF + "url": f"data:image/{image_format};base64,{base64_image}" + }, + }, + ], + } + ], + max_tokens=max_tokens, + ) + return response.choices[0] + + def run_chat_completion_with_image_url( + self, + message: str, + image_url: str, + model: str = "gpt-4o-mini", + max_tokens: int = 2048, + ): + """Run the OpenAI chat completion with an image URL + + Parameters + ---------- + message : + The prompt to send for chat completion together with the image + image_url : + The URL of the image + model : + The model to use. The default is the gpt-4o-mini model. + max_tokens : + The maximum number of tokens to generate for chat completion. One + token is roughly one word in plain text, however it can be more per + word in some cases. The default is 150. + + Returns + ------- + : + The response from OpenAI + """ + response = self.client.chat.completions.create( + model=model, + messages=[ + { + "role": "user", + "content": [ + { + "type": "text", + "text": message, + }, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + ], + } + ], + max_tokens=max_tokens, + ) + return response.choices[0] + + +# encode an image file +def encode_image(image_path: str): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") diff --git a/mira/sources/sympy_ode/app.py b/mira/sources/sympy_ode/app.py new file mode 100644 index 000000000..4bbc9d11a --- /dev/null +++ b/mira/sources/sympy_ode/app.py @@ -0,0 +1,29 @@ +""" +To run the LLM UI as a standalone app: +1. Set the OPENAI_API_KEY environment variable to your OpenAI API key. +2. Have the openai Python package installed (pip install openai). +3. Run with `python -m mira.sources.sympy_ode.app`. Optionally, pass in `debug` + as an argument to run in debug mode (will reload the server on changes). +""" +import os +from flask import Flask +from mira.openai import OpenAIClient +from .llm_ui import llm_ui_blueprint +from .proxies import OPEN_AI_CLIENT + +try: + os.environ["OPENAI_API_KEY"] +except KeyError: + raise ValueError("Set the OPENAI_API_KEY environment variable to run the app") + + +app = Flask(__name__) +app.extensions[OPEN_AI_CLIENT] = OpenAIClient() + +app.register_blueprint(llm_ui_blueprint) + + +if __name__ == "__main__": + import sys + debug = len(sys.argv) > 1 and sys.argv[1].lower() == "debug" + app.run(debug=debug, port=5000) diff --git a/mira/sources/sympy_ode/llm_ui.py b/mira/sources/sympy_ode/llm_ui.py new file mode 100644 index 000000000..8bc6092d9 --- /dev/null +++ b/mira/sources/sympy_ode/llm_ui.py @@ -0,0 +1,125 @@ +import base64 +from typing import List + +from flask import Blueprint, render_template, request +from sympy import latex + +from mira.openai import OpenAIClient +from mira.modeling import Model +from mira.metamodel import TemplateModel +from mira.modeling.ode import OdeModel +from mira.modeling.amr.petrinet import AMRPetriNetModel +from mira.sources.sympy_ode import template_model_from_sympy_odes + +from .proxies import openai_client + + +llm_ui_blueprint = Blueprint("llm", __name__, url_prefix="/llm") + +# Attach the template in this module to the blueprint +llm_ui_blueprint.template_folder = "templates" + + +def convert(base64_image, image_format, client: OpenAIClient, prompt: str = None): + if prompt is None: + prompt = """Transform these equations into a sympy representation based on the example style below + +```python +# Define time variable +t = sympy.symbols("t") + +# Define the time-dependent variables +S, E, I, R = sympy.symbols("S E I R", cls=sympy.Function) + +# Define the parameters +b, g, r = sympy.symbols("b g r") + +odes = [ + sympy.Eq(S(t).diff(t), - b * S(t) * I(t)), + sympy.Eq(E(t).diff(t), b * S(t) * I(t) - r * E(t)), + sympy.Eq(I(t).diff(t), r * E(t) - g * I(t)), + sympy.Eq(R(t).diff(t), g * I(t)) +] +``` + +Instead of using unicode characters, spell out in symbols in lowercase like theta, omega, etc. +Also, provide the code snippet only and no explanation.""" + + choice = client.run_chat_completion_with_image( + message=prompt, + base64_image=base64_image, + image_format=image_format, + ) + text_response = choice.message.content + if "```python" in text_response: + text_response = text_response.replace("```python", "", 1) + if "```" in text_response: + text_response = text_response.replace("```", "", 1) + return text_response + + +def execute_template_model_from_sympy_odes(ode_str) -> TemplateModel: + # FixMe, for now use `exec` on the code, but need to find a safer way to execute + # the code + # Import sympy just in case the code snippet does not import it + import sympy + odes: List[sympy.Eq] = None + # Execute the code and expose the `odes` variable to the local scope + local_dict = locals() + exec(ode_str, globals(), local_dict) + # `odes` should now be defined in the local scope + odes = local_dict.get("odes") + assert odes is not None, "The code should define a variable called `odes`" + return template_model_from_sympy_odes(odes) + + +@llm_ui_blueprint.route("/", methods=["GET", "POST"]) +def upload_image(): + result_text = None + ode_latex = None + petrinet_json_str = None + if request.method == "POST": + # Get the result_text from the form or the file uploaded + result_text = request.form.get("result_text") + file = request.files.get("file") + # If no file is selected or there is no result_text in the request + if not file and not result_text: + return render_template("index.html", error="No file part") + + # If a file is selected but the filename is empty and there is no result_text + if file is not None and file.filename == '' and not result_text: + return render_template("index.html", error="No selected file") + + # User uploaded a file but there is no result_text + if file and not result_text: + # Convert file to base64 + image_data = file.read() + base64_image = base64.b64encode(image_data).decode('utf-8') + # get the image format + image_format = file.content_type.split("/")[-1] + # Call the 'convert' function + result_text = convert( + base64_image=base64_image, + client=openai_client, + image_format=image_format + ) + + # User submitted a result_text for processing + elif result_text: + template_model = execute_template_model_from_sympy_odes(result_text) + # Get the OdeModel + om = OdeModel(model=Model(template_model=template_model), initialized=False) + ode_system = om.get_interpretable_kinetics() + # Make LaTeX representation of the ODE system + ode_latex = latex(ode_system) + + # Get the PetriNet JSON + petrinet_json_str = AMRPetriNetModel(Model(template_model)).to_json_str(indent=2) + + return render_template( + "index.html", + result_text=result_text, + sympy_input=result_text, + ode_latex=ode_latex, + petrinet_json=petrinet_json_str, + ) diff --git a/mira/sources/sympy_ode/proxies.py b/mira/sources/sympy_ode/proxies.py new file mode 100644 index 000000000..385bd7ff9 --- /dev/null +++ b/mira/sources/sympy_ode/proxies.py @@ -0,0 +1,8 @@ +from mira.openai import OpenAIClient + +from flask import current_app +from werkzeug.local import LocalProxy + + +OPEN_AI_CLIENT = "openai_client" +openai_client: OpenAIClient = LocalProxy(lambda: current_app.extensions[OPEN_AI_CLIENT]) diff --git a/mira/sources/sympy_ode/templates/index.html b/mira/sources/sympy_ode/templates/index.html new file mode 100644 index 000000000..ff981ebb8 --- /dev/null +++ b/mira/sources/sympy_ode/templates/index.html @@ -0,0 +1,184 @@ + + + + + + Equation image to MIRA model + + + + + + + + + + + + + + + + + + +

+

Equation image to MIRA model

+
+
+
+
+ + +
+ +
+
+
+ {% if result_text %} + + + {% endif %} + {% if ode_latex %} +
+
+
Original Sympy code
+
+
{{ sympy_input }}
+
+
+
+ + + {% endif %} + {% if error %} + + {% endif %} +
+ + + + + diff --git a/setup.cfg b/setup.cfg index 1680ee3bd..eff9dbaf4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -86,6 +86,7 @@ web = bioregistry scipy numpy + openai uvicorn = uvicorn gunicorn = @@ -119,6 +120,8 @@ biomodels = pystow viz = pygraphviz +llm = + openai [mypy] plugins = pydantic.mypy diff --git a/tests/ode_system.png b/tests/ode_system.png new file mode 100644 index 000000000..bb7734631 Binary files /dev/null and b/tests/ode_system.png differ diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py new file mode 100644 index 000000000..36d7fa6cc --- /dev/null +++ b/tests/test_openai_client.py @@ -0,0 +1,87 @@ +import base64 +import os +import unittest +from pathlib import Path + +import requests + +from mira.openai import OpenAIClient + + +HERE = Path(__file__).parent.resolve() + + +@unittest.skipIf(os.environ.get("GITHUB_ACTIONS") is not None, + reason="Meant to be run locally") +@unittest.skipIf(os.environ.get("OPENAI_API_KEY") is None, + reason="Need OPENAI_API_KEY environment variable to run") +def test_explain_image(): + bananas_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9b/Cavendish_Banana_DS.jpg/640px-Cavendish_Banana_DS.jpg" + res = requests.get(bananas_url) + base64_image = base64.b64encode(res.content).decode("utf-8") + + client = OpenAIClient() + response = client.run_chat_completion_with_image( + message="What is in this image?", + base64_image=base64_image, + image_format="jpeg", + ) + assert "banana" in response.message.content, response.message.content + + +@unittest.skipIf(os.environ.get("GITHUB_ACTIONS") is not None, + reason="Meant to be run locally") +@unittest.skipIf(os.environ.get("OPENAI_API_KEY") is None, + reason="Need OPENAI_API_KEY environment variable to run") +def text_explain_image_url(): + bananas_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9b/Cavendish_Banana_DS.jpg/640px-Cavendish_Banana_DS.jpg" + client = OpenAIClient() + response = client.run_chat_completion_with_image_url( + message="What is in this image?", + image_url=bananas_url, + ) + assert "banana" in response.message.content, response.message.content + + +@unittest.skipIf(os.environ.get("GITHUB_ACTIONS") is not None, + reason="Meant to be run locally") +@unittest.skipIf(os.environ.get("OPENAI_API_KEY") is None, + reason="Need OPENAI_API_KEY environment variable to run") +def test_extract_odes(): + equations_image = HERE / "ode_system.png" + prompt = """Transform these equations into a sympy representation based on the example style below + +```python +# Define time variable +t = sympy.symbols("t") + +# Define the time-dependent variables +S, E, I, R = sympy.symbols("S E I R", cls=sympy.Function) + +# Define the parameters +b, g, r = sympy.symbols("b g r") + +odes = [ + sympy.Eq(S(t).diff(t), - b * S(t) * I(t)), + sympy.Eq(E(t).diff(t), b * S(t) * I(t) - r * E(t)), + sympy.Eq(I(t).diff(t), r * E(t) - g * I(t)), + sympy.Eq(R(t).diff(t), g * I(t)) +] +``` + +Instead of using unicode characters, spell out in symbols in lowercase like theta, omega, etc. + """ + client = OpenAIClient() + + # Load image and base64 encode it + with open(equations_image, "rb") as f: + base64_image = base64.b64encode(f.read()).decode("utf-8") + + response = client.run_chat_completion_with_image( + message=prompt, + image_format="png", + base64_image=base64_image, + max_tokens=1024 * 8, + ) + + print(response.message.content)