Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ssl support #712

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ or ``main.py``) or to a specific file. The ``--app-factory`` option can be used
from the app path file, if not supplied some default method names are tried
(namely `app`, `app_factory`, `get_app` and `create_app`, which can be
variables, functions, or coroutines).
The ``--ssl-context-factory`` option can be used to define method from the app path file, which returns ssl.SSLContext
for ssl support.

All ``runserver`` arguments can be set via environment variables.

Expand Down
2 changes: 2 additions & 0 deletions aiohttp_devtools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo
'or just an instance of aiohttp.Application. env variable AIO_APP_FACTORY')
port_help = 'Port to serve app from, default 8000. env variable: AIO_PORT'
aux_port_help = 'Port to serve auxiliary app (reload and static) on, default port + 1. env variable: AIO_AUX_PORT'
ssl_context_factory_help = 'name of the ssl context factory to create ssl.SSLContext with'


# defaults are all None here so default settings are defined in one place: DEV_DICT validation
Expand All @@ -83,6 +84,7 @@ def serve(path: str, livereload: bool, bind_address: str, port: int, verbose: bo
@click.option('-v', '--verbose', is_flag=True, help=verbose_help)
@click.option("--browser-cache/--no-browser-cache", envvar="AIO_BROWSER_CACHE", default=None,
help=browser_cache_help)
@click.option('--ssl-context-factory', 'ssl_context_factory_name', default=None, help=ssl_context_factory_help)
@click.argument('project_args', nargs=-1)
def runserver(**config: Any) -> None:
"""
Expand Down
49 changes: 43 additions & 6 deletions aiohttp_devtools/runserver/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Awaitable, Callable, Optional, Union

from aiohttp import web
import ssl

import __main__
from ..exceptions import AiohttpDevConfigError as AdevConfigError
Expand Down Expand Up @@ -43,9 +44,10 @@
app_factory_name: Optional[str] = None,
host: str = INFER_HOST,
bind_address: str = "localhost",
main_port: int = 8000,
main_port: Optional[int] = None,
aux_port: Optional[int] = None,
browser_cache: bool = False):
browser_cache: bool = False,
ssl_context_factory_name: Optional[str] = None):
if root_path:
self.root_path = Path(root_path).resolve()
logger.debug('Root path specified: %s', self.root_path)
Expand Down Expand Up @@ -83,9 +85,13 @@
self.host = bind_address

self.bind_address = bind_address
if main_port is None:
main_port = 8000 if ssl_context_factory_name == None else 8443
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
self.protocol = 'http'
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved
self.main_port = main_port
self.aux_port = aux_port or (main_port + 1)
self.browser_cache = browser_cache
self.ssl_context_factory_name = ssl_context_factory_name
logger.debug('config loaded:\n%s', self)

@property
Expand Down Expand Up @@ -135,15 +141,20 @@
if not path.is_dir():
raise AdevConfigError('{} is not a directory'.format(path))
return path

def import_app_factory(self) -> AppFactory:
"""Import and return attribute/class from a python module.
def import_module(self):
"""Import and return python module.

Raises:
AdevConfigError - If the import failed.
"""
rel_py_file = self.py_file.relative_to(self.python_path)
module_path = '.'.join(rel_py_file.with_suffix('').parts)
sys.path.insert(0, str(self.python_path))
module = import_module(module_path)
# Rewrite the package name, so it will appear the same as running the app.
if module.__package__:
__main__.__package__ = module.__package__
Dreamsorcerer marked this conversation as resolved.
Show resolved Hide resolved

sys.path.insert(0, str(self.python_path))
module = import_module(module_path)
Expand All @@ -153,6 +164,16 @@

logger.debug('successfully loaded "%s" from "%s"', module_path, self.python_path)

self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return module

def get_app_factory(self, module) -> AppFactory:
"""Import and return attribute/class from a python module.

Raises:
AdevConfigError - If the import failed.
"""

if self.app_factory_name is None:
try:
self.app_factory_name = next(an for an in APP_FACTORY_NAMES if hasattr(module, an))
Expand All @@ -179,8 +200,24 @@
raise AdevConfigError("'{}.{}' should not have required arguments.".format(
self.py_file.name, self.app_factory_name))

self.watch_path = self.watch_path or Path(module.__file__ or ".").parent
return attr # type: ignore[no-any-return]

def get_ssl_context(self, module) -> ssl.SSLContext:
if self.ssl_context_factory_name is None:
return None
else:
try:
attr = getattr(module, self.ssl_context_factory_name)
except AttributeError:
raise AdevConfigError("Module '{}' does not define a '{}' attribute/class".format(

Check warning on line 212 in aiohttp_devtools/runserver/config.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_devtools/runserver/config.py#L209-L212

Added lines #L209 - L212 were not covered by tests
self.py_file.name, self.ssl_context_factory_name))
ssl_context = attr()

Check warning on line 214 in aiohttp_devtools/runserver/config.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_devtools/runserver/config.py#L214

Added line #L214 was not covered by tests
if isinstance(ssl_context, ssl.SSLContext):
self.protocol = 'https'
return ssl_context

Check warning on line 217 in aiohttp_devtools/runserver/config.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_devtools/runserver/config.py#L216-L217

Added lines #L216 - L217 were not covered by tests
else:
raise AdevConfigError("ssl-context-factory '{}' in module '{}' didn't return valid SSLContext".format(

Check warning on line 219 in aiohttp_devtools/runserver/config.py

View check run for this annotation

Codecov / codecov/patch

aiohttp_devtools/runserver/config.py#L219

Added line #L219 was not covered by tests
self.ssl_context_factory_name, self.py_file.name))

async def load_app(self, app_factory: AppFactory) -> web.Application:
if isinstance(app_factory, web.Application):
Expand Down
10 changes: 6 additions & 4 deletions aiohttp_devtools/runserver/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ def runserver(**config_kwargs: Any) -> RunServer:
"""
# force a full reload in sub processes so they load an updated version of code, this must be called only once
set_start_method('spawn')

config = Config(**config_kwargs)
config.import_app_factory()
module = config.import_module()
ssl_context = config.get_ssl_context(module)
# config.get_app_factory(module)
# config.get_ssl_context_factory(module)

asyncio.run(check_port_open(config.main_port, host=config.bind_address))

Expand All @@ -49,15 +51,15 @@ def runserver(**config_kwargs: Any) -> RunServer:
logger.debug('starting livereload to watch %s', config.static_path_str)
aux_app.cleanup_ctx.append(static_manager.cleanup_ctx)

url = 'http://{0.host}:{0.aux_port}'.format(config)
url = '{0.protocol}://{0.host}:{0.aux_port}'.format(config)
logger.info('Starting aux server at %s ◆', url)

if config.static_path:
rel_path = config.static_path.relative_to(os.getcwd())
logger.info('serving static files from ./%s/ at %s%s', rel_path, url, config.static_url)

return {"app": aux_app, "host": config.bind_address, "port": config.aux_port,
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger}
"shutdown_timeout": 0.01, "access_log_class": AuxAccessLogger, "ssl_context": ssl_context}


def serve_static(*, static_path: str, livereload: bool = True, bind_address: str = "localhost", port: int = 8000,
Expand Down
18 changes: 11 additions & 7 deletions aiohttp_devtools/runserver/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from .log_handlers import AccessLogger
from .utils import MutableValue

import ssl

try:
from aiohttp_jinja2 import static_root_key
except ImportError:
Expand Down Expand Up @@ -103,7 +105,7 @@ async def no_cache_middleware(request: web.Request, handler: Handler) -> web.Str
# we set the app key even in middleware to make the switch to production easier and for backwards compat.
@web.middleware
async def static_middleware(request: web.Request, handler: Handler) -> web.StreamResponse:
static_url = 'http://{}:{}/{}'.format(get_host(request), config.aux_port, static_path)
static_url = '{}://{}:{}/{}'.format(config.protocol, get_host(request), config.aux_port, static_path)
dft_logger.debug('setting app static_root_url to "%s"', static_url)
_change_static_url(request.app, static_url)
return await handler(request)
Expand All @@ -120,10 +122,10 @@ def shutdown() -> NoReturn:

path = config.path_prefix + "/shutdown"
app.router.add_route("GET", path, do_shutdown, name="_devtools.shutdown")
dft_logger.debug("Created shutdown endpoint at http://{}:{}{}".format(config.host, config.main_port, path))
dft_logger.debug("Created shutdown endpoint at {}://{}:{}{}".format(config.protocol, config.host, config.main_port, path))

if config.static_path is not None:
static_url = 'http://{}:{}/{}'.format(config.host, config.aux_port, static_path)
static_url = '{}://{}:{}/{}'.format(config.protocol, config.host, config.aux_port, static_path)
dft_logger.debug('settings app static_root_url to "%s"', static_url)
_set_static_url(app, static_url)

Expand Down Expand Up @@ -164,7 +166,9 @@ def set_tty(tty_path: Optional[str]) -> Iterator[None]:
def serve_main_app(config: Config, tty_path: Optional[str]) -> None:
with set_tty(tty_path):
setup_logging(config.verbose)
app_factory = config.import_app_factory()
module = config.import_module()
app_factory = config.get_app_factory(module)
ssl_context = config.get_ssl_context(module)
if sys.version_info >= (3, 11):
with asyncio.Runner() as runner:
app_runner = runner.run(create_main_app(config, app_factory))
Expand All @@ -180,7 +184,7 @@ def serve_main_app(config: Config, tty_path: Optional[str]) -> None:
loop = asyncio.new_event_loop()
runner = loop.run_until_complete(create_main_app(config, app_factory))
try:
loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port))
loop.run_until_complete(start_main_app(runner, config.bind_address, config.main_port, ssl_context))
loop.run_forever()
except KeyboardInterrupt: # pragma: no cover
pass
Expand All @@ -197,9 +201,9 @@ async def create_main_app(config: Config, app_factory: AppFactory) -> web.AppRun
return web.AppRunner(app, access_log_class=AccessLogger, shutdown_timeout=0.1)


async def start_main_app(runner: web.AppRunner, host: str, port: int) -> None:
async def start_main_app(runner: web.AppRunner, host: str, port: int, ssl_context: ssl.SSLContext) -> None:
await runner.setup()
site = web.TCPSite(runner, host=host, port=port)
site = web.TCPSite(runner, host=host, port=port, ssl_context=ssl_context)
await site.start()


Expand Down
4 changes: 2 additions & 2 deletions aiohttp_devtools/runserver/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def _src_reload_when_live(self, checks: int) -> None:
assert self._app is not None and self._session is not None

if self._app[WS]:
url = "http://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config)
url = "{0.protocol}://{0.host}:{0.main_port}/?_checking_alive=1".format(self._config)
logger.debug('checking app at "%s" is running before prompting reload...', url)
for i in range(checks):
await asyncio.sleep(0.1)
Expand All @@ -123,7 +123,7 @@ async def _src_reload_when_live(self, checks: int) -> None:

def _start_dev_server(self) -> None:
act = 'Start' if self._reloads == 0 else 'Restart'
logger.info('%sing dev server at http://%s:%s ●', act, self._config.host, self._config.main_port)
logger.info('%sing dev server at %s://%s:%s ●', act, self._config.protocol, self._config.host, self._config.main_port)

try:
tty_path = os.ttyname(sys.stdin.fileno())
Expand Down
12 changes: 8 additions & 4 deletions tests/test_runserver_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ async def test_create_app_wrong_name(tmpworkdir):
mktree(tmpworkdir, SIMPLE_APP)
config = Config(app_path='app.py', app_factory_name='missing')
with pytest.raises(AiohttpDevConfigError) as excinfo:
config.import_app_factory()
module = config.import_module
config.get_app_factory(module)
assert excinfo.value.args[0] == "Module 'app.py' does not define a 'missing' attribute/class"


Expand All @@ -56,7 +57,8 @@ async def app_factory():
"""
})
config = Config(app_path='app.py')
app = await config.load_app(config.import_app_factory())
module = config.import_module()
app = await config.load_app(config.get_app_factory(module))
assert isinstance(app, web.Application)


