From b8ce89c72da64359c06021c6a9765fbea72841b2 Mon Sep 17 00:00:00 2001 From: Takehiro Suzuki Date: Tue, 14 Jan 2025 14:30:11 +0900 Subject: [PATCH] Fix: active_models validation causing 422 errors (#684) * fix * chore: lint --- backend/app/repositories/custom_bot.py | 19 ++++++-- backend/app/repositories/models/custom_bot.py | 5 +++ backend/app/usecases/bot.py | 7 ++- backend/tests/test_usecases/test_bot.py | 45 ++++++++++++++----- 4 files changed, 59 insertions(+), 17 deletions(-) diff --git a/backend/app/repositories/custom_bot.py b/backend/app/repositories/custom_bot.py index 8909936b9..de8b83502 100644 --- a/backend/app/repositories/custom_bot.py +++ b/backend/app/repositories/custom_bot.py @@ -30,6 +30,7 @@ ConversationQuickStarterModel, GenerationParamsModel, KnowledgeModel, + default_active_models, ) from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel from app.repositories.models.custom_bot_kb import BedrockKnowledgeBaseModel @@ -490,7 +491,11 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel: if "GuardrailsParams" in item else None ), - active_models=ActiveModelsModel.model_validate(item.get("ActiveModels", {})), + active_models=( + ActiveModelsModel.model_validate(item.get("ActiveModels")) + if item.get("ActiveModels") + else default_active_models # for backward compatibility + ), ) logger.info(f"Found bot: {bot}") @@ -568,7 +573,11 @@ def find_public_bot_by_id(bot_id: str) -> BotModel: if "GuardrailsParams" in item else None ), - active_models=ActiveModelsModel.model_validate(item.get("ActiveModels")), + active_models=( + ActiveModelsModel.model_validate(item.get("ActiveModels")) + if item.get("ActiveModels") + else default_active_models # for backward compatibility + ), ) logger.info(f"Found public bot: {bot}") return bot @@ -598,7 +607,11 @@ def find_alias_by_id(user_id: str, alias_id: str) -> BotAliasModel: has_knowledge=item["HasKnowledge"], has_agent=item.get("HasAgent", False), conversation_quick_starters=item.get("ConversationQuickStarters", []), - active_models=ActiveModelsModel.model_validate(item.get("ActiveModels")), + active_models=( + ActiveModelsModel.model_validate(item.get("ActiveModels")) + if item.get("ActiveModels") + else default_active_models # for backward compatibility + ), ) logger.info(f"Found alias: {bot}") diff --git a/backend/app/repositories/models/custom_bot.py b/backend/app/repositories/models/custom_bot.py index a0892c4d7..c65d5b081 100644 --- a/backend/app/repositories/models/custom_bot.py +++ b/backend/app/repositories/models/custom_bot.py @@ -20,6 +20,11 @@ def _create_model_activate_model(model_names: List[str]) -> Type[DynamicBaseMode ) +default_active_models = ActiveModelsModel.model_validate( + {field_name: True for field_name in ActiveModelsModel.model_fields.keys()} +) + + class KnowledgeModel(BaseModel): source_urls: list[str] sitemap_urls: list[str] diff --git a/backend/app/usecases/bot.py b/backend/app/usecases/bot.py index 617b1f19f..9e7a5ebd5 100644 --- a/backend/app/usecases/bot.py +++ b/backend/app/usecases/bot.py @@ -537,7 +537,8 @@ def fetch_all_bots_by_user_id( ConversationQuickStarter(**starter) for starter in item.get("ConversationQuickStarters", []) ] - or bot.active_models != item["ActiveModels"] + or bot.active_models + != ActiveModelsModel.model_validate(dict(item.get("ActiveModels", {}))) ): # Update alias to the latest original bot store_alias( @@ -719,9 +720,7 @@ def fetch_bot_summary(user_id: str, bot_id: str) -> BotSummaryOutput: ) for starter in bot.conversation_quick_starters ], - active_models=ActiveModelsOutput.model_validate( - dict(bot.active_models) - ), + active_models=bot.active_models, ), ) return BotSummaryOutput( diff --git a/backend/tests/test_usecases/test_bot.py b/backend/tests/test_usecases/test_bot.py index 76071e103..a263a0f14 100644 --- a/backend/tests/test_usecases/test_bot.py +++ b/backend/tests/test_usecases/test_bot.py @@ -3,14 +3,6 @@ sys.path.insert(0, ".") import unittest -from pydantic import BaseModel - -from tests.test_usecases.utils.bot_factory import ( - create_test_bot_alias, - create_test_private_bot, - create_test_public_bot, -) - from app.repositories.custom_bot import ( delete_alias_by_id, delete_bot_by_id, @@ -21,8 +13,17 @@ update_bot_publication, update_bot_visibility, ) - -from app.usecases.bot import fetch_all_bots_by_user_id, issue_presigned_url +from app.usecases.bot import ( + fetch_all_bots_by_user_id, + fetch_bot_summary, + issue_presigned_url, +) +from pydantic import BaseModel +from tests.test_usecases.utils.bot_factory import ( + create_test_bot_alias, + create_test_private_bot, + create_test_public_bot, +) class TestIssuePresignedUrl(unittest.TestCase): @@ -142,5 +143,29 @@ def test_order_is_descending(self): self.assertEqual(bots[5].id, self.first_bot_id) +class TestSharing(unittest.TestCase): + def setUp(self) -> None: + self.publisher_id = "test_user_pub" + self.subscriber_id = "test_user_sub" + + self.bot = create_test_public_bot("test_bot", True, self.publisher_id) + store_bot(self.publisher_id, self.bot) + + def tearDown(self) -> None: + delete_bot_by_id(self.publisher_id, self.bot.id) + try: + delete_alias_by_id(self.subscriber_id, self.bot.id) + except: + print("Alias not found") + + def test_share_and_subscribe(self): + # Share the bot to public + update_bot_visibility(self.publisher_id, self.bot.id, True) + + # Subscribe (equal to open shared URL on browser) + bot_summary = fetch_bot_summary(self.subscriber_id, self.bot.id) + self.assertEqual(bot_summary.id, self.bot.id) + + if __name__ == "__main__": unittest.main()