Skip to content

Commit

Permalink
Merge pull request #11 from lucidrains/fix-tests-yet-again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains authored Jan 22, 2025
2 parents 503cac2 + 698a523 commit 0c7eafc
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions tests/test_titans.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,29 @@
from contextlib import contextmanager

import torch
from torch import nn

import pytest
from titans_pytorch import NeuralMemory
from titans_pytorch.mac_transformer import flex_attention, SegmentedAttention, MemoryAsContextTransformer

# functions

def exists(v):
return v is not None

def diff(x, y):
return (x - y).abs().amax()

@contextmanager
def torch_default_dtype(dtype):
prev_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(prev_dtype)

# main test

@pytest.mark.parametrize('seq_len', (32, 1024, 77))
@pytest.mark.parametrize('silu', (False, True))
@pytest.mark.parametrize('learned_mem_model_weights', (False, True))
Expand Down Expand Up @@ -137,6 +150,7 @@ def test_mac_sampling(sliding):
assert torch.allclose(sampled, sampled_with_cache)

@pytest.mark.parametrize('seq_len', (2, 64))
@torch_default_dtype(torch.float64)
def test_neural_mem_inference(
seq_len
):
Expand Down Expand Up @@ -167,7 +181,7 @@ def test_neural_mem_inference(

sequential_retrieved = torch.cat(sequential_retrieved, dim = -2)

assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-5)
assert torch.allclose(parallel_retrieved, sequential_retrieved, atol = 1e-6)

@pytest.mark.parametrize('seq_len', (1023, 17))
@pytest.mark.parametrize('sliding', (True, False))
Expand All @@ -194,9 +208,10 @@ def test_flex(

assert torch.allclose(out_flex, out_non_flex, atol = 1e-5)

@torch_default_dtype(torch.float64)
def test_assoc_scan():
from titans_pytorch.titans import AssocScan
import torch.nn.functional as F
torch.set_default_dtype(torch.float64)

scan = AssocScan()

Expand Down

0 comments on commit 0c7eafc

Please sign in to comment.