diff --git a/genv/remote/utils/ssh.py b/genv/remote/utils/ssh.py index 6bb965e..0d18d3f 100644 --- a/genv/remote/utils/ssh.py +++ b/genv/remote/utils/ssh.py @@ -43,44 +43,49 @@ class Command: :param args: Genv command arguments :param sudo: Run command as root using sudo + :param shell: Run command as regular shell command """ args: Iterable[str] sudo: bool = False + shell: bool = False + + @property + def all_args(self) -> Iterable[str]: + """Returns all command arguments""" + + return self.args if self.shell else ["genv"] + self.args async def run( - config: Config, command: Command, stdins: Optional[Iterable[str]] = None + config: Config, + command: Command, + stdins: Optional[Iterable[str]] = None, ) -> Tuple[Iterable[Host], Iterable[str]]: """ - Runs a Genv command on multiple hosts over SSH. + Runs a command on multiple hosts over SSH. Waits for the command to finish on all hosts. Raises 'RuntimeError' if failed to connect to any of the hosts and 'config.throw_on_error' is True. :param config: Execution configuration - :param command: Genv command specification + :param command: Command specification :param stdins: Input to send per host :return: Returns the hosts that succeeded and their standard outputs """ - ssh_runners_and_inputs = [] - for host, stdin in zip(config.hosts, stdins or [None for _ in config.hosts]): - ssh_runner = Runner( - host.hostname, - host.timeout, - ) - ssh_runners_and_inputs.append((ssh_runner, stdin)) - ssh_outputs = await asyncio.gather( + runners = [Runner(host.hostname, host.timeout) for host in config.hosts] + + results = await asyncio.gather( *( - ssh_runner.run("genv", *command.args, stdin=stdin, sudo=command.sudo) - for ssh_runner, stdin in ssh_runners_and_inputs + runner.run(*command.all_args, stdin=stdin, sudo=command.sudo, check=False) + for runner, stdin in zip(runners, stdins or [None for _ in runners]) ) ) - processes = [runner_output.command_process for runner_output in ssh_outputs] - stdouts = [runner_output.stdout for runner_output in ssh_outputs] - stderrs = [runner_output.stderr for runner_output in ssh_outputs] + processes = [result.process for result in results] + stdouts = [result.stdout for result in results] + stderrs = [result.stderr for result in results] def filter( objs: Iterable[Any], pred: Callable[[asyncio.subprocess.Process], bool] @@ -94,7 +99,7 @@ def failed(objs: Iterable[Any]) -> Iterable[Any]: return filter(objs, lambda process: process.returncode != 0) for host, stderr in failed(zip(config.hosts, stderrs)): - message = f"Failed connecting over SSH to {host.hostname} ({stderr})" + message = f"Failed running SSH command on {host.hostname} ({stderr})" if config.throw_on_error: raise RuntimeError(message) diff --git a/genv/utils/runners/local.py b/genv/utils/runners/local.py index fb40754..f57a6e6 100644 --- a/genv/utils/runners/local.py +++ b/genv/utils/runners/local.py @@ -5,9 +5,6 @@ class Runner(Base): - def name(self) -> str: - return "local" - async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process: if sudo: args = ["sudo", *args] @@ -19,6 +16,3 @@ async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process: stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) - - def _get_error_msg(self, command: str, stderr: str): - return f"Failed to run a command on the local machine: command: '{command}' ({stderr})" diff --git a/genv/utils/runners/runner.py b/genv/utils/runners/runner.py index 58417e7..e942a6c 100644 --- a/genv/utils/runners/runner.py +++ b/genv/utils/runners/runner.py @@ -5,51 +5,48 @@ class CommandResults: - command_process: Process + process: Process stdout: str stderr: str - def __init__(self, command_process: Process, stdout: str, stderr: str): - self.command_process = command_process + def __init__(self, process: Process, stdout: str, stderr: str): + self.process = process self.stdout = stdout self.stderr = stderr class Runner(ABC): - DEFAULT_STD_ENCODING = 'utf-8' _process_env: Dict[str, str] def __init__(self, process_env: Optional[Dict] = None): self._process_env = process_env - async def run(self, *args: str, stdin: Optional[str] = None, sudo: bool = False, check: bool = False - ) -> CommandResults: + async def run( + self, + *args: str, + stdin: Optional[str] = None, + sudo: bool = False, + check: bool = False, + ) -> CommandResults: stdin_fd = asyncio.subprocess.PIPE if stdin else asyncio.subprocess.DEVNULL - process = await self._open_process(*args, - stdin_fd=stdin_fd, - sudo=sudo) + process = await self._open_process(*args, stdin_fd=stdin_fd, sudo=sudo) - stdout, stderr = await process.communicate(stdin.encode(Runner.DEFAULT_STD_ENCODING) if stdin else None) - stdout = stdout.decode(self.DEFAULT_STD_ENCODING).strip() - stderr = stderr.decode(self.DEFAULT_STD_ENCODING).strip() + stdout, stderr = await process.communicate( + stdin.encode("utf-8") if stdin else None + ) + + stdout = stdout.decode("utf-8").strip() + stderr = stderr.decode("utf-8").strip() if check and process.returncode != 0: command = " ".join(args) - stdin_str = ' with stdin ' + str(stdin) + stdin_str = " with stdin " + str(stdin) raise RuntimeError( f"Failed running '{command}' {' with sudo ' if sudo else ''} { stdin_str if stdin else '' } ({stderr})" ) return CommandResults(process, stdout, stderr) - @abstractmethod - def name(self) -> str: - raise NotImplementedError('This should be implemented in subclasses') - @abstractmethod async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process: - raise NotImplementedError('This should be implemented in subclasses') - - @abstractmethod - def _get_error_msg(self, command: str, stderr: str): - raise NotImplementedError('This should be implemented in subclasses') + raise NotImplementedError("This should be implemented in subclasses") diff --git a/genv/utils/runners/ssh.py b/genv/utils/runners/ssh.py index 5ccf611..3f18e5b 100644 --- a/genv/utils/runners/ssh.py +++ b/genv/utils/runners/ssh.py @@ -8,8 +8,6 @@ class Runner(Base): host_name: str timeout: Optional[int] - __SSH_COMMAND_PREFIX = "ssh" - __SSH_TIMEOUT_PARAMETER = "-o ConnectTimeout={0}" def __init__( self, @@ -21,15 +19,12 @@ def __init__( self.host_name = host_name self.timeout = timeout - def name(self) -> str: - return self.host_name - async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process: ssh_parameters = self.calc_ssh_params() remote_command = self.calc_command_on_remote_machine(args, sudo) return await asyncio.create_subprocess_exec( - Runner.__SSH_COMMAND_PREFIX, + "ssh", *ssh_parameters, remote_command, stdin=stdin_fd, @@ -40,7 +35,7 @@ async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process: def calc_ssh_params(self) -> List[str]: ssh_parameters = [] if self.timeout is not None: - ssh_parameters.append(self.__SSH_TIMEOUT_PARAMETER.format(self.timeout)) + ssh_parameters.append(f"-o ConnectTimeout={self.timeout}") ssh_parameters.append(self.host_name) return ssh_parameters @@ -52,9 +47,6 @@ def calc_command_on_remote_machine(self, args: Tuple[str, ...], sudo: bool) -> s command = f"sudo {command}" return command - def _get_error_msg(self, command: str, stderr: str): - return f"Failed to run a command using ssh on {self.host_name}: command: '{command}' ({stderr})" - @staticmethod def _add_environment_vars(command: str, process_env: Dict[str, str]): env_str = " ".join(