From d70a477760f9722616727b1e0738a31160c67be1 Mon Sep 17 00:00:00 2001 From: Ryan Lewis Date: Mon, 30 Oct 2023 14:08:49 -0500 Subject: [PATCH] Refactor: use StepResponse for json/header validation --- examples/example_nodes/example_rest_node.py | 20 ++++++----- examples/example_nodes/webcam_rest_node.py | 18 +++++----- wei/core/data_classes.py | 40 +++++++++++++++++++++ wei/core/interfaces/rest_interface.py | 22 ++++++------ 4 files changed, 72 insertions(+), 28 deletions(-) diff --git a/examples/example_nodes/example_rest_node.py b/examples/example_nodes/example_rest_node.py index 1d5b23e0..c2096ef9 100644 --- a/examples/example_nodes/example_rest_node.py +++ b/examples/example_nodes/example_rest_node.py @@ -4,9 +4,14 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from fastapi.responses import FileResponse, JSONResponse +from fastapi.responses import JSONResponse -from wei.core.data_classes import ModuleStatus, StepResponse, StepStatus +from wei.core.data_classes import ( + ModuleStatus, + StepFileResponse, + StepResponse, + StepStatus, +) global state, module_resources @@ -95,13 +100,10 @@ def do_action( state = ModuleStatus.IDLE # Use the FileResponse class to return files file_name = json.loads(action_vars)["file_name"] - return FileResponse( - path=file_name, - headers=StepResponse( - action_response=StepStatus.SUCCEEDED, - action_msg=file_name, - action_log="", - ).model_dump(mode="json"), + return StepFileResponse( + action_response=StepStatus.SUCCEEDED, + path=file_name, # The path to the file to be returned + action_log="", ) else: # Handle Unsupported actions diff --git a/examples/example_nodes/webcam_rest_node.py b/examples/example_nodes/webcam_rest_node.py index 712d0dcd..9a099d34 100644 --- a/examples/example_nodes/webcam_rest_node.py +++ b/examples/example_nodes/webcam_rest_node.py @@ -7,9 +7,14 @@ import cv2 from fastapi import FastAPI -from fastapi.responses import FileResponse, JSONResponse +from fastapi.responses import JSONResponse -from wei.core.data_classes import ModuleStatus, StepResponse, StepStatus +from wei.core.data_classes import ( + ModuleStatus, + StepFileResponse, + StepResponse, + StepStatus, +) global state @@ -85,13 +90,10 @@ def do_action( state = ModuleStatus.IDLE print("success") - return FileResponse( + return StepFileResponse( + action_response=StepStatus.SUCCEEDED, path=image_name, - headers=StepResponse( - action_response=StepStatus.SUCCEEDED, - action_msg=image_name, - action_log="", - ).model_dump(mode="json"), + action_log="", ) else: state = ModuleStatus.IDLE diff --git a/wei/core/data_classes.py b/wei/core/data_classes.py index 64e0662d..f0a61f9d 100644 --- a/wei/core/data_classes.py +++ b/wei/core/data_classes.py @@ -7,6 +7,7 @@ import ulid import yaml +from fastapi.responses import FileResponse from pydantic import BaseModel as _BaseModel from pydantic import Field, validator @@ -315,6 +316,45 @@ class StepResponse(BaseModel): action_log: str = "" """Error or log messages resulting from the action""" + def to_headers(self) -> Dict[str, str]: + """Converts the response to a dictionary of headers""" + return { + "X-WEI-action-response": str(self.action_response), + "X-WEI-action-msg": self.action_msg, + "X-WEI-action-log": self.action_log, + } + + @classmethod + def from_headers(cls, response: FileResponse): + """Creates a StepResponse from the headers of a file response""" + return cls( + action_response=StepStatus(response.headers["X-WEI-action-response"]), + action_msg=response.headers["X-WEI-action-msg"], + action_log=response.headers["X-WEI-action-log"], + ) + + +class StepFileResponse(FileResponse): + """ + Convenience wrapper for FastAPI's FileResponse class + If not using FastAPI, return a response with + - The file object as the response content + - The StepResponse parameters as custom headers, prefixed with "wei_" + """ + + def __init__(self, action_response: StepStatus, action_log: str, path: PathLike): + """ + Returns a FileResponse with the given path as the response content + """ + return super().__init__( + path=path, + headers=StepResponse( + action_response=action_response, + action_msg=str(path), + action_log=action_log, + ).to_headers(), + ) + class ExperimentStatus(str, Enum): """Status for an experiment""" diff --git a/wei/core/interfaces/rest_interface.py b/wei/core/interfaces/rest_interface.py index ba3dcd53..f116ca8f 100644 --- a/wei/core/interfaces/rest_interface.py +++ b/wei/core/interfaces/rest_interface.py @@ -5,7 +5,7 @@ import requests -from wei.core.data_classes import Interface, Module, Step +from wei.core.data_classes import Interface, Module, Step, StepResponse class RestInterface(Interface): @@ -43,25 +43,25 @@ def send_action(step: Step, **kwargs) -> Tuple[str, str, str]: headers=headers, params={"action_handle": step.action, "action_vars": json.dumps(step.args)}, ) - if "action_response" in rest_response.headers: + if "X-WEI-action_response" in rest_response.headers: + response = StepResponse.from_headers(rest_response.headers) if "exp_path" in kwargs.keys(): path = Path( kwargs["exp_path"], "results", - step.id + "_" + rest_response.headers["action_msg"], + step.id + "_" + response.action_msg, ) else: - path = Path(step.id + rest_response.headers["action_msg"]) + path = Path(step.id + response.action_msg) with open(str(path), "wb") as f: f.write(rest_response.content) - rest_response = rest_response.headers - rest_response["action_msg"] = str(path.name) + response.action_msg = str(path) else: - rest_response = rest_response.json() - print(rest_response) - action_response = rest_response["action_response"] - action_msg = rest_response["action_msg"] - action_log = rest_response["action_log"] + response = StepResponse.model_validate(rest_response.json()) + print(response) + action_response = response.action_response + action_msg = response.action_msg + action_log = response.action_log return action_response, action_msg, action_log