Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added primitives for speculative decoding and tests #598

Merged
merged 12 commits into from
Jul 24, 2024
36 changes: 21 additions & 15 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,24 @@ async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[
if not next_input_message.uid and not next_input_message.tensors:
break # this message means "done sending"

@property
def position(self):
return self._position

@position.setter
def position(self, start_from_position: int):
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None

def step(
self,
inputs: torch.Tensor,
prompts: torch.Tensor,
hypo_ids: torch.LongTensor,
*,
step_id: str,
start_from_position: int,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
Expand All @@ -100,12 +110,6 @@ def step(
if self.closed:
raise Exception("Session is closed, cannot perform step")

if start_from_position is not None:
assert start_from_position <= self._position
self._position = start_from_position
if self.history is not None and self.history.shape[1] >= start_from_position:
self.history = self.history[:, :start_from_position, :] if start_from_position > 0 else None

n_input_tokens = inputs.shape[1]
if self.history is None:
self.history = inputs
Expand All @@ -127,8 +131,8 @@ def step(
request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
if start_from_position is not None:
request_metadata["start_from_position"] = start_from_position
if self._position is not None:
request_metadata["start_from_position"] = self._position
elif self.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
Expand Down Expand Up @@ -235,6 +239,13 @@ def num_blocks(self) -> int:
def position(self) -> int:
return self._position

@position.setter
def position(self, start_from_position: int) -> None:
self._position = start_from_position
for session in self._server_sessions:
assert isinstance(session, _ServerInferenceSession)
session.position = start_from_position

def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_ServerInferenceSession]:
server_sessions = []
try:
Expand Down Expand Up @@ -275,12 +286,7 @@ def step(
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
start_from_position: Optional[int] = None,
) -> torch.Tensor:

if start_from_position is not None:
self._position = start_from_position

assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
Expand Down Expand Up @@ -324,12 +330,12 @@ def step(
self._update_sequence(server_idx, block_idx, attempt_no)

server_session = self._server_sessions[server_idx]
assert server_session.position == self.position, f"{server_session.position} and {self.position}"
inputs = server_session.step(
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids,
step_id=step_id,
start_from_position=start_from_position,
)

server_idx += 1
Expand Down
2 changes: 2 additions & 0 deletions src/petals/models/llama/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
DistributedLlamaForSequenceClassification,
DistributedLlamaModel,
)
from petals.models.llama.speculative_model import DistributedLlamaForSpeculativeGeneration
from petals.utils.auto_config import register_model_classes

register_model_classes(
config=DistributedLlamaConfig,
model=DistributedLlamaModel,
model_for_causal_lm=DistributedLlamaForCausalLM,
model_for_speculative=DistributedLlamaForSpeculativeGeneration,
model_for_sequence_classification=DistributedLlamaForSequenceClassification,
)
111 changes: 111 additions & 0 deletions src/petals/models/llama/speculative_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional, Union

import torch
from transformers.generation import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers.generation.utils import GenerateNonBeamOutput, GenerationMixin
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama import LlamaForCausalLM

from petals.models.llama.config import DistributedLlamaConfig
from petals.models.llama.model import DistributedLlamaForCausalLM


class DistributedLlamaForSpeculativeGeneration(DistributedLlamaForCausalLM, GenerationMixin):
def __init__(self, config: DistributedLlamaConfig, small_model: LlamaForCausalLM):
DistributedLlamaForCausalLM.__init__(self, config)
self.small_model = small_model

def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool,
streamer: Optional["BaseStreamer"],
logits_warper: Optional[LogitsProcessorList],
speculative_inference_iteration_size: int = 10,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
assert not generation_config.do_sample, "sample is not working for speculative generation now"
assert not synced_gpus, "synced_gpus is not working for speculative generation now"
assert (
not generation_config.return_dict_in_generate
), "return_dict_in_generate is not working for speculative generation now"

has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)

# keep track of which sequences are already finished
batch_size = input_ids.shape[0]
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
finished = False
firsts = True

while not finished:
speculative_inference_iteration_size = min(
speculative_inference_iteration_size, self.active_session._max_length - input_ids.shape[1]
)
with torch.no_grad():
speculative_outputs = self.small_model.generate(
input_ids,
max_new_tokens=speculative_inference_iteration_size,
do_sample=False,
)
speculative_tokens = speculative_outputs[:, -speculative_inference_iteration_size:]

full_sequence = torch.cat([input_ids, speculative_tokens], dim=-1)
assert input_ids.shape[1] + speculative_inference_iteration_size == full_sequence.shape[1]

input_for_validation = full_sequence
if not firsts:
self.active_session.position = input_ids.shape[1] - 1
input_for_validation = input_for_validation[:, -speculative_inference_iteration_size - 1 :]
else:
firsts = False
input_for_validation = input_for_validation[:, :-1]
with torch.no_grad():
precise_model_outputs = self(input_for_validation)
full_token_logits = precise_model_outputs.logits[:, -speculative_inference_iteration_size:, :].clone()

all_valid_tokens = []
first_token = None
for i in range(speculative_inference_iteration_size):
token_logits = full_token_logits[:, i, :]
token_scores = logits_processor(
input_for_validation[:, : -speculative_inference_iteration_size + 1 + i], token_logits
)
valid_token = torch.argmax(token_scores, dim=-1)

if first_token is None:
first_token = valid_token

if valid_token.item() == speculative_tokens[:, i].item():
all_valid_tokens.append(valid_token.unsqueeze(-1))
else:
break

if not all_valid_tokens and first_token is not None:
all_valid_tokens.append(first_token.unsqueeze(-1))
all_valid_tokens = torch.cat(all_valid_tokens, dim=-1)

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
all_valid_tokens = all_valid_tokens * unfinished_sequences + generation_config.pad_token_id * (
1 - unfinished_sequences
)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, all_valid_tokens], dim=-1)

