Skip to content

Commit

Permalink
Merge pull request #413 from gyorilab/image-to-sympy
Browse files Browse the repository at this point in the history
Image to sympy initial implementation
  • Loading branch information
bgyori authored Jan 3, 2025
2 parents f302c6c + 37cab43 commit 63d4308
Show file tree
Hide file tree
Showing 13 changed files with 599 additions and 6 deletions.
6 changes: 3 additions & 3 deletions mira/dkg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def get_grounder_terms(self, prefix: str) -> List["gilda.term.Term"]:
query = dedent(
f"""\
MATCH (n:{prefix})
WHERE NOT n.obsolete and EXISTS(n.name)
WHERE NOT n.obsolete and n.name IS NOT NULL
RETURN n.id, n.name, n.synonyms
"""
)
Expand All @@ -564,7 +564,7 @@ def get_grounder_terms(self, prefix: str) -> List["gilda.term.Term"]:

def get_lexical(self) -> List[Entity]:
"""Get Lexical information for all entities."""
query = f"MATCH (n) WHERE NOT n.obsolete and EXISTS(n.name) RETURN n"
query = f"MATCH (n) WHERE NOT n.obsolete and n.name IS NOT NULL RETURN n"
return [Entity.from_data(n) for n, in self.query_tx(query) or []]

