Skip to content

Commit

Permalink
Improvements in execution over SSH
Browse files Browse the repository at this point in the history
* Supporting non-Genv shell commands

- Minor cosmetics
  • Loading branch information
razrotenberg committed Feb 5, 2024
1 parent c27500e commit 9e0af0f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 55 deletions.
39 changes: 22 additions & 17 deletions genv/remote/utils/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,44 +43,49 @@ class Command:
:param args: Genv command arguments
:param sudo: Run command as root using sudo
:param shell: Run command as regular shell command
"""

args: Iterable[str]
sudo: bool = False
shell: bool = False

@property
def all_args(self) -> Iterable[str]:
"""Returns all command arguments"""

return self.args if self.shell else ["genv"] + self.args


async def run(
config: Config, command: Command, stdins: Optional[Iterable[str]] = None
config: Config,
command: Command,
stdins: Optional[Iterable[str]] = None,
) -> Tuple[Iterable[Host], Iterable[str]]:
"""
Runs a Genv command on multiple hosts over SSH.
Runs a command on multiple hosts over SSH.
Waits for the command to finish on all hosts.
Raises 'RuntimeError' if failed to connect to any of the hosts and 'config.throw_on_error' is True.
:param config: Execution configuration
:param command: Genv command specification
:param command: Command specification
:param stdins: Input to send per host
:return: Returns the hosts that succeeded and their standard outputs
"""
ssh_runners_and_inputs = []
for host, stdin in zip(config.hosts, stdins or [None for _ in config.hosts]):
ssh_runner = Runner(
host.hostname,
host.timeout,
)
ssh_runners_and_inputs.append((ssh_runner, stdin))

ssh_outputs = await asyncio.gather(
runners = [Runner(host.hostname, host.timeout) for host in config.hosts]

results = await asyncio.gather(
*(
ssh_runner.run("genv", *command.args, stdin=stdin, sudo=command.sudo)
for ssh_runner, stdin in ssh_runners_and_inputs
runner.run(*command.all_args, stdin=stdin, sudo=command.sudo, check=False)
for runner, stdin in zip(runners, stdins or [None for _ in runners])
)
)

processes = [runner_output.command_process for runner_output in ssh_outputs]
stdouts = [runner_output.stdout for runner_output in ssh_outputs]
stderrs = [runner_output.stderr for runner_output in ssh_outputs]
processes = [result.process for result in results]
stdouts = [result.stdout for result in results]
stderrs = [result.stderr for result in results]

def filter(
objs: Iterable[Any], pred: Callable[[asyncio.subprocess.Process], bool]
Expand All @@ -94,7 +99,7 @@ def failed(objs: Iterable[Any]) -> Iterable[Any]:
return filter(objs, lambda process: process.returncode != 0)

for host, stderr in failed(zip(config.hosts, stderrs)):
message = f"Failed connecting over SSH to {host.hostname} ({stderr})"
message = f"Failed running SSH command on {host.hostname} ({stderr})"

if config.throw_on_error:
raise RuntimeError(message)
Expand Down
6 changes: 0 additions & 6 deletions genv/utils/runners/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@


class Runner(Base):
def name(self) -> str:
return "local"

async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process:
if sudo:
args = ["sudo", *args]
Expand All @@ -19,6 +16,3 @@ async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process:
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)

def _get_error_msg(self, command: str, stderr: str):
return f"Failed to run a command on the local machine: command: '{command}' ({stderr})"
41 changes: 19 additions & 22 deletions genv/utils/runners/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,51 +5,48 @@


class CommandResults:
command_process: Process
process: Process
stdout: str
stderr: str

def __init__(self, command_process: Process, stdout: str, stderr: str):
self.command_process = command_process
def __init__(self, process: Process, stdout: str, stderr: str):
self.process = process
self.stdout = stdout
self.stderr = stderr


class Runner(ABC):
DEFAULT_STD_ENCODING = 'utf-8'
_process_env: Dict[str, str]

def __init__(self, process_env: Optional[Dict] = None):
self._process_env = process_env

async def run(self, *args: str, stdin: Optional[str] = None, sudo: bool = False, check: bool = False
) -> CommandResults:
async def run(
self,
*args: str,
stdin: Optional[str] = None,
sudo: bool = False,
check: bool = False,
) -> CommandResults:
stdin_fd = asyncio.subprocess.PIPE if stdin else asyncio.subprocess.DEVNULL
process = await self._open_process(*args,
stdin_fd=stdin_fd,
sudo=sudo)
process = await self._open_process(*args, stdin_fd=stdin_fd, sudo=sudo)

stdout, stderr = await process.communicate(stdin.encode(Runner.DEFAULT_STD_ENCODING) if stdin else None)
stdout = stdout.decode(self.DEFAULT_STD_ENCODING).strip()
stderr = stderr.decode(self.DEFAULT_STD_ENCODING).strip()
stdout, stderr = await process.communicate(
stdin.encode("utf-8") if stdin else None
)

stdout = stdout.decode("utf-8").strip()
stderr = stderr.decode("utf-8").strip()

if check and process.returncode != 0:
command = " ".join(args)
stdin_str = ' with stdin ' + str(stdin)
stdin_str = " with stdin " + str(stdin)
raise RuntimeError(
f"Failed running '{command}' {' with sudo ' if sudo else ''} { stdin_str if stdin else '' } ({stderr})"
)

return CommandResults(process, stdout, stderr)

@abstractmethod
def name(self) -> str:
raise NotImplementedError('This should be implemented in subclasses')

@abstractmethod
async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process:
raise NotImplementedError('This should be implemented in subclasses')

@abstractmethod
def _get_error_msg(self, command: str, stderr: str):
raise NotImplementedError('This should be implemented in subclasses')
raise NotImplementedError("This should be implemented in subclasses")
12 changes: 2 additions & 10 deletions genv/utils/runners/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
class Runner(Base):
host_name: str
timeout: Optional[int]
__SSH_COMMAND_PREFIX = "ssh"
__SSH_TIMEOUT_PARAMETER = "-o ConnectTimeout={0}"

def __init__(
self,
Expand All @@ -21,15 +19,12 @@ def __init__(
self.host_name = host_name
self.timeout = timeout

def name(self) -> str:
return self.host_name

async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process:
ssh_parameters = self.calc_ssh_params()
remote_command = self.calc_command_on_remote_machine(args, sudo)

return await asyncio.create_subprocess_exec(
Runner.__SSH_COMMAND_PREFIX,
"ssh",
*ssh_parameters,
remote_command,
stdin=stdin_fd,
Expand All @@ -40,7 +35,7 @@ async def _open_process(self, *args: str, stdin_fd: int, sudo: bool) -> Process:
def calc_ssh_params(self) -> List[str]:
ssh_parameters = []
if self.timeout is not None:
ssh_parameters.append(self.__SSH_TIMEOUT_PARAMETER.format(self.timeout))
ssh_parameters.append(f"-o ConnectTimeout={self.timeout}")
ssh_parameters.append(self.host_name)
return ssh_parameters

Expand All @@ -52,9 +47,6 @@ def calc_command_on_remote_machine(self, args: Tuple[str, ...], sudo: bool) -> s
command = f"sudo {command}"
return command

def _get_error_msg(self, command: str, stderr: str):
return f"Failed to run a command using ssh on {self.host_name}: command: '{command}' ({stderr})"

@staticmethod
def _add_environment_vars(command: str, process_env: Dict[str, str]):
env_str = " ".join(
Expand Down

0 comments on commit 9e0af0f

Please sign in to comment.