diff --git a/setup.py b/setup.py index 65037813..1bf9ce78 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,9 @@ dependencies = [ - "torch>=2.0.1", "numpy", + "sympy", + "torch>=2.0.1", "tqdm", ] diff --git a/src/simulated_bifurcation/polynomial/base_multivariate_polynomial.py b/src/simulated_bifurcation/polynomial/base_multivariate_polynomial.py index 35ab6b67..f19754f3 100644 --- a/src/simulated_bifurcation/polynomial/base_multivariate_polynomial.py +++ b/src/simulated_bifurcation/polynomial/base_multivariate_polynomial.py @@ -17,7 +17,7 @@ from sympy import Poly from ..ising_core import IsingCore -from .polynomial_compiler import Order2MultivariatePolynomialCompiler as O2MPC +from .expression_compiler import ExpressionCompiler class BaseMultivariateQuadraticPolynomial(ABC): @@ -617,13 +617,13 @@ def maximize( ) @classmethod - def from_polynomial( + def from_expression( cls, - polynomial: Poly, + expression: Poly, dtype: torch.dtype = torch.float32, device: Union[str, torch.device] = "cpu", ): - constant, vector, matrix = O2MPC(polynomial).compile() + constant, vector, matrix = ExpressionCompiler(expression).compile() return cls(matrix, vector, constant, dtype, device) diff --git a/src/simulated_bifurcation/polynomial/polynomial_compiler.py b/src/simulated_bifurcation/polynomial/expression_compiler.py similarity index 94% rename from src/simulated_bifurcation/polynomial/polynomial_compiler.py rename to src/simulated_bifurcation/polynomial/expression_compiler.py index 14dbe26b..cfd180d9 100644 --- a/src/simulated_bifurcation/polynomial/polynomial_compiler.py +++ b/src/simulated_bifurcation/polynomial/expression_compiler.py @@ -5,9 +5,9 @@ from sympy import Poly -class Order2MultivariatePolynomialCompiler: +class ExpressionCompiler: def __init__(self, polynomial: Poly) -> None: - Order2MultivariatePolynomialCompiler.__check_polynomial_degree(polynomial) + ExpressionCompiler.__check_polynomial_degree(polynomial) self.polynomial = polynomial self.variables = len(self.polynomial.gens) self.order_2_tensor = torch.zeros(self.variables, self.variables) diff --git a/src/simulated_bifurcation/polynomial/integer_polynomial.py b/src/simulated_bifurcation/polynomial/integer_polynomial.py index 2181b507..a08ed44f 100644 --- a/src/simulated_bifurcation/polynomial/integer_polynomial.py +++ b/src/simulated_bifurcation/polynomial/integer_polynomial.py @@ -43,7 +43,7 @@ from ..ising_core import IsingCore from .base_multivariate_polynomial import BaseMultivariateQuadraticPolynomial -from .polynomial_compiler import Order2MultivariatePolynomialCompiler as O2MPC +from .expression_compiler import ExpressionCompiler class IntegerQuadraticPolynomial(BaseMultivariateQuadraticPolynomial): @@ -191,14 +191,14 @@ def convert_spins(self, ising: IsingCore) -> Optional[torch.Tensor]: return int_vars @classmethod - def from_polynomial( + def from_expression( cls, - polynomial: Poly, + expression: Poly, number_of_bits: int = 1, dtype: torch.dtype = torch.float32, device: Union[str, torch.device] = "cpu", ): - constant, vector, matrix = O2MPC(polynomial).compile() + constant, vector, matrix = ExpressionCompiler(expression).compile() return IntegerQuadraticPolynomial( matrix, vector, constant, number_of_bits, dtype, device ) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models/__init__.py b/tests/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/test_optimization_variables.py b/tests/optimizer/test_optimization_variables.py similarity index 100% rename from tests/test_optimization_variables.py rename to tests/optimizer/test_optimization_variables.py diff --git a/tests/test_stop_window.py b/tests/optimizer/test_stop_window.py similarity index 100% rename from tests/test_stop_window.py rename to tests/optimizer/test_stop_window.py diff --git a/tests/test_symplectic_integrator.py b/tests/optimizer/test_symplectic_integrator.py similarity index 100% rename from tests/test_symplectic_integrator.py rename to tests/optimizer/test_symplectic_integrator.py diff --git a/tests/test_binary_polynomial.py b/tests/polynomial/test_binary_polynomial.py similarity index 80% rename from tests/test_binary_polynomial.py rename to tests/polynomial/test_binary_polynomial.py index 64129f14..63eb7996 100644 --- a/tests/test_binary_polynomial.py +++ b/tests/polynomial/test_binary_polynomial.py @@ -1,5 +1,6 @@ import pytest import torch +from sympy import poly, symbols from src.simulated_bifurcation.polynomial import ( BinaryPolynomial, @@ -72,3 +73,13 @@ def test_optimize_binary_polynomial(): def test_deprecation_warning(): with pytest.warns(DeprecationWarning): BinaryPolynomial(matrix, vector, constant) + + +def test_from_expression(): + x, y = symbols("x y") + expression = poly((x + y) ** 2 + x - y + 2) + polynomial = BinaryQuadraticPolynomial.from_expression(expression) + assert isinstance(polynomial, BinaryQuadraticPolynomial) + assert torch.equal(torch.tensor([[1.0, 2.0], [0.0, 1.0]]), polynomial.matrix) + assert torch.equal(torch.tensor([1.0, -1.0]), polynomial.vector) + assert torch.equal(torch.tensor(2.0), polynomial.constant) diff --git a/tests/test_polynomial_compiler.py b/tests/polynomial/test_expression_compiler.py similarity index 69% rename from tests/test_polynomial_compiler.py rename to tests/polynomial/test_expression_compiler.py index 25677f91..43d66707 100644 --- a/tests/test_polynomial_compiler.py +++ b/tests/polynomial/test_expression_compiler.py @@ -2,15 +2,13 @@ import torch from sympy import poly, symbols -from src.simulated_bifurcation.polynomial.polynomial_compiler import ( - Order2MultivariatePolynomialCompiler as O2MPC, -) +from src.simulated_bifurcation.polynomial.expression_compiler import ExpressionCompiler def test_polynomial_compiler(): x, y, z = symbols("x y z") polynomial = poly(x * y - 2 * y * z + 3 * x**2 - z + 2) - constant, vector, matrix = O2MPC(polynomial).compile() + constant, vector, matrix = ExpressionCompiler(polynomial).compile() assert torch.equal(torch.tensor([2.0]), constant) assert torch.equal(torch.tensor([0.0, 0.0, -1.0]), vector) assert torch.equal( @@ -21,6 +19,6 @@ def test_polynomial_compiler(): def test_wrong_degree(): x, y = symbols("x y") with pytest.raises(ValueError, match="Expected degree 2 polynomial, got 1."): - O2MPC(poly(x + 4 * y)) + ExpressionCompiler(poly(x + 4 * y)) with pytest.raises(ValueError, match="Expected degree 2 polynomial, got 3."): - O2MPC(poly(x**2 * y - 5 * x * y + 2 * y + 3)) + ExpressionCompiler(poly(x**2 * y - 5 * x * y + 2 * y + 3)) diff --git a/tests/test_integer_polynomial.py b/tests/polynomial/test_integer_polynomial.py similarity index 82% rename from tests/test_integer_polynomial.py rename to tests/polynomial/test_integer_polynomial.py index 8554d306..bd385dac 100644 --- a/tests/test_integer_polynomial.py +++ b/tests/polynomial/test_integer_polynomial.py @@ -1,5 +1,6 @@ import pytest import torch +from sympy import poly, symbols from src.simulated_bifurcation.polynomial import ( IntegerPolynomial, @@ -85,3 +86,14 @@ def test_optimize_integer_polynomial(): def test_deprecation_warning(): with pytest.warns(DeprecationWarning): IntegerPolynomial(matrix, vector, constant) + + +def test_from_expression(): + x, y = symbols("x y") + expression = poly((x + y) ** 2 + x - y + 2) + polynomial = IntegerQuadraticPolynomial.from_expression(expression, 3) + assert isinstance(polynomial, IntegerQuadraticPolynomial) + assert torch.equal(torch.tensor([[1.0, 2.0], [0.0, 1.0]]), polynomial.matrix) + assert torch.equal(torch.tensor([1.0, -1.0]), polynomial.vector) + assert torch.equal(torch.tensor(2.0), polynomial.constant) + assert 3 == polynomial.number_of_bits diff --git a/tests/test_polynomial.py b/tests/polynomial/test_polynomial.py similarity index 100% rename from tests/test_polynomial.py rename to tests/polynomial/test_polynomial.py diff --git a/tests/test_spin_polynomial.py b/tests/polynomial/test_spin_polynomial.py similarity index 85% rename from tests/test_spin_polynomial.py rename to tests/polynomial/test_spin_polynomial.py index 9012d73e..8cfb79c5 100644 --- a/tests/test_spin_polynomial.py +++ b/tests/polynomial/test_spin_polynomial.py @@ -1,5 +1,6 @@ import pytest import torch +from sympy import poly, symbols from src.simulated_bifurcation.polynomial import SpinPolynomial, SpinQuadraticPolynomial @@ -96,3 +97,13 @@ def check_device_and_dtype(dtype: torch.dtype): def test_deprecation_warning(): with pytest.warns(DeprecationWarning): SpinPolynomial(matrix, vector, constant) + + +def test_from_expression(): + x, y = symbols("x y") + expression = poly((x + y) ** 2 + x - y + 2) + polynomial = SpinQuadraticPolynomial.from_expression(expression) + assert isinstance(polynomial, SpinQuadraticPolynomial) + assert torch.equal(torch.tensor([[1.0, 2.0], [0.0, 1.0]]), polynomial.matrix) + assert torch.equal(torch.tensor([1.0, -1.0]), polynomial.vector) + assert torch.equal(torch.tensor(2.0), polynomial.constant)