Skip to content

Commit

Permalink
prepare for value residual learning for deeper neural memories, but f…
Browse files Browse the repository at this point in the history
…irst validate value residual learning applies to linear attention as well in another repo
  • Loading branch information
lucidrains committed Jan 22, 2025
1 parent 31d17ad commit 713d599
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "titans-pytorch"
version = "0.1.21"
version = "0.1.22"
description = "Titans"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
51 changes: 41 additions & 10 deletions titans_pytorch/titans.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ def exists(v):
def default(v, d):
return v if exists(v) else d

def xnor(x, y):
return not (x ^ y)

def identity(t):
return t

Expand Down Expand Up @@ -366,6 +369,7 @@ def __init__(
pre_rmsnorm = True,
post_rmsnorm = True,
qk_rmsnorm = False,
accept_value_residual = False,
learned_mem_model_weights = True,
max_grad_norm: float | None = None,
use_accelerated_scan = False,
Expand Down Expand Up @@ -399,7 +403,7 @@ def __init__(

self.heads = heads

self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
self.merge_heads = Rearrange('b h n d -> b n (h d)')
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()

Expand Down Expand Up @@ -448,6 +452,14 @@ def forward_and_loss(params, inputs, loss_weights, target):
self.to_keys_values = Sequential(LinearNoBias(dim, dim_inner * 2), activation)
self.store_memory_loss_fn = store_memory_loss_fn

# value residual learning

self.learned_value_residual = Sequential(
LinearNoBias(dim, heads),
Rearrange('b n h -> b h n 1'),
nn.Sigmoid()
) if accept_value_residual else None

# empty memory embed

self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
Expand Down Expand Up @@ -529,8 +541,11 @@ def store_memories(
seq,
past_state: tuple[dict[str, Tensor], dict[str, Tensor]],
return_aux_kv_loss = False,
chunk_size = None
chunk_size = None,
value_residual = None
):
assert xnor(exists(value_residual), exists(self.learned_value_residual))

seq_len, chunk_size = seq.shape[-2], default(chunk_size, self.store_chunk_size)

# handle edge case
Expand Down Expand Up @@ -585,9 +600,17 @@ def store_memories(

keys = self.k_norm(keys)

# maybe value residual learning

orig_values = values

if exists(self.learned_value_residual):
mix = self.learned_value_residual(seq)
values = values.lerp(value_residual, mix)

# take care of chunking

keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = chunk_size) for t in (keys, values))
keys, values = tuple(rearrange(t, 'b h (n c) d -> (b h n) c d', c = chunk_size) for t in (keys, values))

adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = chunk_size)

Expand Down Expand Up @@ -645,10 +668,12 @@ def store_memories(

last_update = updates.apply(lambda t: t[:, -1])

output = (updates, orig_values)

if not return_aux_kv_loss:
return updates
return output

return updates, aux_kv_recon_loss.mean()
return output, aux_kv_recon_loss.mean()

def retrieve_memories(
self,
Expand Down Expand Up @@ -698,7 +723,7 @@ def retrieve_memories(
# fetch values from memory model

curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
queries = rearrange(queries, 'b h (n c) d -> (b h n) c d', c = chunk_size)

# forward functional call

Expand Down Expand Up @@ -735,7 +760,8 @@ def forward(
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
return_aux_kv_loss = False,
chunk_size = None,
store_chunk_size = None
store_chunk_size = None,
return_values = False
):
batch, seq_len = seq.shape[:2]

Expand All @@ -756,13 +782,18 @@ def forward(
store_seq = default(store_seq, seq)
store_chunk_size = default(store_chunk_size, chunk_size)

updates, aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)
(updates, values), aux_kv_recon_loss = self.store_memories(store_seq, past_state, chunk_size = store_chunk_size, return_aux_kv_loss = True)

past_weights, _ = past_state

retrieved = self.retrieve_memories(seq, past_weights + updates, chunk_size = chunk_size)

output = retrieved

if return_values:
output = (retrieved, values)

if not return_aux_kv_loss:
return retrieved
return output

return retrieved, aux_kv_recon_loss
return output, aux_kv_recon_loss

0 comments on commit 713d599

Please sign in to comment.