diff --git a/pykern/http.py b/pykern/http.py index f1bdfa67..d7ede8fb 100644 --- a/pykern/http.py +++ b/pykern/http.py @@ -1,4 +1,5 @@ -"""HTTP server +"""HTTP server & client + :copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. :license: http://www.apache.org/licenses/LICENSE-2.0.html @@ -6,6 +7,7 @@ from pykern.pkcollections import PKDict from pykern.pkdebug import pkdc, pkdlog, pkdp, pkdexc, pkdformat +import asyncio import inspect import msgpack import pykern.pkasyncio @@ -13,7 +15,9 @@ import pykern.pkconfig import pykern.quest import re +import tornado.httpclient import tornado.web +import tornado.websocket #: Http auth header name @@ -25,15 +29,15 @@ #: POSIT: Matches anything generated by `unique_key` _UNIQUE_KEY_CHARS_RE = r"\w+" +#: validates auth secret (only word chars) +_AUTH_SECRET_RE = re.compile(f"^{_UNIQUE_KEY_CHARS_RE}$") + #: Regex to test format of auth header and extract token _AUTH_HEADER_RE = re.compile( _AUTH_HEADER_SCHEME_BEARER + r"\s+(" + _UNIQUE_KEY_CHARS_RE + ")", re.IGNORECASE, ) -_CONTENT_TYPE_HEADER = "Content-Type" -_CONTENT_TYPE = "application/msgpack" - _VERSION_HEADER = "X-PyKern-HTTP-Version" _VERSION_HEADER_VALUE = "1" @@ -57,104 +61,59 @@ def server_start(api_classes, attr_classes, http_config, coros=()): l.start() -class Reply: - - def __init__(self, result=None, exc=None, api_error=None): - def _exception(exc): - if exc is None: - pkdlog("ERROR: no reply and no exception") - return 500 - if isinstance(exc, APINotFound): - return 404 - if isinstance(exc, APIForbidden): - return 403 - pkdlog("untranslated exception={}", exc) - return 500 - - if isinstance(result, Reply): - self.http_status = result.http_status - self.content = result.content - elif result is not None or api_error is not None: - self.http_status = 200 - self.content = PKDict( - api_error=api_error, - api_result=result, - ) - else: - self.http_status = _exception(exc) - self.content = None +class APICallError(pykern.quest.APIError): + """Raised for an object not found""" + def __init__(self, exception): + super().__init__("exception={}", exception) -class ReplyExc(Exception): - """Raised to end the request. - Args: - pk_args (dict): exception args that specific to this module - log_fmt (str): server side log data - """ - - def __init__(self, *args, **kwargs): - super().__init__() - if "pk_args" in kwargs: - self.pk_args = kwargs["pk_args"] - del kwargs["pk_args"] - else: - self.pk_args = PKDict() - if args or kwargs: - kwargs["pkdebug_frame"] = inspect.currentframe().f_back.f_back - pkdlog(*args, **kwargs) +class APIDisconnected(pykern.quest.APIError): + """Raised when remote server closed or other error""" - def __repr__(self): - a = self.pk_args - return "{}({})".format( - self.__class__.__name__, - ",".join( - ("{}={}".format(k, a[k]) for k in sorted(a.keys())), - ), - ) + def __init__(self): + super().__init__("") - def __str__(self): - return self.__repr__() - -class APIForbidden(ReplyExc): +class APIForbidden(pykern.quest.APIError): """Raised for forbidden or protocol error""" - pass + def __init__(self): + super().__init__("") -class APINotFound(ReplyExc): +class APINotFound(pykern.quest.APIError): """Raised for an object not found""" - pass + def __init__(self, api_name): + super().__init__("api_name={}", api_name) class HTTPClient: """Wrapper for `tornado.httpclient.AsyncHTTPClient` - Args: - http_config (PKDict): tcp_ip, tcp_port, api_uri, auth_secret, request_config + Maybe called as an async context manager + + `http_config.request_config` is deprecated. + Args: + http_config (PKDict): tcp_ip, tcp_port, api_uri, auth_secret """ def __init__(self, http_config): - self._uri = ( - f"http://{http_config.tcp_ip}:{http_config.tcp_port}{http_config.api_uri}" - ) - self._headers = PKDict( - { - _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {_auth_secret(http_config.auth_secret)}", - _CONTENT_TYPE_HEADER: _CONTENT_TYPE, - _VERSION_HEADER: _VERSION_HEADER_VALUE, - } + # TODO(robnagler) tls with verification(?) + self.uri = ( + f"ws://{http_config.tcp_ip}:{http_config.tcp_port}{http_config.api_uri}" ) - self._request_config = http_config.get("request_config") or PKDict() + self.auth_secret = _auth_secret(http_config.auth_secret) + self._connection = None + self._destroyed = False + self._call_id = 0 + self._pending_calls = PKDict() async def call_api(self, api_name, api_args): """Make a request to the API server - `http_config.request_config` (see `__init__` and if it exists) is passed verbatim to `AsyncHTTPClient.fetch`. - Args: api_name (str): what to call on the server api_args (PKDict): passed verbatim to the API on the server. @@ -164,35 +123,144 @@ async def call_api(self, api_name, api_args): APIError: if there was an raise in the API or on a server protocol violation Exception: other exceptions that `AsyncHTTPClient.fetch` may raise, e.g. NotFound """ - # Need to be careful with the lifecycle of AsyncHTTPClient - # https://github.com/radiasoft/pykern/issues/529 - r = await tornado.httpclient.AsyncHTTPClient(force_instance=True).fetch( - self._uri, - body=_pack_msg(PKDict(api_name=api_name, api_args=api_args)), - headers=self._headers, - method="POST", - **self._request_config, + + def _send(): + self._call_id += 1 + c = PKDict(api_name=api_name, api_args=api_args, call_id=self._call_id) + rv = _ClientCall(c) + self._pending_calls[rv.call_id] = rv + self._connection.write_message(_pack_msg(c), binary=True) + return rv + + # TODO(robnagler) backwards compatibility + if not self._connection: + # will check destroyed + await self.connect() + if self._destroyed: + return + return await _send().reply_get() + + async def connect(self): + if self._destroyed: + raise AssertionError("destroyed") + if self._connection: + raise AssertionError("already connected") + self._connection = await tornado.websocket.websocket_connect( + tornado.httpclient.HTTPRequest( + self.uri, + headers=( + { + _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {self.auth_secret}", + _VERSION_HEADER: _VERSION_HEADER_VALUE, + } + if self.auth_secret + else None + ), + method="GET", + ), + # TODO(robnagler) accept in http_config. share defaults with sirepo.job. + max_message_size=int(2e8), + ping_interval=120, + ping_timeout=240, ) - rv, e = _unpack_msg(r) - if e: - raise pykern.quest.APIError(*e) - if rv.api_error: - raise pykern.quest.APIError( - "api_error={} api_name={} api_args={}", rv.api_error, api_name, api_args + asyncio.create_task(self._read_loop()) + + def destroy(self): + """Must be called""" + if self._destroyed: + return + self._destroyed = True + if self._connection: + self._connection.close() + self._connection = None + x = self._pending_calls + self._pending_calls = None + for c in x.values(): + c.reply_q.put_nowait(None) + + async def __aenter__(self): + await self.connect() + if self._destroyed: + raise APIDisconnected() + return self + + async def __aexit__(self, *args, **kwargs): + self.destroy() + return False + + async def _read_loop(self): + def _unpack(msg): + r, e = _unpack_msg(msg) + if e: + pkdlog("unpack msg error={} {}", e, self) + return None + return r + + m = r = None + try: + if self._destroyed: + return + while m := await self._connection.read_message(): + if self._destroyed: + return + if not (r := _unpack(m)): + break + # Remove from pending + if not (c := self._pending_calls.pkdel(r.call_id)): + pkdlog("call_id not found reply={} {}", r, self) + # TODO(robnagler) possibly too harsh, but safer for now + break + c.reply_q.put_nowait(r) + m = r = None + except Exception as e: + pkdlog("exception={} reply={} stack={}", e, r, pkdexc()) + try: + if not self._destroyed: + self.destroy() + except Exception as e: + pkdlog("exception={} stack={}", e, pkdexc()) + + def __repr__(self): + def _calls(): + return ", ".join( + ( + f"{v.api_name}#{v.call_id}" + for v in sorted( + self._pending_calls.values(), key=lambda x: x.call_id + ) + ), ) - return rv.api_result + def _destroyed(): + return "DESTROYED, " if self._destroyed else "" -class _HTTPRequestHandler(tornado.web.RequestHandler): - def initialize(self, server): - self.pykern_server = server - self.pykern_context = PKDict() + return f"{self.__class__.__name__}({_destroyed()}call_id={self._call_id}, calls=[{_calls()}])" - async def get(self): - await self.pykern_server.dispatch(self) - async def post(self): - await self.pykern_server.dispatch(self) +class _ClientCall(PKDict): + def __init__(self, call_msg): + super().__init__(**call_msg) + # TODO(robnagler) should be one for regular replies + self.reply_q = tornado.queues.Queue() + self._destroyed = False + + async def reply_get(self): + rv = await self.reply_q.get() + if self._destroyed: + raise APIDisconnected() + self.reply_q.task_done() + self.destroy() + if rv is None: + raise APIDisconnected() + if rv.api_error: + raise pykern.quest.APIError(rv.api_error) + return rv.api_result + + def destroy(self): + if self._destroyed: + return + self._destroyed = True + self.reply_q = None class _HTTPServer: @@ -226,68 +294,62 @@ def _api_map(): self.api_map = _api_map() self.attr_classes = attr_classes self.auth_secret = _auth_secret(h.pkdel("auth_secret")) - h.uri_map = ((h.api_uri, _HTTPRequestHandler, PKDict(server=self)),) + h.uri_map = ((h.api_uri, _ServerHandler, PKDict(server=self)),) self.api_uri = h.pkdel("api_uri") h.log_function = self._log_end - self.req_id = 0 + self._ws_id = 0 loop.http_server(h) - async def dispatch(self, handler): - async def _call(api, api_args): - with pykern.quest.start(api.api_class, self.attr_classes) as qcall: - return await getattr(qcall, api.api_func_name)(api_args) + def handle_get(self, handler): + def _authenticate(headers): + if not (h := headers.get(_AUTH_HEADER)): + return "no auth token" + if not (m := _AUTH_HEADER_RE.search(h)): + return "auth token format invalid" + if m.group(1) != self.auth_secret: + return "auth token mismatch" + return None + + def _validate_version(headers): + if not (v := headers.get(_VERSION_HEADER)): + return f"missing {_VERSION_HEADER} header" + if v != _VERSION_HEADER_VALUE: + return f"invalid version {v}" + return None - m = None - r = None try: - self.req_id += 1 - handler.pykern_context.req_id = self.req_id - handler.pykern_context.api_name = None - handler.pykern_context.req_msg = None - try: - self._log(handler, "start") - self._authenticate(handler) - m, e = _unpack_msg(handler.request) - if e: - raise APIForbidden(*e) - handler.pykern_context.req_msg = m - self._log(handler, "call") - if not (a := self.api_map.get(m.api_name)): - raise APINotFound() - r = Reply(result=await _call(a, m.api_args)) - except pykern.quest.APIError as e: - self._log(handler, "api_error", e) - r = Reply(api_error=str(e)) - except Exception as e: - self._log(handler, "error", e) - r = Reply(exc=e) - self._send_reply(handler, r) + self._log(handler, "start") + # TODO(robnagler) special case for websockets. Need to switch + # to first message sent defines protocol and version. + if not self.auth_secret: + # _auth_secret can be false + return True + h = handler.request.headers + if e := _authenticate(h): + k = PKDict(status_code=403, reason="Forbidden") + elif e := _validate_version(h): + k = PKDict(status_code=412, reason="Precondition Failed") + else: + return True + handler.pykern_context.error = e + handler.send_error(**k) + return False except Exception as e: - self._log( - handler, - "reply_error", - e, - getattr(r, "content", None), - ) - raise + pkdlog("exception={} stack={}", e, pkdexc()) + self._log(handler, "error", "exception={}", [e]) + return False - def _authenticate(self, handler): - def _token(headers): - if not (h := headers.get(_AUTH_HEADER)): - return None - if m := _AUTH_HEADER_RE.search(h): - return m.group(1) + def handle_open(self, handler): + try: + self._ws_id += 1 + handler.pykern_context.ws_id = self._ws_id + return _ServerConnection(self, handler, ws_id=self._ws_id) + except Exception as e: + pkdlog("exception={} stack={}", e, pkdexc()) + self._log(handler, "error", "exception={}", [e]) return None - if handler.request.method != "POST": - raise APIForbidden() - if t := _token(handler.request.headers): - if t == self.auth_secret: - return - raise APIForbidden("token mismatch") - raise APIForbidden("no token") - - def _log(self, handler, which, exc=None, reply=None): + def _log(self, handler, which, fmt="", args=None): def _add(key, value): nonlocal f, a if value is not None: @@ -296,39 +358,157 @@ def _add(key, value): f = "" a = [] - _add("req_id", handler.pykern_context.get("req_id")) - _add("api", handler.pykern_context.get("api_name")) - if exc: - _add("exception", exc) - _add("req", handler.pykern_context.get("req_msg")) - if which != "api_error": - _add("reply", reply) - _add("stack", pkdexc()) + if x := getattr(handler, "pykern_context", None): + _add("error", x.pkdel("error")) + _add("ws_id", x.get("ws_id")) + if fmt: + f = f + " " + fmt + a.extend(args) self.loop.http_log(handler, which, f, a) def _log_end(self, handler): self._log(handler, "end") - def _send_reply(self, handler, reply): - if (c := reply.content) is None: - m = b"" - else: - m = _pack_msg(c) - handler.set_header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE) - handler.set_header(_VERSION_HEADER, _VERSION_HEADER_VALUE) - handler.set_header("Content-Length", str(len(m))) - handler.set_status(reply.http_status) - handler.write(m) + +class _ServerConnection: + + def __init__(self, server, handler, ws_id): + self.ws_id = ws_id + self.server = server + self.handler = handler + self._destroyed = False + self.remote_peer = server.loop.remote_peer(handler.request) + self._log("open") + + def destroy(self): + if self._destroyed: + return + self._destroyed = True + self.handler.close() + self.handler = None + + def handle_on_close(self): + if self._destroyed: + return + self.handler = None + self._log("on_close") + # TODO(robnagler) deal with open requests + + async def handle_on_message(self, msg): + def _api(call): + if n := call.get("api_name"): + if rv := self.server.api_map.get(n): + return rv + else: + n = "" + self._log("error", call, "api not found={}", [n]) + _reply(c, APINotFound(n)) + return None + + async def _call(call, api, api_args): + with pykern.quest.start(api.api_class, self.server.attr_classes) as qcall: + try: + return await getattr(qcall, api.api_func_name)(api_args) + except Exception as e: + pkdlog("exception={} call={} stack={}", call, e, pkdexc()) + return APICallError(e) + + def _reply(call, obj): + try: + if not isinstance(obj, Exception): + r = PKDict(api_result=obj, api_error=None) + elif isinstance(obj, pykern.quest.APIError): + r = PKDict(api_result=None, api_error=str(obj)) + else: + r = PKDict(api_result=None, api_error=f"unhandled_exception={obj}") + r.call_id = call.call_id + self.handler.write_message(_pack_msg(r), binary=True) + except Exception as e: + pkdlog("exception={} call={} stack={}", call, e, pkdexc()) + self.destroy() + + c = None + try: + c, e = _unpack_msg(msg) + if e: + self._log("error", None, "msg unpack error={}", [e]) + self.destroy() + return None + self._log("start", c) + if not (a := _api(c)): + return + r = await _call(c, a, c.api_args) + if self._destroyed: + return + _reply(c, r) + self._log("end", c) + c = None + except Exception as e: + pkdlog("exception={} call={} stack={}", e, c, pkdexc()) + _reply(c, e) + + def _log(self, which, call=None, fmt="", args=None): + if fmt: + fmt = " " + fmt + pkdlog( + "{} ip={} ws={}#{}" + fmt, + which, + self.remote_peer, + self.ws_id, + call and call.call_id, + *(args if args else ()), + ) + + +class _ServerHandler(tornado.websocket.WebSocketHandler): + def initialize(self, server): + self.pykern_server = server + self.pykern_context = PKDict() + self.pykern_connection = None + + async def get(self, *args, **kwargs): + # if not self.pykern_server.handle_get(self): + # return + return await super().get(*args, **kwargs) + + async def on_message(self, msg): + # WebSocketHandler only allows one on_message at a time. + asyncio.create_task(self.pykern_connection.handle_on_message(msg)) + + def on_close(self): + if self.pykern_connection: + self.pykern_connection.handle_on_close() + self.pykern_connection = None + + def open(self): + self.pykern_connection = self.pykern_server.handle_open(self) def _auth_secret(value): + """Validate config value or default + + Special case: `auth_secret` can be ``False``, which means no + authentication or version checking. Temporary change to deal with + JavaScript WebSocket which does not support sending headers. + + + In dev mode, defaults to something if not set. + + Returns: + bool: Valid value of auth_secret + + """ if value: if len(value) < 16: - raise AssertionError("secret too short len={len(value)} (<16)") + raise ValueError("auth_secret too short len={len(value)} (<16)") + if not _AUTH_SECRET_RE.search(value): + raise ValueError("auth_secret contains non-word chars") + return value + if isinstance(value, bool): return value if pykern.pkconfig.in_dev_mode(): return "default_dev_secret" - raise AssertionError("must supply http_config.auth_secret") + raise ValueError("must supply http_config.auth_secret") def _pack_msg(content): @@ -345,21 +525,22 @@ def _datetime(obj): return p.bytes() -def _unpack_msg(request): - def _header(name, value): - if not (v := request.headers.get(name)): - return ("missing header={}", name) - if v != value: - return ("unexpected {}={}", name, c) - return None - - if e := ( - _header(_VERSION_HEADER, _VERSION_HEADER_VALUE) - or _header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE) - ): - return None, e - u = msgpack.Unpacker( - object_pairs_hook=pykern.pkcollections.object_pairs_hook, - ) - u.feed(request.body) - return u.unpack(), None +def _unpack_msg(content): + try: + u = msgpack.Unpacker( + object_pairs_hook=pykern.pkcollections.object_pairs_hook, + ) + u.feed(content) + rv = u.unpack() + except Exception as e: + return None, f"msgpack exception={e}" + if not isinstance(rv, PKDict): + return None, f"msg not dict type={type(rv)}" + if "call_id" not in rv: + return None, "msg missing call_id keys={list(rv.keys())}" + i = rv.call_id + if not isinstance(i, int): + return None, f"msg call_id non-integer type={type(i)}" + if i <= 0: + return None, f"msg call_id non-positive call_id={i}" + return rv, None diff --git a/pykern/http_unit.py b/pykern/http_unit.py new file mode 100644 index 00000000..37ade3e0 --- /dev/null +++ b/pykern/http_unit.py @@ -0,0 +1,73 @@ +"""support for `pykern.http` tests + +:copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. +:license: http://www.apache.org/licenses/LICENSE-2.0.html +""" + +# Defer imports for unit tests + + +class Setup: + + def __init__(self, api_classes, attr_classes=(), coros=()): + import os, time + from pykern.pkcollections import PKDict + + def _global_config(): + c = PKDict( + PYKERN_PKDEBUG_WANT_PID_TIME="1", + ) + os.environ.update(**c) + from pykern import pkconfig + + pkconfig.reset_state_for_testing(c) + + def _http_config(): + from pykern import pkconst, pkunit + + return PKDict( + # any uri is fine + api_uri="/http_unit", + # just needs to be >= 16 word (required by http) chars; apps should generate this randomly + auth_secret="http_unit_auth_secret", + tcp_ip=pkconst.LOCALHOST_IP, + tcp_port=pkunit.unbound_localhost_tcp_port(), + ) + + def _server(): + from pykern import pkdebug, http + + if rv := os.fork(): + return rv + try: + pkdebug.pkdlog("start server") + http.server_start( + attr_classes=attr_classes, + api_classes=api_classes, + http_config=self.http_config.copy(), + coros=coros, + ) + except Exception as e: + pkdebug.pkdlog("server exception={} stack={}", e, pkdebug.pkdexc()) + finally: + os._exit(0) + + _global_config() + self.http_config = _http_config() + self.server_pid = _server() + time.sleep(1) + from pykern import http + + self.client = http.HTTPClient(self.http_config.copy()) + + def destroy(self): + import os, signal + + os.kill(self.server_pid, signal.SIGKILL) + + def __enter__(self): + return self.client + + def __exit__(self, *args, **kwargs): + self.client.destroy() + return False diff --git a/pykern/pkasyncio.py b/pykern/pkasyncio.py index c47a4021..3bfbf91a 100644 --- a/pykern/pkasyncio.py +++ b/pykern/pkasyncio.py @@ -65,21 +65,9 @@ async def _do(): self.run(_do()) def http_log(self, handler, which="end", fmt="", args=None): - def _remote_peer(request): - # https://github.com/tornadoweb/tornado/issues/2967#issuecomment-757370594 - # implementation may change; Code in tornado.httputil check connection. - if c := request.connection: - # socket is not set on stream for websockets. - if getattr(c, "stream", None) and ( - s := getattr(c.stream, "socket", None) - ): - return "{}:{}".format(*s.getpeername()) - i = request.headers.get("proxy-for", request.remote_ip) - return f"{i}:0" - r = handler.request f = "{} ip={} uri={}" - a = [which, _remote_peer(r), r.uri] + a = [which, self.remote_peer(r), r.uri] if fmt: f += " " + fmt a += args @@ -99,6 +87,16 @@ def _remote_peer(request): ] pkdlog(f, *a) + def remote_peer(self, request): + # https://github.com/tornadoweb/tornado/issues/2967#issuecomment-757370594 + # implementation may change; Code in tornado.httputil check connection. + if c := request.connection: + # socket is not set on stream for websockets. + if getattr(c, "stream", None) and (s := getattr(c.stream, "socket", None)): + return "{}:{}".format(*s.getpeername()) + i = request.headers.get("proxy-for", request.remote_ip) + return f"{i}:0" + def run(self, *coros): for c in coros: if not inspect.iscoroutine(c): diff --git a/pykern/pkunit.py b/pykern/pkunit.py index a40e3618..d424b392 100644 --- a/pykern/pkunit.py +++ b/pykern/pkunit.py @@ -279,12 +279,12 @@ def file_eq(expect_path, *args, **kwargs): _FileEq(expect_path, *args, **kwargs) -def unbound_localhost_tcp_port(start, stop): +def unbound_localhost_tcp_port(start=10000, stop=20000): """Looks for AF_INET SOCK_STREAM port for which bind succeeds Args: - start (int): first port - stop (int): one greater than last port (passed to range) + start (int): first port [10000] + stop (int): one greater than last port (passed to range) [20000] Returns: int: port is available or raises ValueError """ diff --git a/tests/http_test.py b/tests/http_test.py new file mode 100644 index 00000000..f63ad396 --- /dev/null +++ b/tests/http_test.py @@ -0,0 +1,30 @@ +"""test http server + +:copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. +:license: http://www.apache.org/licenses/LICENSE-2.0.html +""" + +import pytest + + +@pytest.mark.asyncio +async def test_basic(): + from pykern import http_unit + + with http_unit.Setup(api_classes=(_class(),)) as c: + from pykern.pkcollections import PKDict + from pykern import pkunit + + e = PKDict(a=1) + pkunit.pkeq(e, await c.call_api("echo", e)) + + +def _class(): + from pykern import quest + + class _API(quest.API): + + async def api_echo(self, api_args): + return api_args + + return _API diff --git a/tests/pkcli/projex2_data/xyzzy1.out/README.md b/tests/pkcli/projex2_data/xyzzy1.out/README.md index 5fd8c1c2..d1c88ed4 100644 --- a/tests/pkcli/projex2_data/xyzzy1.out/README.md +++ b/tests/pkcli/projex2_data/xyzzy1.out/README.md @@ -10,4 +10,4 @@ Documentation: https://xyzzy1.readthedocs.io License: https://www.apache.org/licenses/LICENSE-2.0.html -Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. +Copyright (c) 2025 RadiaSoft LLC. All Rights Reserved.