Expand All @@ -69,9 +71,10 @@ def app_factory():
"""
})
config = Config(app_path='app.py')
module = config.import_module()
with pytest.raises(AiohttpDevConfigError,
match=r"'app_factory' returned 'int' not an aiohttp\.web\.Application"):
await config.load_app(config.import_app_factory())
await config.load_app(config.get_app_factory(module))


@forked
Expand All @@ -83,6 +86,7 @@ def app_factory(foo):
"""
})
config = Config(app_path='app.py')
module = config.import_module()
with pytest.raises(AiohttpDevConfigError,
match=r"'app\.py\.app_factory' should not have required arguments"):
await config.load_app(config.import_app_factory())
await config.load_app(config.get_app_factory(module))
19 changes: 12 additions & 7 deletions tests/test_runserver_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ async def create_app():

set_start_method("spawn")
config = Config(app_path="app.py", root_path=tmpworkdir, main_port=0, app_factory_name="create_app")
config.import_app_factory()
module = config.import_module()
config.get_app_factory(module)
app_task = AppTask(config)

app_task._start_dev_server()
Expand All @@ -162,7 +163,8 @@ async def create_app():
async def test_run_app_aiohttp_client(tmpworkdir, aiohttp_client):
mktree(tmpworkdir, SIMPLE_APP)
config = Config(app_path='app.py')
app_factory = config.import_app_factory()
module = config.import_module()
app_factory = config.get_app_factory(module)
app = await config.load_app(app_factory)
modify_main_app(app, config)
assert isinstance(app, aiohttp.web.Application)
Expand All @@ -178,7 +180,8 @@ async def test_run_app_aiohttp_client(tmpworkdir, aiohttp_client):
async def test_run_app_browser_cache(tmpworkdir, aiohttp_client):
mktree(tmpworkdir, SIMPLE_APP)
config = Config(app_path="app.py", browser_cache=True)
app_factory = config.import_app_factory()
module = config.import_module()
app_factory = config.get_app_factory(module)
app = await config.load_app(app_factory)
modify_main_app(app, config)
cli = await aiohttp_client(app)
Expand Down Expand Up @@ -208,8 +211,9 @@ async def test_serve_main_app(tmpworkdir, mocker):
loop.call_later(0.5, loop.stop)

