Skip to content

Commit

Permalink
Optimize space usage of ExplorationReport before saving
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <[email protected]>
  • Loading branch information
zjgemi committed Jan 7, 2025
1 parent fc72c85 commit 2ef45ab
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 14 deletions.
3 changes: 2 additions & 1 deletion dpgen2/exploration/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def converged(
"""
pass

@abstractmethod
def no_candidate(self) -> bool:
r"""If no candidate configuration is found"""
return all([len(ii) == 0 for ii in self.get_candidate_ids()])
pass

@abstractmethod
def get_candidate_ids(
Expand Down
21 changes: 18 additions & 3 deletions dpgen2/exploration/report/report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ def __init__(
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def doc() -> str:
Expand Down Expand Up @@ -274,6 +278,10 @@ def record(
# accurate set is substracted by the candidate set
self.accur = self.accur - self.candi
self.model_devi = model_devi
self._no_candidate = (len(self.candi) == 0)
self._failed_ratio = float(len(self.failed)) / float(self.nframes)
self._accurate_ratio = float(len(self.accur)) / float(self.nframes)
self._candidate_ratio = float(len(self.candi)) / float(self.nframes)

def _record_one_traj(
self,
Expand Down Expand Up @@ -346,29 +354,36 @@ def failed_ratio(
self,
tag=None,
):
return float(len(self.failed)) / float(self.nframes)
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
return float(len(self.accur)) / float(self.nframes)
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
return float(len(self.candi)) / float(self.nframes)
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = self.ntraj
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
20 changes: 14 additions & 6 deletions dpgen2/exploration/report/report_trust_levels_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def __init__(
self.fmt_str = " ".join([f"%{ii}s" for ii in spaces])
self.fmt_flt = "%.4f"
self.header_str = "#" + self.fmt_str % print_tuple
self._no_candidate = False
self._failed_ratio = None
self._accurate_ratio = None
self._candidate_ratio = None

@staticmethod
def args() -> List[Argument]:
Expand Down Expand Up @@ -133,6 +137,10 @@ def record(
assert len(self.traj_accu) == ntraj
assert len(self.traj_fail) == ntraj
self.model_devi = model_devi
self._no_candidate = (sum([len(ii) for ii in self.traj_cand]) == 0)
self._failed_ratio = float(sum([len(ii) for ii in self.traj_fail])) / float(sum(self.traj_nframes))
self._accurate_ratio = float(sum([len(ii) for ii in self.traj_accu])) / float(sum(self.traj_nframes))
self._candidate_ratio = float(sum([len(ii) for ii in self.traj_cand])) / float(sum(self.traj_nframes))

def _get_indexes(
self,
Expand Down Expand Up @@ -205,22 +213,22 @@ def failed_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_fail]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._failed_ratio

def accurate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_accu]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._accurate_ratio

def candidate_ratio(
self,
tag=None,
):
traj_nf = [len(ii) for ii in self.traj_cand]
return float(sum(traj_nf)) / float(sum(self.traj_nframes))
return self._candidate_ratio

def no_candidate(self) -> bool:
return self._no_candidate

@abstractmethod
def get_candidate_ids(
Expand Down
4 changes: 4 additions & 0 deletions dpgen2/exploration/report/report_trust_levels_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ def converged(
def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
4 changes: 4 additions & 0 deletions dpgen2/exploration/report/report_trust_levels_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,16 @@ def converged(
def get_candidate_ids(
self,
max_nframes: Optional[int] = None,
clear: bool = True,
) -> List[List[int]]:
ntraj = len(self.traj_nframes)
id_cand = self._get_candidates(max_nframes)
id_cand_list = [[] for ii in range(ntraj)]
for ii in id_cand:
id_cand_list[ii[0]].append(ii[1])
# free the memory, this method should only be called once
if clear:
self.clear()
return id_cand_list

def _get_candidates(
Expand Down
8 changes: 4 additions & 4 deletions tests/exploration/test_report_adaptive_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class MockedReport:
self.assertFalse(ter.converged([mr, mr1, mr]))
self.assertTrue(ter.converged([mr1, mr, mr]))

picked = ter.get_candidate_ids(2)
picked = ter.get_candidate_ids(2, clear=False)
npicked = 0
self.assertEqual(len(picked), 2)
for ii in range(2):
Expand Down Expand Up @@ -218,12 +218,12 @@ def faked_choices(
return ret

ter.record(model_devi)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(ter.candi, expected_cand)
self.assertEqual(ter.accur, expected_accu)
self.assertEqual(set(ter.failed), expected_fail)
with mock.patch("random.choices", faked_choices):
picked = ter.get_candidate_ids(11)
self.assertFalse(ter.converged([]))
self.assertEqual(len(picked), 2)
self.assertEqual(sorted(picked[0]), [1, 3])
self.assertEqual(sorted(picked[1]), [1, 5, 7])
Expand Down

0 comments on commit 2ef45ab

Please sign in to comment.