Skip to content

Commit

Permalink
feat(fal): don't scale by default on deploy (#338)
Browse files Browse the repository at this point in the history
* feat(fal): don't scale by default on deploy

* add test
  • Loading branch information
efiop authored Oct 22, 2024
1 parent 0007191 commit 05beae0
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 5 deletions.
2 changes: 1 addition & 1 deletion projects/fal/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ readme = "README.md"
requires-python = ">=3.8"
dependencies = [
"isolate[build]>=0.13.0,<1.14.0",
"isolate-proto==0.5.4",
"isolate-proto==0.5.5",
"grpcio==1.64.0",
"dill==0.3.7",
"cloudpickle==3.0.0",
Expand Down
10 changes: 6 additions & 4 deletions projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def register(
application_name: str | None = None,
application_auth_mode: Literal["public", "shared", "private"] | None = None,
metadata: dict[str, Any] | None = None,
scale: bool = True,
) -> str | None:
"""Register the given function on the host for API call execution."""
raise NotImplementedError
Expand Down Expand Up @@ -430,6 +431,7 @@ def register(
application_auth_mode: Literal["public", "shared", "private"] | None = None,
metadata: dict[str, Any] | None = None,
deployment_strategy: Literal["recreate", "rolling"] = "recreate",
scale: bool = True,
) -> str | None:
environment_options = options.environment.copy()
environment_options.setdefault("python_version", active_python())
Expand All @@ -439,15 +441,14 @@ def register(
"machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
)
keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
max_concurrency = options.host.get("max_concurrency")
min_concurrency = options.host.get("min_concurrency")
max_multiplexing = options.host.get("max_multiplexing")
base_image = options.host.get("_base_image", None)
scheduler = options.host.get("_scheduler", None)
scheduler_options = options.host.get("_scheduler_options", None)
max_concurrency = options.host.get("max_concurrency")
min_concurrency = options.host.get("min_concurrency")
max_multiplexing = options.host.get("max_multiplexing")
exposed_port = options.get_exposed_port()
request_timeout = options.host.get("request_timeout")

machine_requirements = MachineRequirements(
machine_types=machine_type, # type: ignore
num_gpus=options.host.get("num_gpus"),
Expand Down Expand Up @@ -486,6 +487,7 @@ def register(
machine_requirements=machine_requirements,
metadata=metadata,
deployment_strategy=deployment_strategy,
scale=scale,
):
for log in partial_result.logs:
self._log_printer.print(log)
Expand Down
10 changes: 10 additions & 0 deletions projects/fal/src/fal/cli/deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def _deploy_from_reference(
application_auth_mode=app_auth,
metadata=isolated_function.options.host.get("metadata", {}),
deployment_strategy=deployment_strategy,
scale=not args.no_scale,
)

if app_id:
Expand Down Expand Up @@ -219,5 +220,14 @@ def valid_auth_option(option):
help="Deployment strategy.",
default="recreate",
)
parser.add_argument(
"--no-scale",
action="store_true",
help=(
"Use min_concurrency/max_concurrency/max_multiplexing from previous "
"deployment of application with this name, if exists. Otherwise will "
"use the values from the application code."
),
)

parser.set_defaults(func=_deploy)
2 changes: 2 additions & 0 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def register(
machine_requirements: MachineRequirements | None = None,
metadata: dict[str, Any] | None = None,
deployment_strategy: Literal["recreate", "rolling"] = "recreate",
scale: bool = True,
) -> Iterator[isolate_proto.RegisterApplicationResult]:
wrapped_function = to_serialized_object(function, serialization_method)
if machine_requirements:
Expand Down Expand Up @@ -546,6 +547,7 @@ def register(
auth_mode=auth_mode,
metadata=struct_metadata,
deployment_strategy=deployment_strategy_proto,
scale=scale,
)
for partial_result in self.stub.RegisterApplication(request):
yield from_grpc(partial_result)
Expand Down
42 changes: 42 additions & 0 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,48 @@ def test_404_response(test_app: str, request: pytest.FixtureRequest):
apps.run(test_app, path="/other", arguments={"lhs": 1, "rhs": 2})


def test_app_deploy_scale(aliased_app: tuple[str, str]):
import uuid
from dataclasses import replace

app_alias = str(uuid.uuid4()) + "-alias"
app_revision = addition_app.host.register(
func=addition_app.func,
options=addition_app.options,
application_name=app_alias,
application_auth_mode="private",
)

host: api.FalServerlessHost = addition_app.host # type: ignore
options = replace(
addition_app.options, host={**addition_app.options.host, "max_multiplexing": 30}
)
kwargs = dict(
func=addition_app.func,
options=options,
application_name=app_alias,
application_auth_mode="private",
)

app_revision = addition_app.host.register(**kwargs, scale=False)

with host._connection as client:
res = client.list_aliases()
found = next(filter(lambda alias: alias.alias == app_alias, res), None)
assert found, f"Could not find app {app_alias} in {res}"
assert found.revision == app_revision
assert found.max_multiplexing == 1

app_revision = addition_app.host.register(**kwargs, scale=True)

with host._connection as client:
res = client.list_aliases()
found = next(filter(lambda alias: alias.alias == app_alias, res), None)
assert found, f"Could not find app {app_alias} in {res}"
assert found.revision == app_revision
assert found.max_multiplexing == 30


def test_app_update_app(aliased_app: tuple[str, str]):
app_revision, app_alias = aliased_app

Expand Down

0 comments on commit 05beae0

Please sign in to comment.