diff --git a/apps/assistants/models.py b/apps/assistants/models.py index 7b2f895a2..e2150507a 100644 --- a/apps/assistants/models.py +++ b/apps/assistants/models.py @@ -3,13 +3,14 @@ from django.contrib.postgres.fields import ArrayField from django.core.validators import MaxValueValidator, MinValueValidator from django.db import models, transaction +from django.db.models import Q from django.urls import reverse from field_audit import audit_fields from field_audit.models import AuditAction, AuditingManager from apps.chat.agent.tools import get_assistant_tools from apps.custom_actions.mixins import CustomActionOperationMixin -from apps.experiments.models import VersionsMixin, VersionsObjectManagerMixin +from apps.experiments.models import Experiment, VersionsMixin, VersionsObjectManagerMixin from apps.experiments.versioning import VersionField from apps.teams.models import BaseTeamModel from apps.utils.models import BaseModel @@ -171,29 +172,47 @@ def compare_with_model(self, new: Self, exclude_fields: list[str], early_abort=F def archive(self): from apps.assistants.tasks import delete_openai_assistant_task - # don't archive assistant if it's still referenced by an active experiment or pipeline - if self.get_related_experiments_queryset().exists(): + if self.get_related_experiments_queryset().exists() or self.get_related_pipeline_node_queryset().exists(): return False - - if self.get_related_pipeline_node_queryset().exists(): - return False - - super().archive() if self.is_working_version: - # TODO: should this delete the assistant from OpenAI? + for ( + version + ) in self.versions.all(): # first perform all checks so assistants are not archived prior to return False + if ( + version.get_related_experiments_queryset().exists() + or version.get_related_pipeline_node_queryset().exists() + ): + return False + for version in self.versions.all(): + delete_openai_assistant_task.delay(version.id) self.versions.update(is_archived=True, audit_action=AuditAction.AUDIT) - else: - delete_openai_assistant_task.delay(self.id) + super().archive() + delete_openai_assistant_task.delay(self.id) return True - def get_related_experiments_queryset(self): - return self.experiment_set.filter(is_archived=False) + def get_related_experiments_queryset(self, assistant_ids: list = None): + if assistant_ids: + return Experiment.objects.filter( + Q(working_version_id=None) | Q(is_default_version=True), + assistant_id__in=assistant_ids, + is_archived=False, + ) + + return self.experiment_set.filter(Q(working_version_id=None) | Q(is_default_version=True), is_archived=False) - def get_related_pipeline_node_queryset(self): + def get_related_pipeline_node_queryset(self, assistant_ids: list = None): from apps.pipelines.models import Node + if assistant_ids: + return Node.objects.filter(type="AssistantNode").filter( + Q(pipeline__working_version_id=None), + params__assistant_id__in=assistant_ids, + pipeline__is_archived=False, + ) + return Node.objects.filter(type="AssistantNode").filter( + Q(pipeline__working_version_id=None), params__assistant_id=str(self.id), pipeline__is_archived=False, ) diff --git a/apps/assistants/tests/test_delete.py b/apps/assistants/tests/test_delete.py index 2c8485116..1ac176567 100644 --- a/apps/assistants/tests/test_delete.py +++ b/apps/assistants/tests/test_delete.py @@ -1,12 +1,14 @@ import uuid -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from apps.assistants.models import ToolResources from apps.assistants.sync import _get_files_to_delete, delete_openai_files_for_resource from apps.utils.factories.assistants import OpenAiAssistantFactory +from apps.utils.factories.experiment import ExperimentFactory from apps.utils.factories.files import FileFactory +from apps.utils.factories.pipelines import NodeFactory, PipelineFactory @pytest.fixture() @@ -28,36 +30,126 @@ def code_resource(assistant): @pytest.mark.django_db() -def test_files_to_delete_when_only_referenced_by_one_resource(code_resource): - files_to_delete = list(_get_files_to_delete(code_resource.assistant.team, code_resource.id)) - assert len(files_to_delete) == 2 - assert {f.id for f in files_to_delete} == {f.id for f in code_resource.files.all()} +class TestAssistantDeletion: + def test_files_to_delete_when_only_referenced_by_one_resource(self, code_resource): + files_to_delete = list(_get_files_to_delete(code_resource.assistant.team, code_resource.id)) + assert len(files_to_delete) == 2 + assert {f.id for f in files_to_delete} == {f.id for f in code_resource.files.all()} + def test_files_not_to_delete_when_referenced_by_multiple_resources(self, code_resource): + all_files = list(code_resource.files.all()) + tool_resource = ToolResources.objects.create(tool_type="file_search", assistant=code_resource.assistant) + tool_resource.files.set([all_files[0]]) -@pytest.mark.django_db() -def test_files_not_to_delete_when_referenced_by_multiple_resources(code_resource): - all_files = list(code_resource.files.all()) - tool_resource = ToolResources.objects.create(tool_type="file_search", assistant=code_resource.assistant) - tool_resource.files.set([all_files[0]]) + # only the second file should be deleted + files_to_delete = list(_get_files_to_delete(code_resource.assistant.team, code_resource.id)) + assert len(files_to_delete) == 1 + assert files_to_delete[0].id == all_files[1].id + + files_to_delete = list(_get_files_to_delete(tool_resource.assistant.team, tool_resource.id)) + assert len(files_to_delete) == 0 + + def test_delete_openai_files_for_resource(self, code_resource): + all_files = list(code_resource.files.all()) + assert all(f.external_id for f in all_files) + assert all(f.external_source for f in all_files) + client = Mock() + delete_openai_files_for_resource(client, code_resource.assistant.team, code_resource) + + assert client.files.delete.call_count == 2 + all_files = list(code_resource.files.all()) + assert not any(f.external_id for f in all_files) + assert not any(f.external_source for f in all_files) - # only the second file should be deleted - files_to_delete = list(_get_files_to_delete(code_resource.assistant.team, code_resource.id)) - assert len(files_to_delete) == 1 - assert files_to_delete[0].id == all_files[1].id - files_to_delete = list(_get_files_to_delete(tool_resource.assistant.team, tool_resource.id)) - assert len(files_to_delete) == 0 +# assistant.refresh_from_db() @pytest.mark.django_db() -def test_delete_openai_files_for_resource(code_resource): - all_files = list(code_resource.files.all()) - assert all(f.external_id for f in all_files) - assert all(f.external_source for f in all_files) - client = Mock() - delete_openai_files_for_resource(client, code_resource.assistant.team, code_resource) - - assert client.files.delete.call_count == 2 - all_files = list(code_resource.files.all()) - assert not any(f.external_id for f in all_files) - assert not any(f.external_source for f in all_files) +class TestAssistantArchival: + def test_archive_assistant(self): + assistant = OpenAiAssistantFactory() + assert assistant.is_archived is False + assistant.archive() + assert assistant.is_archived is True + + @patch("apps.assistants.sync.push_assistant_to_openai", Mock()) + def test_archive_assistant_succeeds_with_released_related_experiment(self): + exp_v1 = ExperimentFactory() + exp_v2 = exp_v1.create_new_version() + exp_v1.save() + assistant = OpenAiAssistantFactory() + exp_v2 = exp_v1.create_new_version() + exp_v2.assistant = assistant + exp_v2.save() + assert exp_v2.is_default_version is False + assert exp_v2.is_working_version is False + assistant.archive() + assistant.refresh_from_db() + + assert assistant.is_archived is True # archiving succeeded + + @patch("apps.assistants.sync.push_assistant_to_openai", Mock()) + def test_asistant_archive_blocked_by_working_related_experiment(self): + assistant = OpenAiAssistantFactory() + experiment = ExperimentFactory(assistant=assistant) + experiment.save() + + assert experiment.is_working_version is True + assistant.archive() + assistant.refresh_from_db() + assert assistant.is_archived is False # archiving blocked + + @patch("apps.assistants.sync.push_assistant_to_openai", Mock()) + def test_asistant_archive_blocked_by_published_related_experiment(self): + assistant = OpenAiAssistantFactory() + v2_assistant = assistant.create_new_version() + experiment = ExperimentFactory(assistant=v2_assistant) + experiment.is_default = True + experiment.save() + + assistant.archive() + assert assistant.is_archived is False # archiving failed + assert v2_assistant.is_archived is False + + experiment.archive() # first archive related experiment through v2_assistant + assistant.archive() + v2_assistant.refresh_from_db() + + assert assistant.is_archived is True # archiving successful + assert v2_assistant.is_archived is True + + @patch("apps.assistants.sync.push_assistant_to_openai", Mock()) + def test_archive_assistant_succeeds_with_released_related_pipeline(self): + pipeline = PipelineFactory() + exp_v1 = ExperimentFactory(pipeline=pipeline) + exp_v2 = exp_v1.create_new_version() + assistant = OpenAiAssistantFactory() + NodeFactory(pipeline=exp_v2.pipeline, type="AssistantNode", params={"assistant_id": assistant.id}) + exp_v2.is_default_version = False + exp_v2.save() + + assert exp_v2.pipeline.is_working_version is False + assistant.archive() + assistant.refresh_from_db() + + assert assistant.is_archived is True # archiving successful + + @patch("apps.assistants.sync.push_assistant_to_openai", Mock()) + def test_archive_assistant_fails_with_working_related_pipeline(self): + pipeline = PipelineFactory() + assistant = OpenAiAssistantFactory() + NodeFactory(pipeline=pipeline, type="AssistantNode", params={"assistant_id": assistant.id}) + exp_v1 = ExperimentFactory(pipeline=pipeline) + exp_v1.save() + + assert pipeline.is_working_version is True + assistant.archive() + assert assistant.is_archived is False # archiving failed + + exp_v1.archive() + pipeline.archive() + assistant.archive() + + assert pipeline.is_archived is True + assert assistant.is_archived is True # archiving successful diff --git a/apps/assistants/views.py b/apps/assistants/views.py index 7cb1d0674..0007c1a34 100644 --- a/apps/assistants/views.py +++ b/apps/assistants/views.py @@ -1,6 +1,7 @@ from django.contrib import messages from django.contrib.auth.mixins import PermissionRequiredMixin from django.db import transaction +from django.db.models import Q from django.http import HttpResponse, HttpResponseRedirect from django.shortcuts import get_object_or_404, render from django.template.loader import render_to_string @@ -212,13 +213,23 @@ def delete(self, request, team_slug: str, pk: int): messages.success(request, "Assistant Archived") return HttpResponse() else: + version_query = None + if assistant.is_working_version: + version_query = list( + map( + str, + OpenAiAssistant.objects.filter( + Q(id=assistant.id) | Q(working_version__id=assistant.id) + ).values_list("id", flat=True), + ) + ) experiments = [ Chip(label=experiment.name, url=experiment.get_absolute_url()) - for experiment in assistant.get_related_experiments_queryset() + for experiment in assistant.get_related_experiments_queryset(query=version_query) ] pipeline_nodes = [ Chip(label=node.pipeline.name, url=node.pipeline.get_absolute_url()) - for node in assistant.get_related_pipeline_node_queryset().select_related("pipeline") + for node in assistant.get_related_pipeline_node_queryset(query=version_query).select_related("pipeline") ] response = render_to_string( "assistants/partials/referenced_objects.html",