Skip to content

Commit

Permalink
Add @property n_params (#94)
Browse files Browse the repository at this point in the history
* CHGNet add property n_params

* add version and n_params properties to CHGNetCalculator and StructOptimizer

* test version and n_params on CHGNet, CHGNetCalculator, StructOptimizer
  • Loading branch information
janosh authored Oct 31, 2023
1 parent daf3a4b commit 8948863
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 5 deletions.
20 changes: 20 additions & 0 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def __init__(
self.stress_weight = stress_weight
print(f"CHGNet will run on {self.device}")

@property
def version(self) -> str:
"""The version of CHGNet."""
return self.model.version

@property
def n_params(self) -> int:
"""The number of parameters in CHGNet."""
return self.model.n_params

def calculate(
self,
atoms: Atoms | None = None,
Expand Down Expand Up @@ -185,6 +195,16 @@ def __init__(
on_isolated_atoms=on_isolated_atoms,
)

@property
def version(self) -> str:
"""The version of CHGNet."""
return self.calculator.model.version

@property
def n_params(self) -> int:
"""The number of parameters in CHGNet."""
return self.calculator.model.n_params

def relax(
self,
atoms: Structure | Atoms,
Expand Down
8 changes: 6 additions & 2 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,15 +310,19 @@ def __init__(
nn.Linear(in_features=mlp_hidden_dims[-1], out_features=1),
)

n_params = sum(p.numel() for p in self.parameters())
version_str = f" v{version}" if version else ""
print(f"CHGNet{version_str} initialized with {n_params:,} parameters")
print(f"CHGNet{version_str} initialized with {self.n_params:,} parameters")

@property
def version(self) -> str | None:
"""Return the version of the loaded checkpoint."""
return self.model_args.get("version")

@property
def n_params(self) -> int:
"""Return the number of parameters in the model."""
return sum(p.numel() for p in self.parameters())

def forward(
self,
graphs: Sequence[CrystalGraph],
Expand Down
7 changes: 7 additions & 0 deletions tests/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
chgnet = CHGNet.load()


def test_version_and_params():
calculator = relaxer.calculator
model = calculator.model
assert relaxer.version == calculator.version == model.version
assert relaxer.n_params == calculator.n_params == model.n_params


def test_eos():
eos = EquationOfState()
eos.fit(atoms=structure)
Expand Down
14 changes: 11 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,25 @@ def test_as_to_from_dict() -> None:
assert model_3.todict() == to_dict


def test_model_load(capsys: pytest.CaptureFixture) -> None:
def test_model_load_version_params(capsys: pytest.CaptureFixture) -> None:
model = CHGNet.load()
assert model.version == "0.3.0"
assert model.n_params == 412_525
stdout, stderr = capsys.readouterr()
assert stdout == f"CHGNet v{model.version} initialized with 412,525 parameters\n"
assert (
stdout
== f"CHGNet v{model.version} initialized with {model.n_params:,} parameters\n"
)
assert stderr == ""

model = CHGNet.load(model_name="0.2.0")
assert model.version == "0.2.0"
assert model.n_params == 400_438
stdout, stderr = capsys.readouterr()
assert stdout == f"CHGNet v{model.version} initialized with 400,438 parameters\n"
assert (
stdout
== f"CHGNet v{model.version} initialized with {model.n_params:,} parameters\n"
)
assert stderr == ""

model_name = "0.1.0" # invalid
Expand Down

0 comments on commit 8948863

Please sign in to comment.