diff --git a/meldingen_core/actions/melding.py b/meldingen_core/actions/melding.py index df83c18..0646487 100644 --- a/meldingen_core/actions/melding.py +++ b/meldingen_core/actions/melding.py @@ -98,6 +98,11 @@ async def __call__(self, pk: int, values: dict[str, Any], token: str) -> T: class BaseStateTransitionAction(Generic[T, T_co], metaclass=ABCMeta): + """ + This action covers transitions that do not require the melding's token to be verified. + Typically these actions are performed by authenticated users. + """ + _state_machine: BaseMeldingStateMachine[T] _repository: BaseMeldingRepository[T, T_co] @@ -124,7 +129,12 @@ async def __call__(self, melding_id: int) -> T: return melding -class MeldingAnswerQuestionsAction(Generic[T, T_co]): +class BaseMeldingFormStateTransitionAction(Generic[T, T_co], metaclass=ABCMeta): + """ + This action covers transitions that require the melding's token to be verified. + This is the case for unauthenticated state transitions where a user submits a melding. + """ + _state_machine: BaseMeldingStateMachine[T] _repository: BaseMeldingRepository[T, T_co] _verify_token: TokenVerifier[T, T_co] @@ -139,37 +149,35 @@ def __init__( self._repository = repository self._verify_token = token_verifier + @property + @abstractmethod + def transition_name(self) -> str: ... + async def __call__(self, melding_id: int, token: str) -> T: melding = await self._verify_token(melding_id, token) - await self._state_machine.transition(melding, MeldingTransitions.ANSWER_QUESTIONS) + await self._state_machine.transition(melding, self.transition_name) await self._repository.save(melding) return melding -class MeldingAddAttachmentsAction(Generic[T, T_co]): - _state_machine: BaseMeldingStateMachine[T] - _repository: BaseMeldingRepository[T, T_co] - _verify_token: TokenVerifier[T, T_co] +class MeldingAnswerQuestionsAction(BaseMeldingFormStateTransitionAction[T, T_co]): + @property + def transition_name(self) -> str: + return MeldingTransitions.ANSWER_QUESTIONS - def __init__( - self, - state_machine: BaseMeldingStateMachine[T], - repository: BaseMeldingRepository[T, T_co], - token_verifier: TokenVerifier[T, T_co], - ): - self._state_machine = state_machine - self._repository = repository - self._verify_token = token_verifier - async def __call__(self, melding_id: int, token: str) -> T: - melding = await self._verify_token(melding_id, token) +class MeldingAddAttachmentsAction(BaseMeldingFormStateTransitionAction[T, T_co]): + @property + def transition_name(self) -> str: + return MeldingTransitions.ADD_ATTACHMENTS - await self._state_machine.transition(melding, MeldingTransitions.ADD_ATTACHMENTS) - await self._repository.save(melding) - return melding +class MeldingSubmitLocationAction(BaseMeldingFormStateTransitionAction[T, T_co]): + @property + def transition_name(self) -> str: + return MeldingTransitions.SUBMIT_LOCATION class MeldingProcessAction(BaseStateTransitionAction[T, T_co]): diff --git a/meldingen_core/statemachine.py b/meldingen_core/statemachine.py index 32aad92..c7914ed 100644 --- a/meldingen_core/statemachine.py +++ b/meldingen_core/statemachine.py @@ -12,6 +12,7 @@ class MeldingStates(StrEnum): CLASSIFIED = "classified" QUESTIONS_ANSWERED = "questions_answered" ATTACHMENTS_ADDED = "attachments_added" + LOCATION_SUBMITTED = "location_submitted" PROCESSING = "processing" COMPLETED = "completed" @@ -21,6 +22,7 @@ class MeldingTransitions(StrEnum): CLASSIFY = "classify" ANSWER_QUESTIONS = "answer_questions" ADD_ATTACHMENTS = "add_attachments" + SUBMIT_LOCATION = "submit_location" COMPLETE = "complete" diff --git a/tests/test_actions/test_melding_actions.py b/tests/test_actions/test_melding_actions.py index 471a45e..6e166a1 100644 --- a/tests/test_actions/test_melding_actions.py +++ b/tests/test_actions/test_melding_actions.py @@ -12,6 +12,7 @@ MeldingListAction, MeldingProcessAction, MeldingRetrieveAction, + MeldingSubmitLocationAction, MeldingUpdateAction, ) from meldingen_core.classification import ClassificationNotFoundException, Classifier @@ -137,7 +138,7 @@ async def test_process_action_not_found() -> None: @pytest.mark.anyio -async def test_add_attachments_actions() -> None: +async def test_add_attachments_action() -> None: repository = Mock(BaseMeldingRepository) repo_melding = Melding("melding text") repository.retrieve.return_value = repo_melding @@ -156,6 +157,20 @@ async def test_add_attachments_actions() -> None: repository.save.assert_called_once_with(repo_melding) +@pytest.mark.anyio +async def test_add_attachments_action_not_found() -> None: + repository = Mock(BaseMeldingRepository) + repository.retrieve.return_value = None + token_verifier: TokenVerifier[Melding, Melding] = TokenVerifier(repository) + + process: MeldingAddAttachmentsAction[Melding, Melding] = MeldingAddAttachmentsAction( + Mock(BaseMeldingStateMachine), Mock(BaseMeldingRepository), token_verifier + ) + + with pytest.raises(NotFoundException): + await process(1, "token") + + @pytest.mark.anyio async def test_complete_action() -> None: state_machine = Mock(BaseMeldingStateMachine) @@ -180,3 +195,37 @@ async def test_complete_action_not_found() -> None: with pytest.raises(NotFoundException): await process(1) + + +@pytest.mark.anyio +async def test_submit_location_action() -> None: + repository = Mock(BaseMeldingRepository) + repo_melding = Melding("melding text") + repository.retrieve.return_value = repo_melding + state_machine = Mock(BaseMeldingStateMachine) + token_verifier = AsyncMock(TokenVerifier) + token_verifier.return_value = repo_melding + + submit_location: MeldingSubmitLocationAction[Melding, Melding] = MeldingSubmitLocationAction( + state_machine, repository, token_verifier + ) + + melding = await submit_location(1, "token") + + assert melding == repo_melding + state_machine.transition.assert_called_once_with(repo_melding, MeldingTransitions.SUBMIT_LOCATION) + repository.save.assert_called_once_with(repo_melding) + + +@pytest.mark.anyio +async def test_submit_location_action_not_found() -> None: + repository = Mock(BaseMeldingRepository) + repository.retrieve.return_value = None + token_verifier: TokenVerifier[Melding, Melding] = TokenVerifier(repository) + + process: MeldingSubmitLocationAction[Melding, Melding] = MeldingSubmitLocationAction( + Mock(BaseMeldingStateMachine), Mock(BaseMeldingRepository), token_verifier + ) + + with pytest.raises(NotFoundException): + await process(1, "token")