Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bqth29 committed Oct 18, 2023
1 parent 3a7fe7a commit c54c5cc
Show file tree
Hide file tree
Showing 14 changed files with 50 additions and 17 deletions.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@


dependencies = [
"torch>=2.0.1",
"numpy",
"sympy",
"torch>=2.0.1",
"tqdm",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from sympy import Poly

from ..ising_core import IsingCore
from .polynomial_compiler import Order2MultivariatePolynomialCompiler as O2MPC
from .expression_compiler import ExpressionCompiler


class BaseMultivariateQuadraticPolynomial(ABC):
Expand Down Expand Up @@ -617,13 +617,13 @@ def maximize(
)

@classmethod
def from_polynomial(
def from_expression(
cls,
polynomial: Poly,
expression: Poly,
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
):
constant, vector, matrix = O2MPC(polynomial).compile()
constant, vector, matrix = ExpressionCompiler(expression).compile()
return cls(matrix, vector, constant, dtype, device)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from sympy import Poly


class Order2MultivariatePolynomialCompiler:
class ExpressionCompiler:
def __init__(self, polynomial: Poly) -> None:
Order2MultivariatePolynomialCompiler.__check_polynomial_degree(polynomial)
ExpressionCompiler.__check_polynomial_degree(polynomial)
self.polynomial = polynomial
self.variables = len(self.polynomial.gens)
self.order_2_tensor = torch.zeros(self.variables, self.variables)
Expand Down
8 changes: 4 additions & 4 deletions src/simulated_bifurcation/polynomial/integer_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@

from ..ising_core import IsingCore
from .base_multivariate_polynomial import BaseMultivariateQuadraticPolynomial
from .polynomial_compiler import Order2MultivariatePolynomialCompiler as O2MPC
from .expression_compiler import ExpressionCompiler


class IntegerQuadraticPolynomial(BaseMultivariateQuadraticPolynomial):
Expand Down Expand Up @@ -191,14 +191,14 @@ def convert_spins(self, ising: IsingCore) -> Optional[torch.Tensor]:
return int_vars

@classmethod
def from_polynomial(
def from_expression(
cls,
polynomial: Poly,
expression: Poly,
number_of_bits: int = 1,
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
):
constant, vector, matrix = O2MPC(polynomial).compile()
constant, vector, matrix = ExpressionCompiler(expression).compile()
return IntegerQuadraticPolynomial(
matrix, vector, constant, number_of_bits, dtype, device
)
Expand Down
Empty file removed tests/__init__.py
Empty file.
Empty file removed tests/models/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
from sympy import poly, symbols

from src.simulated_bifurcation.polynomial import (
BinaryPolynomial,
Expand Down Expand Up @@ -72,3 +73,13 @@ def test_optimize_binary_polynomial():
def test_deprecation_warning():
with pytest.warns(DeprecationWarning):
BinaryPolynomial(matrix, vector, constant)


def test_from_expression():
x, y = symbols("x y")
expression = poly((x + y) ** 2 + x - y + 2)
polynomial = BinaryQuadraticPolynomial.from_expression(expression)
assert isinstance(polynomial, BinaryQuadraticPolynomial)
assert torch.equal(torch.tensor([[1.0, 2.0], [0.0, 1.0]]), polynomial.matrix)
assert torch.equal(torch.tensor([1.0, -1.0]), polynomial.vector)
assert torch.equal(torch.tensor(2.0), polynomial.constant)
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import torch
from sympy import poly, symbols

from src.simulated_bifurcation.polynomial.polynomial_compiler import (
Order2MultivariatePolynomialCompiler as O2MPC,
)
from src.simulated_bifurcation.polynomial.expression_compiler import ExpressionCompiler


def test_polynomial_compiler():
x, y, z = symbols("x y z")
polynomial = poly(x * y - 2 * y * z + 3 * x**2 - z + 2)
constant, vector, matrix = O2MPC(polynomial).compile()
constant, vector, matrix = ExpressionCompiler(polynomial).compile()
assert torch.equal(torch.tensor([2.0]), constant)
assert torch.equal(torch.tensor([0.0, 0.0, -1.0]), vector)
assert torch.equal(
Expand All @@ -21,6 +19,6 @@ def test_polynomial_compiler():
def test_wrong_degree():
x, y = symbols("x y")
with pytest.raises(ValueError, match="Expected degree 2 polynomial, got 1."):
O2MPC(poly(x + 4 * y))
ExpressionCompiler(poly(x + 4 * y))
with pytest.raises(ValueError, match="Expected degree 2 polynomial, got 3."):
O2MPC(poly(x**2 * y - 5 * x * y + 2 * y + 3))
ExpressionCompiler(poly(x**2 * y - 5 * x * y + 2 * y + 3))
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
from sympy import poly, symbols

from src.simulated_bifurcation.polynomial import (
IntegerPolynomial,
Expand Down Expand Up @@ -85,3 +86,14 @@ def test_optimize_integer_polynomial():
def test_deprecation_warning():
with pytest.warns(DeprecationWarning):
IntegerPolynomial(matrix, vector, constant)


def test_from_expression():
x, y = symbols("x y")
expression = poly((x + y) ** 2 + x - y + 2)
polynomial = IntegerQuadraticPolynomial.from_expression(expression, 3)
assert isinstance(polynomial, IntegerQuadraticPolynomial)
assert torch.equal(torch.tensor([[1.0, 2.0], [0.0, 1.0]]), polynomial.matrix)
assert torch.equal(torch.tensor([1.0, -1.0]), polynomial.vector)
assert torch.equal(torch.tensor(2.0), polynomial.constant)
assert 3 == polynomial.number_of_bits
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import torch
from sympy import poly, symbols

from src.simulated_bifurcation.polynomial import SpinPolynomial, SpinQuadraticPolynomial

Expand Down Expand Up @@ -96,3 +97,13 @@ def check_device_and_dtype(dtype: torch.dtype):
def test_deprecation_warning():
with pytest.warns(DeprecationWarning):
SpinPolynomial(matrix, vector, constant)


def test_from_expression():
x, y = symbols("x y")
expression = poly((x + y) ** 2 + x - y + 2)
polynomial = SpinQuadraticPolynomial.from_expression(expression)
assert isinstance(polynomial, SpinQuadraticPolynomial)
assert torch.equal(torch.tensor([[1.0, 2.0], [0.0, 1.0]]), polynomial.matrix)
assert torch.equal(torch.tensor([1.0, -1.0]), polynomial.vector)
assert torch.equal(torch.tensor(2.0), polynomial.constant)

0 comments on commit c54c5cc

Please sign in to comment.