Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix typing of task decorator for retry_condition_fn argument #16621

Merged

Conversation

peterbygrave
Copy link
Contributor

In migrating to prefect v3, we are seeing issues with mypy checking of the task decorators. For example when making the decorator a call e.g. @task() or @task(name="foo") you end up with:

error: Argument 1 has incompatible type "Callable[[int, int], int]"; expected "Callable[[VarArg(Never), KwArg(Never)], int]"  [arg-type]

I found that it was the typing of retry_condition_fn that flakes out because the input typing should connected to the return type, but nothing can be inferred because nothing was given.

In this PR I add an overload to differentiate between None and an actual function is passed.

Checklist

  • This pull request references any related issue by including "closes <link to issue>"
    • If no issue exists and your change is not a small fix, please create an issue first.
  • If this pull request adds new functionality, it includes unit tests that cover the changes
  • If this pull request removes docs files, it includes redirect settings in mint.json.
  • If this pull request adds functions or classes, it includes helpful docstrings.

refresh_cache: Optional[bool] = None,
on_completion: Optional[list[StateHookCallable]] = None,
on_failure: Optional[list[StateHookCallable]] = None,
retry_condition_fn: Optional[Callable[[Task[P, R], TaskRun, State], bool]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
retry_condition_fn: Optional[Callable[[Task[P, R], TaskRun, State], bool]] = None,
retry_condition_fn: Optional[Callable[[Task[P, R], TaskRun, State], bool]] = None,

Note this previously was Task[P, Any] but now has R because it should exactly match. The overload above now matches first when None is given (or default is used).

@@ -1634,7 +1634,43 @@ def task(
refresh_cache: Optional[bool] = None,
on_completion: Optional[list[StateHookCallable]] = None,
on_failure: Optional[list[StateHookCallable]] = None,
retry_condition_fn: Optional[Callable[[Task[P, Any], TaskRun, State], bool]] = None,
retry_condition_fn: Literal[None] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new thing that matches first when it is None and doesn't interact with returned P and R.

Comment on lines +2 to +11
- case: prefect_task_decorator_no_args
main: |
from prefect import task
@task
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""
Copy link
Contributor Author

@peterbygrave peterbygrave Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works on main branch

Comment on lines +13 to +44
- case: prefect_task_decorator_call_with_no_args
main: |
from prefect import task
@task()
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""

- case: prefect_task_decorator_with_name_arg
main: |
from prefect import task
@task(name="bar")
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""

- case: prefect_task_decorator_with_retry_condition_fn_as_none_arg
main: |
from prefect.tasks import task
@task(retry_condition_fn=None)
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:5: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These 3 all fail on main branch

Copy link

codspeed-hq bot commented Jan 6, 2025

CodSpeed Performance Report

Merging #16621 will not alter performance

Comparing peterbygrave:fix-tasks-overload-retry-condition-fn (cab5183) with main (16e85ce)

Summary

✅ 2 untouched benchmarks

Comment on lines +46 to +59
- case: prefect_task_decorator_with_retry_condition_fn_arg
main: |
from prefect.tasks import P, R, Task, task
from prefect.client.schemas import TaskRun
from prefect.states import State
def retry_condition_fn(task: Task[P, R], task_run: TaskRun, state: State) -> bool:
return False
@task(retry_condition_fn=retry_condition_fn)
def foo(bar: str) -> int:
return 42
reveal_type(foo)
out: "main:9: note: Revealed type is \"\
prefect.tasks.Task[[bar: builtins.str], builtins.int]\
\""
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works on main branch

@zzstoatzz
Copy link
Collaborator

hi @peterbygrave - thanks for the PR! can you provide your version of prefect and a specific code example?

for example, I am not seeing pyright have any issues with this on main, though I can see mypy warns as you mention

from typing import Any, reveal_type

from prefect import Task, task
from prefect.cache_policies import NONE
from prefect.client.schemas.objects import State, TaskRun


def state_hook(task: Task[Any, Any], task_run: TaskRun, state: State):
    print(f"Task run {task_run.name!r} from {task.name!r} entered state {state.name!r}")


@task(retries=1, persist_result=False, on_completion=[state_hook])
def identity(x: int) -> int:
    return x


@task(persist_result=False)
def identity_no_cache(x: int) -> int:
    return x


@task(cache_policy=NONE)
def returns_something(x: int) -> int:
    return x


@task(
    name="conditionally_retries",
    retry_condition_fn=lambda task, task_run, state: True,
    retry_delay_seconds=[1, 2],
)
def conditionally_retries(x: int) -> int:
    return x


@task
def also_returns_something(x: int) -> int:
    return x


if __name__ == "__main__":
    value = returns_something(1)
    something = identity(value)
    something_else = identity_no_cache(value)
    other_value = also_returns_something(1)
    result = conditionally_retries(1)

    reveal_type(value)
    reveal_type(something)
    reveal_type(something_else)
    reveal_type(other_value)
    reveal_type(result)

when I checkout your branch, this result (from the task with a retry_condition_fn) becomes Unknown

image

@peterbygrave
Copy link
Contributor Author

hi @peterbygrave - thanks for the PR! can you provide your version of prefect and a specific code example?

We are just doing our v2 to v3 migration so first saw this on v3.1.11, but has issues on main branch.

The code example I have are in the tests. Its all isolated to retry_condition_fn because it is capable of narrowing the P and R.

for example, I am not seeing pyright have any issues with this on main, though I can see mypy warns as you mention

from typing import Any, reveal_type

from prefect import Task, task
from prefect.cache_policies import NONE
from prefect.client.schemas.objects import State, TaskRun


def state_hook(task: Task[Any, Any], task_run: TaskRun, state: State):
    print(f"Task run {task_run.name!r} from {task.name!r} entered state {state.name!r}")


@task(retries=1, persist_result=False, on_completion=[state_hook])
def identity(x: int) -> int:
    return x


@task(persist_result=False)
def identity_no_cache(x: int) -> int:
    return x


@task(cache_policy=NONE)
def returns_something(x: int) -> int:
    return x


@task(
    name="conditionally_retries",
    retry_condition_fn=lambda task, task_run, state: True,
    retry_delay_seconds=[1, 2],
)
def conditionally_retries(x: int) -> int:
    return x


@task
def also_returns_something(x: int) -> int:
    return x


if __name__ == "__main__":
    value = returns_something(1)
    something = identity(value)
    something_else = identity_no_cache(value)
    other_value = also_returns_something(1)
    result = conditionally_retries(1)

    reveal_type(value)
    reveal_type(something)
    reveal_type(something_else)
    reveal_type(other_value)
    reveal_type(result)

when I checkout your branch, this result (from the task with a retry_condition_fn) becomes Unknown

image

So the on_completion arg type does not try to infer the P and R generics (its just Task, no generics) for the returned Task type. But retry_condition_fn does (because it has Task[P, R]. So in your example you pass a lambda in which has no typing at all, which on first pass through the decorator destroys the typing in the returned partial. So when the function is passed through the decorator/partial again it can't infer the P and R of the return Task anymore. My guess would be if you want to support lambdas, then you'd need to modify the typing to something like:

retry_condition_fn: Optional[Callable[[Task, TaskRun, State], bool]]

I'm not sure if that would need to be in additional or replacement of. If replacement it would probably exclude the ability to have stronger type checking - in that one wouldn't be able to check you are handling a task with the right args and return type (if you wanted to depend on them).

Copy link
Collaborator

@zzstoatzz zzstoatzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the delay here @peterbygrave - I've verified that this still passes mypy after recent typing changes on main so this LGTM!

thank you!

@zzstoatzz zzstoatzz merged commit 07b1cab into PrefectHQ:main Jan 17, 2025
44 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants