Skip to content

Commit

Permalink
Convert the dev server to WSGI.
Browse files Browse the repository at this point in the history
- serve: Add the --restart-on-change flag.
  • Loading branch information
rblank committed Nov 28, 2024
1 parent 054e824 commit 0cf6c42
Showing 1 changed file with 135 additions and 66 deletions.
201 changes: 135 additions & 66 deletions tdoc/common/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,25 @@

import argparse
import contextlib
from http import server
from http import HTTPMethod, HTTPStatus
import itertools
from importlib import metadata
import json
import mimetypes
import os
import pathlib
import posixpath
import re
import shutil
import socket
import socketserver
import stat
import subprocess
import sys
import threading
import time
from urllib import parse
from wsgiref import simple_server, util as wsgiutil

from . import __project__, __version__, util

Expand Down Expand Up @@ -85,8 +89,8 @@ def main(argv, stdin, stdout, stderr):
"(default: %(default)s).")
arg('--port', metavar='PORT', dest='port', default=8000, type=int,
help="The port to bind the server to (default: %(default)s).")
arg('--protocol', metavar='VERSION', dest='protocol', default='HTTP/1.0',
help="The HTTP protocol version to conform to (default: %(default)s).")
arg('--restart-on-change', action='store_true', dest='restart_on_change',
help="Restart the server on changes.")
arg('--watch', metavar='PATH', action='append', dest='watch', default=[],
help="Additional directories to watch for changes.")

Expand Down Expand Up @@ -126,15 +130,21 @@ def cmd_serve(cfg):

class Server(ServerBase):
address_family = family
class Handler(HandlerBase):
protocol = cfg.protocol

with Server(addr, Handler, cfg=cfg) as srv:
with Application(cfg, addr) as app, Server(addr, RequestHandler) as srv:
app.server = srv
srv.set_app(app)
try:
srv.serve_forever()
except KeyboardInterrupt:
cfg.restart_on_change = False
cfg.stderr.write("Interrupted, exiting\n")

if cfg.restart_on_change:
cfg.stdout.flush()
cfg.stderr.flush()
os.execv(sys.argv[0], sys.argv)


def cmd_version(cfg):
cfg.stdout.write(f"{__project__}-{__version__}\n")
Expand All @@ -153,9 +163,47 @@ def sphinx_build(cfg, target, *, build, tags=(), **kwargs):
stderr=cfg.stderr, **kwargs)


class ServerBase(server.ThreadingHTTPServer):
def __init__(self, *args, cfg, **kwargs):
class ServerBase(socketserver.ThreadingMixIn, simple_server.WSGIServer):
daemon_threads = True

def server_bind(self):
with contextlib.suppress(Exception):
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
return super().server_bind()


class RequestHandler(simple_server.WSGIRequestHandler):
def log_request(self, code='-', size='-'):
pass

def log_message(self, format, *args):
self.server.application.cfg.stderr.write("%s - - [%s] %s\n" % (
self.address_string(), self.log_date_time_string(),
(format % args).translate(self._control_char_table)))


def status_str(status):
return f'{status} {status.phrase}'


def error(respond, status):
respond(status_str(status), [
('Content-Type', 'text/plain;charset=utf-8'),
])
return [status.description.encode('utf-8')]


def try_stat(path):
try:
return path.stat()
except OSError:
return None


class Application:
def __init__(self, cfg, addr):
self.cfg = cfg
self.addr = addr
self.lock = threading.Condition(threading.Lock())
self.directory = self.build_dir(0) / 'html'
self.upgrade_msg = None
Expand All @@ -167,28 +215,13 @@ def __init__(self, *args, cfg, **kwargs):
self.builder.start()
self.checker = threading.Thread(target=self.check_upgrade, daemon=True)
self.checker.start()
super().__init__(*args, **kwargs)

def server_bind(self):
with contextlib.suppress(Exception):
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
return super().server_bind()

def finish_request(self, request, client_addr):
with self.lock: directory = self.directory
self.RequestHandlerClass(request, client_addr, self,
directory=directory)

IGNORED_EXCEPTIONS = (BrokenPipeError, ConnectionAbortedError)
self.apps = {'*build': self.handle_build}

def handle_error(self, request, client_addr):
if not isinstance(sys.exception(), self.IGNORED_EXCEPTIONS):
super().handle_error(request, client_addr)
def __enter__(self): return self

def server_close(self):
def __exit__(self, typ, value, tb):
with self.lock: self.stop = True
self.builder.join()
return super().server_close()

def watch_and_build(self):
interval = self.cfg.interval * 1_000_000_000
Expand All @@ -208,6 +241,11 @@ def watch_and_build(self):
prev = mtime + delay - interval
continue
if prev_mtime != 0:
if self.cfg.restart_on_change:
self.cfg.stdout.write(
"\nSource change detected, restarting\n")
self.server.shutdown()
break
self.cfg.stdout.write(
"\nSource change detected, rebuilding\n")
prev_mtime = mtime
Expand Down Expand Up @@ -266,7 +304,7 @@ def on_error(fn, path, e):
shutil.rmtree(build, onexc=on_error)

