diff --git a/diffpass/model.py b/diffpass/model.py index 1f686e2..712fd64 100644 --- a/diffpass/model.py +++ b/diffpass/model.py @@ -78,10 +78,12 @@ def __init__( # as "batch" dimensions. self.nonfixed_group_sizes_ = ( tuple( - s - len(fm) if fm is not None else s - for s, fm in zip(self.group_sizes, self.fixed_matchings) + s - num_efm + for s, num_efm in zip( + self.group_sizes, self._effective_number_fixed_matchings + ) ) - if self.fixed_matchings is not None + if self.fixed_matchings else self.group_sizes ) self.log_alphas = ParameterList( @@ -105,14 +107,14 @@ def __init__( def _validate_fixed_matchings( self, fixed_matchings: Optional[Sequence[Sequence[Sequence[int]]]] = None ) -> None: - if fixed_matchings is not None: + if fixed_matchings: if len(fixed_matchings) != len(self.group_sizes): raise ValueError( "If `fixed_matchings` is provided, it must have the same length as " "`group_sizes`." ) for s, fm in zip(self.group_sizes, fixed_matchings): - if fm is None: + if not fm: continue if any([len(p) != 2 for p in fm]): raise ValueError( @@ -123,16 +125,23 @@ def _validate_fixed_matchings( "All fixed matchings must be within the range of the corresponding " "group size." ) + self._effective_number_fixed_matchings = [] + self._fixed_matchings_zip = [] for idx, (s, fm) in enumerate(zip(self.group_sizes, fixed_matchings)): - mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool) - if fm: - for i, j in fm: + _fm = [] if fm is None else fm + num_fm = len(_fm) + is_not_fully_fixed = s - num_fm > 1 + num_efm = s - (s - num_fm) * is_not_fully_fixed + self._effective_number_fixed_matchings.append(num_efm) + if not is_not_fully_fixed: + mask = torch.zeros(*self.ensemble_shape, s, s, dtype=torch.bool) + else: + mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool) + for i, j in _fm: mask[..., j, :] = False mask[..., :, i] = False self.register_buffer(f"_not_fixed_masks_{idx}", mask) - self._fixed_matchings_zip = [ - tuple(zip(*fm)) if fm else ((), ()) for fm in fixed_matchings - ] + self._fixed_matchings_zip.append(tuple(zip(*_fm)) if _fm else ((), ())) @property def _not_fixed_masks(self) -> list[torch.Tensor]: @@ -154,7 +163,7 @@ def mode(self, value) -> None: _mats_fn_no_fixed = getattr(self, f"_{self._mode}_mats") self._mats_fn = ( _mats_fn_no_fixed - if self.fixed_matchings is None + if not self.fixed_matchings else self._impl_fixed_matchings(_mats_fn_no_fixed) ) diff --git a/diffpass/train.py b/diffpass/train.py index 3595a5b..81965a3 100644 --- a/diffpass/train.py +++ b/diffpass/train.py @@ -353,6 +353,7 @@ def fit( optimizer_kwargs: Optional[dict[str, Any]] = None, mean_centering: bool = True, similarity_gradient_bypass: bool = False, + show_pbar: bool = True, ) -> DiffPASSResults: # Validate inputs self.validate_inputs(x, y) @@ -400,13 +401,18 @@ def fit( ) self.optimizer_ = optimizer_cls(self.parameters(), **optimizer_kwargs) + # Progress bar + pbar = range(epochs + 1) + if show_pbar: + pbar = tqdm(pbar) + # ------------------------------------------------------------------------------------------ ## Gradient descent # ------------------------------------------------------------------------------------------ x_perm_hard = None with torch.set_grad_enabled(True): self.optimizer_.zero_grad() - for i in tqdm(range(epochs + 1)): + for i in pbar: # Hard pass self.hard_() with torch.no_grad(): diff --git a/nbs/model.ipynb b/nbs/model.ipynb index 59e1996..d8f1d97 100644 --- a/nbs/model.ipynb +++ b/nbs/model.ipynb @@ -133,10 +133,12 @@ " # as \"batch\" dimensions.\n", " self.nonfixed_group_sizes_ = (\n", " tuple(\n", - " s - len(fm) if fm is not None else s\n", - " for s, fm in zip(self.group_sizes, self.fixed_matchings)\n", + " s - num_efm\n", + " for s, num_efm in zip(\n", + " self.group_sizes, self._effective_number_fixed_matchings\n", + " )\n", " )\n", - " if self.fixed_matchings is not None\n", + " if self.fixed_matchings\n", " else self.group_sizes\n", " )\n", " self.log_alphas = ParameterList(\n", @@ -160,14 +162,14 @@ " def _validate_fixed_matchings(\n", " self, fixed_matchings: Optional[Sequence[Sequence[Sequence[int]]]] = None\n", " ) -> None:\n", - " if fixed_matchings is not None:\n", + " if fixed_matchings:\n", " if len(fixed_matchings) != len(self.group_sizes):\n", " raise ValueError(\n", " \"If `fixed_matchings` is provided, it must have the same length as \"\n", " \"`group_sizes`.\"\n", " )\n", " for s, fm in zip(self.group_sizes, fixed_matchings):\n", - " if fm is None:\n", + " if not fm:\n", " continue\n", " if any([len(p) != 2 for p in fm]):\n", " raise ValueError(\n", @@ -178,16 +180,23 @@ " \"All fixed matchings must be within the range of the corresponding \"\n", " \"group size.\"\n", " )\n", + " self._effective_number_fixed_matchings = []\n", + " self._fixed_matchings_zip = []\n", " for idx, (s, fm) in enumerate(zip(self.group_sizes, fixed_matchings)):\n", - " mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)\n", - " if fm:\n", - " for i, j in fm:\n", + " _fm = [] if fm is None else fm\n", + " num_fm = len(_fm)\n", + " is_not_fully_fixed = s - num_fm > 1\n", + " num_efm = s - (s - num_fm) * is_not_fully_fixed\n", + " self._effective_number_fixed_matchings.append(num_efm)\n", + " if not is_not_fully_fixed:\n", + " mask = torch.zeros(*self.ensemble_shape, s, s, dtype=torch.bool)\n", + " else:\n", + " mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)\n", + " for i, j in _fm:\n", " mask[..., j, :] = False\n", " mask[..., :, i] = False\n", " self.register_buffer(f\"_not_fixed_masks_{idx}\", mask)\n", - " self._fixed_matchings_zip = [\n", - " tuple(zip(*fm)) if fm else ((), ()) for fm in fixed_matchings\n", - " ]\n", + " self._fixed_matchings_zip.append(tuple(zip(*_fm)) if _fm else ((), ()))\n", "\n", " @property\n", " def _not_fixed_masks(self) -> list[torch.Tensor]:\n", @@ -209,7 +218,7 @@ " _mats_fn_no_fixed = getattr(self, f\"_{self._mode}_mats\")\n", " self._mats_fn = (\n", " _mats_fn_no_fixed\n", - " if self.fixed_matchings is None\n", + " if not self.fixed_matchings\n", " else self._impl_fixed_matchings(_mats_fn_no_fixed)\n", " )\n", "\n", diff --git a/nbs/train.ipynb b/nbs/train.ipynb index 0804d16..1feed8c 100644 --- a/nbs/train.ipynb +++ b/nbs/train.ipynb @@ -425,6 +425,7 @@ " optimizer_kwargs: Optional[dict[str, Any]] = None,\n", " mean_centering: bool = True,\n", " similarity_gradient_bypass: bool = False,\n", + " show_pbar: bool = True,\n", " ) -> DiffPASSResults:\n", " # Validate inputs\n", " self.validate_inputs(x, y)\n", @@ -472,13 +473,18 @@ " )\n", " self.optimizer_ = optimizer_cls(self.parameters(), **optimizer_kwargs)\n", "\n", + " # Progress bar\n", + " pbar = range(epochs + 1)\n", + " if show_pbar:\n", + " pbar = tqdm(pbar)\n", + "\n", " # ------------------------------------------------------------------------------------------\n", " ## Gradient descent\n", " # ------------------------------------------------------------------------------------------\n", " x_perm_hard = None\n", " with torch.set_grad_enabled(True):\n", " self.optimizer_.zero_grad()\n", - " for i in tqdm(range(epochs + 1)):\n", + " for i in pbar:\n", " # Hard pass\n", " self.hard_()\n", " with torch.no_grad():\n",