def get_grounder(self, prefix: Union[str, List[str]]) -> "gilda.grounder.Grounder":
Expand Down Expand Up @@ -660,7 +660,7 @@ def _search(self, query: str) -> List[Entity]:
f"""\
MATCH (n)
WHERE
EXISTS(n.name)
n.name IS NOT NULL
AND (
replace(replace(toLower(n.name), '-', ''), '_', '') CONTAINS '{query_lower}'
OR any(
Expand Down
8 changes: 8 additions & 0 deletions mira/dkg/templates/home.html
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ <h4>Grounding</h4>
to see grounding results such as <a href="/api/ground/vaccine">
<code>/api/ground/vaccine</code></a>
</p>
{% if llm_ui %}
<h3>LLM UI</h3>
<p>
The LLM UI provides a user interface to extract an ODE model from an image
of a system of ODEs. Click <a href="{{ url_for("llm.upload_image") }}">here</a>
to access the LLM UI.
</p>
{% endif %}
<h3>Summary</h3>
<div>
<p>
Expand Down
7 changes: 4 additions & 3 deletions mira/dkg/ui.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random

from flask import Blueprint, render_template
from flask import Blueprint, render_template, current_app
from mira.sources.sympy_ode.proxies import OPEN_AI_CLIENT

from .proxies import client, grounder

Expand All @@ -14,11 +13,13 @@ def home():
"""Render the home page."""
node_counter = client.get_node_counter()
node_total = sum(node_counter.values())
llm_ui = OPEN_AI_CLIENT in current_app.extensions
return render_template(
"home.html",
number_terms=len(grounder.entries),
node_counter=node_counter,
node_total=node_total,
llm_ui=llm_ui,
)


Expand Down
13 changes: 13 additions & 0 deletions mira/dkg/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ def startup_event():
for curie, *parts in reader
}

# If the OpenAI API key is set, enable the LLM UI
if api_key := os.environ.get("OPENAI_API_KEY"):
from mira.openai import OpenAIClient
from mira.sources.sympy_ode.llm_ui import llm_ui_blueprint
from mira.sources.sympy_ode.proxies import OPEN_AI_CLIENT
openai_client = OpenAIClient(api_key)
flask_app.extensions[OPEN_AI_CLIENT] = openai_client
flask_app.register_blueprint(llm_ui_blueprint)
else:
logger.warning(
"OpenAI API key not found in environment, LLM capabilities will be disabled"
)

# Set MIRA_NEO4J_URL in the environment
# to point this somewhere specific
client = Neo4jClient()
Expand Down
11 changes: 11 additions & 0 deletions mira/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
try:
import openai
from .client import OpenAIClient
except ImportError as ierr:
if 'openai' in str(ierr):
raise ImportError(
"The openai python package is needed to use the mira openai module is not "
"installed. Run `pip install openai` to install it."
) from ierr
else:
raise ierr
124 changes: 124 additions & 0 deletions mira/openai/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import base64
from typing import Literal

from openai import OpenAI


ImageFmts = Literal["jpeg", "jpg", "png", "webp", "gif"]


class OpenAIClient:

def __init__(self, api_key: str = None):
self.client = OpenAI(api_key=api_key)

def run_chat_completion_with_image(
self,
message: str,
base64_image: str,
model: str = "gpt-4o-mini",
image_format: ImageFmts = "jpeg",
max_tokens: int = 2048,
):
"""Run the OpenAI chat completion with an image
Parameters
----------
message :
The prompt to send for chat completion together with the image
base64_image :
The image data as a base64 string
model :
The model to use. The default is the gpt-4o-mini model.
image_format :
The format of the image. The default is "jpeg". Currently supports
"jpeg", "jpg", "png", "webp", "gif". GIF images cannot be animated.
max_tokens :
The maximum number of tokens to generate for chat completion. One
token is roughly one word in plain text, however it can be more per
word in some cases. The default is 150.
Returns
-------
:
The response from OpenAI as a string.
"""
response = self.client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": message,
},
{
"type": "image_url",
"image_url": {
# Supports PNG, JPEG, WEBP, non-animated GIF
"url": f"data:image/{image_format};base64,{base64_image}"
},
},
],
}
],
max_tokens=max_tokens,
)
return response.choices[0]

def run_chat_completion_with_image_url(
self,
message: str,
image_url: str,
model: str = "gpt-4o-mini",
max_tokens: int = 2048,
):
"""Run the OpenAI chat completion with an image URL
Parameters
----------
message :
The prompt to send for chat completion together with the image
image_url :
The URL of the image
model :
The model to use. The default is the gpt-4o-mini model.
max_tokens :
The maximum number of tokens to generate for chat completion. One
token is roughly one word in plain text, however it can be more per
word in some cases. The default is 150.
Returns
-------
:
The response from OpenAI
"""
response = self.client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": message,
},
{
"type": "image_url",
"image_url": {
"url": image_url,
},
},
],
}
],
max_tokens=max_tokens,
)
return response.choices[0]


# encode an image file
def encode_image(image_path: str):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
29 changes: 29 additions & 0 deletions mira/sources/sympy_ode/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
To run the LLM UI as a standalone app:
1. Set the OPENAI_API_KEY environment variable to your OpenAI API key.
2. Have the openai Python package installed (pip install openai).
3. Run with `python -m mira.sources.sympy_ode.app`. Optionally, pass in `debug`
as an argument to run in debug mode (will reload the server on changes).
"""
import os
from flask import Flask
from mira.openai import OpenAIClient
from .llm_ui import llm_ui_blueprint
from .proxies import OPEN_AI_CLIENT

try:
os.environ["OPENAI_API_KEY"]
except KeyError:
raise ValueError("Set the OPENAI_API_KEY environment variable to run the app")


app = Flask(__name__)
app.extensions[OPEN_AI_CLIENT] = OpenAIClient()

app.register_blueprint(llm_ui_blueprint)


if __name__ == "__main__":
import sys
debug = len(sys.argv) > 1 and sys.argv[1].lower() == "debug"
app.run(debug=debug, port=5000)
125 changes: 125 additions & 0 deletions mira/sources/sympy_ode/llm_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import base64
from typing import List

from flask import Blueprint, render_template, request
from sympy import latex

from mira.openai import OpenAIClient
from mira.modeling import Model
from mira.metamodel import TemplateModel
from mira.modeling.ode import OdeModel
from mira.modeling.amr.petrinet import AMRPetriNetModel
from mira.sources.sympy_ode import template_model_from_sympy_odes