config = Config(app_path="app.py", main_port=0)
runner = await create_main_app(config, config.import_app_factory())
await start_main_app(runner, config.bind_address, config.main_port)
module = config.import_module()
runner = await create_main_app(config, config.get_app_factory(module))
await start_main_app(runner, config.bind_address, config.main_port, None)

mock_modify_main_app.assert_called_with(mock.ANY, config)

Expand All @@ -232,8 +236,9 @@ async def hello(request):
mock_modify_main_app = mocker.patch('aiohttp_devtools.runserver.serve.modify_main_app')

config = Config(app_path="app.py", main_port=0)
runner = await create_main_app(config, config.import_app_factory())
await start_main_app(runner, config.bind_address, config.main_port)
module = config.import_module()
runner = await create_main_app(config, config.get_app_factory(module))
await start_main_app(runner, config.bind_address, config.main_port, None)

mock_modify_main_app.assert_called_with(mock.ANY, config)

Expand Down
2 changes: 2 additions & 0 deletions tests/test_runserver_watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ async def test_python_no_server(mocker):

config = MagicMock()
config.main_port = 8000
config.protocol = 'http'
app_task = AppTask(config)
start_mock = mocker.patch.object(app_task, "_start_dev_server", autospec=True)
stop_mock = mocker.patch.object(app_task, "_stop_dev_server", autospec=True)
Expand Down Expand Up @@ -109,6 +110,7 @@ async def test_reload_server_running(aiohttp_client, mocker):
config = MagicMock()
config.host = "localhost"
config.main_port = cli.server.port
config.protocol = 'http'

app_task = AppTask(config)
app_task._app = app
Expand Down
Loading