def print_serving(self):
host, port = self.socket.getsockname()[:2]
host, port = self.addr[:2]
if ':' in host: host = f'[{host}]'
self.cfg.stdout.write(self.cfg.ansi("Serving at <@{LBLUE}%s@{NORM}>\n")
% f"http://{host}:{port}/")
Expand All @@ -290,37 +328,68 @@ def check_upgrade(self):
except Exception:
if self.cfg.debug: raise

def __call__(self, env, respond):
script_name, path_info = env['SCRIPT_NAME'], env['PATH_INFO']
name = wsgiutil.shift_path_info(env)
if (handler := self.apps.get(name)) is not None:
return handler(env, respond)
env['SCRIPT_NAME'], env['PATH_INFO'] = script_name, path_info
return self.handle_default(env, respond)

def handle_default(self, env, respond):
env['wsgi.multithread'] = True
if (method := env['REQUEST_METHOD']) not in (HTTPMethod.HEAD,
HTTPMethod.GET):
return error(respond, HTTPStatus.NOT_IMPLEMENTED)
path = self.file_path(env['PATH_INFO'])
if (st := try_stat(path)) is None:
return error(respond, HTTPStatus.NOT_FOUND)

if stat.S_ISDIR(st.st_mode):
parts = parse.urlsplit(env['PATH_INFO'])
if not parts.path.endswith('/'):
location = parse.urlunsplit(
(parts[:2] + (parts[2] + '/',) + parts[3:]))
respond(status_str(HTTPStatus.MOVED_PERMANENTLY), [
('Location', location),
('Content-Length', '0'),
])
return []
path = path / 'index.html'
if (st := try_stat(path)) is None:
return error(respond, HTTPStatus.NOT_FOUND)

if not stat.S_ISREG(st.st_mode):
return error(respond, HTTPStatus.NOT_FOUND)
mime_type = mimetypes.guess_type(path)[0]
if not mime_type: mime_type = 'application/octet-stream'
respond(status_str(HTTPStatus.OK), [
('Content-Type', mime_type),
('Content-Length', str(st.st_size)),
])
if method == HTTPMethod.HEAD: return []
wrapper = env.get('wsgi.file_wrapper', wsgiutil.FileWrapper)
return wrapper(open(path, 'rb'))

def file_path(self, path):
trailing = path.rstrip().endswith('/')
try:
path = parse.unquote(path, errors='surrogatepass')
except UnicodeDecodeError:
path = parse.unquote(path)
with self.lock: res = self.directory
for part in filter(None, posixpath.normpath(path).split('/')):
if pathlib.Path(part).parent.name or part in (os.curdir, os.pardir):
continue
res = res / part
return res / '' if trailing else res

class HandlerBase(server.SimpleHTTPRequestHandler):
def log_request(self, code='-', size='-'):
pass

def log_message(self, format, *args):
self.server.cfg.stderr.write("%s - - [%s] %s\n" % (
self.address_string(), self.log_date_time_string(),
(format % args).translate(self._control_char_table)))

def do_GET(self):
if not self.dispatch_star_handler(True):
super().do_GET()

def do_HEAD(self):
if not self.dispatch_star_handler(False):
super().do_HEAD()

def dispatch_star_handler(self, write_content):
url = parse.urlparse(self.path)
if not url.path.startswith('/*'): return
if handler := getattr(self, f'handle_star_{url.path[2:]}', None):
content = handler(url, write_content)
if write_content and content: self.wfile.write(content)
else:
self.send_error(server.HTTPStatus.NOT_FOUND)
return True

def handle_star_build(self, url, write_content):
def handle_build(self, env, respond):
if (method := env['REQUEST_METHOD']) not in (HTTPMethod.HEAD,
HTTPMethod.GET):
yield from error(respond, HTTPStatus.NOT_IMPLEMENTED)
t = None
for k, v in parse.parse_qsl(url.query):
for k, v in parse.parse_qsl(env.get('QUERY_STRING', '')):
if k == 't':
t = v
break
Expand All @@ -329,20 +398,20 @@ def handle_star_build(self, url, write_content):
# content length is needed upfront, we return a fixed size, and
# terminate the request if the padding exceeds the available space.
size = 600
self.send_response(server.HTTPStatus.OK)
self.send_header('Content-type', 'text/plain')
self.send_header('Content-Length', str(size))
self.end_headers()
if not write_content: return
with self.server.lock:
while ((mtime := self.server.build_mtime) is None
respond(status_str(HTTPStatus.OK), [
('Content-Type', 'text/plain;charset=utf-8'),
('Content-Length', str(size)),
])
if method == HTTPMethod.HEAD: return
with self.lock:
while ((mtime := self.build_mtime) is None
or t == build_tag(mtime)) and size > 0:
if self.server.lock.wait(timeout=1): continue
self.wfile.write(b' ')
if self.lock.wait(timeout=1): continue
yield b' '
size -= 1
tag = build_tag(mtime).encode('utf-8')
if len(tag) > size: tag = b'' # Not enough remaining capacity
return b' ' * (size - len(tag)) + tag
yield b' ' * (size - len(tag)) + tag


def build_tag(mtime):
Expand Down

0 comments on commit 0cf6c42

Please sign in to comment.