diff --git a/src/petals/client/inference_session.py b/src/petals/client/inference_session.py index c52df709..4d94e7a7 100644 --- a/src/petals/client/inference_session.py +++ b/src/petals/client/inference_session.py @@ -84,8 +84,13 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[ break # this message means "done sending" def step( - self, inputs: torch.Tensor, prompts: torch.Tensor, hypo_ids: torch.LongTensor, *, - step_id: str, start_from_position: int + self, + inputs: torch.Tensor, + prompts: torch.Tensor, + hypo_ids: torch.LongTensor, + *, + step_id: str, + start_from_position: int, ) -> torch.Tensor: """ Inference step: send a chunk of input tensors and receive a chunk of outputs @@ -266,8 +271,11 @@ def __enter__(self) -> "InferenceSession": return self def step( - self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, - hypo_ids: Optional[torch.Tensor] = None, start_from_position: Optional[int] = None + self, + inputs: torch.Tensor, + prompts: Optional[torch.Tensor] = None, + hypo_ids: Optional[torch.Tensor] = None, + start_from_position: Optional[int] = None, ) -> torch.Tensor: if start_from_position is not None: @@ -317,8 +325,11 @@ def step( server_session = self._server_sessions[server_idx] inputs = server_session.step( - inputs, prompts[server_session.span.start : server_session.span.end], hypo_ids, - step_id=step_id, start_from_position=start_from_position + inputs, + prompts[server_session.span.start : server_session.span.end], + hypo_ids, + step_id=step_id, + start_from_position=start_from_position, ) server_idx += 1 diff --git a/src/petals/server/block_functions.py b/src/petals/server/block_functions.py index 3127d681..121ec8a0 100644 --- a/src/petals/server/block_functions.py +++ b/src/petals/server/block_functions.py @@ -162,7 +162,9 @@ async def iterate_rpc_inference( async for request, step_metadata in input_iterator: if "start_from_position" in step_metadata: start_from_position = step_metadata["start_from_position"] - assert prefix_length >= start_from_position, f"prefix_length={prefix_length}, start_from_position={start_from_position}" + assert ( + prefix_length >= start_from_position, + ), f"prefix_length={prefix_length}, start_from_position={start_from_position}" prefix_length = start_from_position flat_tensors = tuple(deserialize_torch_tensor(tensor) for tensor in request.tensors)