Skip to content

Commit

Permalink
Fix: active_models validation causing 422 errors (#684)
Browse files Browse the repository at this point in the history
* fix

* chore: lint
  • Loading branch information
statefb authored Jan 14, 2025
1 parent 27658bc commit b8ce89c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 17 deletions.
19 changes: 16 additions & 3 deletions backend/app/repositories/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ConversationQuickStarterModel,
GenerationParamsModel,
KnowledgeModel,
default_active_models,
)
from app.repositories.models.custom_bot_guardrails import BedrockGuardrailsModel
from app.repositories.models.custom_bot_kb import BedrockKnowledgeBaseModel
Expand Down Expand Up @@ -490,7 +491,11 @@ def find_private_bot_by_id(user_id: str, bot_id: str) -> BotModel:
if "GuardrailsParams" in item
else None
),
active_models=ActiveModelsModel.model_validate(item.get("ActiveModels", {})),
active_models=(
ActiveModelsModel.model_validate(item.get("ActiveModels"))
if item.get("ActiveModels")
else default_active_models # for backward compatibility
),
)

logger.info(f"Found bot: {bot}")
Expand Down Expand Up @@ -568,7 +573,11 @@ def find_public_bot_by_id(bot_id: str) -> BotModel:
if "GuardrailsParams" in item
else None
),
active_models=ActiveModelsModel.model_validate(item.get("ActiveModels")),
active_models=(
ActiveModelsModel.model_validate(item.get("ActiveModels"))
if item.get("ActiveModels")
else default_active_models # for backward compatibility
),
)
logger.info(f"Found public bot: {bot}")
return bot
Expand Down Expand Up @@ -598,7 +607,11 @@ def find_alias_by_id(user_id: str, alias_id: str) -> BotAliasModel:
has_knowledge=item["HasKnowledge"],
has_agent=item.get("HasAgent", False),
conversation_quick_starters=item.get("ConversationQuickStarters", []),
active_models=ActiveModelsModel.model_validate(item.get("ActiveModels")),
active_models=(
ActiveModelsModel.model_validate(item.get("ActiveModels"))
if item.get("ActiveModels")
else default_active_models # for backward compatibility
),
)

logger.info(f"Found alias: {bot}")
Expand Down
5 changes: 5 additions & 0 deletions backend/app/repositories/models/custom_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def _create_model_activate_model(model_names: List[str]) -> Type[DynamicBaseMode
)


default_active_models = ActiveModelsModel.model_validate(
{field_name: True for field_name in ActiveModelsModel.model_fields.keys()}
)


class KnowledgeModel(BaseModel):
source_urls: list[str]
sitemap_urls: list[str]
Expand Down
7 changes: 3 additions & 4 deletions backend/app/usecases/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,8 @@ def fetch_all_bots_by_user_id(
ConversationQuickStarter(**starter)
for starter in item.get("ConversationQuickStarters", [])
]
or bot.active_models != item["ActiveModels"]
or bot.active_models
!= ActiveModelsModel.model_validate(dict(item.get("ActiveModels", {})))
):
# Update alias to the latest original bot
store_alias(
Expand Down Expand Up @@ -719,9 +720,7 @@ def fetch_bot_summary(user_id: str, bot_id: str) -> BotSummaryOutput:
)
for starter in bot.conversation_quick_starters
],
active_models=ActiveModelsOutput.model_validate(
dict(bot.active_models)
),
active_models=bot.active_models,
),
)
return BotSummaryOutput(
Expand Down
45 changes: 35 additions & 10 deletions backend/tests/test_usecases/test_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@
sys.path.insert(0, ".")
import unittest

from pydantic import BaseModel

from tests.test_usecases.utils.bot_factory import (
create_test_bot_alias,
create_test_private_bot,
create_test_public_bot,
)

from app.repositories.custom_bot import (
delete_alias_by_id,
delete_bot_by_id,
Expand All @@ -21,8 +13,17 @@
update_bot_publication,
update_bot_visibility,
)

from app.usecases.bot import fetch_all_bots_by_user_id, issue_presigned_url
from app.usecases.bot import (
fetch_all_bots_by_user_id,
fetch_bot_summary,
issue_presigned_url,
)
from pydantic import BaseModel
from tests.test_usecases.utils.bot_factory import (
create_test_bot_alias,
create_test_private_bot,
create_test_public_bot,
)


class TestIssuePresignedUrl(unittest.TestCase):
Expand Down Expand Up @@ -142,5 +143,29 @@ def test_order_is_descending(self):
self.assertEqual(bots[5].id, self.first_bot_id)


class TestSharing(unittest.TestCase):
def setUp(self) -> None:
self.publisher_id = "test_user_pub"
self.subscriber_id = "test_user_sub"

self.bot = create_test_public_bot("test_bot", True, self.publisher_id)
store_bot(self.publisher_id, self.bot)

def tearDown(self) -> None:
delete_bot_by_id(self.publisher_id, self.bot.id)
try:
delete_alias_by_id(self.subscriber_id, self.bot.id)
except:
print("Alias not found")

def test_share_and_subscribe(self):
# Share the bot to public
update_bot_visibility(self.publisher_id, self.bot.id, True)

# Subscribe (equal to open shared URL on browser)
bot_summary = fetch_bot_summary(self.subscriber_id, self.bot.id)
self.assertEqual(bot_summary.id, self.bot.id)


if __name__ == "__main__":
unittest.main()

0 comments on commit b8ce89c

Please sign in to comment.