diff --git a/apps/audit/tests.py b/apps/audit/tests.py index a09f27686..5155f6ae6 100644 --- a/apps/audit/tests.py +++ b/apps/audit/tests.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass from unittest import mock from field_audit.models import USER_TYPE_REQUEST @@ -19,7 +20,7 @@ def test_change_context_returns_none_without_request(): def test_change_context_returns_none_without_request_with_team(): - with current_team(AuthedRequest.Team()): + with current_team(Team()): context = AuditContextProvider().change_context(None) assert context["user_type"] != USER_TYPE_REQUEST assert context["team"] == 17 @@ -27,18 +28,18 @@ def test_change_context_returns_none_without_request_with_team(): def test_change_context_returns_value_for_unauthorized_req(): request = AuthedRequest(auth=False) - assert AuditContextProvider().change_context(request) == {} + assert "user_type" not in AuditContextProvider().change_context(request) def test_change_context_returns_value_for_unauthorized_team_req(): request = AuthedRequest(auth=False) - with current_team(AuthedRequest.Team()): + with current_team(Team()): assert AuditContextProvider().change_context(request) == {"team": 17} def test_change_context_returns_value_for_authorized_team_req(): request = AuthedRequest(auth=True) - with current_team(AuthedRequest.Team()): + with current_team(Team()): assert AuditContextProvider().change_context(request) == { "user_type": USER_TYPE_REQUEST, "username": "test@example.com", @@ -49,40 +50,47 @@ def test_change_context_returns_value_for_authorized_team_req(): @mock.patch("apps.audit.auditors._get_hijack_username", return_value="admin@example.com") def test_change_context_hijacked_request(_): request = AuthedRequest(session={"hijack_history": [1]}) - assert AuditContextProvider().change_context(request) == { - "user_type": USER_TYPE_REQUEST, - "username": "admin@example.com", - "as_username": request.user.username, - } + with current_team(None): + assert AuditContextProvider().change_context(request) == { + "user_type": USER_TYPE_REQUEST, + "username": "admin@example.com", + "as_username": request.user.username, + } @mock.patch("apps.audit.auditors._get_hijack_username", return_value=None) def test_change_context_hijacked_request__no_hijacked_user(_): request = AuthedRequest(session={"hijack_history": [1]}) - assert AuditContextProvider().change_context(request) == { - "user_type": USER_TYPE_REQUEST, - "username": "test@example.com", - } + with current_team(None): + assert AuditContextProvider().change_context(request) == { + "user_type": USER_TYPE_REQUEST, + "username": "test@example.com", + } def test_change_context_hijacked_request__bad_hijack_history(): request = AuthedRequest(session={"hijack_history": ["not a number"]}) - assert AuditContextProvider().change_context(request) == { - "user_type": USER_TYPE_REQUEST, - "username": "test@example.com", - } + with current_team(None): + assert AuditContextProvider().change_context(request) == { + "user_type": USER_TYPE_REQUEST, + "username": "test@example.com", + } class AuthedRequest: - class User: - username = "test@example.com" - is_authenticated = True - - class Team: - id = 17 - slug = "seventeen" - def __init__(self, auth=True, session=None): - self.user = self.User() + self.user = User() self.session = session or {} self.user.is_authenticated = auth + + +@dataclass +class User: + username: str = "test@example.com" + is_authenticated: str = True + + +@dataclass +class Team: + id: int = 17 + slug: str = "seventeen" diff --git a/apps/events/tests/test_actions.py b/apps/events/tests/test_actions.py index ce9cc3e7a..30b0fd3cb 100644 --- a/apps/events/tests/test_actions.py +++ b/apps/events/tests/test_actions.py @@ -65,6 +65,7 @@ def test_end_conversation_runs_pipeline(session, pipeline): "end": {"message": f"human: {input}"}, }, "experiment_session": session.id, + "shared_state": {"user_input": output_message, "outputs": {"start": output_message, "end": output_message}}, } ) assert pipeline.runs.count() == 1 diff --git a/apps/pipelines/admin.py b/apps/pipelines/admin.py index 842f33e03..d5537f827 100644 --- a/apps/pipelines/admin.py +++ b/apps/pipelines/admin.py @@ -1,3 +1,6 @@ +import json + +from django import forms from django.contrib import admin from .models import Node, Pipeline, PipelineChatHistory, PipelineChatMessages, PipelineRun @@ -13,8 +16,18 @@ class PipelineNodeInline(admin.TabularInline): extra = 0 +class PrettyJSONEncoder(json.JSONEncoder): + def __init__(self, *args, indent, sort_keys, **kwargs): + super().__init__(*args, indent=4, sort_keys=True, **kwargs) + + +class PipelineAdminForm(forms.ModelForm): + data = forms.JSONField(encoder=PrettyJSONEncoder) + + @admin.register(Pipeline) class PipelineAdmin(admin.ModelAdmin): + form = PipelineAdminForm inlines = [PipelineNodeInline, PipelineRunInline] diff --git a/apps/pipelines/migrations/0012_auto_20250116_1508.py b/apps/pipelines/migrations/0012_auto_20250116_1508.py new file mode 100644 index 000000000..51d38a5c4 --- /dev/null +++ b/apps/pipelines/migrations/0012_auto_20250116_1508.py @@ -0,0 +1,49 @@ +# Generated by Django 5.1.2 on 2025-01-16 15:08 + +from django.db import migrations + +def _add_default_name_to_nodes(apps, schema_editor): + Pipeline = apps.get_model("pipelines", "Pipeline") + Node = apps.get_model("pipelines", "Node") + + for pipeline in Pipeline.objects.all(): + nodes = pipeline.node_set.all() + data = pipeline.data + for node in nodes: + if "name" in node.params: + continue + + name = node.flow_id + + if node.type == "StartNode": + name = "start" + if node.type == "EndNode": + name = "end" + + node.params["name"] = name + node.save() + + +def _remove_node_names(apps, schema_editor): + Pipeline = apps.get_model("pipelines", "Pipeline") + Node = apps.get_model("pipelines", "Node") + + for pipeline in Pipeline.objects.all(): + nodes = pipeline.node_set.all() + data = pipeline.data + + for node in nodes: + if "name" in node.params: + del node.params["name"] + node.save() + + +class Migration(migrations.Migration): + + dependencies = [ + ('pipelines', '0011_migrate_assistant_id'), + ] + + operations = [ + migrations.RunPython(_add_default_name_to_nodes, reverse_code=_remove_node_names) + ] diff --git a/apps/pipelines/models.py b/apps/pipelines/models.py index c8ef19c95..a8bdd854b 100644 --- a/apps/pipelines/models.py +++ b/apps/pipelines/models.py @@ -1,3 +1,4 @@ +from collections import defaultdict from collections.abc import Iterator from datetime import datetime from functools import cached_property @@ -88,14 +89,14 @@ def create_default(cls, team): "x": -200, "y": 200, }, - data=FlowNodeData(id=start_id, type=StartNode.__name__), + data=FlowNodeData(id=start_id, type=StartNode.__name__, params={"name": "start"}), ) end_id = str(uuid4()) end_node = FlowNode( id=end_id, type="endNode", position={"x": 1000, "y": 200}, - data=FlowNodeData(id=end_id, type=EndNode.__name__), + data=FlowNodeData(id=end_id, type=EndNode.__name__, params={"name": "end"}), ) default_nodes = [start_node.model_dump(), end_node.model_dump()] new_pipeline = cls.objects.create( @@ -136,14 +137,24 @@ def validate(self) -> dict: from apps.pipelines.graph import PipelineGraph from apps.pipelines.nodes import nodes as pipeline_nodes - errors = {} - - for node in self.node_set.all(): + errors = defaultdict(dict) + nodes = self.node_set.all() + for node in nodes: node_class = getattr(pipeline_nodes, node.type) try: node_class.model_validate(node.params) except pydantic.ValidationError as e: errors[node.flow_id] = {err["loc"][0]: err["msg"] for err in e.errors()} + + name_to_flow_id = defaultdict(list) + for node in nodes: + name_to_flow_id[node.params.get("name")].append(node.flow_id) + + for name, flow_ids in name_to_flow_id.items(): + if len(flow_ids) > 1: + for flow_id in flow_ids: + errors[flow_id].update({"name": "All node names must be unique"}) + if errors: return {"node": errors} diff --git a/apps/pipelines/nodes/base.py b/apps/pipelines/nodes/base.py index 1140f280b..3cc86719e 100644 --- a/apps/pipelines/nodes/base.py +++ b/apps/pipelines/nodes/base.py @@ -6,7 +6,7 @@ from typing import Annotated, Any, Literal, Self from langchain_core.runnables import RunnableConfig -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic.config import JsonDict from apps.experiments.models import ExperimentSession @@ -27,11 +27,25 @@ def add_messages(left: dict, right: dict): return output +def add_shared_state_messages(left: dict, right: dict): + output = {**left} + try: + output["outputs"].update(right["outputs"]) + except KeyError: + output["outputs"] = right.get("outputs", {}) + for key, value in right.items(): + if key != "outputs": + output[key] = value + + return output + + class PipelineState(dict): messages: Annotated[Sequence[Any], operator.add] outputs: Annotated[dict, add_messages] experiment_session: ExperimentSession pipeline_version: int + shared_state: Annotated[dict, add_shared_state_messages] ai_message_id: int | None = None message_metadata: dict | None = None attachments: list | None = None @@ -44,8 +58,9 @@ def json_safe(self): return copy @classmethod - def from_node_output(cls, node_id: str, output: Any = None, **kwargs) -> Self: + def from_node_output(cls, node_name: str, node_id: str, output: Any = None, **kwargs) -> Self: kwargs["outputs"] = {node_id: {"message": output}} + kwargs["shared_state"] = {"outputs": {node_name: output}} if output is not None: kwargs["messages"] = [output] return cls(**kwargs) @@ -77,6 +92,7 @@ def _process(self, state: PipelineState) -> PipelineState: model_config = ConfigDict(arbitrary_types_allowed=True) _config: RunnableConfig | None = None + name: str = Field(title="Node Name", json_schema_extra={"ui:widget": "node_name"}) def process(self, node_id: str, incoming_edges: list, state: PipelineState, config) -> PipelineState: self._config = config @@ -92,6 +108,7 @@ def process(self, node_id: str, incoming_edges: list, state: PipelineState, conf break else: # This is the first node in the graph input = state["messages"][-1] + state["shared_state"]["user_input"] = input return self._process(input=input, state=state, node_id=node_id) def process_conditional(self, state: PipelineState, node_id: str | None = None) -> str: @@ -101,7 +118,7 @@ def process_conditional(self, state: PipelineState, node_id: str | None = None) state["outputs"][node_id]["output_handle"] = output_handle return conditional_branch - def _process(self, input: str, state: PipelineState, node_id: str) -> str: + def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState: """The method that executes node specific functionality""" raise NotImplementedError @@ -169,6 +186,14 @@ class NodeSchema(BaseModel): can_add: bool = None deprecated: bool = False deprecation_message: str = None + field_order: list[str] = Field( + None, + description=( + "The order of the fields in the UI. " + "Any field not in this list will be appended to the end. " + "The 'name' field is always displayed first regardless of its position in this list." + ), + ) @model_validator(mode="after") def update_metadata_fields(self) -> Self: @@ -190,6 +215,8 @@ def __call__(self, schema: JsonDict): schema["ui:deprecated"] = self.deprecated if self.deprecated and self.deprecation_message: schema["ui:deprecation_message"] = self.deprecation_message + if self.field_order: + schema["ui:order"] = self.field_order def deprecated_node(cls=None, *, message=None): diff --git a/apps/pipelines/nodes/helpers.py b/apps/pipelines/nodes/helpers.py index c1d7443b5..9ccbbcdeb 100644 --- a/apps/pipelines/nodes/helpers.py +++ b/apps/pipelines/nodes/helpers.py @@ -1,10 +1,12 @@ from contextlib import contextmanager +from typing import Self +from django.contrib.contenttypes.models import ContentType from django.db import transaction from apps.channels.models import ChannelPlatform, ExperimentChannel from apps.chat.models import Chat -from apps.experiments.models import ConsentForm, Experiment, ExperimentSession, Participant +from apps.experiments.models import ConsentForm, Experiment, ExperimentSession, Participant, ParticipantData from apps.teams.models import Team from apps.teams.utils import current_team from apps.users.models import CustomUser @@ -31,3 +33,35 @@ def temporary_session(team: Team, user_id: int): ) yield experiment_session transaction.set_rollback(True) + + +class ParticipantDataProxy: + """Allows multiple access without needing to re-fetch from the DB""" + + @classmethod + def from_state(cls, pipeline_state) -> Self: + # using `.get` here for the sake of tests. In practice the session should always be present + return cls(pipeline_state.get("experiment_session")) + + def __init__(self, experiment_session): + self.session = experiment_session + self._participant_data = None + + def _get_db_object(self): + if not self._participant_data: + content_type = ContentType.objects.get_for_model(Experiment) + self._participant_data, _ = ParticipantData.objects.get_or_create( + participant_id=self.session.participant_id, + content_type=content_type, + object_id=self.session.experiment_id, + team_id=self.session.experiment.team_id, + ) + return self._participant_data + + def get(self): + return self._get_db_object().data + + def set(self, data): + participant_data = self._get_db_object() + participant_data.data = data + participant_data.save(update_fields=["data"]) diff --git a/apps/pipelines/nodes/nodes.py b/apps/pipelines/nodes/nodes.py index 7625da575..80a491889 100644 --- a/apps/pipelines/nodes/nodes.py +++ b/apps/pipelines/nodes/nodes.py @@ -5,9 +5,9 @@ from typing import Literal import tiktoken -from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ValidationError from django.core.validators import validate_email +from django.db.models import TextChoices from jinja2 import meta from jinja2.sandbox import SandboxedEnvironment from langchain_core.messages import BaseMessage @@ -25,7 +25,7 @@ from apps.chat.agent.tools import get_node_tools from apps.chat.conversation import compress_chat_history, compress_pipeline_chat_history from apps.chat.models import ChatMessageType -from apps.experiments.models import Experiment, ExperimentSession, ParticipantData +from apps.experiments.models import ExperimentSession, ParticipantData from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError from apps.pipelines.models import Node, PipelineChatHistory, PipelineChatHistoryTypes from apps.pipelines.nodes.base import ( @@ -37,6 +37,7 @@ Widgets, deprecated_node, ) +from apps.pipelines.nodes.helpers import ParticipantDataProxy from apps.pipelines.tasks import send_email_from_pipeline from apps.service_providers.exceptions import ServiceProviderConfigError from apps.service_providers.llm_service.adapters import AssistantAdapter, ChatAdapter @@ -82,7 +83,7 @@ def all_variables(in_): content = all_variables(input) template = SandboxedEnvironment().from_string(self.template_string) output = template.render(content) - return PipelineState.from_node_output(node_id=node_id, output=output) + return PipelineState.from_node_output(node_name=self.name, node_id=node_id, output=output) class LLMResponseMixin(BaseModel): @@ -185,7 +186,7 @@ class LLMResponse(PipelineNode, LLMResponseMixin): def _process(self, input, node_id: str, **kwargs) -> PipelineState: llm = self.get_chat_model() output = llm.invoke(input, config=self._config) - return PipelineState.from_node_output(node_id=node_id, output=output.content) + return PipelineState.from_node_output(node_name=self.name, node_id=node_id, output=output.content) class LLMResponseWithPrompt(LLMResponse, HistoryMixin): @@ -255,7 +256,7 @@ def _process(self, input, state: PipelineState, node_id: str) -> PipelineState: # Invoke runnable result = chat.invoke(input=input) - return PipelineState.from_node_output(node_id=node_id, output=result.output) + return PipelineState.from_node_output(node_name=self.name, node_id=node_id, output=result.output) def tools_enabled(self) -> bool: return len(self.tools) > 0 or len(self.custom_actions) > 0 @@ -283,7 +284,7 @@ def _process(self, input, node_id: str, **kwargs) -> PipelineState: send_email_from_pipeline.delay( recipient_list=self.recipient_list.split(","), subject=self.subject, message=input ) - return PipelineState.from_node_output(node_id=node_id, output=None) + return PipelineState.from_node_output(node_name=self.name, node_id=node_id, output=None) class Passthrough(PipelineNode): @@ -294,18 +295,20 @@ class Passthrough(PipelineNode): def _process(self, input, state: PipelineState, node_id: str) -> PipelineState: if self.logger: self.logger.debug(f"Returning input: '{input}' without modification", input=input, output=input) - return PipelineState.from_node_output(node_id=node_id, output=input) + return PipelineState.from_node_output(node_name=self.name, node_id=node_id, output=input) class StartNode(Passthrough): """The start of the pipeline""" + name: str = "start" model_config = ConfigDict(json_schema_extra=NodeSchema(label="Start", flow_node_type="startNode")) class EndNode(Passthrough): """The end of the pipeline""" + name: str = "end" model_config = ConfigDict(json_schema_extra=NodeSchema(label="End", flow_node_type="endNode")) @@ -327,16 +330,7 @@ def get_output_map(self): return {"output_0": "true", "output_1": "false"} -class RouterNode(Passthrough, HistoryMixin): - """Routes the input to one of the linked nodes""" - - model_config = ConfigDict(json_schema_extra=NodeSchema(label="Router")) - - prompt: str = Field( - default="You are an extremely helpful router", - min_length=1, - json_schema_extra=UiSchema(widget=Widgets.expandable_text), - ) +class RouterMixin(BaseModel): num_outputs: int = Field(2, json_schema_extra=UiSchema(widget=Widgets.none)) keywords: list[str] = Field(default_factory=list, json_schema_extra=UiSchema(widget=Widgets.keywords)) @@ -351,6 +345,36 @@ def ensure_keywords_exist(cls, value, info: FieldValidationInfo): return value[:num_outputs] # Ensure the number of keywords matches the number of outputs + def _get_keyword(self, result: str): + keyword = result.lower().strip() + if keyword in [k.lower() for k in self.keywords]: + return keyword.lower() + else: + return self.keywords[0].lower() + + def get_output_map(self): + """Returns a mapping of the form: + {"output_1": "keyword 1", "output_2": "keyword_2", ...} where keywords are defined by the user + """ + return {f"output_{output_num}": keyword.lower() for output_num, keyword in enumerate(self.keywords)} + + +class RouterNode(RouterMixin, Passthrough, HistoryMixin): + """Routes the input to one of the linked nodes using an LLM""" + + model_config = ConfigDict( + json_schema_extra=NodeSchema( + label="LLM Router", + field_order=["llm_provider_id", "llm_temperature", "history_type", "prompt", "keywords"], + ) + ) + + prompt: str = Field( + default="You are an extremely helpful router", + min_length=1, + json_schema_extra=UiSchema(widget=Widgets.expandable_text), + ) + def _process_conditional(self, state: PipelineState, node_id=None): prompt = OcsPromptTemplate.from_messages( [("system", self.prompt), MessagesPlaceholder("history", optional=True), ("human", "{input}")] @@ -367,23 +391,48 @@ def _process_conditional(self, state: PipelineState, node_id=None): chain = prompt | self.get_chat_model() result = chain.invoke(context, config=self._config) - keyword = self._get_keyword(result) + keyword = self._get_keyword(result.content) if session: self._save_history(session, node_id, node_input, keyword) return keyword - def _get_keyword(self, result): - keyword = result.content.lower().strip() - if keyword in [k.lower() for k in self.keywords]: - return keyword.lower() + +class StaticRouterNode(RouterMixin, Passthrough): + """Routes the input to a linked node using the shared state of the pipeline""" + + class DataSource(TextChoices): + participant_data = "participant_data", "Participant Data" + shared_state = "shared_state", "Shared State" + + model_config = ConfigDict( + json_schema_extra=NodeSchema( + label="Static Router", + field_order=["data_source", "route_key", "keywords"], + ) + ) + + data_source: DataSource = Field( + DataSource.participant_data, + description="The source of the data to use for routing", + json_schema_extra=UiSchema(enum_labels=DataSource.labels), + ) + route_key: str = Field(..., description="The key in the data to use for routing") + + def _process_conditional(self, state: PipelineState, node_id=None): + from apps.service_providers.llm_service.prompt_context import SafeAccessWrapper + + if self.data_source == self.DataSource.participant_data: + data = ParticipantDataProxy.from_state(state).get() else: - return self.keywords[0].lower() + data = state["shared_state"] - def get_output_map(self): - """Returns a mapping of the form: - {"output_1": "keyword 1", "output_2": "keyword_2", ...} where keywords are defined by the user - """ - return {f"output_{output_num}": keyword.lower() for output_num, keyword in enumerate(self.keywords)} + formatted_key = f"{{data.{self.route_key}}}" + try: + result = formatted_key.format(data=SafeAccessWrapper(data)) + except KeyError: + result = "" + + return self._get_keyword(result) class ExtractStructuredDataNodeMixin: @@ -425,7 +474,7 @@ def _process(self, input, state: PipelineState, node_id: str, **kwargs) -> Pipel self.post_extraction_hook(new_reference_data, state) output = json.dumps(new_reference_data) - return PipelineState.from_node_output(node_id=node_id, output=output) + return PipelineState.from_node_output(node_name=self.name, node_id=node_id, output=output) def post_extraction_hook(self, output, state): pass @@ -638,6 +687,7 @@ def _process(self, input, state: PipelineState, node_id: str, **kwargs) -> Pipel output = chain_output.output return PipelineState.from_node_output( + node_name=self.name, node_id=node_id, output=output, message_metadata={ @@ -661,6 +711,8 @@ def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: Experimen # Available functions: # - get_participant_data() -> dict # - set_participant_data(data: Any) -> None +# - get_state_key(key_name: str) -> str | None +# - set_state_key(key_name: str, data: Any) -> None def main(input: str, **kwargs) -> str: return input @@ -724,7 +776,7 @@ def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineSt result = str(custom_locals[function_name](input)) except Exception as exc: raise PipelineNodeRunError(exc) from exc - return PipelineState.from_node_output(node_id=node_id, output=result) + return PipelineState.from_node_output(node_name=self.name, node_id=node_id, output=result) def _get_custom_globals(self, state: PipelineState): from RestrictedPython.Eval import ( @@ -734,34 +786,7 @@ def _get_custom_globals(self, state: PipelineState): custom_globals = safe_globals.copy() - class ParticipantDataProxy: - """Allows multiple access without needing to re-fetch from the DB""" - - def __init__(self, state): - self.state = state - self._participant_data = None - - def _get_db_object(self): - if not self._participant_data: - content_type = ContentType.objects.get_for_model(Experiment) - session = state["experiment_session"] - self._participant_data, _ = ParticipantData.objects.get_or_create( - participant_id=session.participant_id, - content_type=content_type, - object_id=session.experiment_id, - team_id=session.experiment.team_id, - ) - return self._participant_data - - def get(self): - return self._get_db_object().data - - def set(self, data): - participant_data = self._get_db_object() - participant_data.data = data - participant_data.save(update_fields=["data"]) - - participant_data_proxy = ParticipantDataProxy(state) + participant_data_proxy = ParticipantDataProxy.from_state(state) custom_globals.update( { "__builtins__": self._get_custom_builtins(), @@ -773,10 +798,26 @@ def set(self, data): "_write_": lambda x: x, "get_participant_data": participant_data_proxy.get, "set_participant_data": participant_data_proxy.set, + "get_state_key": self._get_state_key(state), + "set_state_key": self._set_state_key(state), } ) return custom_globals + def _get_state_key(self, state: PipelineState): + def get_state_key(key_name: str): + return state["shared_state"].get(key_name) + + return get_state_key + + def _set_state_key(self, state: PipelineState): + def set_state_key(key_name: str, value): + if key_name in {"user_input", "outputs"}: + raise PipelineNodeRunError(f"Cannot set the '{key_name}' key of the shared state") + state["shared_state"][key_name] = value + + return set_state_key + def _get_custom_builtins(self): allowed_modules = { "json", diff --git a/apps/pipelines/tests/data/AssistantNode.json b/apps/pipelines/tests/data/AssistantNode.json index ff0226a2d..88ea799e4 100644 --- a/apps/pipelines/tests/data/AssistantNode.json +++ b/apps/pipelines/tests/data/AssistantNode.json @@ -1,6 +1,11 @@ { "description": "Calls an OpenAI assistant", "properties": { + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "assistant_id": { "title": "Assistant Id", "type": "integer", @@ -22,6 +27,7 @@ } }, "required": [ + "name", "assistant_id" ], "title": "AssistantNode", diff --git a/apps/pipelines/tests/data/BooleanNode.json b/apps/pipelines/tests/data/BooleanNode.json index ed49ec7b6..957ee079b 100644 --- a/apps/pipelines/tests/data/BooleanNode.json +++ b/apps/pipelines/tests/data/BooleanNode.json @@ -1,12 +1,18 @@ { "description": "Branches based whether the input matches a certain value", "properties": { + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "input_equals": { "title": "Input Equals", "type": "string" } }, "required": [ + "name", "input_equals" ], "title": "BooleanNode", diff --git a/apps/pipelines/tests/data/CodeNode.json b/apps/pipelines/tests/data/CodeNode.json index 23a94a204..b50817ec6 100644 --- a/apps/pipelines/tests/data/CodeNode.json +++ b/apps/pipelines/tests/data/CodeNode.json @@ -1,14 +1,22 @@ { "description": "Runs python", "properties": { + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "code": { - "default": "# You must define a main function, which takes the node input as a string.\n# Return a string to pass to the next node.\n\n# Available functions:\n# - get_participant_data() -> dict\n# - set_participant_data(data: Any) -> None\n\ndef main(input: str, **kwargs) -> str:\n return input\n", + "default": "# You must define a main function, which takes the node input as a string.\n# Return a string to pass to the next node.\n\n# Available functions:\n# - get_participant_data() -> dict\n# - set_participant_data(data: Any) -> None\n# - get_state_key(key_name: str) -> str | None\n# - set_state_key(key_name: str, data: Any) -> None\n\ndef main(input: str, **kwargs) -> str:\n return input\n", "description": "The code to run", "title": "Code", "type": "string", "ui:widget": "code" } }, + "required": [ + "name" + ], "title": "CodeNode", "type": "object", "ui:can_add": true, diff --git a/apps/pipelines/tests/data/EndNode.json b/apps/pipelines/tests/data/EndNode.json index 7b98672f9..60bb1558d 100644 --- a/apps/pipelines/tests/data/EndNode.json +++ b/apps/pipelines/tests/data/EndNode.json @@ -1,6 +1,12 @@ { "description": "The end of the pipeline", - "properties": {}, + "properties": { + "name": { + "default": "end", + "title": "Name", + "type": "string" + } + }, "title": "EndNode", "type": "object", "ui:can_add": false, diff --git a/apps/pipelines/tests/data/ExtractParticipantData.json b/apps/pipelines/tests/data/ExtractParticipantData.json index b32bda172..82faee41b 100644 --- a/apps/pipelines/tests/data/ExtractParticipantData.json +++ b/apps/pipelines/tests/data/ExtractParticipantData.json @@ -19,6 +19,11 @@ "type": "number", "ui:widget": "range" }, + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "data_schema": { "default": "{\"name\": \"the name of the user\"}", "description": "A JSON object structure where the key is the name of the field and the value the description", @@ -34,7 +39,8 @@ }, "required": [ "llm_provider_id", - "llm_provider_model_id" + "llm_provider_model_id", + "name" ], "title": "ExtractParticipantData", "type": "object", diff --git a/apps/pipelines/tests/data/ExtractStructuredData.json b/apps/pipelines/tests/data/ExtractStructuredData.json index 2acb1dbef..c03e796eb 100644 --- a/apps/pipelines/tests/data/ExtractStructuredData.json +++ b/apps/pipelines/tests/data/ExtractStructuredData.json @@ -19,6 +19,11 @@ "type": "number", "ui:widget": "range" }, + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "data_schema": { "default": "{\"name\": \"the name of the user\"}", "description": "A JSON object structure where the key is the name of the field and the value the description", @@ -29,7 +34,8 @@ }, "required": [ "llm_provider_id", - "llm_provider_model_id" + "llm_provider_model_id", + "name" ], "title": "ExtractStructuredData", "type": "object", diff --git a/apps/pipelines/tests/data/LLMResponse.json b/apps/pipelines/tests/data/LLMResponse.json index 9e23f17bb..fdfe34dfb 100644 --- a/apps/pipelines/tests/data/LLMResponse.json +++ b/apps/pipelines/tests/data/LLMResponse.json @@ -18,11 +18,17 @@ "title": "Temperature", "type": "number", "ui:widget": "range" + }, + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" } }, "required": [ "llm_provider_id", - "llm_provider_model_id" + "llm_provider_model_id", + "name" ], "title": "LLMResponse", "type": "object", diff --git a/apps/pipelines/tests/data/LLMResponseWithPrompt.json b/apps/pipelines/tests/data/LLMResponseWithPrompt.json index 614191500..ab2df7751 100644 --- a/apps/pipelines/tests/data/LLMResponseWithPrompt.json +++ b/apps/pipelines/tests/data/LLMResponseWithPrompt.json @@ -43,6 +43,11 @@ "ui:widget": "none", "type": "string" }, + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "source_material_id": { "default": null, "title": "Source Material Id", @@ -79,7 +84,8 @@ }, "required": [ "llm_provider_id", - "llm_provider_model_id" + "llm_provider_model_id", + "name" ], "title": "LLMResponseWithPrompt", "type": "object", diff --git a/apps/pipelines/tests/data/Passthrough.json b/apps/pipelines/tests/data/Passthrough.json index 7bf86ab82..42554a519 100644 --- a/apps/pipelines/tests/data/Passthrough.json +++ b/apps/pipelines/tests/data/Passthrough.json @@ -1,6 +1,15 @@ { "description": "Returns the input without modification", - "properties": {}, + "properties": { + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + } + }, + "required": [ + "name" + ], "title": "Passthrough", "type": "object", "ui:can_add": false, diff --git a/apps/pipelines/tests/data/RenderTemplate.json b/apps/pipelines/tests/data/RenderTemplate.json index 3aeda17bd..93f35aced 100644 --- a/apps/pipelines/tests/data/RenderTemplate.json +++ b/apps/pipelines/tests/data/RenderTemplate.json @@ -1,6 +1,11 @@ { "description": "Renders a Jinja template", "properties": { + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "template_string": { "description": "Use {{your_variable_name}} to refer to designate input", "title": "Template String", @@ -9,6 +14,7 @@ } }, "required": [ + "name", "template_string" ], "title": "RenderTemplate", diff --git a/apps/pipelines/tests/data/RouterNode.json b/apps/pipelines/tests/data/RouterNode.json index af5c17ccf..3060e0f4f 100644 --- a/apps/pipelines/tests/data/RouterNode.json +++ b/apps/pipelines/tests/data/RouterNode.json @@ -1,5 +1,5 @@ { - "description": "Routes the input to one of the linked nodes", + "description": "Routes the input to one of the linked nodes using an LLM", "properties": { "llm_provider_id": { "title": "LLM Model", @@ -43,12 +43,10 @@ "ui:widget": "none", "type": "string" }, - "prompt": { - "default": "You are an extremely helpful router", - "minLength": 1, - "title": "Prompt", + "name": { + "title": "Node Name", "type": "string", - "ui:widget": "expandable_text" + "ui:widget": "node_name" }, "num_outputs": { "default": 2, @@ -63,11 +61,19 @@ "title": "Keywords", "type": "array", "ui:widget": "keywords" + }, + "prompt": { + "default": "You are an extremely helpful router", + "minLength": 1, + "title": "Prompt", + "type": "string", + "ui:widget": "expandable_text" } }, "required": [ "llm_provider_id", - "llm_provider_model_id" + "llm_provider_model_id", + "name" ], "title": "RouterNode", "type": "object", @@ -75,5 +81,12 @@ "ui:can_delete": true, "ui:deprecated": false, "ui:flow_node_type": "pipelineNode", - "ui:label": "Router" + "ui:label": "LLM Router", + "ui:order": [ + "llm_provider_id", + "llm_temperature", + "history_type", + "prompt", + "keywords" + ] } \ No newline at end of file diff --git a/apps/pipelines/tests/data/SendEmail.json b/apps/pipelines/tests/data/SendEmail.json index 013886489..9a182b58e 100644 --- a/apps/pipelines/tests/data/SendEmail.json +++ b/apps/pipelines/tests/data/SendEmail.json @@ -1,6 +1,11 @@ { "description": "Send the input to the node to the list of addresses provided", "properties": { + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, "recipient_list": { "description": "A comma-separated list of email addresses", "title": "Recipient List", @@ -12,6 +17,7 @@ } }, "required": [ + "name", "recipient_list", "subject" ], diff --git a/apps/pipelines/tests/data/StartNode.json b/apps/pipelines/tests/data/StartNode.json index b4a3c1a8d..7b65186fa 100644 --- a/apps/pipelines/tests/data/StartNode.json +++ b/apps/pipelines/tests/data/StartNode.json @@ -1,6 +1,12 @@ { "description": "The start of the pipeline", - "properties": {}, + "properties": { + "name": { + "default": "start", + "title": "Name", + "type": "string" + } + }, "title": "StartNode", "type": "object", "ui:can_add": false, diff --git a/apps/pipelines/tests/data/StaticRouterNode.json b/apps/pipelines/tests/data/StaticRouterNode.json new file mode 100644 index 000000000..04bab0955 --- /dev/null +++ b/apps/pipelines/tests/data/StaticRouterNode.json @@ -0,0 +1,59 @@ +{ + "description": "Routes the input to a linked node using the shared state of the pipeline", + "properties": { + "name": { + "title": "Node Name", + "type": "string", + "ui:widget": "node_name" + }, + "num_outputs": { + "default": 2, + "title": "Num Outputs", + "type": "integer", + "ui:widget": "none" + }, + "keywords": { + "items": { + "type": "string" + }, + "title": "Keywords", + "type": "array", + "ui:widget": "keywords" + }, + "data_source": { + "enum": [ + "participant_data", + "shared_state" + ], + "title": "DataSource", + "type": "string", + "default": "participant_data", + "description": "The source of the data to use for routing", + "ui:enumLabels": [ + "Participant Data", + "Shared State" + ] + }, + "route_key": { + "description": "The key in the data to use for routing", + "title": "Route Key", + "type": "string" + } + }, + "required": [ + "name", + "route_key" + ], + "title": "StaticRouterNode", + "type": "object", + "ui:can_add": true, + "ui:can_delete": true, + "ui:deprecated": false, + "ui:flow_node_type": "pipelineNode", + "ui:label": "Static Router", + "ui:order": [ + "data_source", + "route_key", + "keywords" + ] +} \ No newline at end of file diff --git a/apps/pipelines/tests/test_code_node.py b/apps/pipelines/tests/test_code_node.py index 0167fa8c3..17f0b073d 100644 --- a/apps/pipelines/tests/test_code_node.py +++ b/apps/pipelines/tests/test_code_node.py @@ -10,6 +10,8 @@ code_node, create_runnable, end_node, + passthrough_node, + render_template_node, start_node, ) from apps.utils.factories.experiment import ExperimentSessionFactory @@ -141,7 +143,7 @@ def main(input, **kwargs): end_node(), ] assert ( - create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=[input]))[ + create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=["hi"]))[ "messages" ][-1] == "robot" @@ -172,10 +174,107 @@ def main(input, **kwargs): end_node(), ] assert ( - create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=[input]))[ + create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=["hi"]))[ "messages" ][-1] == output ) participant_data.refresh_from_db() assert participant_data.data["fun_facts"]["personality"] == output + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +def test_shared_state(pipeline, experiment_session): + output = "['fun loving', 'likes puppies']" + code_set = f""" +def main(input, **kwargs): + return set_state_key("fun_facts", {output}) +""" + code_get = """ +def main(input, **kwargs): + return str(get_state_key("fun_facts")) +""" + nodes = [ + start_node(), + code_node(code_set), + code_node(code_get), + end_node(), + ] + assert ( + create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=["hi"]))[ + "messages" + ][-1] + == output + ) + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +def test_shared_state_get_outputs(pipeline, experiment_session): + # Shared state contains the outputs of the previous nodes + + input = "hello" + code_get = """ +def main(input, **kwargs): + return str(get_state_key("outputs")) +""" + nodes = [ + start_node(), + passthrough_node(), + render_template_node("The input is: {{ input }}"), + code_node(code_get), + end_node(), + ] + assert create_runnable(pipeline, nodes).invoke( + PipelineState(experiment_session=experiment_session, messages=[input]) + )["messages"][-1] == str( + { + "start": input, + "passthrough": input, + "render template": f"The input is: {input}", + } + ) + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +def test_shared_state_set_outputs(pipeline, experiment_session): + input = "hello" + code_set = """ +def main(input, **kwargs): + set_state_key("outputs", "foobar") + return input +""" + nodes = [ + start_node(), + code_node(code_set), + end_node(), + ] + with pytest.raises(PipelineNodeRunError, match="Cannot set the 'outputs' key of the shared state"): + create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=[input]))[ + "messages" + ][-1] + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +def test_shared_state_user_input(pipeline, experiment_session): + # Shared state contains the user input + + input = "hello" + code_get = """ +def main(input, **kwargs): + return str(get_state_key("user_input")) +""" + nodes = [ + start_node(), + code_node(code_get), + end_node(), + ] + assert ( + create_runnable(pipeline, nodes).invoke(PipelineState(experiment_session=experiment_session, messages=[input]))[ + "messages" + ][-1] + == input + ) diff --git a/apps/pipelines/tests/test_nodes.py b/apps/pipelines/tests/test_nodes.py index 2c70c9ad1..819b80bd9 100644 --- a/apps/pipelines/tests/test_nodes.py +++ b/apps/pipelines/tests/test_nodes.py @@ -30,7 +30,7 @@ class TestSendEmailInputValidation: ], ) def test_valid_recipient_list(self, recipient_list): - model = SendEmail(recipient_list=recipient_list, subject="Test Subject") + model = SendEmail(name="email", recipient_list=recipient_list, subject="Test Subject") assert model.recipient_list == recipient_list @pytest.mark.parametrize( @@ -44,4 +44,4 @@ def test_valid_recipient_list(self, recipient_list): ) def test_invalid_recipient_list(self, recipient_list): with pytest.raises(ValidationError, match="Invalid list of emails addresses"): - SendEmail(recipient_list=recipient_list, subject="Test Subject") + SendEmail(name="email", recipient_list=recipient_list, subject="Test Subject") diff --git a/apps/pipelines/tests/test_pipeline_runs.py b/apps/pipelines/tests/test_pipeline_runs.py index 2d74f562d..9321d960c 100644 --- a/apps/pipelines/tests/test_pipeline_runs.py +++ b/apps/pipelines/tests/test_pipeline_runs.py @@ -44,6 +44,7 @@ def test_running_pipeline_creates_run(pipeline: Pipeline, session: ExperimentSes pipeline.node_ids[0]: {"message": "foo"}, pipeline.node_ids[1]: {"message": "foo"}, }, + shared_state={"outputs": {"end": "foo", "start": "foo"}, "user_input": "foo"}, ) assert len(run.log["entries"]) == 8 @@ -107,6 +108,8 @@ def test_running_failed_pipeline_logs_error(pipeline: Pipeline, session: Experim error_message = "Bad things are afoot" class FailingNode(PipelineNode): + name: str = "failure" + def process(self, *args, **kwargs) -> RunnableLambda: raise Exception(error_message) diff --git a/apps/pipelines/tests/test_runnable_builder.py b/apps/pipelines/tests/test_runnable_builder.py index a1f65b7ff..3e1f562b5 100644 --- a/apps/pipelines/tests/test_runnable_builder.py +++ b/apps/pipelines/tests/test_runnable_builder.py @@ -9,10 +9,12 @@ from apps.experiments.models import ParticipantData from apps.pipelines.exceptions import PipelineBuildError, PipelineNodeBuildError from apps.pipelines.nodes.base import PipelineState -from apps.pipelines.nodes.nodes import EndNode, StartNode +from apps.pipelines.nodes.helpers import ParticipantDataProxy +from apps.pipelines.nodes.nodes import EndNode, StartNode, StaticRouterNode from apps.pipelines.tests.utils import ( assistant_node, boolean_node, + code_node, create_runnable, email_node, end_node, @@ -24,6 +26,7 @@ render_template_node, router_node, start_node, + state_key_router_node, ) from apps.service_providers.llm_service.runnables import ChainOutput from apps.utils.factories.assistants import OpenAiAssistantFactory @@ -368,6 +371,102 @@ def test_router_node(get_llm_service, provider, provider_model, pipeline, experi assert output["messages"][-1] == "A z" +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +def test_static_router_shared_state(pipeline, experiment_session): + # The static router will switch based on a state key, and pass its input through + + code_set = """ +def main(input, **kwargs): + if "go to first" in input.lower(): + set_state_key("route_to", "first") + elif "go to second" in input.lower(): + set_state_key("route_to", "second") + return input +""" + start = start_node() + code = code_node(code_set) + router = state_key_router_node( + "route_to", ["first", "second"], data_source=StaticRouterNode.DataSource.shared_state + ) + template_a = render_template_node("A {{ input }}") + template_b = render_template_node("B {{ input }}") + end = end_node() + nodes = [start, code, router, template_a, template_b, end] + edges = [ + {"id": "start -> code", "source": start["id"], "target": code["id"]}, + {"id": "code -> router", "source": code["id"], "target": router["id"]}, + { + "id": "router -> A", + "source": router["id"], + "target": template_a["id"], + "sourceHandle": "output_0", + }, + { + "id": "router -> B", + "source": router["id"], + "target": template_b["id"], + "sourceHandle": "output_1", + }, + {"id": "A -> end", "source": template_a["id"], "target": end["id"]}, + {"id": "B -> end", "source": template_b["id"], "target": end["id"]}, + ] + runnable = create_runnable(pipeline, nodes, edges) + output = runnable.invoke(PipelineState(messages=["Go to FIRST"], experiment_session=experiment_session)) + assert output["messages"][-1] == "A Go to FIRST" + + output = runnable.invoke(PipelineState(messages=["Go to Second"], experiment_session=experiment_session)) + assert output["messages"][-1] == "B Go to Second" + + # default route + output = runnable.invoke(PipelineState(messages=["Go to Third"], experiment_session=experiment_session)) + assert output["messages"][-1] == "A Go to Third" + + +@django_db_with_data(available_apps=("apps.service_providers",)) +@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock()) +def test_static_router_participant_data(pipeline, experiment_session): + start = start_node() + router = state_key_router_node( + "route_to", ["first", "second"], data_source=StaticRouterNode.DataSource.participant_data + ) + template_a = render_template_node("A {{ input }}") + template_b = render_template_node("B {{ input }}") + end = end_node() + nodes = [start, router, template_a, template_b, end] + edges = [ + {"id": "start -> router", "source": start["id"], "target": router["id"]}, + { + "id": "router -> A", + "source": router["id"], + "target": template_a["id"], + "sourceHandle": "output_0", + }, + { + "id": "router -> B", + "source": router["id"], + "target": template_b["id"], + "sourceHandle": "output_1", + }, + {"id": "A -> end", "source": template_a["id"], "target": end["id"]}, + {"id": "B -> end", "source": template_b["id"], "target": end["id"]}, + ] + runnable = create_runnable(pipeline, nodes, edges) + + ParticipantDataProxy(experiment_session).set({"route_to": "first"}) + output = runnable.invoke(PipelineState(messages=["Hi"], experiment_session=experiment_session)) + assert output["messages"][-1] == "A Hi" + + ParticipantDataProxy(experiment_session).set({"route_to": "second"}) + output = runnable.invoke(PipelineState(messages=["Hi"], experiment_session=experiment_session)) + assert output["messages"][-1] == "B Hi" + + # default route + ParticipantDataProxy(experiment_session).set({}) + output = runnable.invoke(PipelineState(messages=["Hi"], experiment_session=experiment_session)) + assert output["messages"][-1] == "A Hi" + + @contextmanager def extract_structured_data_pipeline(provider, provider_model, pipeline, llm=None): service = build_fake_llm_service(responses=[{"name": "John"}], token_counts=[0], fake_llm=llm) diff --git a/apps/pipelines/tests/utils.py b/apps/pipelines/tests/utils.py index f4fc8a221..20311d29b 100644 --- a/apps/pipelines/tests/utils.py +++ b/apps/pipelines/tests/utils.py @@ -37,11 +37,11 @@ def create_runnable( def start_node(): - return {"id": str(uuid4()), "type": nodes.StartNode.__name__} + return {"id": str(uuid4()), "type": nodes.StartNode.__name__, "params": {"name": "start"}} def end_node(): - return {"id": str(uuid4()), "type": nodes.EndNode.__name__} + return {"id": str(uuid4()), "type": nodes.EndNode.__name__, "params": {"name": "end"}} def email_node(): @@ -50,6 +50,7 @@ def email_node(): "label": "Send an email", "type": "SendEmail", "params": { + "name": "email", "recipient_list": "test@example.com", "subject": "This is an interesting email", }, @@ -71,6 +72,7 @@ def llm_response_with_prompt_node( ) params = { + "name": "llm response with prompt", "llm_provider_id": provider_id, "llm_provider_model_id": provider_model_id, "prompt": prompt, @@ -96,6 +98,7 @@ def llm_response_node(provider_id: str, provider_model_id: str): "id": str(uuid4()), "type": nodes.LLMResponse.__name__, "params": { + "name": "llm response", "llm_provider_id": provider_id, "llm_provider_model_id": provider_model_id, }, @@ -109,6 +112,7 @@ def render_template_node(template_string: str | None = None): "id": str(uuid4()), "type": nodes.RenderTemplate.__name__, "params": { + "name": "render template", "template_string": template_string, }, } @@ -118,6 +122,7 @@ def passthrough_node(): return { "id": str(uuid4()), "type": nodes.Passthrough.__name__, + "params": {"name": "passthrough"}, } @@ -125,7 +130,7 @@ def boolean_node(): return { "id": str(uuid4()), "type": nodes.BooleanNode.__name__, - "params": {"input_equals": "hello"}, + "params": {"name": "boolean", "input_equals": "hello"}, } @@ -134,6 +139,7 @@ def router_node(provider_id: str, provider_model_id: str, keywords: list[str]): "id": str(uuid4()), "type": nodes.RouterNode.__name__, "params": { + "name": "router", "prompt": "You are a router", "keywords": keywords, "num_outputs": len(keywords), @@ -143,11 +149,26 @@ def router_node(provider_id: str, provider_model_id: str, keywords: list[str]): } +def state_key_router_node(route_key: str, keywords: list[str], data_source="shared_state"): + return { + "id": str(uuid4()), + "type": nodes.StaticRouterNode.__name__, + "params": { + "name": "static router", + "data_source": data_source, + "route_key": route_key, + "keywords": keywords, + "num_outputs": len(keywords), + }, + } + + def assistant_node(assistant_id: str): return { "id": str(uuid4()), "type": nodes.AssistantNode.__name__, "params": { + "name": "assistant", "assistant_id": assistant_id, "citations_enabled": True, "input_formatter": "", @@ -160,6 +181,7 @@ def extract_participant_data_node(provider_id: str, provider_model_id: str, data "id": str(uuid4()), "type": nodes.ExtractParticipantData.__name__, "params": { + "name": "extract participant data", "llm_provider_id": provider_id, "llm_provider_model_id": provider_model_id, "data_schema": data_schema, @@ -173,6 +195,7 @@ def extract_structured_data_node(provider_id: str, provider_model_id: str, data_ "id": str(uuid4()), "type": nodes.ExtractStructuredData.__name__, "params": { + "name": "extract structured data", "llm_provider_id": provider_id, "llm_provider_model_id": provider_model_id, "data_schema": data_schema, @@ -187,6 +210,7 @@ def code_node(code: str | None = None): "id": str(uuid4()), "type": nodes.CodeNode.__name__, "params": { + "name": "code node", "code": code, }, } diff --git a/apps/service_providers/llm_service/prompt_context.py b/apps/service_providers/llm_service/prompt_context.py index be3bce755..47b8a59a3 100644 --- a/apps/service_providers/llm_service/prompt_context.py +++ b/apps/service_providers/llm_service/prompt_context.py @@ -83,8 +83,8 @@ class SafeAccessWrapper(dict): """ def __init__(self, data: Any): - super().__init__(self, __data=data) self.__data = data + super().__init__(self, __data=data) def __getitem__(self, key): if isinstance(self.__data, list | str): diff --git a/assets/javascript/apps/pipeline/Pipeline.tsx b/assets/javascript/apps/pipeline/Pipeline.tsx index 601d6641d..2c413b1ce 100644 --- a/assets/javascript/apps/pipeline/Pipeline.tsx +++ b/assets/javascript/apps/pipeline/Pipeline.tsx @@ -90,9 +90,10 @@ export default function Pipeline() { const data: NodeData = JSON.parse( event.dataTransfer.getData("nodedata") ); + const newId = getNodeId(data.type); + data.params["name"] = newId; const flowType = data.flowType; delete data.flowType; - const newId = getNodeId(data.type); const newNode = { id: newId, diff --git a/assets/javascript/apps/pipeline/PipelineNode.tsx b/assets/javascript/apps/pipeline/PipelineNode.tsx index 2dd688035..23f6181c3 100644 --- a/assets/javascript/apps/pipeline/PipelineNode.tsx +++ b/assets/javascript/apps/pipeline/PipelineNode.tsx @@ -1,11 +1,11 @@ import {Node, NodeProps, NodeToolbar, Position} from "reactflow"; import React, {ChangeEvent} from "react"; -import {getCachedData, nodeBorderClass} from "./utils"; +import {concatenate, getCachedData, nodeBorderClass} from "./utils"; import usePipelineStore from "./stores/pipelineStore"; import usePipelineManagerStore from "./stores/pipelineManagerStore"; import useEditorStore from "./stores/editorStore"; import {JsonSchema, NodeData} from "./types/nodeParams"; -import {getNodeInputWidget, showAdvancedButton} from "./nodes/GetInputWidget"; +import {getWidgetsForNode} from "./nodes/GetInputWidget"; import NodeInput from "./nodes/NodeInput"; import NodeOutputs from "./nodes/NodeOutputs"; import {HelpContent} from "./panel/ComponentHelp"; @@ -19,10 +19,7 @@ export function PipelineNode(nodeProps: NodeProps) { const deleteNode = usePipelineStore((state) => state.deleteNode); const hasErrors = usePipelineManagerStore((state) => state.nodeHasErrors(id)); const nodeError = usePipelineManagerStore((state) => state.getNodeFieldError(id, "root")); - const {nodeSchemas} = getCachedData(); - const nodeSchema = nodeSchemas.get(data.type)!; - const schemaProperties = Object.getOwnPropertyNames(nodeSchema.properties); - const requiredProperties = nodeSchema.required || []; + const nodeSchema = getCachedData().nodeSchemas.get(data.type)!; const updateParamValue = ( event: ChangeEvent, @@ -77,33 +74,19 @@ export function PipelineNode(nodeProps: NodeProps) {
- +
- {schemaProperties.map((name) => ( - - {getNodeInputWidget({ - id: id, - name: name, - schema: nodeSchema.properties[name], - params: data.params, - updateParamValue: updateParamValue, - nodeType: data.type, - required: requiredProperties.includes(name), - })} - - ))} + {getWidgetsForNode({schema: nodeSchema, nodeId: id, nodeData: data, updateParamValue: updateParamValue})} +
+
+
- {showAdvancedButton(data.type) && ( -
- -
- )}
@@ -111,11 +94,16 @@ export function PipelineNode(nodeProps: NodeProps) { ); } -function NodeHeader({nodeSchema}: {nodeSchema: JsonSchema}) { +function NodeHeader({nodeSchema, nodeName}: {nodeSchema: JsonSchema, nodeName: string}) { + const defaultNodeNameRegex = /^[A-Za-z]+-[a-zA-Z0-9]{5}$/; + const hasCustomName = !defaultNodeNameRegex.test(nodeName); + const header = hasCustomName ? nodeName : nodeSchema["ui:label"]; + const subheader = hasCustomName ? nodeSchema["ui:label"] : ""; return (
- {nodeSchema["ui:label"]} + {header} + {subheader &&
{subheader}
}
); } diff --git a/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx b/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx index 258cc49d4..cff478463 100644 --- a/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx +++ b/assets/javascript/apps/pipeline/nodes/GetInputWidget.tsx @@ -1,8 +1,19 @@ import React from "react"; -import {NodeParams, PropertySchema} from "../types/nodeParams"; +import {JsonSchema, NodeParams, PropertySchema} from "../types/nodeParams"; import usePipelineManagerStore from "../stores/pipelineManagerStore"; import {getWidget} from "./widgets"; +type GetWidgetsParams = { + schema: JsonSchema; + nodeId: string; + nodeData: any; + updateParamValue: (event: React.ChangeEvent) => any; +} + +type GetWidgetParamsGeneric = GetWidgetsParams & { + widgetGenerator: (params: InputWidgetParams) => React.ReactElement; +} + type InputWidgetParams = { id: string; @@ -16,6 +27,7 @@ type InputWidgetParams = { const nodeTypeToInputParamsMap: Record = { "RouterNode": ["llm_model", "history_type", "prompt"], + "StaticRouterNode": ["data_source", "route_key"], "ExtractParticipantData": ["llm_model", "history_type", "data_schema"], "ExtractStructuredData": ["llm_model", "history_type", "data_schema"], "LLMResponseWithPrompt": ["llm_model", "history_type", "prompt"], @@ -23,17 +35,74 @@ const nodeTypeToInputParamsMap: Record = { "AssistantNode": ["assistant_id", "citations_enabled"], }; -export const showAdvancedButton = (nodeType: string) => { - return nodeTypeToInputParamsMap[nodeType] !== undefined; +/** + * Retrieves the full list of widgets for the given schema + */ +export const getWidgets = ( + {schema, nodeId, nodeData, updateParamValue}: GetWidgetsParams +) => { + return getWidgetsGeneric({schema, nodeId, nodeData, updateParamValue, widgetGenerator: getInputWidget}); } +/** + * Retrieves the list of widgets for the given schema which should be displayed ona node + */ +export const getWidgetsForNode = ( + {schema, nodeId, nodeData, updateParamValue}: GetWidgetsParams +) => { + return getWidgetsGeneric({schema, nodeId, nodeData, updateParamValue, widgetGenerator: getNodeInputWidget}); +} + +const getWidgetsGeneric = ( + {schema, nodeId, nodeData, updateParamValue, widgetGenerator}: GetWidgetParamsGeneric +) => { + const schemaProperties = Object.getOwnPropertyNames(schema.properties); + const requiredProperties = schema.required || []; + if (schema["ui:order"]) { + schemaProperties.sort((a, b) => { + // 'name' should always be first + if (a === "name") return -1; + if (b === "name") return 1; + + const indexA = schema["ui:order"]!.indexOf(a); + const indexB = schema["ui:order"]!.indexOf(b); + // If 'a' is not in the order list, it should be at the end + if (indexA === -1) return 1; + if (indexB === -1) return -1; + return indexA - indexB; + }); + } + return schemaProperties.map((name) => ( + + {widgetGenerator({ + id: nodeId, + name: name, + schema: schema.properties[name], + params: nodeData.params, + updateParamValue: updateParamValue, + nodeType: nodeData.type, + required: requiredProperties.includes(name), + })} + + )); +} + +/** + * Retrieves the appropriate input widget for the specified node type and parameter. + * + * This calls `getInputWidget` under the hood but also filters the parameters to only those which + * should be shown on the node. + * + * @returns The input widget for the specified node type and parameter. + */ export const getNodeInputWidget = (param: InputWidgetParams) => { if (!param.nodeType) { return <>; } const allowedInNode = nodeTypeToInputParamsMap[param.nodeType]; - if (allowedInNode && !allowedInNode.includes(param.name)) { + if (param.name == "name" || (allowedInNode && !allowedInNode.includes(param.name))) { + /* name param is always in the advanced box */ return <>; } return getInputWidget(param); @@ -41,11 +110,6 @@ export const getNodeInputWidget = (param: InputWidgetParams) => { /** * Generates the appropriate input widget based on the input parameter type. - * @param id - The node ID - * @param inputParam - The input parameter to generate the widget for. - * @param params - The parameters for the node. - * @param setParams - The function to update the node parameters. - * @param updateParamValue - The function to update the value of the input parameter. * @returns The input widget for the specified parameter type. */ export const getInputWidget = (params: InputWidgetParams) => { @@ -55,16 +119,16 @@ export const getInputWidget = (params: InputWidgetParams) => { During the migration, we kept the data in llm_model as a safeguard. This check can safely be deleted once a second migration to delete all instances of llm_model has been run. TODO: Remove this check once there are no instances of llm_model or max_token_limit in the node definitions. */ - return + return <> } - const getNodeFieldError = usePipelineManagerStore((state) => state.getNodeFieldError); const widgetOrType = params.schema["ui:widget"] || params.schema.type; if (widgetOrType == 'none') { return <>; } - const Widget = getWidget(widgetOrType) + const getNodeFieldError = usePipelineManagerStore((state) => state.getNodeFieldError); + const Widget = getWidget(widgetOrType, params.schema) let fieldError = getNodeFieldError(params.id, params.name); const paramValue = params.params[params.name]; if (params.required && (paramValue === null || paramValue === undefined)) { @@ -76,7 +140,7 @@ export const getInputWidget = (params: InputWidgetParams) => { name={params.name} label={params.schema.title || params.name.replace(/_/g, " ")} helpText={params.schema.description || ""} - paramValue={paramValue || ""} + paramValue={paramValue ?? ""} inputError={fieldError} updateParamValue={params.updateParamValue} schema={params.schema} diff --git a/assets/javascript/apps/pipeline/nodes/LabeledHandle.tsx b/assets/javascript/apps/pipeline/nodes/LabeledHandle.tsx index 90ae26e00..25ffc4696 100644 --- a/assets/javascript/apps/pipeline/nodes/LabeledHandle.tsx +++ b/assets/javascript/apps/pipeline/nodes/LabeledHandle.tsx @@ -25,14 +25,13 @@ const LabeledHandle = React.forwardRef< HTMLDivElement, HandleProps & React.HTMLAttributes & { - title: string; + label: string | React.ReactElement; handleClassName?: string; labelClassName?: string; } ->(({ className, labelClassName, title, position, ...props }, ref) => ( +>(({ className, labelClassName, label, position, ...props }, ref) => (
- +
)); diff --git a/assets/javascript/apps/pipeline/nodes/NodeInput.tsx b/assets/javascript/apps/pipeline/nodes/NodeInput.tsx index b70a468b7..1d0c5c473 100644 --- a/assets/javascript/apps/pipeline/nodes/NodeInput.tsx +++ b/assets/javascript/apps/pipeline/nodes/NodeInput.tsx @@ -3,5 +3,5 @@ import {Position} from "reactflow"; import {LabeledHandle} from "./LabeledHandle"; export default function NodeInput() { - return + return } diff --git a/assets/javascript/apps/pipeline/nodes/NodeOutputs.tsx b/assets/javascript/apps/pipeline/nodes/NodeOutputs.tsx index 3a4621792..c4f66c27d 100644 --- a/assets/javascript/apps/pipeline/nodes/NodeOutputs.tsx +++ b/assets/javascript/apps/pipeline/nodes/NodeOutputs.tsx @@ -7,11 +7,21 @@ import {LabeledHandle} from "./LabeledHandle"; export default function NodeOutputs({data}: { data: NodeData, }) { - const multipleOutputs = data.type === "RouterNode" || data.type === "BooleanNode"; + const multipleOutputs = data.type === "RouterNode" || data.type === "BooleanNode" || data.type == "StaticRouterNode"; const outputNames = getOutputNames(data.type, data.params); const generateOutputHandle = (outputIndex: number) => { return multipleOutputs ? `output_${outputIndex}` : "output"; }; + const generateOutputLabel = (outputIndex: number, output_label:string) => { + if (multipleOutputs && outputIndex === 0) { + return ( + + {output_label} + + ); + } + return <>{output_label}; + } return ( <> {multipleOutputs &&
Outputs
} @@ -20,7 +30,7 @@ export default function NodeOutputs({data}: { { if (params.keywords?.[i]) { diff --git a/assets/javascript/apps/pipeline/nodes/widgets.tsx b/assets/javascript/apps/pipeline/nodes/widgets.tsx index 6adc81958..5159f0147 100644 --- a/assets/javascript/apps/pipeline/nodes/widgets.tsx +++ b/assets/javascript/apps/pipeline/nodes/widgets.tsx @@ -1,23 +1,16 @@ -import React, { - ChangeEvent, - ChangeEventHandler, - ReactNode, - useId, - useEffect, -} from "react"; -import { useState } from "react"; +import React, {ChangeEvent, ChangeEventHandler, ReactNode, useEffect, useId, useState,} from "react"; import CodeMirror from '@uiw/react-codemirror'; -import { python } from "@codemirror/lang-python"; -import { githubLight, githubDark } from "@uiw/codemirror-theme-github"; -import { CompletionContext, snippetCompletion as snip } from '@codemirror/autocomplete' -import { TypedOption } from "../types/nodeParameterValues"; +import {python} from "@codemirror/lang-python"; +import {githubDark, githubLight} from "@uiw/codemirror-theme-github"; +import {CompletionContext, snippetCompletion as snip} from '@codemirror/autocomplete' +import {TypedOption} from "../types/nodeParameterValues"; import usePipelineStore from "../stores/pipelineStore"; -import { classNames, concatenate, getCachedData, getSelectOptions } from "../utils"; -import { NodeParams, PropertySchema } from "../types/nodeParams"; -import { Node, useUpdateNodeInternals } from "reactflow"; +import {classNames, concatenate, getCachedData, getSelectOptions} from "../utils"; +import {NodeParams, PropertySchema} from "../types/nodeParams"; +import {Node, useUpdateNodeInternals} from "reactflow"; import DOMPurify from 'dompurify'; -export function getWidget(name: string) { +export function getWidget(name: string, params: PropertySchema) { switch (name) { case "toggle": return ToggleWidget @@ -39,7 +32,12 @@ export function getWidget(name: string) { return HistoryTypeWidget case "keywords": return KeywordsWidget + case "node_name": + return NodeNameWidget default: + if (params.enum) { + return SelectWidget + } return DefaultWidget } } @@ -73,6 +71,37 @@ function DefaultWidget(props: WidgetParams) { ); } +/** + * A widget component for displaying and editing the name of a node. + * + * Will display a blank input field if the current value matches the node ID. + */ +function NodeNameWidget(props: WidgetParams) { + const value = concatenate(props.paramValue); + const [inputValue, setInputValue] = React.useState(value === props.nodeId ? "" : value); + + const handleInputChange = (event: React.ChangeEvent) => { + setInputValue(event.target.value); + if (!event.target.value) { + event.target.value = props.nodeId; + } + props.updateParamValue(event); + }; + + return ( + + + + ); +} + function FloatWidget(props: WidgetParams) { return
@@ -520,6 +566,11 @@ export function KeywordsWidget(props: WidgetParams) { const length = parseInt(concatenate(props.nodeParams.num_outputs)) || 1; const keywords = Array.isArray(props.nodeParams.keywords) ? props.nodeParams["keywords"] : [] const canDelete = length > 1; + const defaultMarker = ( + + + + ) return ( <>
@@ -540,7 +591,7 @@ export function KeywordsWidget(props: WidgetParams) { return (
- +
diff --git a/assets/javascript/apps/pipeline/stores/pipelineStore.ts b/assets/javascript/apps/pipeline/stores/pipelineStore.ts index 816d66e76..51ed1ec06 100644 --- a/assets/javascript/apps/pipeline/stores/pipelineStore.ts +++ b/assets/javascript/apps/pipeline/stores/pipelineStore.ts @@ -185,14 +185,15 @@ const usePipelineStore = create((set, get) => ({ } const newId = getNodeId(node.data.type); - + const data = cloneDeep(node.data); + data.params["name"] = newId; // Create a new node object const newNode = { id: newId, type: node.type, position: actualPosition, data: { - ...cloneDeep(node.data), + ...data, id: newId, }, }; diff --git a/assets/javascript/apps/pipeline/types/nodeParams.ts b/assets/javascript/apps/pipeline/types/nodeParams.ts index 0ab993226..e7a787dfe 100644 --- a/assets/javascript/apps/pipeline/types/nodeParams.ts +++ b/assets/javascript/apps/pipeline/types/nodeParams.ts @@ -1,4 +1,7 @@ -export type NodeParams = Record; +export type NodeParams = { + name: string; + [key: string]: any; +} export type PropertySchema = { type: string; @@ -18,7 +21,13 @@ export type JsonSchema = { title: string; description?: string | undefined; required?: string[] | undefined; + "ui:flow_node_type": string; "ui:label": string; + "ui:can_add": boolean; + "ui:can_delete": boolean; + "ui:deprecated": boolean; + "ui:deprecation_message"?: string; + "ui:order"?: string[]; properties: Record; [k: string]: any; }