Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kv cache #120

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/chatglm/chat_sat.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ def chat(model, tokenizer,
[inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0
)
# ---------------
strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=0, end_tokens=[tokenizer.eos_token_id])
strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=0, end_tokens=[tokenizer.eos_token_id], num_beams=num_beams, consider_end=True)
strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id])
output = filling_sequence(
model, seq,
batch_size=1,
Expand Down Expand Up @@ -136,7 +135,7 @@ def chat(model, tokenizer,
use_gpu_initialization=True,
))
model = model.eval()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
model.add_mixin('auto-regressive', CachedAutoregressiveMixin(model_args.num_layers, model_args.num_attention_heads, model_args.hidden_size, 10000))

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
history = None
Expand Down
2 changes: 0 additions & 2 deletions sat/generation/autoregressive_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def filling_sequence(
log_attention_weights=log_attention_weights_part,
**kw_args
)
mem_kv = [o['mem_kv'] for o in output_per_layers]
mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length)
counter += 1
index = counter
# sampling
Expand Down
82 changes: 68 additions & 14 deletions sat/model/cached_autoregressive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,81 @@
from .base_model import BaseModel, BaseMixin, non_conflict
from sat.model.transformer import standard_attention, split_tensor_along_last_dim


class VectorKvCache():
def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, factor=2):
"""
1. capacity: the size of the storage space currently allocated for the cache
2. max_len: the max length of tokens
3. self.mem_size: the number of elements in the cache, like size in c++ vector
"""
self.factor = factor
if self.factor <= 1.0 :
raise ValueError("factor should be greater than 1.")
self.max_len = max_len
self.mem_size = 0
self.capacity = capacity
self.mems_kv = None # [2, num_layers, batch_size, head_nums, seq_len, size_per_head]
self.num_layers = num_layers
self.head_nums = head_nums
self.hidden_units = hidden_units
self.size_per_head = int(hidden_units / head_nums)

def append_kv(self, k, v, layer_id):
b, nh, seq_len, hidden_size = k.shape
mem_len = self.mem_size
self.mems_kv[0, layer_id, :, :, mem_len:mem_len+seq_len, :] = k
self.mems_kv[1, layer_id, :, :, mem_len:mem_len+seq_len, :] = v


def get_kv(self, layer_id, seq_len):
# return key value for attention forward
seq_len = self.mem_size + seq_len
k = self.mems_kv[0, layer_id, :, :, :seq_len, :]
v = self.mems_kv[1, layer_id, :, :, :seq_len, :]
return k, v

def get_mem_size(self):
return self.mem_size

def update_mem_size(self, seq_len):
self.mem_size += seq_len

def reMalloc(self, seq_len, batch_size, dtype, device):
new_capacity = seq_len + self.mem_size
if new_capacity > self.capacity:
new_mems_size = [2, self.num_layers, batch_size, self.head_nums, 0, self.size_per_head] # [num_layers, batch_size, head_num, seq_len, size_per_head]
if int(new_capacity * self.factor) <= self.max_len:
new_mems_size[4] = int(new_capacity * self.factor)
self.capacity = int(new_capacity * self.factor)
else:
new_mems_size[4] = self.max_len
self.capacity = self.max_len
new_mems_kv = torch.empty(*new_mems_size, dtype=dtype, device=device)
if self.mems_kv is not None :
new_mems_kv[:, :, :, :, :self.mem_size, :] = self.mems_kv
self.mems_kv = new_mems_kv

class CachedAutoregressiveMixin(BaseMixin):
def __init__(self):
super().__init__()
def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, factor=2):
super().__init__()
self.num_layers = num_layers
self.mems = VectorKvCache(num_layers, head_nums, hidden_units, max_len, capacity=capacity, factor=factor)

@non_conflict
def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention,
def attention_fn(self, q, k, v, mask, dropout_fn, cross_attention=False, old_impl=standard_attention,
**kw_args):
if not cross_attention:
mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size
layer_id = kw_args['layer_id']
b, nh, seq_len, hidden_size = k.shape
if layer_id == 0 :
self.mems.reMalloc(seq_len, b, k.dtype, k.device)
self.mems.append_kv(k, v, layer_id)
k, v = self.mems.get_kv(layer_id, seq_len)
if layer_id == self.num_layers - 1 :
self.mems.update_mem_size(seq_len)

cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2)
kw_args['output_this_layer']['mem_kv'] = cache_kv

if mem is not None: # the first time, mem is None
# might change batch_size
mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4)
memk, memv = mem[0], mem[1]
k = torch.cat((memk, k), dim=2)
v = torch.cat((memv, v), dim=2)
return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args)
return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, **kw_args)


class CachedAutoregressiveModel(BaseModel):
Expand Down