Skip to content

Commit

Permalink
allow hyper connections to be defined without depth connections, then…
Browse files Browse the repository at this point in the history
… wrap neural mem with hyper connections, remove the gating wrapper indirection
  • Loading branch information
lucidrains committed Jan 21, 2025
1 parent f2c36c1 commit 7df45e6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 60 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "titans-pytorch"
version = "0.1.17"
version = "0.1.18"
description = "Titans"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -29,7 +29,7 @@ dependencies = [
"axial_positional_embedding>=0.3.9",
"einops>=0.8.0",
"einx>=0.3.0",
"hyper-connections>=0.1.8",
"hyper-connections>=0.1.9",
"Ninja",
"rotary-embedding-torch",
"tensordict",
Expand Down
94 changes: 36 additions & 58 deletions titans_pytorch/mac_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ def forward_flex(
self,
seq,
value_residual = None,
flex_attn_fn: Callable | None = None
flex_attn_fn: Callable | None = None,
output_gating = None
):

assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))
Expand Down Expand Up @@ -267,17 +268,21 @@ def forward_flex(

out = self.to_out(out)

if exists(output_gating):
out = out * output_gating

return out, orig_v

def forward(
self,
seq,
value_residual = None,
flex_attn_fn: Callable | None = None,
disable_flex_attn = False
disable_flex_attn = False,
output_gating = None
):
if seq.is_cuda and self.use_flex_attn and not disable_flex_attn:
return self.forward_flex(seq, value_residual, flex_attn_fn)
return self.forward_flex(seq, value_residual, flex_attn_fn, output_gating = output_gating)

assert not (exists(value_residual) ^ exists(self.to_learned_v_mix))

Expand Down Expand Up @@ -361,50 +366,10 @@ def forward(

out = inverse_segment(out)

return out, orig_v

# Attention + Neural Memory gating configuration, as depicted in Figure 2

class NeuralMemoryGatingWrapper(Module):
def __init__(
self,
dim,
attn: SegmentedAttention,
neural_mem: NeuralMemory | None = None,
gate_attn_output = True
):
super().__init__()
self.attn = attn
self.neural_mem = neural_mem
self.gate_attn_output = gate_attn_output

def forward(
self,
seq,
*args,
**kwargs
):
batch, seq_len = seq.shape[:2]
mem = self.neural_mem

if not exists(mem):
return self.attn(seq, *args, **kwargs), 0.
if exists(output_gating):
out = out * output_gating

# initial retrieve, still should store first, it doesn't make sense not to, unless if all layers share the same neural memory

retrieved, kv_aux_loss = mem(seq, return_aux_kv_loss = True)

if not self.gate_attn_output:
seq = seq + retrieved

# attention

attn_out, values = self.attn(seq, *args, **kwargs)

if self.gate_attn_output:
attn_out = attn_out * retrieved.sigmoid()

return (attn_out, values), kv_aux_loss
return out, orig_v

# MAC transformer

Expand Down Expand Up @@ -494,16 +459,10 @@ def __init__(
**neural_memory_kwargs
)

attn = NeuralMemoryGatingWrapper(
dim,
attn = attn,
neural_mem = mem,
gate_attn_output = neural_mem_gate_attn_output
)

ff = FeedForward(dim = dim, mult = ff_mult)

self.layers.append(ModuleList([
init_hyper_conn(dim = dim, branch = mem, add_branch_out_to_residual = not neural_mem_gate_attn_output) if exists(mem) else None,
init_hyper_conn(dim = dim, branch = attn),
init_hyper_conn(dim = dim, branch = ff)
]))
Expand All @@ -512,6 +471,10 @@ def __init__(

self.to_logits = LinearNoBias(dim, num_tokens)

# whether to gate the attention output with the retrieved memories

self.gate_attn_output = neural_mem_gate_attn_output

# auxiliary loss on kv recon

self.has_aux_kv_recon_loss = aux_kv_recon_loss_weight > 0.
Expand Down Expand Up @@ -652,19 +615,34 @@ def forward(

x = self.expand_streams(x)

for attn, ff in self.layers:
for mem, attn, ff in self.layers:

retrieved = None
attn_out_gates = None

if exists(mem):
retrieved, mem_kv_aux_loss = mem(x, return_aux_kv_loss = True)
kv_recon_losses = kv_recon_losses + mem_kv_aux_loss

(x, values), maybe_mem_kv_aux_loss = attn(
if self.gate_attn_output:
attn_out_gates = retrieved.sigmoid()
else:
seq = retrieved

# attention

x, values = attn(
x,
value_residual = value_residual,
disable_flex_attn = disable_flex_attn,
flex_attn_fn = flex_attn_fn
flex_attn_fn = flex_attn_fn,
output_gating = attn_out_gates
)

kv_recon_losses = kv_recon_losses + maybe_mem_kv_aux_loss

value_residual = default(value_residual, values)

# feedforward

x = ff(x)

x = self.reduce_streams(x)
Expand Down

0 comments on commit 7df45e6

Please sign in to comment.