Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline Shared State and Static Router Node #1019

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0af762b
Pretty print data field in django admin
proteusvacuum Dec 31, 2024
2e4e82c
Add name field to all nodes
proteusvacuum Dec 31, 2024
de71913
Ensure unique names in frontend
proteusvacuum Jan 1, 2025
8bed6e1
Nicer looking syntax errors
proteusvacuum Jan 1, 2025
20ba980
Add shared_state on PipelineState
proteusvacuum Jan 1, 2025
7420c72
Add StateKeyRouterNode
proteusvacuum Jan 1, 2025
8adc53e
Add get and set state key to autocomplete
proteusvacuum Jan 1, 2025
3cd5203
Add 4 spaces as tabs to codemirror
proteusvacuum Jan 1, 2025
634e591
Add shared state tests
proteusvacuum Jan 2, 2025
fd0247f
Merge branch 'main' into fr/pipelines-shared-state
snopoke Jan 16, 2025
60cb12b
rename state key router
snopoke Jan 16, 2025
8df7204
rename field + add help text
snopoke Jan 16, 2025
a11e5d6
resolve migration conflict
snopoke Jan 16, 2025
6969b46
Fix merge issues
proteusvacuum Jan 16, 2025
72e07c5
allow nested access in static router
snopoke Jan 16, 2025
9fc29aa
Update statekeyrouternode schema
proteusvacuum Jan 16, 2025
eefcb45
Add node name to node header if it has been changed
proteusvacuum Jan 16, 2025
faf9f1d
Merge branch 'main' into fr/pipelines-shared-state
proteusvacuum Jan 16, 2025
1942196
add option to route on participant data
snopoke Jan 16, 2025
90e839a
Merge remote-tracking branch 'origin/fr/pipelines-shared-state' into …
snopoke Jan 16, 2025
0b300f8
default routes for static router
snopoke Jan 16, 2025
3dfaca9
display name as title with node type as subtitle
snopoke Jan 16, 2025
91866bf
fix type hint
snopoke Jan 16, 2025
482ca36
custom widget for node name
snopoke Jan 16, 2025
84b3307
fix tests
snopoke Jan 17, 2025
381741b
prevent overwriting of 'user_input' shared state key
snopoke Jan 17, 2025
1e0e6b9
Merge branch 'main' into fr/pipelines-shared-state
snopoke Jan 17, 2025
2474f81
make router default output clearer
snopoke Jan 17, 2025
d19c07f
Customize field order
snopoke Jan 17, 2025
010f388
update test regex
snopoke Jan 17, 2025
147c55a
Merge branch 'main' into fr/pipelines-shared-state
snopoke Jan 17, 2025
b142f26
post merge fix
snopoke Jan 17, 2025
5f7524c
coderabbit updates
snopoke Jan 17, 2025
39728de
fix auto-complete
snopoke Jan 17, 2025
3e3f5a3
fix typing
snopoke Jan 17, 2025
f0c011a
update JsonSchema type
snopoke Jan 17, 2025
172028d
fix type checks
snopoke Jan 17, 2025
bc88520
attempt to fix transient build failure
snopoke Jan 17, 2025
16b0e96
remove print
snopoke Jan 17, 2025
0a53d96
fix indent
snopoke Jan 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 34 additions & 26 deletions apps/audit/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from unittest import mock

from field_audit.models import USER_TYPE_REQUEST
Expand All @@ -19,26 +20,26 @@ def test_change_context_returns_none_without_request():


def test_change_context_returns_none_without_request_with_team():
with current_team(AuthedRequest.Team()):
with current_team(Team()):
context = AuditContextProvider().change_context(None)
assert context["user_type"] != USER_TYPE_REQUEST
assert context["team"] == 17


def test_change_context_returns_value_for_unauthorized_req():
request = AuthedRequest(auth=False)
assert AuditContextProvider().change_context(request) == {}
assert "user_type" not in AuditContextProvider().change_context(request)


