Skip to content

Commit

Permalink
all: Replace strings with RequestType flags
Browse files Browse the repository at this point in the history
Signed-off-by:  Eric Callahan <[email protected]>
  • Loading branch information
Arksine committed Nov 20, 2023
1 parent 612a5d8 commit f81e340
Show file tree
Hide file tree
Showing 23 changed files with 313 additions and 229 deletions.
17 changes: 9 additions & 8 deletions moonraker/components/announcements.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import logging
import email.utils
import xml.etree.ElementTree as etree
from ..common import RequestType
from typing import (
TYPE_CHECKING,
Awaitable,
Expand Down Expand Up @@ -57,23 +58,23 @@ def __init__(self, config: ConfigHelper) -> None:
)

self.server.register_endpoint(
"/server/announcements/list", ["GET"],
"/server/announcements/list", RequestType.GET,
self._list_announcements
)
self.server.register_endpoint(
"/server/announcements/dismiss", ["POST"],
"/server/announcements/dismiss", RequestType.POST,
self._handle_dismiss_request
)
self.server.register_endpoint(
"/server/announcements/update", ["POST"],
"/server/announcements/update", RequestType.POST,
self._handle_update_request
)
self.server.register_endpoint(
"/server/announcements/feed", ["POST", "DELETE"],
"/server/announcements/feed", RequestType.POST | RequestType.DELETE,
self._handle_feed_request
)
self.server.register_endpoint(
"/server/announcements/feeds", ["GET"],
"/server/announcements/feeds", RequestType.GET,
self._handle_list_feeds
)
self.server.register_notification(
Expand Down Expand Up @@ -170,13 +171,13 @@ async def _handle_list_feeds(
async def _handle_feed_request(
self, web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action()
req_type = web_request.get_request_type()
name: str = web_request.get("name")
name = name.lower()
changed: bool = False
db: MoonrakerDatabase = self.server.lookup_component("database")
result = "skipped"
if action == "POST":
if req_type == RequestType.POST:
if name not in self.subscriptions:
feed = RssFeed(name, self.entry_mgr, self.dev_mode)
self.subscriptions[name] = feed
Expand All @@ -187,7 +188,7 @@ async def _handle_feed_request(
"moonraker", "announcements.stored_feeds", self.stored_feeds
)
result = "added"
elif action == "DELETE":
elif req_type == RequestType.DELETE:
if name not in self.stored_feeds:
raise self.server.error(f"Feed '{name}' not stored")
if name in self.configured_feeds:
Expand Down
64 changes: 37 additions & 27 deletions moonraker/components/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tornado.web import HTTPError
from libnacl.sign import Signer, Verifier
from ..utils import json_wrapper as jsonw
from ..common import RequestType, TransportType

# Annotation imports
from typing import (
Expand Down Expand Up @@ -226,32 +227,42 @@ def __init__(self, config: ConfigHelper) -> None:
self.permitted_paths.add("/access/refresh_jwt")
self.permitted_paths.add("/access/info")
self.server.register_endpoint(
"/access/login", ['POST'], self._handle_login,
transports=['http', 'websocket'])
"/access/login", RequestType.POST, self._handle_login,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/logout", ['POST'], self._handle_logout,
transports=['http', 'websocket'])
"/access/logout", RequestType.POST, self._handle_logout,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/refresh_jwt", ['POST'], self._handle_refresh_jwt,
transports=['http', 'websocket'])
"/access/refresh_jwt", RequestType.POST, self._handle_refresh_jwt,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/user", ['GET', 'POST', 'DELETE'],
self._handle_user_request, transports=['http', 'websocket'])
"/access/user", RequestType.all(), self._handle_user_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/users/list", ['GET'], self._handle_list_request,
transports=['http', 'websocket'])
"/access/users/list", RequestType.GET, self._handle_list_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/user/password", ['POST'], self._handle_password_reset,
transports=['http', 'websocket'])
"/access/user/password", RequestType.POST, self._handle_password_reset,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/api_key", ['GET', 'POST'],
self._handle_apikey_request, transports=['http', 'websocket'])
"/access/api_key", RequestType.GET | RequestType.POST,
self._handle_apikey_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/oneshot_token", ['GET'],
self._handle_oneshot_request, transports=['http', 'websocket'])
"/access/oneshot_token", RequestType.GET, self._handle_oneshot_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/access/info", ['GET'],
self._handle_info_request, transports=['http', 'websocket'])
"/access/info", RequestType.GET, self._handle_info_request,
transports=TransportType.HTTP | TransportType.WEBSOCKET
)
wsm: WebsocketManager = self.server.lookup_component("websockets")
wsm.register_notification("authorization:user_created")
wsm.register_notification(
Expand All @@ -274,8 +285,7 @@ async def component_init(self) -> None:
self.prune_timer.start(delay=PRUNE_CHECK_TIME)

async def _handle_apikey_request(self, web_request: WebRequest) -> str:
action = web_request.get_action()
if action.upper() == 'POST':
if web_request.get_request_type() == RequestType.POST:
self.api_key = uuid.uuid4().hex
self.users[API_USER]['api_key'] = self.api_key
self._sync_user(API_USER)
Expand Down Expand Up @@ -360,11 +370,11 @@ async def _handle_refresh_jwt(self,
'action': 'user_jwt_refresh'
}

async def _handle_user_request(self,
web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action()
if action == "GET":
async def _handle_user_request(
self, web_request: WebRequest
) -> Dict[str, Any]:
req_type = web_request.get_request_type()
if req_type == RequestType.GET:
user = web_request.get_current_user()
if user is None:
return {
Expand All @@ -378,10 +388,10 @@ async def _handle_user_request(self,
'source': user.get("source", "moonraker"),
'created_on': user.get('created_on')
}
elif action == "POST":
elif req_type == RequestType.POST:
# Create User
return await self._login_jwt_user(web_request, create=True)
elif action == "DELETE":
elif req_type == RequestType.DELETE:
# Delete User
return self._delete_jwt_user(web_request)
raise self.server.error("Invalid Request Method")
Expand Down
11 changes: 7 additions & 4 deletions moonraker/components/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import time
from collections import deque
from ..common import RequestType

# Annotation imports
from typing import (
Expand Down Expand Up @@ -59,11 +60,13 @@ def __init__(self, config: ConfigHelper) -> None:

# Register endpoints
self.server.register_endpoint(
"/server/temperature_store", ['GET'],
self._handle_temp_store_request)
"/server/temperature_store", RequestType.GET,
self._handle_temp_store_request
)
self.server.register_endpoint(
"/server/gcode_store", ['GET'],
self._handle_gcode_store_request)
"/server/gcode_store", RequestType.GET,
self._handle_gcode_store_request
)

async def _init_sensors(self) -> None:
klippy_apis: APIComp = self.server.lookup_component('klippy_apis')
Expand Down
28 changes: 16 additions & 12 deletions moonraker/components/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import lmdb
from ..utils import Sentinel, ServerError
from ..utils import json_wrapper as jsonw
from ..common import RequestType

# Annotation imports
from typing import (
Expand Down Expand Up @@ -174,15 +175,17 @@ def __init__(self, config: ConfigHelper) -> None:
self.insert_item("moonraker", "database.unsafe_shutdowns",
unsafe_shutdowns + 1)
self.server.register_endpoint(
"/server/database/list", ['GET'], self._handle_list_request)
"/server/database/list", RequestType.GET, self._handle_list_request
)
self.server.register_endpoint(
"/server/database/item", ["GET", "POST", "DELETE"],
self._handle_item_request)
"/server/database/item", RequestType.all(), self._handle_item_request
)
self.server.register_debug_endpoint(
"/debug/database/list", ['GET'], self._handle_list_request)
"/debug/database/list", RequestType.GET, self._handle_list_request
)
self.server.register_debug_endpoint(
"/debug/database/item", ["GET", "POST", "DELETE"],
self._handle_item_request)
"/debug/database/item", RequestType.all(), self._handle_item_request
)

