Skip to content

Commit

Permalink
Merge pull request #786 from ACEsuit/develop
Browse files Browse the repository at this point in the history
make cueq optional dep and add special test
  • Loading branch information
ilyes319 authored Jan 15, 2025
2 parents fca3022 + 140d250 commit 0bcbdb6
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
33 changes: 27 additions & 6 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,46 @@ on:
branches: [main]

jobs:
pytest-container:
pytest-general:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"

- name: Install requirements
- name: Install requirements (general tests)
run: |
pip install -U pip
pip install .[dev]
- name: Log installed environment
- name: Log installed environment (general tests)
run: |
python3 -m pip freeze
- name: Run general unit tests
run: |
pytest tests --ignore=tests/test_cueq.py
pytest-cueq:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"

- name: Install requirements (with cueq)
run: |
pip install -U pip
pip install ".[dev, cueq]"
- name: Log installed environment (with cueq)
run: |
python3 -m pip freeze
- name: Run unit tests
- name: Run cueq-specific tests
run: |
pytest tests
pytest tests/test_cueq.py tests/test_calculator.py
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ install_requires =
GitPython
pyYAML
tqdm
cuequivariance-torch
# for plotting:
matplotlib
pandas
Expand Down Expand Up @@ -60,5 +59,6 @@ dev =
pytest-benchmark
pylint
schedulefree = schedulefree
cueq = cuequivariance-torch
cueq-cuda-11 = cuequivariance-ops-torch-cu11
cueq-cuda-12 = cuequivariance-ops-torch-cu12
1 change: 1 addition & 0 deletions tests/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def test_calculator_descriptor(fitting_configs, trained_equivariant_model):
assert not np.allclose(desc, desc_rotated, atol=1e-6)


@pytest.mark.skipif(not CUET_AVAILABLE, reason="cuequivariance not installed")
def test_calculator_descriptor_cueq(fitting_configs, trained_equivariant_model_cueq):
at = fitting_configs[2].copy()
at_rotated = fitting_configs[2].copy()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cueq.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_bidirectional_conversion(
loss_e3nn_back.backward()

# Compare gradients for all conversions
tol = 1e-4 if default_dtype == torch.float32 else 1e-8
tol = 1e-4 if default_dtype == torch.float32 else 1e-7

def print_gradient_diff(name1, p1, name2, p2, conv_type):
if p1.grad is not None and p1.grad.shape == p2.grad.shape:
Expand All @@ -152,7 +152,7 @@ def print_gradient_diff(name1, p1, name2, p2, conv_type):
print(
f"{conv_type} - Parameter {name1}/{name2}, Max error: {error.max()}"
)
torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=1e-10)
torch.testing.assert_close(p1.grad, p2.grad, atol=tol, rtol=tol)

# E3nn to CuEq gradients
for (name_e3nn, p_e3nn), (name_cueq, p_cueq) in zip(
Expand Down

0 comments on commit 0bcbdb6

Please sign in to comment.