Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ran black on the b/e repo #452

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,16 @@ repos:
# - types-pytz
# - types-requests

# Black (uncompromising) Python code formatter
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:
- id: black
args:
- --skip-string-normalization
- --target-version
- py311

# Pylint
# To check import errors we need to install every package
# used by the DM. This is often impractical on the client,
Expand Down
88 changes: 53 additions & 35 deletions api/remote_ispyb_connector.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,72 @@
import threading
import mysql.connector
from ispyb.connector.mysqlsp.main import ISPyBMySQLSPConnector as Connector
from ispyb.exception import (ISPyBConnectionException, ISPyBNoResultException,
ISPyBRetrieveFailed, ISPyBWriteFailed)
from ispyb.exception import (
ISPyBConnectionException,
ISPyBNoResultException,
ISPyBRetrieveFailed,
ISPyBWriteFailed,
)
import sshtunnel
import time
import pymysql


class SSHConnector(Connector):
def __init__(self,
user=None,
pw=None,
host="localhost",
db=None,
port=3306,
reconn_attempts=6,
reconn_delay=1,
remote=False,
ssh_user=None,
ssh_password=None,
ssh_host=None,
conn_inactivity=360,
):
def __init__(
self,
user=None,
pw=None,
host="localhost",
db=None,
port=3306,
reconn_attempts=6,
reconn_delay=1,
remote=False,
ssh_user=None,
ssh_password=None,
ssh_host=None,
conn_inactivity=360,
):
self.conn_inactivity = conn_inactivity
self.lock = threading.Lock()
self.server = None

if remote:
creds = {'ssh_host': ssh_host,
'ssh_user': ssh_user,
'ssh_pass': ssh_password,
'db_host': host,
'db_port': int(port),
'db_user': user,
'db_pass': pw,
'db_name': db}
creds = {
'ssh_host': ssh_host,
'ssh_user': ssh_user,
'ssh_pass': ssh_password,
'db_host': host,
'db_port': int(port),
'db_user': user,
'db_pass': pw,
'db_name': db,
}
self.remote_connect(**creds)

else:
self.connect(user=user, pw=pw, host=host, db=db, port=port, conn_inactivity=conn_inactivity)

def remote_connect(self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user, db_pass, db_name):

self.connect(
user=user,
pw=pw,
host=host,
db=db,
port=port,
conn_inactivity=conn_inactivity,
)

def remote_connect(
self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user, db_pass, db_name
):
sshtunnel.SSH_TIMEOUT = 10.0
sshtunnel.TUNNEL_TIMEOUT = 10.0
self.conn_inactivity = int(self.conn_inactivity)

self.server = sshtunnel.SSHTunnelForwarder(
(ssh_host),
ssh_username=ssh_user, ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port)
ssh_username=ssh_user,
ssh_password=ssh_pass,
remote_bind_address=(db_host, db_port),
)