if streamer is not None:
streamer.put(all_valid_tokens.cpu())

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, None)
finished = unfinished_sequences.max() == 0

del precise_model_outputs

if streamer is not None:
streamer.end()

return input_ids
1 change: 1 addition & 0 deletions src/petals/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
AutoDistributedModel,
AutoDistributedModelForCausalLM,
AutoDistributedModelForSequenceClassification,
AutoDistributedSpeculativeModel,
)
from petals.utils.dht import declare_active_modules, get_remote_module_infos
5 changes: 5 additions & 0 deletions src/petals/utils/auto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class _ModelClasses:
config: Type[PretrainedConfig]
model: Optional[Type[PreTrainedModel]] = None
model_for_causal_lm: Optional[Type[PreTrainedModel]] = None
model_for_speculative: Optional[Type[PreTrainedModel]] = None
model_for_sequence_classification: Optional[Type[PreTrainedModel]] = None


Expand Down Expand Up @@ -90,5 +91,9 @@ class AutoDistributedModelForCausalLM(DefaultRevisionMixin, _AutoDistributedBase
_mapping_field = "model_for_causal_lm"


class AutoDistributedSpeculativeModel(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_speculative"


class AutoDistributedModelForSequenceClassification(DefaultRevisionMixin, _AutoDistributedBase):
_mapping_field = "model_for_sequence_classification"
54 changes: 52 additions & 2 deletions tests/test_speculative_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@

import pytest
import torch
import transformers

from petals import AutoDistributedConfig, RemoteSequential
from petals import (
AutoDistributedConfig,
AutoDistributedSpeculativeModel,
DistributedLlamaForSpeculativeGeneration,
RemoteSequential,
)
from petals.server.block_functions import MAX_SHORT_INFERENCE_TOKENS
from petals.server.from_pretrained import load_pretrained_block
from test_utils import *
Expand All @@ -26,10 +32,54 @@ def test_remote_block_with_cache_invalidation_exact_match(atol_forward=1e-4, ato
with torch.inference_mode():
with remote_block.inference_session(max_length=inputs.shape[1]) as sess:
initial_outputs_inference = sess.step(inputs)
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :], start_from_position=2)
sess.position = 2
secondary_outputs_inference = sess.step(short_inputs[:, 2:, :])
result = torch.cat([initial_outputs_inference[:, :2, :], secondary_outputs_inference], dim=1)

ref_block = load_pretrained_block(MODEL_NAME, block_index, torch_dtype=torch.float32)
(outputs_local,) = ref_block(short_inputs)

assert torch.allclose(outputs_local, result, rtol=0, atol=atol_inference)


@pytest.fixture
def noisy_model():
noisy_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)
lm_head = noisy_model.get_output_embeddings()
assert isinstance(lm_head, torch.nn.Linear)
with torch.no_grad():
lm_head.weight += torch.randn_like(lm_head.weight) * 0.02
return noisy_model


@pytest.fixture
def model():
return transformers.AutoModelForCausalLM.from_pretrained(
MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32
)


@pytest.fixture
def tokenizer():
# We set use_fast=False since LlamaTokenizerFast is slow on load
return transformers.AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)


@pytest.mark.forked
@pytest.mark.skipif(
"llama" not in MODEL_NAME.lower(),
reason="Speculative generation now works only for llama models",
)
def test_remote_speculative_generation(tokenizer, model, noisy_model, atol_inference=1e-3):
speculated_distributed_model = AutoDistributedSpeculativeModel.from_pretrained(
MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype=torch.float32, small_model=noisy_model
)

inputs_single = tokenizer("A cat sat on a mat", return_tensors="pt")["input_ids"]

generated_spec = speculated_distributed_model.generate(inputs_single, max_new_tokens=100, do_sample=False)
generated_local = model.generate(inputs_single, max_new_tokens=100, do_sample=False)

assert torch.allclose(generated_spec, generated_local, rtol=0, atol=atol_inference)
Loading