from .proxies import openai_client


llm_ui_blueprint = Blueprint("llm", __name__, url_prefix="/llm")

# Attach the template in this module to the blueprint
llm_ui_blueprint.template_folder = "templates"


def convert(base64_image, image_format, client: OpenAIClient, prompt: str = None):
if prompt is None:
prompt = """Transform these equations into a sympy representation based on the example style below
```python
# Define time variable
t = sympy.symbols("t")
# Define the time-dependent variables
S, E, I, R = sympy.symbols("S E I R", cls=sympy.Function)
# Define the parameters
b, g, r = sympy.symbols("b g r")
odes = [
sympy.Eq(S(t).diff(t), - b * S(t) * I(t)),
sympy.Eq(E(t).diff(t), b * S(t) * I(t) - r * E(t)),
sympy.Eq(I(t).diff(t), r * E(t) - g * I(t)),
sympy.Eq(R(t).diff(t), g * I(t))
]
```
Instead of using unicode characters, spell out in symbols in lowercase like theta, omega, etc.
Also, provide the code snippet only and no explanation."""

choice = client.run_chat_completion_with_image(
message=prompt,
base64_image=base64_image,
image_format=image_format,
)
text_response = choice.message.content
if "```python" in text_response:
text_response = text_response.replace("```python", "", 1)
if "```" in text_response:
text_response = text_response.replace("```", "", 1)
return text_response


def execute_template_model_from_sympy_odes(ode_str) -> TemplateModel:
# FixMe, for now use `exec` on the code, but need to find a safer way to execute
# the code
# Import sympy just in case the code snippet does not import it
import sympy
odes: List[sympy.Eq] = None
# Execute the code and expose the `odes` variable to the local scope
local_dict = locals()
exec(ode_str, globals(), local_dict)
# `odes` should now be defined in the local scope
odes = local_dict.get("odes")
assert odes is not None, "The code should define a variable called `odes`"
return template_model_from_sympy_odes(odes)


@llm_ui_blueprint.route("/", methods=["GET", "POST"])
def upload_image():
result_text = None
ode_latex = None
petrinet_json_str = None
if request.method == "POST":
# Get the result_text from the form or the file uploaded
result_text = request.form.get("result_text")
file = request.files.get("file")
# If no file is selected or there is no result_text in the request
if not file and not result_text:
return render_template("index.html", error="No file part")

# If a file is selected but the filename is empty and there is no result_text
if file is not None and file.filename == '' and not result_text:
return render_template("index.html", error="No selected file")

# User uploaded a file but there is no result_text
if file and not result_text:
# Convert file to base64
image_data = file.read()
base64_image = base64.b64encode(image_data).decode('utf-8')
# get the image format
image_format = file.content_type.split("/")[-1]
# Call the 'convert' function
result_text = convert(
base64_image=base64_image,
client=openai_client,
image_format=image_format
)

# User submitted a result_text for processing
elif result_text:
template_model = execute_template_model_from_sympy_odes(result_text)
# Get the OdeModel
om = OdeModel(model=Model(template_model=template_model), initialized=False)
ode_system = om.get_interpretable_kinetics()
# Make LaTeX representation of the ODE system
ode_latex = latex(ode_system)

# Get the PetriNet JSON
petrinet_json_str = AMRPetriNetModel(Model(template_model)).to_json_str(indent=2)

return render_template(
"index.html",
result_text=result_text,
sympy_input=result_text,
ode_latex=ode_latex,
petrinet_json=petrinet_json_str,
)
8 changes: 8 additions & 0 deletions mira/sources/sympy_ode/proxies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from mira.openai import OpenAIClient

from flask import current_app
from werkzeug.local import LocalProxy


OPEN_AI_CLIENT = "openai_client"
openai_client: OpenAIClient = LocalProxy(lambda: current_app.extensions[OPEN_AI_CLIENT])
Loading

0 comments on commit 63d4308

Please sign in to comment.