From 3c8c81a1b305ed8a7ab8593b4a6110e04d4819ad Mon Sep 17 00:00:00 2001 From: zhaolei Date: Sun, 9 Jul 2023 22:51:11 +0800 Subject: [PATCH 1/3] add vector-kvcache --- examples/chatglm/chat_sat.py | 2 +- sat/generation/autoregressive_sampling.py | 2 - sat/model/cached_autoregressive_model.py | 87 +++++++++++++++++++---- 3 files changed, 74 insertions(+), 17 deletions(-) diff --git a/examples/chatglm/chat_sat.py b/examples/chatglm/chat_sat.py index e2cf471d..2ba426c7 100644 --- a/examples/chatglm/chat_sat.py +++ b/examples/chatglm/chat_sat.py @@ -136,7 +136,7 @@ def chat(model, tokenizer, use_gpu_initialization=True, )) model = model.eval() - model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) + model.add_mixin('auto-regressive', CachedAutoregressiveMixin(args.num_layers, args.head_nums, args.hidden_units, args.max_length)) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) history = None diff --git a/sat/generation/autoregressive_sampling.py b/sat/generation/autoregressive_sampling.py index a5cbe4da..cf32f8a5 100755 --- a/sat/generation/autoregressive_sampling.py +++ b/sat/generation/autoregressive_sampling.py @@ -113,8 +113,6 @@ def filling_sequence( log_attention_weights=log_attention_weights_part, **kw_args ) - mem_kv = [o['mem_kv'] for o in output_per_layers] - mems = update_mems(mem_kv, mems, max_memory_length=max_memory_length) counter += 1 index = counter # sampling diff --git a/sat/model/cached_autoregressive_model.py b/sat/model/cached_autoregressive_model.py index 03b94281..6b657f74 100755 --- a/sat/model/cached_autoregressive_model.py +++ b/sat/model/cached_autoregressive_model.py @@ -16,27 +16,86 @@ from .base_model import BaseModel, BaseMixin, non_conflict from sat.model.transformer import standard_attention, split_tensor_along_last_dim + +class VectorKvCache(): + def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, factor=2): + """ + 1. capacity: the size of the storage space currently allocated for the cache + 2. max_len: the max length of tokens + 3. self.mem_size: the number of elements in the cache, like size in c++ vector + """ + self.factor = factor + if self.factor <= 1.0 : + raise ValueError("factor should be greater than 1.") + self.max_len = max_len + self.mem_size = 0 + self.capacity = capacity + self.mems_k = None + self.mems_v = None + self.num_layers = num_layers + self.head_nums = head_nums + self.hidden_units = hidden_units + + def append_kv(self, k, v, layer_id): + b, nh, seq_len, hidden_size = k.shape + mem_len = self.mem_size + self.mems_k[layer_id][:, :, mem_len:mem_len+seq_len, :] = k + self.mems_v[layer_id][:, :, mem_len:mem_len+seq_len, :] = v + + + def get_kv(self, layer_id, seq_len): + # return key value for attention forward + mem_k = self.mems_k[layer_id] + mem_v = self.mems_v[layer_id] + seq_len = self.mem_size + seq_len + k = mem_k[:, :, :seq_len, :] + v = mem_v[:, :, :seq_len, :] + return k, v + + def get_mem_size(self): + return self.mem_size + + def update_mem_size(self, seq_len): + self.mem_size += seq_len + + def reMalloc(self, seq_len, batch_size, dtype, device): + new_capacity = seq_len + self.mem_size + if new_capacity > self.capacity: + new_mems_size = [self.num_layers, batch_size, self.head_nums, 0, self.hidden_units] # [num_layers, batch_size, head_num, seq_len, size_per_head] + if int(new_capacity * self.factor) <= self.max_len: + new_mems_size[3] = int(new_capacity * self.factor) + self.capacity = int(new_capacity * self.factor) + else: + new_mems_size[3] = self.max_len + self.capacity = self.max_len + new_mems_k = torch.empty(*new_mems_size, dtype=dtype, device=device) + new_mems_v = torch.empty(*new_mems_size, dtype=dtype, device=device) + if self.mems_k is not None and self.mems_v is not None : + new_mems_k[:, :, :, :self.mem_size, :] = self.mems_k + new_mems_v[:, :, :, :self.mem_size, :] = self.mems_v + self.mems_k = new_mems_k + self.mems_v = new_mems_v + class CachedAutoregressiveMixin(BaseMixin): - def __init__(self): - super().__init__() + def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, factor=2): + super().__init__() + self.num_layers = num_layers + self.mems = vector_kv_cache = VectorKvCache(num_layers, head_nums, hidden_units, max_len, capacity=capacity, factor=factor) @non_conflict - def attention_fn(self, q, k, v, mask, dropout_fn, mems=None, cross_attention=False, old_impl=standard_attention, + def attention_fn(self, q, k, v, mask, dropout_fn, cross_attention=False, old_impl=standard_attention, **kw_args): if not cross_attention: - mem = mems[kw_args['layer_id']] if mems is not None else None # 2, batch, head, seqlen, hidden_size + layer_id = kw_args['layer_id'] b, nh, seq_len, hidden_size = k.shape + if layer_id == 0 : + self.mems.reMalloc(seq_len, b, k.dtype, k.device) + self.mems.append_kv(k, v, layer_id) + k, v = self.mems.get_kv(layer_id) + if layer_id == self.num_layers - 1 : + self.mems.update_mems_size(seq_len) - cache_kv = torch.stack((k, v)).permute(1, 3, 0, 2, 4).detach().contiguous().view(b, seq_len, nh * hidden_size * 2) - kw_args['output_this_layer']['mem_kv'] = cache_kv - - if mem is not None: # the first time, mem is None - # might change batch_size - mem = mem.expand(b, -1, -1).reshape(b, mem.shape[1], 2, nh, hidden_size).permute(2, 0, 3, 1, 4) - memk, memv = mem[0], mem[1] - k = torch.cat((memk, k), dim=2) - v = torch.cat((memv, v), dim=2) - return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, mems=mems, **kw_args) + return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, **kw_args) class CachedAutoregressiveModel(BaseModel): From 6c051014cfe02da96c4415dd020ab9b64ca72ecf Mon Sep 17 00:00:00 2001 From: zhaolei Date: Mon, 10 Jul 2023 02:05:26 +0800 Subject: [PATCH 2/3] add vector kvcache --- examples/chatglm/chat_sat.py | 2 +- sat/model/cached_autoregressive_model.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/chatglm/chat_sat.py b/examples/chatglm/chat_sat.py index 2ba426c7..b63779cf 100644 --- a/examples/chatglm/chat_sat.py +++ b/examples/chatglm/chat_sat.py @@ -136,7 +136,7 @@ def chat(model, tokenizer, use_gpu_initialization=True, )) model = model.eval() - model.add_mixin('auto-regressive', CachedAutoregressiveMixin(args.num_layers, args.head_nums, args.hidden_units, args.max_length)) + model.add_mixin('auto-regressive', CachedAutoregressiveMixin(model_args.num_layers, model_args.num_attention_heads, model_args.hidden_size, args.max_length)) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) history = None diff --git a/sat/model/cached_autoregressive_model.py b/sat/model/cached_autoregressive_model.py index 6b657f74..3719f992 100755 --- a/sat/model/cached_autoregressive_model.py +++ b/sat/model/cached_autoregressive_model.py @@ -35,6 +35,7 @@ def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, fac self.num_layers = num_layers self.head_nums = head_nums self.hidden_units = hidden_units + self.size_per_head = int(hidden_units / head_nums) def append_kv(self, k, v, layer_id): b, nh, seq_len, hidden_size = k.shape @@ -61,7 +62,7 @@ def update_mem_size(self, seq_len): def reMalloc(self, seq_len, batch_size, dtype, device): new_capacity = seq_len + self.mem_size if new_capacity > self.capacity: - new_mems_size = [self.num_layers, batch_size, self.head_nums, 0, self.hidden_units] # [num_layers, batch_size, head_num, seq_len, size_per_head] + new_mems_size = [self.num_layers, batch_size, self.head_nums, 0, self.size_per_head] # [num_layers, batch_size, head_num, seq_len, size_per_head] if int(new_capacity * self.factor) <= self.max_len: new_mems_size[3] = int(new_capacity * self.factor) self.capacity = int(new_capacity * self.factor) @@ -80,7 +81,7 @@ class CachedAutoregressiveMixin(BaseMixin): def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, factor=2): super().__init__() self.num_layers = num_layers - self.mems = vector_kv_cache = VectorKvCache(num_layers, head_nums, hidden_units, max_len, capacity=capacity, factor=factor) + self.mems = VectorKvCache(num_layers, head_nums, hidden_units, max_len, capacity=capacity, factor=factor) @non_conflict def attention_fn(self, q, k, v, mask, dropout_fn, cross_attention=False, old_impl=standard_attention, @@ -91,9 +92,9 @@ def attention_fn(self, q, k, v, mask, dropout_fn, cross_attention=False, old_imp if layer_id == 0 : self.mems.reMalloc(seq_len, b, k.dtype, k.device) self.mems.append_kv(k, v, layer_id) - k, v = self.mems.get_kv(layer_id) + k, v = self.mems.get_kv(layer_id, seq_len) if layer_id == self.num_layers - 1 : - self.mems.update_mems_size(seq_len) + self.mems.update_mem_size(seq_len) return old_impl(q, k, v, mask, dropout_fn, cross_attention=cross_attention, **kw_args) From cb33e5544de2cecf37e6d9c1431c1422966e7724 Mon Sep 17 00:00:00 2001 From: zhaolei Date: Tue, 11 Jul 2023 10:03:02 +0800 Subject: [PATCH 3/3] add vector kvcache --- examples/chatglm/chat_sat.py | 5 ++-- sat/model/cached_autoregressive_model.py | 30 ++++++++++-------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/examples/chatglm/chat_sat.py b/examples/chatglm/chat_sat.py index b63779cf..a226ecf9 100644 --- a/examples/chatglm/chat_sat.py +++ b/examples/chatglm/chat_sat.py @@ -89,8 +89,7 @@ def chat(model, tokenizer, [inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0 ) # --------------- - strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=0, end_tokens=[tokenizer.eos_token_id]) - strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=0, end_tokens=[tokenizer.eos_token_id], num_beams=num_beams, consider_end=True) + strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id]) output = filling_sequence( model, seq, batch_size=1, @@ -136,7 +135,7 @@ def chat(model, tokenizer, use_gpu_initialization=True, )) model = model.eval() - model.add_mixin('auto-regressive', CachedAutoregressiveMixin(model_args.num_layers, model_args.num_attention_heads, model_args.hidden_size, args.max_length)) + model.add_mixin('auto-regressive', CachedAutoregressiveMixin(model_args.num_layers, model_args.num_attention_heads, model_args.hidden_size, 10000)) tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) history = None diff --git a/sat/model/cached_autoregressive_model.py b/sat/model/cached_autoregressive_model.py index 3719f992..973b6bad 100755 --- a/sat/model/cached_autoregressive_model.py +++ b/sat/model/cached_autoregressive_model.py @@ -30,8 +30,7 @@ def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, fac self.max_len = max_len self.mem_size = 0 self.capacity = capacity - self.mems_k = None - self.mems_v = None + self.mems_kv = None # [2, num_layers, batch_size, head_nums, seq_len, size_per_head] self.num_layers = num_layers self.head_nums = head_nums self.hidden_units = hidden_units @@ -40,17 +39,15 @@ def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, fac def append_kv(self, k, v, layer_id): b, nh, seq_len, hidden_size = k.shape mem_len = self.mem_size - self.mems_k[layer_id][:, :, mem_len:mem_len+seq_len, :] = k - self.mems_v[layer_id][:, :, mem_len:mem_len+seq_len, :] = v + self.mems_kv[0, layer_id, :, :, mem_len:mem_len+seq_len, :] = k + self.mems_kv[1, layer_id, :, :, mem_len:mem_len+seq_len, :] = v def get_kv(self, layer_id, seq_len): # return key value for attention forward - mem_k = self.mems_k[layer_id] - mem_v = self.mems_v[layer_id] seq_len = self.mem_size + seq_len - k = mem_k[:, :, :seq_len, :] - v = mem_v[:, :, :seq_len, :] + k = self.mems_kv[0, layer_id, :, :, :seq_len, :] + v = self.mems_kv[1, layer_id, :, :, :seq_len, :] return k, v def get_mem_size(self): @@ -62,20 +59,17 @@ def update_mem_size(self, seq_len): def reMalloc(self, seq_len, batch_size, dtype, device): new_capacity = seq_len + self.mem_size if new_capacity > self.capacity: - new_mems_size = [self.num_layers, batch_size, self.head_nums, 0, self.size_per_head] # [num_layers, batch_size, head_num, seq_len, size_per_head] + new_mems_size = [2, self.num_layers, batch_size, self.head_nums, 0, self.size_per_head] # [num_layers, batch_size, head_num, seq_len, size_per_head] if int(new_capacity * self.factor) <= self.max_len: - new_mems_size[3] = int(new_capacity * self.factor) + new_mems_size[4] = int(new_capacity * self.factor) self.capacity = int(new_capacity * self.factor) else: - new_mems_size[3] = self.max_len + new_mems_size[4] = self.max_len self.capacity = self.max_len - new_mems_k = torch.empty(*new_mems_size, dtype=dtype, device=device) - new_mems_v = torch.empty(*new_mems_size, dtype=dtype, device=device) - if self.mems_k is not None and self.mems_v is not None : - new_mems_k[:, :, :, :self.mem_size, :] = self.mems_k - new_mems_v[:, :, :, :self.mem_size, :] = self.mems_v - self.mems_k = new_mems_k - self.mems_v = new_mems_v + new_mems_kv = torch.empty(*new_mems_size, dtype=dtype, device=device) + if self.mems_kv is not None : + new_mems_kv[:, :, :, :, :self.mem_size, :] = self.mems_kv + self.mems_kv = new_mems_kv class CachedAutoregressiveMixin(BaseMixin): def __init__(self, num_layers, head_nums, hidden_units, max_len, capacity=0, factor=2):