Skip to content

Commit

Permalink
complete neural memory sequential inference with transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 23, 2025
1 parent 886e24c commit 54d5966
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "titans-pytorch"
version = "0.1.24"
version = "0.1.26"
description = "Titans"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down
8 changes: 6 additions & 2 deletions tests/test_titans.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,11 @@ def test_mac(
assert logits.shape == (1, seq_len, 256)

@pytest.mark.parametrize('sliding', (False, True))
def test_mac_sampling(sliding):
@pytest.mark.parametrize('mem_layers', ((), None, (4,)))
def test_mac_sampling(
sliding,
mem_layers
):
transformer = MemoryAsContextTransformer(
num_tokens = 256,
dim = 256,
Expand All @@ -133,7 +137,7 @@ def test_mac_sampling(sliding):
num_persist_mem_tokens = 4,
num_longterm_mem_tokens = 0,
sliding_window_attn = sliding,
neural_memory_layers = (),
neural_memory_layers = mem_layers,
neural_mem_gate_attn_output = False
)

Expand Down
55 changes: 42 additions & 13 deletions titans_pytorch/mac_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,7 @@ def __init__(

layers = tuple(range(1, depth + 1))

if not exists(neural_memory_layers):
neural_memory_layers = layers if has_longterm_mems else ()

assert not (num_longterm_mem_tokens > 0 and len(neural_memory_layers) == 0), 'empty `neural_memory_layers` when longterm memory tokens are present'
neural_memory_layers = default(neural_memory_layers, layers)

# mem, attn, and feedforward layers

Expand All @@ -535,20 +532,23 @@ def __init__(
)

mem = None
mem_hyper_conn = None

if layer in neural_memory_layers:
assert has_longterm_mems, '`num_longterm_mem_tokens` must be greater than 0'
mem_hyper_conn = init_hyper_conn(dim = dim, add_branch_out_to_residual = not neural_mem_gate_attn_output)

mem = NeuralMemory(
dim = dim,
chunk_size = self.neural_memory_segment_len,
**neural_memory_kwargs
)


ff = FeedForward(dim = dim, mult = ff_mult)

self.layers.append(ModuleList([
init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
mem_hyper_conn,
mem,
init_hyper_conn(dim = dim, branch = attn),
init_hyper_conn(dim = dim, branch = ff)
]))
Expand Down Expand Up @@ -691,8 +691,18 @@ def forward(
# kv caching

is_inferencing = exists(cache)
cache = iter(default(cache, []))
assert not (is_inferencing and self.num_longterm_mem_tokens > 0)

if not exists(cache):
cache = (None, None)

kv_caches, neural_mem_caches = cache

kv_caches = iter(default(kv_caches, []))
neural_mem_caches = iter(default(neural_mem_caches, []))

next_kv_caches = []
next_neural_mem_caches = []

# value residual

Expand All @@ -711,21 +721,37 @@ def forward(

x = self.expand_streams(x)

for mem, attn, ff in self.layers:
for mem_hyper_conn, mem, attn, ff in self.layers:

retrieved = None
attn_out_gates = None
next_neural_mem_cache = None

# maybe neural memory

if exists(mem):
retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss

mem_input, add_residual = mem_hyper_conn(x)

if not is_inferencing:
retrieved, mem_kv_aux_loss = mem(
mem_input,
return_aux_kv_loss = True
)

kv_recon_losses = kv_recon_losses + mem_kv_aux_loss

else:
retrieved, next_neural_mem_cache = mem.forward_inference(
mem_input,
seq_index = seq_len - 1,
state = next(neural_mem_caches, None)
)

if self.gate_attn_output:
attn_out_gates = retrieved.sigmoid()
else:
seq = retrieved
x = add_residual(retrieved)

# attention

Expand All @@ -735,12 +761,15 @@ def forward(
disable_flex_attn = disable_flex_attn,
flex_attn_fn = flex_attn_fn,
output_gating = attn_out_gates,
cache = next(cache, None)
cache = next(kv_caches, None)
)

value_residual = default(value_residual, values)

# caches

next_kv_caches.append(next_kv_cache)
next_neural_mem_caches.append(next_neural_mem_cache)

# feedforward

Expand Down Expand Up @@ -775,7 +804,7 @@ def forward(
if not self.sliding_window_attn and divisible_by(seq_len_with_mem, attn_window_size):
next_kv_caches = next_kv_caches[..., 0:0, :]

return logits, next_kv_caches
return logits, (next_kv_caches, next_neural_mem_caches)

ar_loss = F.cross_entropy(rearrange(logits, 'b n l -> b l n'), labels)

Expand Down

0 comments on commit 54d5966

Please sign in to comment.