# stops hanging connections in transport
Expand All @@ -60,8 +76,10 @@ def remote_connect(self, ssh_host, ssh_user, ssh_pass, db_host, db_port, db_user
self.server.start()

self.conn = pymysql.connect(
user=db_user, password=db_pass,
host='127.0.0.1', port=self.server.local_bind_port,
user=db_user,
password=db_pass,
host='127.0.0.1',
port=self.server.local_bind_port,
database=db_name,
)

Expand Down Expand Up @@ -91,13 +109,13 @@ def call_sp_retrieve(self, procname, args):
try:
cursor.callproc(procname=procname, args=args)
except DataError as e:
raise ISPyBRetrieveFailed("DataError({0}): {1}".format(e.errno, traceback.format_exc()))
raise ISPyBRetrieveFailed(
"DataError({0}): {1}".format(e.errno, traceback.format_exc())
)

result = cursor.fetchall()

cursor.close()
if result == []:
raise ISPyBNoResultException
return result


68 changes: 42 additions & 26 deletions api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@


def get_remote_conn():

ispyb_credentials = {
"user": os.environ.get("ISPYB_USER"),
"pw": os.environ.get("ISPYB_PASSWORD"),
Expand All @@ -52,7 +51,7 @@ def get_remote_conn():
'ssh_host': os.environ.get("SSH_HOST"),
'ssh_user': os.environ.get("SSH_USER"),
'ssh_password': os.environ.get("SSH_PASSWORD"),
'remote': True
'remote': True,
}

ispyb_credentials.update(**ssh_credentials)
Expand Down Expand Up @@ -102,7 +101,6 @@ def get_conn():


class ISpyBSafeQuerySet(viewsets.ReadOnlyModelViewSet):

def get_queryset(self):
"""
Optionally restricts the returned purchases to a given proposals
Expand All @@ -115,8 +113,11 @@ def get_queryset(self):
if open_proposal not in proposal_list:
proposal_list.append(open_proposal)

logger.debug('is_authenticated=%s, proposal_list=%s',
self.request.user.is_authenticated, proposal_list)
logger.debug(
'is_authenticated=%s, proposal_list=%s',
self.request.user.is_authenticated,
proposal_list,
)

# Must have a foreign key to a Project for this filter to work.
# get_q_filter() returns a Q expression for filtering
Expand All @@ -130,7 +131,7 @@ def get_open_proposals(self):
if os.environ.get("TEST_SECURITY_FLAG", False):
return ["lb00000"]
else:
# All of well-known (built-in) public Projects (Proposals/Visits)
# All of well-known (built-in) public Projects (Proposals/Visits)
return ["lb27156"]

def get_proposals_for_user_from_django(self, user):
Expand All @@ -142,8 +143,12 @@ def get_proposals_for_user_from_django(self, user):
prop_ids = list(
Project.objects.filter(user_id=user.pk).values_list("title", flat=True)
)
logger.debug("Got %s proposals for user %s: %s",
len(prop_ids), user.username, prop_ids)
logger.debug(
"Got %s proposals for user %s: %s",
len(prop_ids),
user.username,
prop_ids,
)
return prop_ids

def needs_updating(self, user):
Expand All @@ -163,7 +168,6 @@ def needs_updating(self, user):
return False

def run_query_with_connector(self, conn, user):

core = conn.core
try:
rs = core.retrieve_sessions_for_person_login(user.username)
Expand Down Expand Up @@ -242,22 +246,30 @@ def get_proposals_for_user_from_ispyb(self, user):

# Always display the collected results for the user.
# These will be cached.
logger.info("Got %s proposals from %s records for user %s: %s",
len(prop_id_set), len(rs), user.username, prop_id_set)
logger.info(
"Got %s proposals from %s records for user %s: %s",
len(prop_id_set),
len(rs),
user.username,
prop_id_set,
)

# Cache the result and return the result for the user
USER_LIST_DICT[user.username]["RESULTS"] = list(prop_id_set)
return USER_LIST_DICT[user.username]["RESULTS"]
else:
# Return the previous query (cached for an hour)
cached_prop_ids = USER_LIST_DICT[user.username]["RESULTS"]
logger.info("Got %s cached proposals for user %s: %s",
len(cached_prop_ids), user.username, cached_prop_ids)
logger.info(
"Got %s cached proposals for user %s: %s",
len(cached_prop_ids),
user.username,
cached_prop_ids,
)
return cached_prop_ids

def get_proposals_for_user(self, user):
"""Returns a list of proposals (public and private) that the user has access to.
"""
"""Returns a list of proposals (public and private) that the user has access to."""
assert user

ispyb_user = os.environ.get("ISPYB_USER")
Expand All @@ -267,21 +279,23 @@ def get_proposals_for_user(self, user):
logger.info("Getting proposals from ISPyB...")
return self.get_proposals_for_user_from_ispyb(user)
else:
logger.info("No proposals (user %s is not authenticated)", user.username)
logger.info(
"No proposals (user %s is not authenticated)", user.username
)
return []
else:
logger.info("Getting proposals from Django...")
return self.get_proposals_for_user_from_django(user)

def get_q_filter(self, proposal_list):
"""Returns a Q expression representing a (potentially complex) table filter.
"""
"""Returns a Q expression representing a (potentially complex) table filter."""
if self.filter_permissions:
# Q-filter is based on the filter_permissions string
# whether the resultant Project title in the proposal list
# whether the resultant Project title in the proposal list
# OR where the Project is 'open_to_public'
return Q(**{self.filter_permissions + "__title__in": proposal_list}) |\
Q(**{self.filter_permissions + "__open_to_public": True})
return Q(**{self.filter_permissions + "__title__in": proposal_list}) | Q(
**{self.filter_permissions + "__open_to_public": True}
)
else:
# No filter permission?
# Assume this QuerySet is used for the Project model.
Expand All @@ -293,7 +307,6 @@ def get_q_filter(self, proposal_list):


class ISpyBSafeStaticFiles:

def get_queryset(self):
query = ISpyBSafeQuerySet()
query.request = self.request
Expand All @@ -318,12 +331,16 @@ def get_response(self):
logger.info("Path to pass to nginx: %s", self.prefix + file_name)

if hasattr(self, 'file_format'):
if self.file_format=='raw':
if self.file_format == 'raw':
file_field = getattr(object, self.field_name)
filepath = file_field.path
zip_file = open(filepath, 'rb')
response = HttpResponse(FileWrapper(zip_file), content_type='application/zip')
response['Content-Disposition'] = 'attachment; filename="%s"' % file_name
response = HttpResponse(
FileWrapper(zip_file), content_type='application/zip'
)
response['Content-Disposition'] = (
'attachment; filename="%s"' % file_name
)

else:
response = HttpResponse()
Expand All @@ -338,7 +355,6 @@ def get_response(self):


class ISpyBSafeStaticFiles2(ISpyBSafeStaticFiles):

def get_response(self):
logger.info("+ get_response called with: %s", self.input_string)
# it wasn't working because found two objects with test file name
Expand Down
Loading