Skip to content

Commit

Permalink
Added Greek letters to valid einsum symbols, and a test that ensures …
Browse files Browse the repository at this point in the history
…all einsum symbols get parsed properly
  • Loading branch information
jwjeffr committed Jan 25, 2025
1 parent 2bca00c commit ee4ac11
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sparse/numba_backend/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
normalize_axis,
)

_EINSUM_SYMBOLS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
_EINSUM_SYMBOLS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZαβγδϵζηθικλμνξπρστυϕχψωΓΔΕΘΛΠΡΣΦΨΩ"
_EINSUM_SYMBOLS_SET = set(_EINSUM_SYMBOLS)


Expand Down
10 changes: 10 additions & 0 deletions sparse/numba_backend/tests/test_einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@
]


@pytest.mark.parametrize("symbols", sparse.numba_backend._common._EINSUM_SYMBOLS_SET)
def test_symbols(symbols):

arr = sparse.random(shape=10, density=1.0)

# ensure we can use any of the defined symbols
for symbol in symbols:
sparse.einsum(f"{symbol}->{symbol}", arr)


@pytest.mark.parametrize("subscripts", einsum_cases)
@pytest.mark.parametrize("density", [0.1, 1.0])
def test_einsum(subscripts, density):
Expand Down

0 comments on commit ee4ac11

Please sign in to comment.