Skip to content

Commit

Permalink
feat: fal.App for multiple endpoints (#27)
Browse files Browse the repository at this point in the history
* wip: feat: `fal.App` for multiple endpoints

* fix: build the metadata

* add tests
  • Loading branch information
isidentical authored Jan 10, 2024
1 parent 41339af commit 74afc08
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 7 deletions.
1 change: 1 addition & 0 deletions projects/fal/src/fal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fal.api import FalServerlessHost, LocalHost, cached
from fal.api import function
from fal.api import function as isolated
from fal.app import App, endpoint, wrap_app
from fal.sdk import FalServerlessKeyCredentials
from fal.sync import sync_dir

Expand Down
162 changes: 162 additions & 0 deletions projects/fal/src/fal/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from __future__ import annotations

import inspect
import os
import fal.api
from fal.toolkit import mainify
from fastapi import FastAPI
from typing import Any, NamedTuple, Callable, TypeVar, ClassVar
from fal.logging import get_logger

EndpointT = TypeVar("EndpointT", bound=Callable[..., Any])
logger = get_logger(__name__)


def wrap_app(cls: type[App], **kwargs) -> fal.api.IsolatedFunction:
def initialize_and_serve():
app = cls()
app.serve()

try:
app = cls(_allow_init=True)
metadata = app.openapi()
except Exception as exc:
logger.warning("Failed to build OpenAPI specification for %s", cls.__name__)
metadata = {}

wrapper = fal.api.function(
"virtualenv",
requirements=cls.requirements,
machine_type=cls.machine_type,
**cls.host_kwargs,
**kwargs,
metadata=metadata,
serve=True,
)
return wrapper(initialize_and_serve).on(
serve=False,
exposed_port=8080,
)


@mainify
class RouteSignature(NamedTuple):
path: str


@mainify
class App:
requirements: ClassVar[list[str]] = []
machine_type: ClassVar[str] = "S"
host_kwargs: ClassVar[dict[str, Any]] = {}

def __init_subclass__(cls, **kwargs):
cls.host_kwargs = kwargs

if cls.__init__ is not App.__init__:
raise ValueError(
"App classes should not override __init__ directly. "
"Use setup() instead."
)

def __init__(self, *, _allow_init: bool = False):
if not _allow_init and not os.getenv("IS_ISOLATE_AGENT"):
raise NotImplementedError(
"Running apps through SDK is not implemented yet."
)

def setup(self):
"""Setup the application before serving."""

def serve(self) -> None:
import uvicorn

app = self._build_app()
self.setup()
uvicorn.run(app, host="0.0.0.0", port=8080)

def _build_app(self) -> FastAPI:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

_app = FastAPI()

_app.add_middleware(
CORSMiddleware,
allow_credentials=True,
allow_headers=("*"),
allow_methods=("*"),
allow_origins=("*"),
)

routes: dict[RouteSignature, Callable[..., Any]] = {
signature: endpoint
for _, endpoint in inspect.getmembers(self, inspect.ismethod)
if (signature := getattr(endpoint, "route_signature", None))
}
if not routes:
raise ValueError("An application must have at least one route!")

for signature, endpoint in routes.items():
_app.add_api_route(
signature.path,
endpoint,
name=endpoint.__name__,
methods=["POST"],
)

return _app

def openapi(self) -> dict[str, Any]:
"""
Build the OpenAPI specification for the served function.
Attach needed metadata for a better integration to fal.
"""
app = self._build_app()
spec = app.openapi()
self._mark_order_openapi(spec)
return spec

def _mark_order_openapi(self, spec: dict[str, Any]):
"""
Add x-fal-order-* keys to the OpenAPI specification to help the rendering of UI.
NOTE: We rely on the fact that fastapi and Python dicts keep the order of properties.
"""

def mark_order(obj: dict[str, Any], key: str):
obj[f"x-fal-order-{key}"] = list(obj[key].keys())

mark_order(spec, "paths")

def order_schema_object(schema: dict[str, Any]):
"""
Mark the order of properties in the schema object.
They can have 'allOf', 'properties' or '$ref' key.
"""
if "allOf" in schema:
for sub_schema in schema["allOf"]:
order_schema_object(sub_schema)
if "properties" in schema:
mark_order(schema, "properties")

for key in spec["components"].get("schemas") or {}:
order_schema_object(spec["components"]["schemas"][key])

return spec


@mainify
def endpoint(path: str) -> Callable[[EndpointT], EndpointT]:
"""Designate the decorated function as an application endpoint."""

def marker_fn(callable: EndpointT) -> EndpointT:
if hasattr(callable, "route_signature"):
raise ValueError(
f"Can't set multiple routes for the same function: {callable.__name__}"
)

callable.route_signature = RouteSignature(path=path) # type: ignore
return callable

return marker_fn
40 changes: 33 additions & 7 deletions projects/fal/src/fal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import click
import fal.auth as auth
import grpc
import fal
from fal import api, sdk
from fal.console import console
from fal.exceptions import ApplicationExceptionHandler
Expand Down Expand Up @@ -244,6 +245,28 @@ def function_cli(ctx, host: str, port: str):
ctx.obj = api.FalServerlessHost(f"{host}:{port}")


def load_function_from(
host: api.FalServerlessHost,
file_path: str,
function_name: str,
) -> api.IsolatedFunction:
import runpy

module = runpy.run_path(file_path)
if function_name not in module:
raise api.FalServerlessError(f"Function '{function_name}' not found in module")

target = module[function_name]
if issubclass(target, fal.App):
target = fal.wrap_app(target, host=host)

if not isinstance(target, api.IsolatedFunction):
raise api.FalServerlessError(
f"Function '{function_name}' is not a fal.function or a fal.App"
)
return target


@function_cli.command("serve")
@click.option("--alias", default=None)
@click.option(
Expand All @@ -262,15 +285,9 @@ def register_application(
alias: str | None,
auth_mode: ALIAS_AUTH_TYPE,
):
import runpy

user_id = _get_user_id()

module = runpy.run_path(file_path)
if function_name not in module:
raise api.FalServerlessError(f"Function '{function_name}' not found in module")

isolated_function: api.IsolatedFunction = module[function_name]
isolated_function = load_function_from(host, file_path, function_name)
gateway_options = isolated_function.options.gateway
if "serve" not in gateway_options and "exposed_port" not in gateway_options:
raise api.FalServerlessError(
Expand Down Expand Up @@ -307,6 +324,15 @@ def register_application(
console.print(f"URL: https://{user_id}-{id}.{gateway_host}")


@function_cli.command("run")
@click.argument("file_path", required=True)
@click.argument("function_name", required=True)
@click.pass_obj
def run(host: api.FalServerlessHost, file_path: str, function_name: str):
isolated_function = load_function_from(host, file_path, function_name)
isolated_function()


@function_cli.command("logs")
@click.option("--lines", default=100)
@click.option("--url", default=None)
Expand Down
58 changes: 58 additions & 0 deletions projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class Input(BaseModel):
wait_time: int = 0


class StatefulInput(BaseModel):
value: int


class Output(BaseModel):
result: int

Expand Down Expand Up @@ -77,6 +81,28 @@ def subtract(input: Input) -> Output:
run(app, host="0.0.0.0", port=8080)


class StatefulAdditionApp(fal.App, keep_alive=300, max_concurrency=1):
machine_type = "S"

def setup(self):
self.counter = 0

@fal.endpoint("/reset")
def reset(self) -> Output:
self.counter = 0
return Output(result=self.counter)

@fal.endpoint("/increment")
def increment(self, input: StatefulInput) -> Output:
self.counter += input.value
return Output(result=self.counter)

@fal.endpoint("/decrement")
def decrement(self, input: StatefulInput) -> Output:
self.counter -= input.value
return Output(result=self.counter)


@pytest.fixture(scope="module")
def aliased_app() -> Generator[tuple[str, str], None, None]:
# Create a temporary app, register it, and return the ID of it.
Expand Down Expand Up @@ -122,6 +148,21 @@ def test_fastapi_app():
yield f"{user_id}-{app_revision}"


@pytest.fixture(scope="module")
def test_stateful_app():
# Create a temporary app, register it, and return the ID of it.

from fal.cli import _get_user_id

app = fal.wrap_app(StatefulAdditionApp)
app_revision = app.host.register(
func=app.func,
options=app.options,
)
user_id = _get_user_id()
yield f"{user_id}-{app_revision}"


def test_app_client(test_app: str):
response = apps.run(test_app, arguments={"lhs": 1, "rhs": 2})
assert response["result"] == 3
Expand All @@ -130,6 +171,23 @@ def test_app_client(test_app: str):
assert response["result"] == 5


def test_stateful_app_client(test_stateful_app: str):
response = apps.run(test_stateful_app, arguments={}, path="/reset")
assert response["result"] == 0

response = apps.run(test_stateful_app, arguments={"value": 1}, path="/increment")
assert response["result"] == 1

response = apps.run(test_stateful_app, arguments={"value": 2}, path="/increment")
assert response["result"] == 3

response = apps.run(test_stateful_app, arguments={"value": 1}, path="/decrement")
assert response["result"] == 2

response = apps.run(test_stateful_app, arguments={"value": 2}, path="/decrement")
assert response["result"] == 0


def test_app_client_async(test_app: str):
request_handle = apps.submit(test_app, arguments={"lhs": 1, "rhs": 2})
assert request_handle.get() == {"result": 3}
Expand Down
5 changes: 5 additions & 0 deletions projects/fal/tests/test_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@


def test_missing_dependencies_nested_server_error(isolated_client):
from fal import _serialization

_serialization._PACKAGES.clear()
_serialization._MODULES.clear()

@isolated_client()
def test1():
return "hello"
Expand Down

0 comments on commit 74afc08

Please sign in to comment.