From 6626ab67939375f6e3233402f45726d166adfb54 Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Sun, 29 Dec 2024 21:57:15 +0000 Subject: [PATCH 1/7] http_test basic echo works --- pykern/http.py | 25 +++++++++++--- pykern/http_unit.py | 82 +++++++++++++++++++++++++++++++++++++++++++++ pykern/pkunit.py | 6 ++-- tests/http_test.py | 30 +++++++++++++++++ 4 files changed, 136 insertions(+), 7 deletions(-) create mode 100644 pykern/http_unit.py create mode 100644 tests/http_test.py diff --git a/pykern/http.py b/pykern/http.py index f1bdfa67..0100c985 100644 --- a/pykern/http.py +++ b/pykern/http.py @@ -1,4 +1,4 @@ -"""HTTP server +"""HTTP server & client :copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. :license: http://www.apache.org/licenses/LICENSE-2.0.html @@ -25,6 +25,9 @@ #: POSIT: Matches anything generated by `unique_key` _UNIQUE_KEY_CHARS_RE = r"\w+" +#: validates auth secret (only word chars) +_AUTH_SECRET_RE = re.compile(f"^{_UNIQUE_KEY_CHARS_RE}$") + #: Regex to test format of auth header and extract token _AUTH_HEADER_RE = re.compile( _AUTH_HEADER_SCHEME_BEARER + r"\s+(" + _UNIQUE_KEY_CHARS_RE + ")", @@ -132,8 +135,10 @@ class APINotFound(ReplyExc): class HTTPClient: """Wrapper for `tornado.httpclient.AsyncHTTPClient` + Maybe called as a context manager + Args: - http_config (PKDict): tcp_ip, tcp_port, api_uri, auth_secret, request_config + http_config (PKDict): tcp_ip, tcp_port, api_uri, auth_secret, request_config (deprecated) """ @@ -182,6 +187,16 @@ async def call_api(self, api_name, api_args): ) return rv.api_result + def destroy(self): + """Must be called""" + pass + + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + return False + class _HTTPRequestHandler(tornado.web.RequestHandler): def initialize(self, server): @@ -324,11 +339,13 @@ def _send_reply(self, handler, reply): def _auth_secret(value): if value: if len(value) < 16: - raise AssertionError("secret too short len={len(value)} (<16)") + raise ValueError("auth_secret too short len={len(value)} (<16)") + if not _AUTH_SECRET_RE.search(value): + raise ValueError("auth_secret contains non-word chars") return value if pykern.pkconfig.in_dev_mode(): return "default_dev_secret" - raise AssertionError("must supply http_config.auth_secret") + raise ValueError("must supply http_config.auth_secret") def _pack_msg(content): diff --git a/pykern/http_unit.py b/pykern/http_unit.py new file mode 100644 index 00000000..197c9765 --- /dev/null +++ b/pykern/http_unit.py @@ -0,0 +1,82 @@ +"""support for `pykern.http` tests + +:copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. +:license: http://www.apache.org/licenses/LICENSE-2.0.html +""" + +# Defer imports for unit tests + +# any uri is fine +_URI = "/http_unit" + +# just needs to be >= 16 word (required by http) chars; apps should generate this randomly +# +_AUTH_SECRET = "http_unit_auth_secret" + + +class Setup: + + def __init__(self, api_classes, attr_classes=(), coros=()): + import os, time + from pykern.pkcollections import PKDict + + def _client(http_config): + from pykern import pkdebug, http + + def _config(): + c = PKDict( + PYKERN_PKDEBUG_WANT_PID_TIME="1", + ) + os.environ.update(**c) + from pykern import pkconfig + + pkconfig.reset_state_for_testing(c) + + def _http_config(): + from pykern import pkconst, pkunit + + return PKDict( + api_uri=_URI, + auth_secret=_AUTH_SECRET, + tcp_ip=pkconst.LOCALHOST_IP, + tcp_port=pkunit.unbound_localhost_tcp_port(), + ) + + def _server(http_config): + from pykern import pkdebug, http + + p = os.fork() + if p != 0: + return p + try: + pkdebug.pkdlog("start server") + http.server_start( + attr_classes=attr_classes, + api_classes=api_classes, + http_config=http_config, + coros=coros, + ) + except Exception as e: + pkdebug.pkdlog("server exception={} stack={}", e, pkdebug.pkdexc()) + finally: + os._exit(0) + + _config() + h = _http_config() + self.server_pid = _server(h) + time.sleep(1) + from pykern import http + + self.client = http.HTTPClient(h) + + def destroy(self): + import os, signal + + os.kill(self.server_pid, signal.SIGKILL) + + def __enter__(self): + return self.client + + def __exit__(self, *args, **kwargs): + self.client.destroy() + return False diff --git a/pykern/pkunit.py b/pykern/pkunit.py index a40e3618..d424b392 100644 --- a/pykern/pkunit.py +++ b/pykern/pkunit.py @@ -279,12 +279,12 @@ def file_eq(expect_path, *args, **kwargs): _FileEq(expect_path, *args, **kwargs) -def unbound_localhost_tcp_port(start, stop): +def unbound_localhost_tcp_port(start=10000, stop=20000): """Looks for AF_INET SOCK_STREAM port for which bind succeeds Args: - start (int): first port - stop (int): one greater than last port (passed to range) + start (int): first port [10000] + stop (int): one greater than last port (passed to range) [20000] Returns: int: port is available or raises ValueError """ diff --git a/tests/http_test.py b/tests/http_test.py new file mode 100644 index 00000000..f63ad396 --- /dev/null +++ b/tests/http_test.py @@ -0,0 +1,30 @@ +"""test http server + +:copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. +:license: http://www.apache.org/licenses/LICENSE-2.0.html +""" + +import pytest + + +@pytest.mark.asyncio +async def test_basic(): + from pykern import http_unit + + with http_unit.Setup(api_classes=(_class(),)) as c: + from pykern.pkcollections import PKDict + from pykern import pkunit + + e = PKDict(a=1) + pkunit.pkeq(e, await c.call_api("echo", e)) + + +def _class(): + from pykern import quest + + class _API(quest.API): + + async def api_echo(self, api_args): + return api_args + + return _API From 658183547797fbeb6373e025f73e79022245e576 Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Mon, 30 Dec 2024 00:16:40 +0000 Subject: [PATCH 2/7] ckp --- pykern/http.py | 223 +++++++++++++++++++++++++++++++------------- pykern/http_unit.py | 35 +++---- pykern/pkasyncio.py | 24 +++-- 3 files changed, 184 insertions(+), 98 deletions(-) diff --git a/pykern/http.py b/pykern/http.py index 0100c985..023be115 100644 --- a/pykern/http.py +++ b/pykern/http.py @@ -14,6 +14,7 @@ import pykern.quest import re import tornado.web +import tornado.websocket #: Http auth header name @@ -154,6 +155,12 @@ def __init__(self, http_config): } ) self._request_config = http_config.get("request_config") or PKDict() + self._client = None + + async def connect(self): + if self._client: + + async def call_api(self, api_name, api_args): """Make a request to the API server @@ -198,18 +205,6 @@ def __exit__(self, *args, **kwargs): return False -class _HTTPRequestHandler(tornado.web.RequestHandler): - def initialize(self, server): - self.pykern_server = server - self.pykern_context = PKDict() - - async def get(self): - await self.pykern_server.dispatch(self) - - async def post(self): - await self.pykern_server.dispatch(self) - - class _HTTPServer: def __init__(self, loop, api_classes, attr_classes, http_config): @@ -241,13 +236,102 @@ def _api_map(): self.api_map = _api_map() self.attr_classes = attr_classes self.auth_secret = _auth_secret(h.pkdel("auth_secret")) - h.uri_map = ((h.api_uri, _HTTPRequestHandler, PKDict(server=self)),) + h.uri_map = ((h.api_uri, _WebSocketHandler, PKDict(server=self)),) self.api_uri = h.pkdel("api_uri") h.log_function = self._log_end - self.req_id = 0 + self._ws_id = 0 loop.http_server(h) - async def dispatch(self, handler): + def handle_get(self, handler): + def _authenticate(self, headers): + if not (h := headers.get(_AUTH_HEADER)): + return "no auth token" + if not (m := _AUTH_HEADER_RE.search(h)): + return "auth token format invalid" + if m.group(1) != self.auth_secret: + return "auth token mismatch" + return None + + def _validate_version(self, headers): + if not (v := headers.get(_VERSION_HEADER)): + return f"missing {_VERSION_HEADER} header" + if v != _VERSION_HEADER_VALUE: + return f"invalid version {v}" + return None + + self._log(handler, "start") + h = handler.request.headers + if e := self._authenticate(h): + k = PKDict(status_code=403, reason="Forbidden") + elif e := self._validate_version(h): + k = PKDict(status_code=412, reason="Precondition Failed") + else: + return True + handler.pykern_context.error = e + self.send_error(**k) + return False + + def handle_open(self, handler): + self._ws_id += 1 + self.handler.pykern_context.ws_id = self._ws_id + return _WebSocketConnection(self, handler, ws_id=self._ws_id) + + + def _log(self, obj, which, exc=None, reply=None): + def _add(key, value): + nonlocal f, a + if value is not None: + f += (" " if f else "") + key + "={}" + a.append(value) + + f = "" + a = [] + _add("error", obj.pykern_context.get("error")) + _add("ws_id", obj.pykern_context.get("ws_id")) + self.loop.http_log(obj, which, f, a) + + def _log_end(self, handler): + self._log(handler, "end") + + +class _WebSocketHandler(tornado.websocket.WebSocketHandler): + def initialize(self, server): + self.pykern_server = server + self.pykern_context = PKDict() + self.pykern_connection = None + + async def get(self, *args, **kwargs): + if not self.pykern_server.handle_get(self): + return + return await super().get(*args, **kwargs) + + async def on_message(self, msg): + # WebSocketHandler only allows one on_message at a time. + asyncio.create_task(self.pykern_connection.handle_on_message, msg) + + def on_close(self): + if self.pykern_connection: + self.pykern_connection.handle_on_close() + self.pykern_connection = None + + def open(self): + self.pykern_connection = self.pykern_server.handle_open(self) + + +class _WebSocketConnection: + + def __init__(self, server, handler, ws_id): + self.ws_id = ws_id + self.handler = handler + self.remote_peer = server.loop.remote_peer(handler) + self._log(None, "open") + + def handle_on_close(self): + self.handler = None + self._log(None, "on_close") + #TODO(robnagler) deal with open requests + + async def handle_on_message(self, msg): async def _call(api, api_args): with pykern.quest.start(api.api_class, self.attr_classes) as qcall: return await getattr(qcall, api.api_func_name)(api_args) @@ -256,12 +340,8 @@ async def _call(api, api_args): r = None try: self.req_id += 1 - handler.pykern_context.req_id = self.req_id - handler.pykern_context.api_name = None - handler.pykern_context.req_msg = None try: self._log(handler, "start") - self._authenticate(handler) m, e = _unpack_msg(handler.request) if e: raise APIForbidden(*e) @@ -286,54 +366,71 @@ async def _call(api, api_args): ) raise - def _authenticate(self, handler): - def _token(headers): - if not (h := headers.get(_AUTH_HEADER)): - return None - if m := _AUTH_HEADER_RE.search(h): - return m.group(1) - return None - - if handler.request.method != "POST": - raise APIForbidden() - if t := _token(handler.request.headers): - if t == self.auth_secret: - return - raise APIForbidden("token mismatch") - raise APIForbidden("no token") + def _log(self, ws_req, which, fmt="", args=None): + pkdlog( + "{} ip={} ws={}#{}" + fmt, + which, + self.remote_peer, + self.ws_id, + ws_req and ws_req.header.get("reqSeq") or 0, + *args, + ) - def _log(self, handler, which, exc=None, reply=None): - def _add(key, value): - nonlocal f, a - if value is not None: - f += (" " if f else "") + key + "={}" - a.append(value) - f = "" - a = [] - _add("req_id", handler.pykern_context.get("req_id")) - _add("api", handler.pykern_context.get("api_name")) - if exc: - _add("exception", exc) - _add("req", handler.pykern_context.get("req_msg")) - if which != "api_error": - _add("reply", reply) - _add("stack", pkdexc()) - self.loop.http_log(handler, which, f, a) +class _WebSocketMessage(): + def parse_msg(self, msg): + def _maybe_srunit_caller(): + if pkconfig.in_dev_mode() and (c := self.header.get("srunit_caller")): + return pkdformat(" srunit={}", c) + return "" - def _log_end(self, handler): - self._log(handler, "end") + if not isinstance(msg, bytes): + raise AssertionError(f"incoming msg type={type(msg)}") + u = msgpack.Unpacker( + max_buffer_size=sirepo.job.cfg().max_message_bytes, + object_pairs_hook=pkcollections.object_pairs_hook, + ) + u.feed(msg) + self.header = u.unpack() + self.handler.sr_log( + self, + "start", + fmt=" uri={}{}", + args=[self.header.get("uri"), _maybe_srunit_caller()], + ) + if sirepo.const.SCHEMA_COMMON.websocketMsg.version != self.header.get( + "version" + ): + raise AssertionError( + pkdformat("invalid header.version={}", self.header.get("version")) + ) + # Ensures protocol conforms for all requests + if ( + sirepo.const.SCHEMA_COMMON.websocketMsg.kind.httpRequest + != self.header.get("kind") + ): + raise AssertionError( + pkdformat("invalid header.kind={}", self.header.get("kind")) + ) + self.req_seq = self.header.reqSeq + self.uri = self.header.uri + if u.tell() < len(msg): + self.body_as_dict = u.unpack() + if u.tell() < len(msg): + self.attachment = u.unpack() + # content may or may not exist so defer checking + e, self.route, self.kwargs = _path_to_route(self.uri[1:]) + if e: + self.handler.sr_log( + self, + "error", + fmt=" msg={} route={} kwargs={}", + args=[e, self.route, self.kwargs], + ) + self.route = _not_found_route - def _send_reply(self, handler, reply): - if (c := reply.content) is None: - m = b"" - else: - m = _pack_msg(c) - handler.set_header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE) - handler.set_header(_VERSION_HEADER, _VERSION_HEADER_VALUE) - handler.set_header("Content-Length", str(len(m))) - handler.set_status(reply.http_status) - handler.write(m) + def set_log_user(self, log_user): + self.log_user = log_user def _auth_secret(value): diff --git a/pykern/http_unit.py b/pykern/http_unit.py index 197c9765..37ade3e0 100644 --- a/pykern/http_unit.py +++ b/pykern/http_unit.py @@ -6,13 +6,6 @@ # Defer imports for unit tests -# any uri is fine -_URI = "/http_unit" - -# just needs to be >= 16 word (required by http) chars; apps should generate this randomly -# -_AUTH_SECRET = "http_unit_auth_secret" - class Setup: @@ -20,10 +13,7 @@ def __init__(self, api_classes, attr_classes=(), coros=()): import os, time from pykern.pkcollections import PKDict - def _client(http_config): - from pykern import pkdebug, http - - def _config(): + def _global_config(): c = PKDict( PYKERN_PKDEBUG_WANT_PID_TIME="1", ) @@ -36,24 +26,25 @@ def _http_config(): from pykern import pkconst, pkunit return PKDict( - api_uri=_URI, - auth_secret=_AUTH_SECRET, + # any uri is fine + api_uri="/http_unit", + # just needs to be >= 16 word (required by http) chars; apps should generate this randomly + auth_secret="http_unit_auth_secret", tcp_ip=pkconst.LOCALHOST_IP, tcp_port=pkunit.unbound_localhost_tcp_port(), ) - def _server(http_config): + def _server(): from pykern import pkdebug, http - p = os.fork() - if p != 0: - return p + if rv := os.fork(): + return rv try: pkdebug.pkdlog("start server") http.server_start( attr_classes=attr_classes, api_classes=api_classes, - http_config=http_config, + http_config=self.http_config.copy(), coros=coros, ) except Exception as e: @@ -61,13 +52,13 @@ def _server(http_config): finally: os._exit(0) - _config() - h = _http_config() - self.server_pid = _server(h) + _global_config() + self.http_config = _http_config() + self.server_pid = _server() time.sleep(1) from pykern import http - self.client = http.HTTPClient(h) + self.client = http.HTTPClient(self.http_config.copy()) def destroy(self): import os, signal diff --git a/pykern/pkasyncio.py b/pykern/pkasyncio.py index c47a4021..3bfbf91a 100644 --- a/pykern/pkasyncio.py +++ b/pykern/pkasyncio.py @@ -65,21 +65,9 @@ async def _do(): self.run(_do()) def http_log(self, handler, which="end", fmt="", args=None): - def _remote_peer(request): - # https://github.com/tornadoweb/tornado/issues/2967#issuecomment-757370594 - # implementation may change; Code in tornado.httputil check connection. - if c := request.connection: - # socket is not set on stream for websockets. - if getattr(c, "stream", None) and ( - s := getattr(c.stream, "socket", None) - ): - return "{}:{}".format(*s.getpeername()) - i = request.headers.get("proxy-for", request.remote_ip) - return f"{i}:0" - r = handler.request f = "{} ip={} uri={}" - a = [which, _remote_peer(r), r.uri] + a = [which, self.remote_peer(r), r.uri] if fmt: f += " " + fmt a += args @@ -99,6 +87,16 @@ def _remote_peer(request): ] pkdlog(f, *a) + def remote_peer(self, request): + # https://github.com/tornadoweb/tornado/issues/2967#issuecomment-757370594 + # implementation may change; Code in tornado.httputil check connection. + if c := request.connection: + # socket is not set on stream for websockets. + if getattr(c, "stream", None) and (s := getattr(c.stream, "socket", None)): + return "{}:{}".format(*s.getpeername()) + i = request.headers.get("proxy-for", request.remote_ip) + return f"{i}:0" + def run(self, *coros): for c in coros: if not inspect.iscoroutine(c): From 88f8a097b0bd0c93004cf94256dcc4da86002094 Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Mon, 30 Dec 2024 18:47:27 +0000 Subject: [PATCH 3/7] ckp --- pykern/http.py | 249 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 162 insertions(+), 87 deletions(-) diff --git a/pykern/http.py b/pykern/http.py index 023be115..92e2c923 100644 --- a/pykern/http.py +++ b/pykern/http.py @@ -13,6 +13,7 @@ import pykern.pkconfig import pykern.quest import re +import tornado.httpclient import tornado.web import tornado.websocket @@ -35,9 +36,6 @@ re.IGNORECASE, ) -_CONTENT_TYPE_HEADER = "Content-Type" -_CONTENT_TYPE = "application/msgpack" - _VERSION_HEADER = "X-PyKern-HTTP-Version" _VERSION_HEADER_VALUE = "1" @@ -93,7 +91,7 @@ class ReplyExc(Exception): """Raised to end the request. Args: - pk_args (dict): exception args that specific to this module + pk_args (dict): exception args that are specific to this module log_fmt (str): server side log data """ @@ -132,41 +130,58 @@ class APINotFound(ReplyExc): pass +class APIDisconnected(ReplyExc): + """Raised when remote server closed or other error""" + + pass + class HTTPClient: """Wrapper for `tornado.httpclient.AsyncHTTPClient` - Maybe called as a context manager + Maybe called as an async context manager + + `http_config.request_config` is deprecated. Args: - http_config (PKDict): tcp_ip, tcp_port, api_uri, auth_secret, request_config (deprecated) + http_config (PKDict): tcp_ip, tcp_port, api_uri, auth_secret """ def __init__(self, http_config): - self._uri = ( + self.uri = ( f"http://{http_config.tcp_ip}:{http_config.tcp_port}{http_config.api_uri}" ) - self._headers = PKDict( - { - _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {_auth_secret(http_config.auth_secret)}", - _CONTENT_TYPE_HEADER: _CONTENT_TYPE, - _VERSION_HEADER: _VERSION_HEADER_VALUE, - } - ) - self._request_config = http_config.get("request_config") or PKDict() - self._client = None + self.auth_secret = _auth_secret(http_config.auth_secret) + self._connection = None + self._destroyed = False + self._call_id = 0 + self._reader = None + self._pending_calls = PKDict() async def connect(self): - if self._client: - - + if self._destroyed: + raise AssertionError("destroyed") + if self._connection: + raise AssertionError("already connected") + self._connection = await tornado.websocket.websocket_connect( + tornado.httpclient.HTTPRequest( + self.uri, + { + _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {self.auth_secret}", + _VERSION_HEADER: _VERSION_HEADER_VALUE, + }, + ) + #TODO(robnagler) accept in http_config. share defaults with sirepo.job. + max_message_size=int(2e8), + ping_interval=120, + ping_timeout=240, + ) + self._reader = asyncio.create_task(self._read_loop()) async def call_api(self, api_name, api_args): """Make a request to the API server - `http_config.request_config` (see `__init__` and if it exists) is passed verbatim to `AsyncHTTPClient.fetch`. - Args: api_name (str): what to call on the server api_args (PKDict): passed verbatim to the API on the server. @@ -176,33 +191,88 @@ async def call_api(self, api_name, api_args): APIError: if there was an raise in the API or on a server protocol violation Exception: other exceptions that `AsyncHTTPClient.fetch` may raise, e.g. NotFound """ - # Need to be careful with the lifecycle of AsyncHTTPClient - # https://github.com/radiasoft/pykern/issues/529 - r = await tornado.httpclient.AsyncHTTPClient(force_instance=True).fetch( - self._uri, - body=_pack_msg(PKDict(api_name=api_name, api_args=api_args)), - headers=self._headers, - method="POST", - **self._request_config, - ) - rv, e = _unpack_msg(r) + def _send(): + self._call_id += 1 + rv = _HTTPClientCall(api_name, api_args, self._call_id) + self._pending_reqs[r.call_id] = rv + self._connection.write_message(rv.msg()) + return rv + + # TODO(robnagler) backwards compatibility + if not self._connection: + # Does destroy check + await self.connect() + c = _send() + return await c.get_reply() + + def destroy(self): + """Must be called""" + if self._destroyed: + return + self._destroyed = True + if self._connection: + self._connection.close() + self._connection = None + x = self._pending_apis + self._pending_apis = None + for r in x.values(): + #TODO(robnagler) refine set error class + r.reply_q.put_nowait(None) + + async def __aenter__(self): + await self.connect() + return self + + async def __aexit__(self, *args, **kwargs): + self.destroy() + return False + + def _read_loop(self): + while m := await self._connection.read_message(): + + if not self._destroyed: + # TODO(robnagler) + self.destroy() + + + +class _HTTPClientReply(PKDict): + api_error_class + api_error_args + def reply(self, msg): + + if msg is None: + create error reply (closed connection) + i, c = _unpack_msg(r) if e: raise pykern.quest.APIError(*e) if rv.api_error: raise pykern.quest.APIError( "api_error={} api_name={} api_args={}", rv.api_error, api_name, api_args ) - return rv.api_result - def destroy(self): - """Must be called""" - pass - def __enter__(self): - return self +class _HTTPClientCall(PKDict): + def __init__(self, **kwargs): + super().__init__(**kwargs) + #TODO(robnagler) should be one for regular replies + self.reply_q = tornado.queues.Queue() + def reply_put(self, msg): + if msg is None: + create error reply (closed connection) - def __exit__(self, *args, **kwargs): - return False + async def reply_get() + reply_q.get() + c.reply_q.task_done() + c.destroy() + + if not r: + #TODO(robnagler) refine could be destroyed for other reasons + # reply_q will be destroyed + raise APIDisconnected() + if r.error_class: + raise r.error_class(r.error_args) + return r.result class _HTTPServer: @@ -307,7 +377,7 @@ async def get(self, *args, **kwargs): async def on_message(self, msg): # WebSocketHandler only allows one on_message at a time. - asyncio.create_task(self.pykern_connection.handle_on_message, msg) + asyncio.create_task(self.pykern_connection.handle_on_message(msg)) def on_close(self): if self.pykern_connection: @@ -322,57 +392,61 @@ class _WebSocketConnection: def __init__(self, server, handler, ws_id): self.ws_id = ws_id + self.server = server self.handler = handler + self._destroyed = False self.remote_peer = server.loop.remote_peer(handler) self._log(None, "open") + def destroy(self): + if self._destroyed: + return + self._destroyed = True + self.handler.close() + self.handler = None + def handle_on_close(self): + if self._destroyed: + return self.handler = None self._log(None, "on_close") #TODO(robnagler) deal with open requests async def handle_on_message(self, msg): async def _call(api, api_args): - with pykern.quest.start(api.api_class, self.attr_classes) as qcall: + with pykern.quest.start(api.api_class, self.server.attr_classes) as qcall: return await getattr(qcall, api.api_func_name)(api_args) - m = None - r = None + def _reply(call, obj): + if isinstance(obj, Reply): + pass + if isinstance(obj, sirepo.quest.APIError): + pass + if isinstance(obj, Exception): + pass + + c = None try: - self.req_id += 1 - try: - self._log(handler, "start") - m, e = _unpack_msg(handler.request) - if e: - raise APIForbidden(*e) - handler.pykern_context.req_msg = m - self._log(handler, "call") - if not (a := self.api_map.get(m.api_name)): - raise APINotFound() - r = Reply(result=await _call(a, m.api_args)) - except pykern.quest.APIError as e: - self._log(handler, "api_error", e) - r = Reply(api_error=str(e)) - except Exception as e: - self._log(handler, "error", e) - r = Reply(exc=e) - self._send_reply(handler, r) + c, e = _unpack_msg(msg) + if e: + self._log(None, "error", "msg unpack error={}", [e]) + self.destroy() + return None + self._log("start", c) + if not (a := self.server.api_map.get(m.api_name)): + self._log(c, "error", "api not found={}" m.api_name) + _reply(c, APINotFound) + _reply(c, await _call(a, m.api_args)) except Exception as e: - self._log( - handler, - "reply_error", - e, - getattr(r, "content", None), - ) - raise + _reply(c, e) - def _log(self, ws_req, which, fmt="", args=None): + def _log(self, call, which, fmt="", args=None): pkdlog( "{} ip={} ws={}#{}" + fmt, which, self.remote_peer, self.ws_id, - ws_req and ws_req.header.get("reqSeq") or 0, + call and call.call_id, *args, ) @@ -460,20 +534,21 @@ def _datetime(obj): def _unpack_msg(request): - def _header(name, value): - if not (v := request.headers.get(name)): - return ("missing header={}", name) - if v != value: - return ("unexpected {}={}", name, c) - return None - - if e := ( - _header(_VERSION_HEADER, _VERSION_HEADER_VALUE) - or _header(_CONTENT_TYPE_HEADER, _CONTENT_TYPE) - ): - return None, e - u = msgpack.Unpacker( - object_pairs_hook=pykern.pkcollections.object_pairs_hook, - ) - u.feed(request.body) - return u.unpack(), None + try: + u = msgpack.Unpacker( + object_pairs_hook=pykern.pkcollections.object_pairs_hook, + ) + u.feed(request.body) + rv = u.unpack() + except Exception as e: + return None, f"msg unpack exception={e}" + if not isinstance(rv, PKDict): + return None, f"msg not dict type={type(rv)}" + if "call_id" not in rv: + return None, "msg missing call_id keys={list(rv.keys())}" + i = rv.call_id + if not isinstance(i, int): + return None, f"msg call_id non-integer type={type(i)}" + if i <= 0: + return None, f"msg call_id non-positive call_id={i}" + return rv, None From 978142a9926330a4b64efb5490b499133aea26e5 Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Tue, 31 Dec 2024 01:34:04 +0000 Subject: [PATCH 4/7] http_test passes --- pykern/http.py | 475 +++++++++++++++++++++++-------------------------- 1 file changed, 219 insertions(+), 256 deletions(-) diff --git a/pykern/http.py b/pykern/http.py index 92e2c923..c1cb3b03 100644 --- a/pykern/http.py +++ b/pykern/http.py @@ -6,6 +6,7 @@ from pykern.pkcollections import PKDict from pykern.pkdebug import pkdc, pkdlog, pkdp, pkdexc, pkdformat +import asyncio import inspect import msgpack import pykern.pkasyncio @@ -59,81 +60,32 @@ def server_start(api_classes, attr_classes, http_config, coros=()): l.start() -class Reply: - - def __init__(self, result=None, exc=None, api_error=None): - def _exception(exc): - if exc is None: - pkdlog("ERROR: no reply and no exception") - return 500 - if isinstance(exc, APINotFound): - return 404 - if isinstance(exc, APIForbidden): - return 403 - pkdlog("untranslated exception={}", exc) - return 500 - - if isinstance(result, Reply): - self.http_status = result.http_status - self.content = result.content - elif result is not None or api_error is not None: - self.http_status = 200 - self.content = PKDict( - api_error=api_error, - api_result=result, - ) - else: - self.http_status = _exception(exc) - self.content = None - +class APICallError(pykern.quest.APIError): + """Raised for an object not found""" -class ReplyExc(Exception): - """Raised to end the request. + def __init__(self, exception): + super().__init__("exception={}", exception) - Args: - pk_args (dict): exception args that are specific to this module - log_fmt (str): server side log data - """ - def __init__(self, *args, **kwargs): - super().__init__() - if "pk_args" in kwargs: - self.pk_args = kwargs["pk_args"] - del kwargs["pk_args"] - else: - self.pk_args = PKDict() - if args or kwargs: - kwargs["pkdebug_frame"] = inspect.currentframe().f_back.f_back - pkdlog(*args, **kwargs) - - def __repr__(self): - a = self.pk_args - return "{}({})".format( - self.__class__.__name__, - ",".join( - ("{}={}".format(k, a[k]) for k in sorted(a.keys())), - ), - ) +class APIDisconnected(pykern.quest.APIError): + """Raised when remote server closed or other error""" - def __str__(self): - return self.__repr__() + def __init__(self): + super().__init__("") -class APIForbidden(ReplyExc): +class APIForbidden(pykern.quest.APIError): """Raised for forbidden or protocol error""" - pass + def __init__(self): + super().__init__("") -class APINotFound(ReplyExc): +class APINotFound(pykern.quest.APIError): """Raised for an object not found""" - pass - -class APIDisconnected(ReplyExc): - """Raised when remote server closed or other error""" - - pass + def __init__(self, api_name): + super().__init__("api_name={}", api_name) class HTTPClient: @@ -145,40 +97,19 @@ class HTTPClient: Args: http_config (PKDict): tcp_ip, tcp_port, api_uri, auth_secret - """ def __init__(self, http_config): + # TODO(robnagler) tls with verification(?) self.uri = ( - f"http://{http_config.tcp_ip}:{http_config.tcp_port}{http_config.api_uri}" + f"ws://{http_config.tcp_ip}:{http_config.tcp_port}{http_config.api_uri}" ) self.auth_secret = _auth_secret(http_config.auth_secret) self._connection = None self._destroyed = False self._call_id = 0 - self._reader = None self._pending_calls = PKDict() - async def connect(self): - if self._destroyed: - raise AssertionError("destroyed") - if self._connection: - raise AssertionError("already connected") - self._connection = await tornado.websocket.websocket_connect( - tornado.httpclient.HTTPRequest( - self.uri, - { - _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {self.auth_secret}", - _VERSION_HEADER: _VERSION_HEADER_VALUE, - }, - ) - #TODO(robnagler) accept in http_config. share defaults with sirepo.job. - max_message_size=int(2e8), - ping_interval=120, - ping_timeout=240, - ) - self._reader = asyncio.create_task(self._read_loop()) - async def call_api(self, api_name, api_args): """Make a request to the API server @@ -191,19 +122,43 @@ async def call_api(self, api_name, api_args): APIError: if there was an raise in the API or on a server protocol violation Exception: other exceptions that `AsyncHTTPClient.fetch` may raise, e.g. NotFound """ + def _send(): self._call_id += 1 - rv = _HTTPClientCall(api_name, api_args, self._call_id) - self._pending_reqs[r.call_id] = rv - self._connection.write_message(rv.msg()) + c = PKDict(api_name=api_name, api_args=api_args, call_id=self._call_id) + rv = _ClientCall(c) + self._pending_calls[rv.call_id] = rv + self._connection.write_message(_pack_msg(c), binary=True) return rv # TODO(robnagler) backwards compatibility if not self._connection: - # Does destroy check + # will check destroyed await self.connect() - c = _send() - return await c.get_reply() + if self._destroyed: + return + return await _send().reply_get() + + async def connect(self): + if self._destroyed: + raise AssertionError("destroyed") + if self._connection: + raise AssertionError("already connected") + self._connection = await tornado.websocket.websocket_connect( + tornado.httpclient.HTTPRequest( + self.uri, + headers={ + _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {self.auth_secret}", + _VERSION_HEADER: _VERSION_HEADER_VALUE, + }, + method="GET", + ), + # TODO(robnagler) accept in http_config. share defaults with sirepo.job. + max_message_size=int(2e8), + ping_interval=120, + ping_timeout=240, + ) + asyncio.create_task(self._read_loop()) def destroy(self): """Must be called""" @@ -213,66 +168,94 @@ def destroy(self): if self._connection: self._connection.close() self._connection = None - x = self._pending_apis - self._pending_apis = None - for r in x.values(): - #TODO(robnagler) refine set error class - r.reply_q.put_nowait(None) + x = self._pending_calls + self._pending_calls = None + for c in x.values(): + c.reply_q.put_nowait(None) async def __aenter__(self): await self.connect() + if self._destroyed: + raise APIDisconnected() return self async def __aexit__(self, *args, **kwargs): self.destroy() return False - def _read_loop(self): - while m := await self._connection.read_message(): - - if not self._destroyed: - # TODO(robnagler) - self.destroy() + async def _read_loop(self): + def _unpack(msg): + r, e = _unpack_msg(msg) + if e: + pkdlog("unpack msg error={} {}", e, self) + return None + return r + m = r = None + try: + if self._destroyed: + return + while m := await self._connection.read_message(): + if self._destroyed: + return + if not (r := _unpack(m)): + break + # Remove from pending + if not (c := self._pending_calls.pkdel(r.call_id)): + pkdlog("call_id not found reply={} {}", r, self) + # TODO(robnagler) possibly too harsh, but safer for now + break + c.reply_q.put_nowait(r) + m = r = None + except Exception as e: + pkdlog("exception={} reply={} stack={}", e, r, pkdexc()) + try: + if not self._destroyed: + self.destroy() + except Exception as e: + pkdlog("exception={} stack={}", e, pkdexc()) + def __repr__(self): + def _calls(): + return ", ".join( + ( + f"{v.api_name}#{v.call_id}" + for v in sorted( + self._pending_calls.values(), key=lambda x: x.call_id + ) + ), + ) -class _HTTPClientReply(PKDict): - api_error_class - api_error_args - def reply(self, msg): + def _destroyed(): + return "DESTROYED, " if self._destroyed else "" - if msg is None: - create error reply (closed connection) - i, c = _unpack_msg(r) - if e: - raise pykern.quest.APIError(*e) - if rv.api_error: - raise pykern.quest.APIError( - "api_error={} api_name={} api_args={}", rv.api_error, api_name, api_args - ) + return f"{self.__class__.__name__}({_destroyed()}call_id={self._call_id}, calls=[{_calls()}])" -class _HTTPClientCall(PKDict): - def __init__(self, **kwargs): - super().__init__(**kwargs) - #TODO(robnagler) should be one for regular replies +class _ClientCall(PKDict): + def __init__(self, call_msg): + super().__init__(**call_msg) + # TODO(robnagler) should be one for regular replies self.reply_q = tornado.queues.Queue() - def reply_put(self, msg): - if msg is None: - create error reply (closed connection) - - async def reply_get() - reply_q.get() - c.reply_q.task_done() - c.destroy() - - if not r: - #TODO(robnagler) refine could be destroyed for other reasons - # reply_q will be destroyed + self._destroyed = False + + async def reply_get(self): + rv = await self.reply_q.get() + if self._destroyed: raise APIDisconnected() - if r.error_class: - raise r.error_class(r.error_args) - return r.result + self.reply_q.task_done() + self.destroy() + if rv is None: + raise APIDisconnected() + if rv.api_error: + raise pykern.quest.APIError(rv.api_error) + return rv.result + + def destroy(self): + if self._destroyed: + return + self._destroyed = True + self.reply_q = None class _HTTPServer: @@ -306,14 +289,14 @@ def _api_map(): self.api_map = _api_map() self.attr_classes = attr_classes self.auth_secret = _auth_secret(h.pkdel("auth_secret")) - h.uri_map = ((h.api_uri, _WebSocketHandler, PKDict(server=self)),) + h.uri_map = ((h.api_uri, _ServerHandler, PKDict(server=self)),) self.api_uri = h.pkdel("api_uri") h.log_function = self._log_end self._ws_id = 0 loop.http_server(h) def handle_get(self, handler): - def _authenticate(self, headers): + def _authenticate(headers): if not (h := headers.get(_AUTH_HEADER)): return "no auth token" if not (m := _AUTH_HEADER_RE.search(h)): @@ -322,32 +305,41 @@ def _authenticate(self, headers): return "auth token mismatch" return None - def _validate_version(self, headers): + def _validate_version(headers): if not (v := headers.get(_VERSION_HEADER)): return f"missing {_VERSION_HEADER} header" if v != _VERSION_HEADER_VALUE: return f"invalid version {v}" return None - self._log(handler, "start") - h = handler.request.headers - if e := self._authenticate(h): - k = PKDict(status_code=403, reason="Forbidden") - elif e := self._validate_version(h): - k = PKDict(status_code=412, reason="Precondition Failed") - else: - return True - handler.pykern_context.error = e - self.send_error(**k) - return False + try: + self._log(handler, "start") + h = handler.request.headers + if e := _authenticate(h): + k = PKDict(status_code=403, reason="Forbidden") + elif e := _validate_version(h): + k = PKDict(status_code=412, reason="Precondition Failed") + else: + return True + handler.pykern_context.error = e + self.send_error(**k) + return False + except Exception as e: + pkdlog("exception={} stack={}", e, pkdexc()) + self._log(handler, "error", "exception={}", [e]) + return False def handle_open(self, handler): - self._ws_id += 1 - self.handler.pykern_context.ws_id = self._ws_id - return _WebSocketConnection(self, handler, ws_id=self._ws_id) - + try: + self._ws_id += 1 + handler.pykern_context.ws_id = self._ws_id + return _ServerConnection(self, handler, ws_id=self._ws_id) + except Exception as e: + pkdlog("exception={} stack={}", e, pkdexc()) + self._log(handler, "error", "exception={}", [e]) + return None - def _log(self, obj, which, exc=None, reply=None): + def _log(self, handler, which): def _add(key, value): nonlocal f, a if value is not None: @@ -356,47 +348,23 @@ def _add(key, value): f = "" a = [] - _add("error", obj.pykern_context.get("error")) - _add("ws_id", obj.pykern_context.get("ws_id")) - self.loop.http_log(obj, which, f, a) + _add("error", handler.pykern_context.pkdel("error")) + _add("ws_id", handler.pykern_context.get("ws_id")) + self.loop.http_log(handler, which, f, a) def _log_end(self, handler): self._log(handler, "end") -class _WebSocketHandler(tornado.websocket.WebSocketHandler): - def initialize(self, server): - self.pykern_server = server - self.pykern_context = PKDict() - self.pykern_connection = None - - async def get(self, *args, **kwargs): - if not self.pykern_server.handle_get(self): - return - return await super().get(*args, **kwargs) - - async def on_message(self, msg): - # WebSocketHandler only allows one on_message at a time. - asyncio.create_task(self.pykern_connection.handle_on_message(msg)) - - def on_close(self): - if self.pykern_connection: - self.pykern_connection.handle_on_close() - self.pykern_connection = None - - def open(self): - self.pykern_connection = self.pykern_server.handle_open(self) - - -class _WebSocketConnection: +class _ServerConnection: def __init__(self, server, handler, ws_id): self.ws_id = ws_id self.server = server self.handler = handler self._destroyed = False - self.remote_peer = server.loop.remote_peer(handler) - self._log(None, "open") + self.remote_peer = server.loop.remote_peer(handler.request) + self._log("open") def destroy(self): if self._destroyed: @@ -409,102 +377,97 @@ def handle_on_close(self): if self._destroyed: return self.handler = None - self._log(None, "on_close") - #TODO(robnagler) deal with open requests + self._log("on_close") + # TODO(robnagler) deal with open requests async def handle_on_message(self, msg): - async def _call(api, api_args): + def _api(call): + if n := call.get("api_name"): + if rv := self.server.api_map.get(n): + return rv + else: + n = "" + self._log("error", call, "api not found={}", [n]) + _reply(c, APINotFound(n)) + return None + + async def _call(call, api, api_args): with pykern.quest.start(api.api_class, self.server.attr_classes) as qcall: - return await getattr(qcall, api.api_func_name)(api_args) + try: + return await getattr(qcall, api.api_func_name)(api_args) + except Exception as e: + pkdlog("exception={} call={} stack={}", call, e, pkdexc()) + return APICallError(e) def _reply(call, obj): - if isinstance(obj, Reply): - pass - if isinstance(obj, sirepo.quest.APIError): - pass - if isinstance(obj, Exception): - pass + try: + if not isinstance(obj, Exception): + r = PKDict(result=obj, api_error=None) + elif isinstance(obj, pykern.quest.APIError): + r = PKDict(result=None, api_error=str(obj)) + else: + r = PKDict(result=None, api_error=f"unhandled_exception={obj}") + r.call_id = call.call_id + self.handler.write_message(_pack_msg(r), binary=True) + except Exception as e: + pkdlog("exception={} call={} stack={}", call, e, pkdexc()) + self.destroy() c = None try: c, e = _unpack_msg(msg) if e: - self._log(None, "error", "msg unpack error={}", [e]) + self._log("error", None, "msg unpack error={}", [e]) self.destroy() return None self._log("start", c) - if not (a := self.server.api_map.get(m.api_name)): - self._log(c, "error", "api not found={}" m.api_name) - _reply(c, APINotFound) - _reply(c, await _call(a, m.api_args)) + if not (a := _api(c)): + return + r = await _call(c, a, c.api_args) + if self._destroyed: + return + _reply(c, r) + self._log("end", c) + c = None except Exception as e: + pkdlog("exception={} call={} stack={}", e, c, pkdexc()) _reply(c, e) - def _log(self, call, which, fmt="", args=None): + def _log(self, which, call=None, fmt="", args=None): + if fmt: + fmt = " " + fmt pkdlog( "{} ip={} ws={}#{}" + fmt, which, self.remote_peer, self.ws_id, call and call.call_id, - *args, + *(args if args else ()), ) -class _WebSocketMessage(): - def parse_msg(self, msg): - def _maybe_srunit_caller(): - if pkconfig.in_dev_mode() and (c := self.header.get("srunit_caller")): - return pkdformat(" srunit={}", c) - return "" +class _ServerHandler(tornado.websocket.WebSocketHandler): + def initialize(self, server): + self.pykern_server = server + self.pykern_context = PKDict() + self.pykern_connection = None - if not isinstance(msg, bytes): - raise AssertionError(f"incoming msg type={type(msg)}") - u = msgpack.Unpacker( - max_buffer_size=sirepo.job.cfg().max_message_bytes, - object_pairs_hook=pkcollections.object_pairs_hook, - ) - u.feed(msg) - self.header = u.unpack() - self.handler.sr_log( - self, - "start", - fmt=" uri={}{}", - args=[self.header.get("uri"), _maybe_srunit_caller()], - ) - if sirepo.const.SCHEMA_COMMON.websocketMsg.version != self.header.get( - "version" - ): - raise AssertionError( - pkdformat("invalid header.version={}", self.header.get("version")) - ) - # Ensures protocol conforms for all requests - if ( - sirepo.const.SCHEMA_COMMON.websocketMsg.kind.httpRequest - != self.header.get("kind") - ): - raise AssertionError( - pkdformat("invalid header.kind={}", self.header.get("kind")) - ) - self.req_seq = self.header.reqSeq - self.uri = self.header.uri - if u.tell() < len(msg): - self.body_as_dict = u.unpack() - if u.tell() < len(msg): - self.attachment = u.unpack() - # content may or may not exist so defer checking - e, self.route, self.kwargs = _path_to_route(self.uri[1:]) - if e: - self.handler.sr_log( - self, - "error", - fmt=" msg={} route={} kwargs={}", - args=[e, self.route, self.kwargs], - ) - self.route = _not_found_route + async def get(self, *args, **kwargs): + if not self.pykern_server.handle_get(self): + return + return await super().get(*args, **kwargs) + + async def on_message(self, msg): + # WebSocketHandler only allows one on_message at a time. + asyncio.create_task(self.pykern_connection.handle_on_message(msg)) + + def on_close(self): + if self.pykern_connection: + self.pykern_connection.handle_on_close() + self.pykern_connection = None - def set_log_user(self, log_user): - self.log_user = log_user + def open(self): + self.pykern_connection = self.pykern_server.handle_open(self) def _auth_secret(value): @@ -533,15 +496,15 @@ def _datetime(obj): return p.bytes() -def _unpack_msg(request): +def _unpack_msg(content): try: u = msgpack.Unpacker( object_pairs_hook=pykern.pkcollections.object_pairs_hook, ) - u.feed(request.body) + u.feed(content) rv = u.unpack() except Exception as e: - return None, f"msg unpack exception={e}" + return None, f"msgpack exception={e}" if not isinstance(rv, PKDict): return None, f"msg not dict type={type(rv)}" if "call_id" not in rv: From 5a1177fe8909c0f02faf381a52af23c1c0d83b38 Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Wed, 1 Jan 2025 02:42:34 +0000 Subject: [PATCH 5/7] working --- pykern/http.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pykern/http.py b/pykern/http.py index c1cb3b03..9cf76b86 100644 --- a/pykern/http.py +++ b/pykern/http.py @@ -322,7 +322,7 @@ def _validate_version(headers): else: return True handler.pykern_context.error = e - self.send_error(**k) + handler.send_error(**k) return False except Exception as e: pkdlog("exception={} stack={}", e, pkdexc()) @@ -339,7 +339,7 @@ def handle_open(self, handler): self._log(handler, "error", "exception={}", [e]) return None - def _log(self, handler, which): + def _log(self, handler, which, fmt="", args=None): def _add(key, value): nonlocal f, a if value is not None: @@ -348,8 +348,12 @@ def _add(key, value): f = "" a = [] - _add("error", handler.pykern_context.pkdel("error")) - _add("ws_id", handler.pykern_context.get("ws_id")) + if x := getattr(handler, "pykern_context", None): + _add("error", x.pkdel("error")) + _add("ws_id", x.get("ws_id")) + if fmt: + f = f + " " + fmt + a.extend(args) self.loop.http_log(handler, which, f, a) def _log_end(self, handler): From 896810d5bd5bc93c725548f1ce5d08f8a9780f9e Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Thu, 2 Jan 2025 01:08:46 +0000 Subject: [PATCH 6/7] api_result instead of just result --- pykern/http.py | 45 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/pykern/http.py b/pykern/http.py index 9cf76b86..d7ede8fb 100644 --- a/pykern/http.py +++ b/pykern/http.py @@ -1,5 +1,6 @@ """HTTP server & client + :copyright: Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. :license: http://www.apache.org/licenses/LICENSE-2.0.html """ @@ -147,10 +148,14 @@ async def connect(self): self._connection = await tornado.websocket.websocket_connect( tornado.httpclient.HTTPRequest( self.uri, - headers={ - _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {self.auth_secret}", - _VERSION_HEADER: _VERSION_HEADER_VALUE, - }, + headers=( + { + _AUTH_HEADER: f"{_AUTH_HEADER_SCHEME_BEARER} {self.auth_secret}", + _VERSION_HEADER: _VERSION_HEADER_VALUE, + } + if self.auth_secret + else None + ), method="GET", ), # TODO(robnagler) accept in http_config. share defaults with sirepo.job. @@ -249,7 +254,7 @@ async def reply_get(self): raise APIDisconnected() if rv.api_error: raise pykern.quest.APIError(rv.api_error) - return rv.result + return rv.api_result def destroy(self): if self._destroyed: @@ -314,6 +319,11 @@ def _validate_version(headers): try: self._log(handler, "start") + # TODO(robnagler) special case for websockets. Need to switch + # to first message sent defines protocol and version. + if not self.auth_secret: + # _auth_secret can be false + return True h = handler.request.headers if e := _authenticate(h): k = PKDict(status_code=403, reason="Forbidden") @@ -406,11 +416,11 @@ async def _call(call, api, api_args): def _reply(call, obj): try: if not isinstance(obj, Exception): - r = PKDict(result=obj, api_error=None) + r = PKDict(api_result=obj, api_error=None) elif isinstance(obj, pykern.quest.APIError): - r = PKDict(result=None, api_error=str(obj)) + r = PKDict(api_result=None, api_error=str(obj)) else: - r = PKDict(result=None, api_error=f"unhandled_exception={obj}") + r = PKDict(api_result=None, api_error=f"unhandled_exception={obj}") r.call_id = call.call_id self.handler.write_message(_pack_msg(r), binary=True) except Exception as e: @@ -457,8 +467,8 @@ def initialize(self, server): self.pykern_connection = None async def get(self, *args, **kwargs): - if not self.pykern_server.handle_get(self): - return + # if not self.pykern_server.handle_get(self): + # return return await super().get(*args, **kwargs) async def on_message(self, msg): @@ -475,12 +485,27 @@ def open(self): def _auth_secret(value): + """Validate config value or default + + Special case: `auth_secret` can be ``False``, which means no + authentication or version checking. Temporary change to deal with + JavaScript WebSocket which does not support sending headers. + + + In dev mode, defaults to something if not set. + + Returns: + bool: Valid value of auth_secret + + """ if value: if len(value) < 16: raise ValueError("auth_secret too short len={len(value)} (<16)") if not _AUTH_SECRET_RE.search(value): raise ValueError("auth_secret contains non-word chars") return value + if isinstance(value, bool): + return value if pykern.pkconfig.in_dev_mode(): return "default_dev_secret" raise ValueError("must supply http_config.auth_secret") From 8a09ec37e5d0b16d008a9444c08ea369321e8ade Mon Sep 17 00:00:00 2001 From: Rob Nagler <5495179+robnagler@users.noreply.github.com> Date: Thu, 2 Jan 2025 01:20:19 +0000 Subject: [PATCH 7/7] update timestamp in test --- tests/pkcli/projex2_data/xyzzy1.out/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pkcli/projex2_data/xyzzy1.out/README.md b/tests/pkcli/projex2_data/xyzzy1.out/README.md index 5fd8c1c2..d1c88ed4 100644 --- a/tests/pkcli/projex2_data/xyzzy1.out/README.md +++ b/tests/pkcli/projex2_data/xyzzy1.out/README.md @@ -10,4 +10,4 @@ Documentation: https://xyzzy1.readthedocs.io License: https://www.apache.org/licenses/LICENSE-2.0.html -Copyright (c) 2024 RadiaSoft LLC. All Rights Reserved. +Copyright (c) 2025 RadiaSoft LLC. All Rights Reserved.