def test_change_context_returns_value_for_unauthorized_team_req():
request = AuthedRequest(auth=False)
with current_team(AuthedRequest.Team()):
with current_team(Team()):
assert AuditContextProvider().change_context(request) == {"team": 17}


def test_change_context_returns_value_for_authorized_team_req():
request = AuthedRequest(auth=True)
with current_team(AuthedRequest.Team()):
with current_team(Team()):
assert AuditContextProvider().change_context(request) == {
"user_type": USER_TYPE_REQUEST,
"username": "[email protected]",
Expand All @@ -49,40 +50,47 @@ def test_change_context_returns_value_for_authorized_team_req():
@mock.patch("apps.audit.auditors._get_hijack_username", return_value="[email protected]")
def test_change_context_hijacked_request(_):
request = AuthedRequest(session={"hijack_history": [1]})
assert AuditContextProvider().change_context(request) == {
"user_type": USER_TYPE_REQUEST,
"username": "[email protected]",
"as_username": request.user.username,
}
with current_team(None):
assert AuditContextProvider().change_context(request) == {
"user_type": USER_TYPE_REQUEST,
"username": "[email protected]",
"as_username": request.user.username,
}


@mock.patch("apps.audit.auditors._get_hijack_username", return_value=None)
def test_change_context_hijacked_request__no_hijacked_user(_):
request = AuthedRequest(session={"hijack_history": [1]})
assert AuditContextProvider().change_context(request) == {
"user_type": USER_TYPE_REQUEST,
"username": "[email protected]",
}
with current_team(None):
assert AuditContextProvider().change_context(request) == {
"user_type": USER_TYPE_REQUEST,
"username": "[email protected]",
}


def test_change_context_hijacked_request__bad_hijack_history():
request = AuthedRequest(session={"hijack_history": ["not a number"]})
assert AuditContextProvider().change_context(request) == {
"user_type": USER_TYPE_REQUEST,
"username": "[email protected]",
}
with current_team(None):
assert AuditContextProvider().change_context(request) == {
"user_type": USER_TYPE_REQUEST,
"username": "[email protected]",
}


class AuthedRequest:
class User:
username = "[email protected]"
is_authenticated = True

class Team:
id = 17
slug = "seventeen"

def __init__(self, auth=True, session=None):
self.user = self.User()
self.user = User()
self.session = session or {}
self.user.is_authenticated = auth


@dataclass
class User:
username: str = "[email protected]"
is_authenticated: str = True


@dataclass
class Team:
id: int = 17
slug: str = "seventeen"
1 change: 1 addition & 0 deletions apps/events/tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def test_end_conversation_runs_pipeline(session, pipeline):
"end": {"message": f"human: {input}"},
},
"experiment_session": session.id,
"shared_state": {"user_input": output_message, "outputs": {"start": output_message, "end": output_message}},
}
)
assert pipeline.runs.count() == 1
Expand Down
13 changes: 13 additions & 0 deletions apps/pipelines/admin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import json

from django import forms
from django.contrib import admin

from .models import Node, Pipeline, PipelineChatHistory, PipelineChatMessages, PipelineRun
Expand All @@ -13,8 +16,18 @@ class PipelineNodeInline(admin.TabularInline):
extra = 0


class PrettyJSONEncoder(json.JSONEncoder):
def __init__(self, *args, indent, sort_keys, **kwargs):
super().__init__(*args, indent=4, sort_keys=True, **kwargs)
snopoke marked this conversation as resolved.
Show resolved Hide resolved


class PipelineAdminForm(forms.ModelForm):
data = forms.JSONField(encoder=PrettyJSONEncoder)


@admin.register(Pipeline)
class PipelineAdmin(admin.ModelAdmin):
form = PipelineAdminForm
inlines = [PipelineNodeInline, PipelineRunInline]


Expand Down
49 changes: 49 additions & 0 deletions apps/pipelines/migrations/0012_auto_20250116_1508.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Generated by Django 5.1.2 on 2025-01-16 15:08

from django.db import migrations

