diff --git a/tdoc/common/cli.py b/tdoc/common/cli.py index 4baaafa..196279f 100644 --- a/tdoc/common/cli.py +++ b/tdoc/common/cli.py @@ -3,21 +3,25 @@ import argparse import contextlib -from http import server +from http import HTTPMethod, HTTPStatus import itertools from importlib import metadata import json +import mimetypes import os import pathlib +import posixpath import re import shutil import socket +import socketserver import stat import subprocess import sys import threading import time from urllib import parse +from wsgiref import simple_server, util as wsgiutil from . import __project__, __version__, util @@ -85,8 +89,8 @@ def main(argv, stdin, stdout, stderr): "(default: %(default)s).") arg('--port', metavar='PORT', dest='port', default=8000, type=int, help="The port to bind the server to (default: %(default)s).") - arg('--protocol', metavar='VERSION', dest='protocol', default='HTTP/1.0', - help="The HTTP protocol version to conform to (default: %(default)s).") + arg('--restart-on-change', action='store_true', dest='restart_on_change', + help="Restart the server on changes.") arg('--watch', metavar='PATH', action='append', dest='watch', default=[], help="Additional directories to watch for changes.") @@ -126,15 +130,21 @@ def cmd_serve(cfg): class Server(ServerBase): address_family = family - class Handler(HandlerBase): - protocol = cfg.protocol - with Server(addr, Handler, cfg=cfg) as srv: + with Application(cfg, addr) as app, Server(addr, RequestHandler) as srv: + app.server = srv + srv.set_app(app) try: srv.serve_forever() except KeyboardInterrupt: + cfg.restart_on_change = False cfg.stderr.write("Interrupted, exiting\n") + if cfg.restart_on_change: + cfg.stdout.flush() + cfg.stderr.flush() + os.execv(sys.argv[0], sys.argv) + def cmd_version(cfg): cfg.stdout.write(f"{__project__}-{__version__}\n") @@ -153,9 +163,47 @@ def sphinx_build(cfg, target, *, build, tags=(), **kwargs): stderr=cfg.stderr, **kwargs) -class ServerBase(server.ThreadingHTTPServer): - def __init__(self, *args, cfg, **kwargs): +class ServerBase(socketserver.ThreadingMixIn, simple_server.WSGIServer): + daemon_threads = True + + def server_bind(self): + with contextlib.suppress(Exception): + self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + return super().server_bind() + + +class RequestHandler(simple_server.WSGIRequestHandler): + def log_request(self, code='-', size='-'): + pass + + def log_message(self, format, *args): + self.server.application.cfg.stderr.write("%s - - [%s] %s\n" % ( + self.address_string(), self.log_date_time_string(), + (format % args).translate(self._control_char_table))) + + +def status_str(status): + return f'{status} {status.phrase}' + + +def error(respond, status): + respond(status_str(status), [ + ('Content-Type', 'text/plain;charset=utf-8'), + ]) + return [status.description.encode('utf-8')] + + +def try_stat(path): + try: + return path.stat() + except OSError: + return None + + +class Application: + def __init__(self, cfg, addr): self.cfg = cfg + self.addr = addr self.lock = threading.Condition(threading.Lock()) self.directory = self.build_dir(0) / 'html' self.upgrade_msg = None @@ -167,28 +215,13 @@ def __init__(self, *args, cfg, **kwargs): self.builder.start() self.checker = threading.Thread(target=self.check_upgrade, daemon=True) self.checker.start() - super().__init__(*args, **kwargs) - - def server_bind(self): - with contextlib.suppress(Exception): - self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) - return super().server_bind() - - def finish_request(self, request, client_addr): - with self.lock: directory = self.directory - self.RequestHandlerClass(request, client_addr, self, - directory=directory) - - IGNORED_EXCEPTIONS = (BrokenPipeError, ConnectionAbortedError) + self.apps = {'*build': self.handle_build} - def handle_error(self, request, client_addr): - if not isinstance(sys.exception(), self.IGNORED_EXCEPTIONS): - super().handle_error(request, client_addr) + def __enter__(self): return self - def server_close(self): + def __exit__(self, typ, value, tb): with self.lock: self.stop = True self.builder.join() - return super().server_close() def watch_and_build(self): interval = self.cfg.interval * 1_000_000_000 @@ -208,6 +241,11 @@ def watch_and_build(self): prev = mtime + delay - interval continue if prev_mtime != 0: + if self.cfg.restart_on_change: + self.cfg.stdout.write( + "\nSource change detected, restarting\n") + self.server.shutdown() + break self.cfg.stdout.write( "\nSource change detected, rebuilding\n") prev_mtime = mtime @@ -266,7 +304,7 @@ def on_error(fn, path, e): shutil.rmtree(build, onexc=on_error) def print_serving(self): - host, port = self.socket.getsockname()[:2] + host, port = self.addr[:2] if ':' in host: host = f'[{host}]' self.cfg.stdout.write(self.cfg.ansi("Serving at <@{LBLUE}%s@{NORM}>\n") % f"http://{host}:{port}/") @@ -290,37 +328,68 @@ def check_upgrade(self): except Exception: if self.cfg.debug: raise + def __call__(self, env, respond): + script_name, path_info = env['SCRIPT_NAME'], env['PATH_INFO'] + name = wsgiutil.shift_path_info(env) + if (handler := self.apps.get(name)) is not None: + return handler(env, respond) + env['SCRIPT_NAME'], env['PATH_INFO'] = script_name, path_info + return self.handle_default(env, respond) + + def handle_default(self, env, respond): + env['wsgi.multithread'] = True + if (method := env['REQUEST_METHOD']) not in (HTTPMethod.HEAD, + HTTPMethod.GET): + return error(respond, HTTPStatus.NOT_IMPLEMENTED) + path = self.file_path(env['PATH_INFO']) + if (st := try_stat(path)) is None: + return error(respond, HTTPStatus.NOT_FOUND) + + if stat.S_ISDIR(st.st_mode): + parts = parse.urlsplit(env['PATH_INFO']) + if not parts.path.endswith('/'): + location = parse.urlunsplit( + (parts[:2] + (parts[2] + '/',) + parts[3:])) + respond(status_str(HTTPStatus.MOVED_PERMANENTLY), [ + ('Location', location), + ('Content-Length', '0'), + ]) + return [] + path = path / 'index.html' + if (st := try_stat(path)) is None: + return error(respond, HTTPStatus.NOT_FOUND) + + if not stat.S_ISREG(st.st_mode): + return error(respond, HTTPStatus.NOT_FOUND) + mime_type = mimetypes.guess_type(path)[0] + if not mime_type: mime_type = 'application/octet-stream' + respond(status_str(HTTPStatus.OK), [ + ('Content-Type', mime_type), + ('Content-Length', str(st.st_size)), + ]) + if method == HTTPMethod.HEAD: return [] + wrapper = env.get('wsgi.file_wrapper', wsgiutil.FileWrapper) + return wrapper(open(path, 'rb')) + + def file_path(self, path): + trailing = path.rstrip().endswith('/') + try: + path = parse.unquote(path, errors='surrogatepass') + except UnicodeDecodeError: + path = parse.unquote(path) + with self.lock: res = self.directory + for part in filter(None, posixpath.normpath(path).split('/')): + if pathlib.Path(part).parent.name or part in (os.curdir, os.pardir): + continue + res = res / part + return res / '' if trailing else res -class HandlerBase(server.SimpleHTTPRequestHandler): - def log_request(self, code='-', size='-'): - pass - - def log_message(self, format, *args): - self.server.cfg.stderr.write("%s - - [%s] %s\n" % ( - self.address_string(), self.log_date_time_string(), - (format % args).translate(self._control_char_table))) - - def do_GET(self): - if not self.dispatch_star_handler(True): - super().do_GET() - - def do_HEAD(self): - if not self.dispatch_star_handler(False): - super().do_HEAD() - - def dispatch_star_handler(self, write_content): - url = parse.urlparse(self.path) - if not url.path.startswith('/*'): return - if handler := getattr(self, f'handle_star_{url.path[2:]}', None): - content = handler(url, write_content) - if write_content and content: self.wfile.write(content) - else: - self.send_error(server.HTTPStatus.NOT_FOUND) - return True - - def handle_star_build(self, url, write_content): + def handle_build(self, env, respond): + if (method := env['REQUEST_METHOD']) not in (HTTPMethod.HEAD, + HTTPMethod.GET): + yield from error(respond, HTTPStatus.NOT_IMPLEMENTED) t = None - for k, v in parse.parse_qsl(url.query): + for k, v in parse.parse_qsl(env.get('QUERY_STRING', '')): if k == 't': t = v break @@ -329,20 +398,20 @@ def handle_star_build(self, url, write_content): # content length is needed upfront, we return a fixed size, and # terminate the request if the padding exceeds the available space. size = 600 - self.send_response(server.HTTPStatus.OK) - self.send_header('Content-type', 'text/plain') - self.send_header('Content-Length', str(size)) - self.end_headers() - if not write_content: return - with self.server.lock: - while ((mtime := self.server.build_mtime) is None + respond(status_str(HTTPStatus.OK), [ + ('Content-Type', 'text/plain;charset=utf-8'), + ('Content-Length', str(size)), + ]) + if method == HTTPMethod.HEAD: return + with self.lock: + while ((mtime := self.build_mtime) is None or t == build_tag(mtime)) and size > 0: - if self.server.lock.wait(timeout=1): continue - self.wfile.write(b' ') + if self.lock.wait(timeout=1): continue + yield b' ' size -= 1 tag = build_tag(mtime).encode('utf-8') if len(tag) > size: tag = b'' # Not enough remaining capacity - return b' ' * (size - len(tag)) + tag + yield b' ' * (size - len(tag)) + tag def build_tag(mtime):