Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize space usage of ExplorationReport before saving #279

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dpgen2/exploration/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@
"""
pass

@abstractmethod
def no_candidate(self) -> bool:
r"""If no candidate configuration is found"""
return all([len(ii) == 0 for ii in self.get_candidate_ids()])
pass

Check warning on line 66 in dpgen2/exploration/report/report.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/exploration/report/report.py#L66

Added line #L66 was not covered by tests

@abstractmethod
def get_candidate_ids(
Expand Down
21 changes: 18 additions & 3 deletions dpgen2/exploration/report/report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def doc() -> str:
Expand Down Expand Up @@ -274,6 +278,10 @@
# accurate set is substracted by the candidate set
self.accur = self.accur - self.candi
self.model_devi = model_devi
self._no_candidate = len(self.candi) == 0
self._failed_ratio = float(len(self.failed)) / float(self.nframes)
self._accurate_ratio = float(len(self.accur)) / float(self.nframes)
self._candidate_ratio = float(len(self.candi)) / float(self.nframes)

def _record_one_traj(
self,
Expand Down Expand Up @@ -346,29 +354,36 @@
self,
tag=None,
):
return float(len(self.failed)) / float(self.nframes)
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
return float(len(self.accur)) / float(self.nframes)
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
return float(len(self.candi)) / float(self.nframes)
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

Check warning on line 372 in dpgen2/exploration/report/report_adaptive_lower.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/exploration/report/report_adaptive_lower.py#L372

Added line #L372 was not covered by tests

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = self.ntraj
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
26 changes: 20 additions & 6 deletions dpgen2/exploration/report/report_trust_levels_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __init__(
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def args() -> List[Argument]:
Expand Down Expand Up @@ -133,6 +137,16 @@ def record(
assert len(self.traj_accu) == ntraj
assert len(self.traj_fail) == ntraj
self.model_devi = model_devi
self._no_candidate = sum([len(ii) for ii in self.traj_cand]) == 0
self._failed_ratio = float(sum([len(ii) for ii in self.traj_fail])) / float(
sum(self.traj_nframes)
)
self._accurate_ratio = float(sum([len(ii) for ii in self.traj_accu])) / float(
sum(self.traj_nframes)
)
self._candidate_ratio = float(sum([len(ii) for ii in self.traj_cand])) / float(
sum(self.traj_nframes)
)

def _get_indexes(
self,
Expand Down Expand Up @@ -205,22 +219,22 @@ def failed_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_fail]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_accu]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_cand]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

@abstractmethod
def get_candidate_ids(
Expand Down
8 changes: 7 additions & 1 deletion dpgen2/exploration/report/report_trust_levels_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,23 @@
converged bool
If the exploration is converged.
"""
return self.accurate_ratio() >= self.conv_accuracy
accurate_ratio = self.accurate_ratio()
assert isinstance(accurate_ratio, float)
return accurate_ratio >= self.conv_accuracy

Check warning on line 46 in dpgen2/exploration/report/report_trust_levels_max.py

View check run for this annotation

Codecov / codecov/patch

dpgen2/exploration/report/report_trust_levels_max.py#L44-L46

Added lines #L44 - L46 were not covered by tests

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
8 changes: 7 additions & 1 deletion dpgen2/exploration/report/report_trust_levels_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,23 @@ def converged(
converged bool
If the exploration is converged.
"""
return self.accurate_ratio() >= self.conv_accuracy
accurate_ratio = self.accurate_ratio()
assert isinstance(accurate_ratio, float)
return accurate_ratio >= self.conv_accuracy

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
8 changes: 4 additions & 4 deletions tests/exploration/test_report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class MockedReport:
self.assertFalse(ter.converged([mr, mr1, mr]))
self.assertTrue(ter.converged([mr1, mr, mr]))

picked = ter.get_candidate_ids(2)
picked = ter.get_candidate_ids(2, clear=False)
npicked = 0
self.assertEqual(len(picked), 2)
for ii in range(2):
Expand Down Expand Up @@ -218,12 +218,12 @@ def faked_choices(
return ret

ter.record(model_devi)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(ter.candi, expected_cand)
self.assertEqual(ter.accur, expected_accu)
self.assertEqual(set(ter.failed), expected_fail)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(len(picked), 2)
self.assertEqual(sorted(picked[0]), [1, 3])
self.assertEqual(sorted(picked[1]), [1, 5, 7])
Expand Down
Loading