Skip to content

Commit

Permalink
Adds support for key-based SSH connections (#534)
Browse files Browse the repository at this point in the history
* Centralised environment variables (#529)

* refactor: Restructured settings.py

* docs: Minor tweaks

* refactor: Move security and infection config to settings

* refactor: b/e & f/e/ tags now in settings (also fixed f/e tag value)

* refactor: Move Neo4j config to settings

* refactor: More variables into settings

* refactor: Moved remaining config

* docs: Adds configuration guide as comments

* docs: Variable prefix now 'stack_' not 'stack_env_'

---------

Co-authored-by: Alan Christie <[email protected]>

* feat: Adds support for private keys on SSH tunnel

* fix: Fixes key-based logic

---------

Co-authored-by: Alan Christie <[email protected]>
  • Loading branch information
alanbchristie and Alan Christie authored Feb 19, 2024
1 parent b5639af commit f3483bb
Show file tree
Hide file tree
Showing 10 changed files with 463 additions and 391 deletions.
12 changes: 6 additions & 6 deletions api/infections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# Infections are injected into the application via the environment variable
# 'INFECTIONS', a comma-separated list of infection names.

import os
from typing import Dict, Set

from django.conf import settings

from api.utils import deployment_mode_is_production

# The built-in set of infections.
Expand All @@ -20,9 +21,6 @@
INFECTION_STRUCTURE_DOWNLOAD: 'An error in the DownloadStructures view'
}

# What infection have been set?
_INFECTIONS: str = os.environ.get('INFECTIONS', '').lower()


def have_infection(name: str) -> bool:
"""Returns True if we've been given the named infection.
Expand All @@ -31,9 +29,11 @@ def have_infection(name: str) -> bool:


def _get_infections() -> Set[str]:
if _INFECTIONS == '':
if settings.INFECTIONS == '':
return set()
infections: set[str] = {
infection for infection in _INFECTIONS.split(',') if infection in _CATALOGUE
infection
for infection in settings.INFECTIONS.split(',')
if infection in _CATALOGUE
}
return infections
54 changes: 44 additions & 10 deletions api/remote_ispyb_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
remote=False,
ssh_user=None,
ssh_password=None,
ssh_private_key_filename=None,
ssh_host=None,
conn_inactivity=360,
):
Expand All @@ -45,6 +46,7 @@ def __init__(
'ssh_host': ssh_host,
'ssh_user': ssh_user,
'ssh_pass': ssh_password,
'ssh_pkey': ssh_private_key_filename,
'db_host': host,
'db_port': int(port),
'db_user': user,
Expand All @@ -53,12 +55,11 @@ def __init__(
}
self.remote_connect(**creds)
logger.debug(
"Started host=%s username=%s local_bind_port=%s",
"Started remote ssh_host=%s ssh_user=%s local_bind_port=%s",
ssh_host,
ssh_user,
self.server.local_bind_port,
)

else:
self.connect(
user=user,
Expand All @@ -68,29 +69,60 @@ def __init__(
port=port,
conn_inactivity=conn_inactivity,
)
logger.debug("Started host=%s user=%s port=%s", host, user, port)
logger.debug("Started direct host=%s user=%s port=%s", host, user, port)

def remote_connect(
self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user, db_pass, db_name
self,
ssh_host,
ssh_user,
ssh_pass,
ssh_pkey,
db_host,
db_port,
db_user,
db_pass,
db_name,
):
sshtunnel.SSH_TIMEOUT = 10.0
sshtunnel.TUNNEL_TIMEOUT = 10.0
sshtunnel.DEFAULT_LOGLEVEL = logging.CRITICAL
self.conn_inactivity = int(self.conn_inactivity)

self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port),
)
if ssh_pkey:
logger.debug(
'Creating SSHTunnelForwarder (with SSH Key) host=%s user=%s',
ssh_host,
ssh_user,
)
self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user,
ssh_pkey=ssh_pkey,
remote_bind_address=(db_host, db_port),
)
else:
logger.debug(
'Creating SSHTunnelForwarder (with password) host=%s user=%s',
ssh_host,
ssh_user,
)
self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port),
)
logger.debug('Created SSHTunnelForwarder')

# stops hanging connections in transport
self.server.daemon_forward_servers = True
self.server.daemon_transport = True

logger.debug('Starting SSH server...')
self.server.start()
logger.debug('Started SSH server')

logger.debug('Connecting to ISPyB (db_user=%s db_name=%s)...', db_user, db_name)
self.conn = pymysql.connect(
user=db_user,
password=db_pass,
Expand All @@ -100,8 +132,10 @@ def remote_connect(
)

if self.conn is not None:
logger.debug('Connected')
self.conn.autocommit = True
else:
logger.debug('Failed to connect')
self.server.stop()
raise ISPyBConnectionException
self.last_activity_ts = time.time()
Expand Down
39 changes: 20 additions & 19 deletions api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,52 @@


def get_remote_conn() -> Optional[SSHConnector]:
ispyb_credentials: Dict[str, Any] = {
"user": os.environ.get("ISPYB_USER"),
"pw": os.environ.get("ISPYB_PASSWORD"),
"host": os.environ.get("ISPYB_HOST"),
"port": os.environ.get("ISPYB_PORT"),
credentials: Dict[str, Any] = {
"user": settings.ISPYB_USER,
"pw": settings.ISPYB_PASSWORD,
"host": settings.ISPYB_HOST,
"port": settings.ISPYB_PORT,
"db": "ispyb",
"conn_inactivity": 360,
}

ssh_credentials: Dict[str, Any] = {
'ssh_host': os.environ.get("SSH_HOST"),
'ssh_user': os.environ.get("SSH_USER"),
'ssh_password': os.environ.get("SSH_PASSWORD"),
'ssh_host': settings.SSH_HOST,
'ssh_user': settings.SSH_USER,
'ssh_password': settings.SSH_PASSWORD,
"ssh_private_key_filename": settings.SSH_PRIVATE_KEY_FILENAME,
'remote': True,
}

ispyb_credentials.update(**ssh_credentials)
credentials.update(**ssh_credentials)

# Caution: Credentials may not be set in the environment.
# Assume the credentials are invalid if there is no host.
# If a host is not defined other properties are useless.
if not ispyb_credentials["host"]:
if not credentials["host"]:
logger.debug("No ISPyB host - cannot return a connector")
return None

# Try to get an SSH connection (aware that it might fail)
conn: Optional[SSHConnector] = None
try:
conn = SSHConnector(**ispyb_credentials)
conn = SSHConnector(**credentials)
except Exception:
# Log the exception if DEBUG level or lower/finer?
# The following wil not log if the level is set to INFO for example.
# The following will not log if the level is set to INFO for example.
if logging.DEBUG >= logger.level:
logger.info("ispyb_credentials=%s", ispyb_credentials)
logger.info("credentials=%s", credentials)
logger.exception("Got the following exception creating SSHConnector...")

return conn


def get_conn() -> Optional[Connector]:
credentials: Dict[str, Any] = {
"user": os.environ.get("ISPYB_USER"),
"pw": os.environ.get("ISPYB_PASSWORD"),
"host": os.environ.get("ISPYB_HOST"),
"port": os.environ.get("ISPYB_PORT"),
"user": settings.ISPYB_USER,
"pw": settings.ISPYB_PASSWORD,
"host": settings.ISPYB_HOST,
"port": settings.ISPYB_PORT,
"db": "ispyb",
"conn_inactivity": 360,
}
Expand All @@ -108,7 +109,7 @@ def get_conn() -> Optional[Connector]:
conn = Connector(**credentials)
except Exception:
# Log the exception if DEBUG level or lower/finer?
# The following wil not log if the level is set to INFO for example.
# The following will not log if the level is set to INFO for example.
if logging.DEBUG >= logger.level:
logger.info("credentials=%s", credentials)
logger.exception("Got the following exception creating Connector...")
Expand Down Expand Up @@ -349,7 +350,7 @@ def get_proposals_for_user(self, user, restrict_to_membership=False):
assert user

proposals = set()
ispyb_user = os.environ.get("ISPYB_USER")
ispyb_user = settings.ISPYB_USER
logger.debug(
"ispyb_user=%s restrict_to_membership=%s",
ispyb_user,
Expand Down
Loading

0 comments on commit f3483bb

Please sign in to comment.