def _add_default_name_to_nodes(apps, schema_editor):
Pipeline = apps.get_model("pipelines", "Pipeline")
Node = apps.get_model("pipelines", "Node")

for pipeline in Pipeline.objects.all():
nodes = pipeline.node_set.all()
data = pipeline.data
for node in nodes:
if "name" in node.params:
continue

name = node.flow_id

if node.type == "StartNode":
name = "start"
if node.type == "EndNode":
name = "end"

node.params["name"] = name
node.save()


def _remove_node_names(apps, schema_editor):
Pipeline = apps.get_model("pipelines", "Pipeline")
Node = apps.get_model("pipelines", "Node")

for pipeline in Pipeline.objects.all():
nodes = pipeline.node_set.all()
data = pipeline.data

for node in nodes:
if "name" in node.params:
del node.params["name"]
node.save()


class Migration(migrations.Migration):

dependencies = [
('pipelines', '0011_migrate_assistant_id'),
]

operations = [
migrations.RunPython(_add_default_name_to_nodes, reverse_code=_remove_node_names)
]
21 changes: 16 additions & 5 deletions apps/pipelines/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Iterator
from datetime import datetime
from functools import cached_property
Expand Down Expand Up @@ -88,14 +89,14 @@ def create_default(cls, team):
"x": -200,
"y": 200,
},
data=FlowNodeData(id=start_id, type=StartNode.__name__),
data=FlowNodeData(id=start_id, type=StartNode.__name__, params={"name": "start"}),
)
end_id = str(uuid4())
end_node = FlowNode(
id=end_id,
type="endNode",
position={"x": 1000, "y": 200},
data=FlowNodeData(id=end_id, type=EndNode.__name__),
data=FlowNodeData(id=end_id, type=EndNode.__name__, params={"name": "end"}),
)
default_nodes = [start_node.model_dump(), end_node.model_dump()]
new_pipeline = cls.objects.create(
Expand Down Expand Up @@ -136,14 +137,24 @@ def validate(self) -> dict:
from apps.pipelines.graph import PipelineGraph
from apps.pipelines.nodes import nodes as pipeline_nodes

errors = {}

for node in self.node_set.all():
errors = defaultdict(dict)
nodes = self.node_set.all()
for node in nodes:
node_class = getattr(pipeline_nodes, node.type)
try:
node_class.model_validate(node.params)
except pydantic.ValidationError as e:
errors[node.flow_id] = {err["loc"][0]: err["msg"] for err in e.errors()}

name_to_flow_id = defaultdict(list)
for node in nodes:
name_to_flow_id[node.params.get("name")].append(node.flow_id)

for name, flow_ids in name_to_flow_id.items():
if len(flow_ids) > 1:
for flow_id in flow_ids:
errors[flow_id].update({"name": "All node names must be unique"})

if errors:
return {"node": errors}

Expand Down
33 changes: 30 additions & 3 deletions apps/pipelines/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Annotated, Any, Literal, Self

from langchain_core.runnables import RunnableConfig
from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic.config import JsonDict

from apps.experiments.models import ExperimentSession
Expand All @@ -27,11 +27,25 @@ def add_messages(left: dict, right: dict):
return output


def add_shared_state_messages(left: dict, right: dict):
output = {**left}
try:
output["outputs"].update(right["outputs"])
except KeyError:
output["outputs"] = right.get("outputs", {})
for key, value in right.items():
if key != "outputs":
output[key] = value

return output


class PipelineState(dict):
messages: Annotated[Sequence[Any], operator.add]
outputs: Annotated[dict, add_messages]
experiment_session: ExperimentSession
pipeline_version: int
shared_state: Annotated[dict, add_shared_state_messages]
ai_message_id: int | None = None
message_metadata: dict | None = None
attachments: list | None = None
Expand All @@ -44,8 +58,9 @@ def json_safe(self):
return copy

@classmethod
def from_node_output(cls, node_id: str, output: Any = None, **kwargs) -> Self:
def from_node_output(cls, node_name: str, node_id: str, output: Any = None, **kwargs) -> Self:
kwargs["outputs"] = {node_id: {"message": output}}
kwargs["shared_state"] = {"outputs": {node_name: output}}
if output is not None:
kwargs["messages"] = [output]
return cls(**kwargs)
Expand Down Expand Up @@ -77,6 +92,7 @@ def _process(self, state: PipelineState) -> PipelineState:
model_config = ConfigDict(arbitrary_types_allowed=True)

_config: RunnableConfig | None = None
name: str = Field(title="Node Name", json_schema_extra={"ui:widget": "node_name"})

def process(self, node_id: str, incoming_edges: list, state: PipelineState, config) -> PipelineState:
self._config = config
Expand All @@ -92,6 +108,7 @@ def process(self, node_id: str, incoming_edges: list, state: PipelineState, conf
break
else: # This is the first node in the graph
input = state["messages"][-1]
state["shared_state"]["user_input"] = input
snopoke marked this conversation as resolved.
Show resolved Hide resolved
return self._process(input=input, state=state, node_id=node_id)

def process_conditional(self, state: PipelineState, node_id: str | None = None) -> str:
Expand All @@ -101,7 +118,7 @@ def process_conditional(self, state: PipelineState, node_id: str | None = None)
state["outputs"][node_id]["output_handle"] = output_handle
return conditional_branch

def _process(self, input: str, state: PipelineState, node_id: str) -> str:
def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState:
"""The method that executes node specific functionality"""
raise NotImplementedError

Expand Down Expand Up @@ -169,6 +186,14 @@ class NodeSchema(BaseModel):
can_add: bool = None
deprecated: bool = False
deprecation_message: str = None
field_order: list[str] = Field(
None,
description=(
"The order of the fields in the UI. "
"Any field not in this list will be appended to the end. "
"The 'name' field is always displayed first regardless of its position in this list."
),
)

@model_validator(mode="after")
def update_metadata_fields(self) -> Self:
Expand All @@ -190,6 +215,8 @@ def __call__(self, schema: JsonDict):
schema["ui:deprecated"] = self.deprecated
if self.deprecated and self.deprecation_message:
schema["ui:deprecation_message"] = self.deprecation_message
if self.field_order:
schema["ui:order"] = self.field_order


def deprecated_node(cls=None, *, message=None):
Expand Down
36 changes: 35 additions & 1 deletion apps/pipelines/nodes/helpers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from contextlib import contextmanager
from typing import Self

from django.contrib.contenttypes.models import ContentType
from django.db import transaction

from apps.channels.models import ChannelPlatform, ExperimentChannel
from apps.chat.models import Chat
from apps.experiments.models import ConsentForm, Experiment, ExperimentSession, Participant
from apps.experiments.models import ConsentForm, Experiment, ExperimentSession, Participant, ParticipantData
from apps.teams.models import Team
from apps.teams.utils import current_team
from apps.users.models import CustomUser
Expand All @@ -31,3 +33,35 @@ def temporary_session(team: Team, user_id: int):
)
yield experiment_session
transaction.set_rollback(True)


class ParticipantDataProxy:
"""Allows multiple access without needing to re-fetch from the DB"""

@classmethod
def from_state(cls, pipeline_state) -> Self:
# using `.get` here for the sake of tests. In practice the session should always be present
return cls(pipeline_state.get("experiment_session"))

def __init__(self, experiment_session):
self.session = experiment_session
self._participant_data = None

def _get_db_object(self):
if not self._participant_data:
content_type = ContentType.objects.get_for_model(Experiment)
self._participant_data, _ = ParticipantData.objects.get_or_create(
participant_id=self.session.participant_id,
content_type=content_type,
object_id=self.session.experiment_id,
team_id=self.session.experiment.team_id,
)
return self._participant_data
snopoke marked this conversation as resolved.
Show resolved Hide resolved

def get(self):
return self._get_db_object().data

def set(self, data):
participant_data = self._get_db_object()
participant_data.data = data
participant_data.save(update_fields=["data"])
Loading
Loading