diff --git a/python/requirements.txt b/python/requirements.txt index 0e61b10193e..0b863e88178 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -8,3 +8,4 @@ typing_extensions>=4.7.1 construct-classes>=0.1.2 appdirs>=1.4.4 cryptography >=43.0.3 +platformdirs >=2 diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 867cec4081c..3ba76a67d9c 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -29,7 +29,7 @@ from ..client import TrezorClient from ..messages import Capability from ..transport import Transport -from ..transport.thp import channel_database +from ..transport.thp.channel_database import get_channel_db LOG = logging.getLogger(__name__) @@ -102,7 +102,7 @@ def get_passphrase( def get_client(transport: Transport) -> TrezorClient: - stored_channels = channel_database.load_stored_channels() + stored_channels = get_channel_db().load_stored_channels() stored_transport_paths = [ch.transport_path for ch in stored_channels] path = transport.get_path() if path in stored_transport_paths: @@ -115,7 +115,7 @@ def get_client(transport: Transport) -> TrezorClient: ) except Exception: LOG.debug("Failed to resume a channel. Replacing by a new one.") - channel_database.remove_channel(path) + get_channel_db().remove_channel(path) client = TrezorClient(transport) else: client = TrezorClient(transport) @@ -355,7 +355,7 @@ def trezorctl_command_with_client( try: return func(client, *args, **kwargs) finally: - channel_database.save_channel(client.protocol) + get_channel_db().save_channel(client.protocol) # if not session_was_resumed: # try: # client.end_session() diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index bdf5206e9fd..183ed8026fb 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -29,6 +29,7 @@ from ..transport import DeviceIsBusy, enumerate_devices from ..transport.session import Session from ..transport.thp import channel_database +from ..transport.thp.channel_database import get_channel_db from ..transport.udp import UdpTransport from . import ( AliasedGroup, @@ -196,6 +197,13 @@ def configure_logging(verbose: int) -> None: "--record", help="Record screen changes into a specified directory.", ) +@click.option( + "-n", + "--no-store", + is_flag=True, + help="Do not store channels data between commands.", + default=False, +) @click.version_option(version=__version__) @click.pass_context def cli_main( @@ -207,9 +215,10 @@ def cli_main( script: bool, session_id: Optional[str], record: Optional[str], + no_store: bool, ) -> None: configure_logging(verbose) - + channel_database.set_channel_database(should_not_store=no_store) bytes_session_id: Optional[bytes] = None if session_id is not None: try: @@ -296,10 +305,7 @@ def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: try: client = get_client(transport) description = format_device_name(client.features) - # json_string = channel_database.channel_to_str(client.protocol) - # print(json_string) - channel_database.save_channel(client.protocol) - # client.end_session() + get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" except Exception: @@ -376,9 +382,14 @@ def clear_session(session: "Session") -> None: @cli.command() -def new_clear_session() -> None: - """New Clear session (remove cached channels from trezorlib).""" - channel_database.clear_stored_channels() +def delete_channels() -> None: + """ + Delete cached channels. + + Do not use together with the `-n` (`--no-store`) flag, + as the JSON database will not be deleted. + """ + get_channel_db().clear_stored_channels() @cli.command() diff --git a/python/src/trezorlib/transport/thp/channel_database.py b/python/src/trezorlib/transport/thp/channel_database.py index 100bf150b60..143430069fb 100644 --- a/python/src/trezorlib/transport/thp/channel_database.py +++ b/python/src/trezorlib/transport/thp/channel_database.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import os @@ -8,39 +10,104 @@ LOG = logging.getLogger(__name__) -if True: - from platformdirs import user_cache_dir, user_config_dir +db: "ChannelDatabase | None" = None + + +def get_channel_db() -> ChannelDatabase: + if db is None: + set_channel_database(should_not_store=True) + assert db is not None + return db + + +class ChannelDatabase: - APP_NAME = "@trezor" # TODO - DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json") - CONFIG_PATH = os.path.join(user_config_dir(appname=APP_NAME), "config.json") -else: - DATA_PATH = os.path.join("./channel_data.json") - CONFIG_PATH = os.path.join("./config.json") + def load_stored_channels(self) -> t.List[ChannelData]: ... + def clear_stored_channels(self) -> None: ... + def read_all_channels(self) -> t.List: ... + def save_all_channels(self, channels: t.List[t.Dict]) -> None: ... + def save_channel(self, new_channel: ProtocolAndChannel): ... + def remove_channel(self, transport_path: str) -> None: ... -class ChannelDatabase: # TODO not finished - should_store: bool = False +class DummyChannelDatabase(ChannelDatabase): - def __init__( - self, config_path: str = CONFIG_PATH, data_path: str = DATA_PATH - ) -> None: - if not os.path.exists(CONFIG_PATH): - with open(CONFIG_PATH, "w") as f: - json.dump([], f) + def load_stored_channels(self) -> t.List[ChannelData]: + return [] + def clear_stored_channels(self) -> None: + pass -def load_stored_channels() -> t.List[ChannelData]: - dicts = read_all_channels() - return [dict_to_channel_data(d) for d in dicts] + def read_all_channels(self) -> t.List: + return [] + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + return -def channel_to_str(channel: ProtocolAndChannel) -> str: - return json.dumps(channel.get_channel_data().to_dict()) + def save_channel(self, new_channel: ProtocolAndChannel): + pass + def remove_channel(self, transport_path: str) -> None: + pass -def str_to_channel_data(channel_data: str) -> ChannelData: - return dict_to_channel_data(json.loads(channel_data)) + +class JsonChannelDatabase(ChannelDatabase): + def __init__(self, data_path: str) -> None: + self.data_path = data_path + super().__init__() + + def load_stored_channels(self) -> t.List[ChannelData]: + dicts = self.read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + def clear_stored_channels(self) -> None: + LOG.debug("Clearing contents of %s", self.data_path) + with open(self.data_path, "w") as f: + json.dump([], f) + try: + os.remove(self.data_path) + except Exception as e: + LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e))) + + def read_all_channels(self) -> t.List: + ensure_file_exists(self.data_path) + with open(self.data_path, "r") as f: + return json.load(f) + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(self.data_path, "w") as f: + json.dump(channels, f, indent=4) + + def save_channel(self, new_channel: ProtocolAndChannel): + + LOG.debug("save channel") + channels = self.read_all_channels() + transport_path = new_channel.transport.get_path() + + # If the channel is found in database: replace the old entry by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + self.save_all_channels(channels) + return + + # Channel was not found: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + self.save_all_channels(channels) + + def remove_channel(self, transport_path: str) -> None: + LOG.debug( + "Removing channel with path %s from the channel database.", + transport_path, + ) + channels = self.read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + self.save_all_channels(remaining_channels) def dict_to_channel_data(dict: t.Dict) -> ChannelData: @@ -57,63 +124,23 @@ def dict_to_channel_data(dict: t.Dict) -> ChannelData: ) -def ensure_file_exists() -> None: - LOG.debug("checking if file %s exists", DATA_PATH) - if not os.path.exists(DATA_PATH): - os.makedirs(os.path.dirname(DATA_PATH), exist_ok=True) - LOG.debug("File %s does not exist. Creating a new one.", DATA_PATH) - with open(DATA_PATH, "w") as f: +def ensure_file_exists(file_path: str) -> None: + LOG.debug("checking if file %s exists", file_path) + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + LOG.debug("File %s does not exist. Creating a new one.", file_path) + with open(file_path, "w") as f: json.dump([], f) -def clear_stored_channels() -> None: - LOG.debug("Clearing contents of %s", DATA_PATH) - with open(DATA_PATH, "w") as f: - json.dump([], f) - try: - os.remove(DATA_PATH) - except Exception as e: - LOG.exception("Failed to delete %s (%s)", DATA_PATH, str(type(e))) +def set_channel_database(should_not_store: bool): + global db + if should_not_store: + db = DummyChannelDatabase() + else: + from platformdirs import user_cache_dir + APP_NAME = "@trezor" # TODO + DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json") -def read_all_channels() -> t.List: - ensure_file_exists() - with open(DATA_PATH, "r") as f: - return json.load(f) - - -def save_all_channels(channels: t.List[t.Dict]) -> None: - LOG.debug("saving all channels") - with open(DATA_PATH, "w") as f: - json.dump(channels, f, indent=4) - - -def save_channel(new_channel: ProtocolAndChannel): - LOG.debug("save channel") - channels = read_all_channels() - transport_path = new_channel.transport.get_path() - - # If the channel is found in database: replace the old entry by the new - for i, channel in enumerate(channels): - if channel["transport_path"] == transport_path: - LOG.debug("Modified channel entry for %s", transport_path) - channels[i] = new_channel.get_channel_data().to_dict() - save_all_channels(channels) - return - - # Channel was not found: add a new channel entry - LOG.debug("Created a new channel entry on path %s", transport_path) - channels.append(new_channel.get_channel_data().to_dict()) - save_all_channels(channels) - - -def remove_channel(transport_path: str) -> None: - LOG.debug( - "Removing channel with path %s from the channel database.", - transport_path, - ) - channels = read_all_channels() - remaining_channels = [ - ch for ch in channels if ch["transport_path"] != transport_path - ] - save_all_channels(remaining_channels) + db = JsonChannelDatabase(DATA_PATH) diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 3824a2a43cf..f99021a29cd 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -18,7 +18,8 @@ from ..thp.channel_data import ChannelData from ..thp.checksum import CHECKSUM_LENGTH from ..thp.message_header import MessageHeader -from . import channel_database, control_byte +from . import control_byte +from .channel_database import ChannelDatabase, get_channel_db from .protocol_and_channel import ProtocolAndChannel LOG = logging.getLogger(__name__) @@ -76,6 +77,7 @@ def __init__( self.sync_bit_receive = channel_data.sync_bit_receive self.sync_bit_send = channel_data.sync_bit_send self._has_valid_channel = True + self.channel_database: ChannelDatabase = get_channel_db() def get_channel(self) -> ProtocolV2: if not self._has_valid_channel: @@ -99,13 +101,13 @@ def read(self, session_id: int) -> t.Any: sid, msg_type, msg_data = self.read_and_decrypt() if sid != session_id: raise Exception("Received messsage on a different session.") - channel_database.save_channel(self) + self.channel_database.save_channel(self) return self.mapping.decode(msg_type, msg_data) def write(self, session_id: int, msg: t.Any) -> None: msg_type, msg_data = self.mapping.encode(msg) self._encrypt_and_write(session_id, msg_type, msg_data) - channel_database.save_channel(self) + self.channel_database.save_channel(self) def get_features(self) -> messages.Features: if not self._has_valid_channel: diff --git a/tests/conftest.py b/tests/conftest.py index 7e03aca60a8..9356d8b38ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -322,9 +322,9 @@ def client( # Get a new client _raw_client = _get_raw_client(request) - from trezorlib.transport.thp import channel_database + from trezorlib.transport.thp.channel_database import get_channel_db - channel_database.clear_stored_channels() + get_channel_db().clear_stored_channels() _raw_client.protocol = None _raw_client.__init__( transport=_raw_client.transport,