def get_database_path(self) -> str:
return self.database_path
Expand Down Expand Up @@ -735,7 +738,7 @@ async def _handle_list_request(self,
async def _handle_item_request(self,
web_request: WebRequest
) -> Dict[str, Any]:
action = web_request.get_action()
req_type = web_request.get_request_type()
is_debug = web_request.get_endpoint().startswith("/debug/")
namespace = web_request.get_str("namespace")
if namespace in self.forbidden_namespaces and not is_debug:
Expand All @@ -744,7 +747,7 @@ async def _handle_item_request(self,
" is forbidden", 403)
key: Any
valid_types: Tuple[type, ...]
if action != "GET":
if req_type != RequestType.GET:
if namespace in self.protected_namespaces and not is_debug:
raise self.server.error(
f"Write access to namespace '{namespace}'"
Expand All @@ -758,16 +761,17 @@ async def _handle_item_request(self,
raise self.server.error(
"Value for argument 'key' is an invalid type: "
f"{type(key).__name__}")
if action == "GET":
if req_type == RequestType.GET:
val = await self.get_item(namespace, key)
elif action == "POST":
elif req_type == RequestType.POST:
val = web_request.get("value")
await self.insert_item(namespace, key, val)
elif action == "DELETE":
elif req_type == RequestType.DELETE:
val = await self.delete_item(namespace, key, drop_empty_db=True)

if is_debug:
self.debug_counter[action.lower()] += 1
name = req_type.name or str(req_type).split(".", 1)[-1]
self.debug_counter[name.lower()] += 1
await self.insert_item(
"moonraker", "database.debug_counter", self.debug_counter
)
Expand Down
14 changes: 7 additions & 7 deletions moonraker/components/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import asyncio
import pathlib
import logging
from ..common import BaseRemoteConnection
from ..common import BaseRemoteConnection, RequestType, TransportType
from ..utils import get_unix_peer_credentials

# Annotation imports
Expand Down Expand Up @@ -35,19 +35,19 @@ def __init__(self, config: ConfigHelper) -> None:
self.agent_methods: Dict[int, List[str]] = {}
self.uds_server: Optional[asyncio.AbstractServer] = None
self.server.register_endpoint(
"/connection/register_remote_method", ["POST"],
"/connection/register_remote_method", RequestType.POST,
self._register_agent_method,
transports=["websocket"]
transports=TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/connection/send_event", ["POST"], self._handle_agent_event,
transports=["websocket"]
"/connection/send_event", RequestType.POST, self._handle_agent_event,
transports=TransportType.WEBSOCKET
)
self.server.register_endpoint(
"/server/extensions/list", ["GET"], self._handle_list_extensions
"/server/extensions/list", RequestType.GET, self._handle_list_extensions
)
self.server.register_endpoint(
"/server/extensions/request", ["POST"], self._handle_call_agent
"/server/extensions/request", RequestType.POST, self._handle_call_agent
)

def register_agent(self, connection: BaseRemoteConnection) -> None:
Expand Down
43 changes: 27 additions & 16 deletions moonraker/components/file_manager/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from inotify_simple import flags as iFlags
from ...utils import source_info
from ...utils import json_wrapper as jsonw
from ...common import RequestType, TransportType

# Annotation imports
from typing import (
Expand Down Expand Up @@ -107,27 +108,37 @@ def __init__(self, config: ConfigHelper) -> None:

# Register file management endpoints
self.server.register_endpoint(
"/server/files/list", ['GET'], self._handle_filelist_request)
"/server/files/list", RequestType.GET, self._handle_filelist_request
)
self.server.register_endpoint(
"/server/files/metadata", ['GET'], self._handle_metadata_request)
"/server/files/metadata", RequestType.GET, self._handle_metadata_request
)
self.server.register_endpoint(
"/server/files/metascan", ['POST'], self._handle_metascan_request)
"/server/files/metascan", RequestType.POST, self._handle_metascan_request
)
self.server.register_endpoint(
"/server/files/thumbnails", ['GET'], self._handle_list_thumbs)
"/server/files/thumbnails", RequestType.GET, self._handle_list_thumbs
)
self.server.register_endpoint(
"/server/files/roots", ['GET'], self._handle_list_roots)
"/server/files/roots", RequestType.GET, self._handle_list_roots
)
self.server.register_endpoint(
"/server/files/directory", ['GET', 'POST', 'DELETE'],
self._handle_directory_request)
"/server/files/directory", RequestType.all(),
self._handle_directory_request
)
self.server.register_endpoint(
"/server/files/move", ['POST'], self._handle_file_move_copy)
"/server/files/move", RequestType.POST, self._handle_file_move_copy
)
self.server.register_endpoint(
"/server/files/copy", ['POST'], self._handle_file_move_copy)
"/server/files/copy", RequestType.POST, self._handle_file_move_copy
)
self.server.register_endpoint(
"/server/files/zip", ['POST'], self._handle_zip_files)
"/server/files/zip", RequestType.POST, self._handle_zip_files
)
self.server.register_endpoint(
"/server/files/delete_file", ['DELETE'], self._handle_file_delete,
transports=["websocket"])
"/server/files/delete_file", RequestType.DELETE, self._handle_file_delete,
transports=TransportType.WEBSOCKET
)
# register client notificaitons
self.server.register_notification("file_manager:filelist_changed")

Expand Down Expand Up @@ -458,24 +469,24 @@ async def _handle_directory_request(self,
) -> Dict[str, Any]:
directory = web_request.get_str('path', "gcodes")
root, dir_path = self._convert_request_path(directory)
method = web_request.get_action()
if method == 'GET':
req_type = web_request.get_request_type()
if req_type == RequestType.GET:
is_extended = web_request.get_boolean('extended', False)
# Get list of files and subdirectories for this target
dir_info = self._list_directory(dir_path, root, is_extended)
return dir_info
async with self.sync_lock:
self.check_reserved_path(dir_path, True)
action = "create_dir"
if method == 'POST' and root in self.full_access_roots:
if req_type == RequestType.POST and root in self.full_access_roots:
# Create a new directory
self.sync_lock.setup("create_dir", dir_path)
try:
os.mkdir(dir_path)
except Exception as e:
raise self.server.error(str(e))
self.fs_observer.on_item_create(root, dir_path, is_dir=True)
elif method == 'DELETE' and root in self.full_access_roots:
elif req_type == RequestType.DELETE and root in self.full_access_roots:
# Remove a directory
action = "delete_dir"
if directory.strip("/") == root:
Expand Down
Loading

0 comments on commit f81e340

Please sign in to comment.