Skip to content

Commit

Permalink
Implement "effectively fixed pairs" and add progress bar option
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Nov 6, 2023
1 parent b50bc55 commit 4b550be
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 26 deletions.
33 changes: 21 additions & 12 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,12 @@ def __init__(
# as "batch" dimensions.
self.nonfixed_group_sizes_ = (
tuple(
s - len(fm) if fm is not None else s
for s, fm in zip(self.group_sizes, self.fixed_matchings)
s - num_efm
for s, num_efm in zip(
self.group_sizes, self._effective_number_fixed_matchings
)
)
if self.fixed_matchings is not None
if self.fixed_matchings
else self.group_sizes
)
self.log_alphas = ParameterList(
Expand All @@ -105,14 +107,14 @@ def __init__(
def _validate_fixed_matchings(
self, fixed_matchings: Optional[Sequence[Sequence[Sequence[int]]]] = None
) -> None:
if fixed_matchings is not None:
if fixed_matchings:
if len(fixed_matchings) != len(self.group_sizes):
raise ValueError(
"If `fixed_matchings` is provided, it must have the same length as "
"`group_sizes`."
)
for s, fm in zip(self.group_sizes, fixed_matchings):
if fm is None:
if not fm:
continue
if any([len(p) != 2 for p in fm]):
raise ValueError(
Expand All @@ -123,16 +125,23 @@ def _validate_fixed_matchings(
"All fixed matchings must be within the range of the corresponding "
"group size."
)
self._effective_number_fixed_matchings = []
self._fixed_matchings_zip = []
for idx, (s, fm) in enumerate(zip(self.group_sizes, fixed_matchings)):
mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)
if fm:
for i, j in fm:
_fm = [] if fm is None else fm
num_fm = len(_fm)
is_not_fully_fixed = s - num_fm > 1
num_efm = s - (s - num_fm) * is_not_fully_fixed
self._effective_number_fixed_matchings.append(num_efm)
if not is_not_fully_fixed:
mask = torch.zeros(*self.ensemble_shape, s, s, dtype=torch.bool)
else:
mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)
for i, j in _fm:
mask[..., j, :] = False
mask[..., :, i] = False
self.register_buffer(f"_not_fixed_masks_{idx}", mask)
self._fixed_matchings_zip = [
tuple(zip(*fm)) if fm else ((), ()) for fm in fixed_matchings
]
self._fixed_matchings_zip.append(tuple(zip(*_fm)) if _fm else ((), ()))

@property
def _not_fixed_masks(self) -> list[torch.Tensor]:
Expand All @@ -154,7 +163,7 @@ def mode(self, value) -> None:
_mats_fn_no_fixed = getattr(self, f"_{self._mode}_mats")
self._mats_fn = (
_mats_fn_no_fixed
if self.fixed_matchings is None
if not self.fixed_matchings
else self._impl_fixed_matchings(_mats_fn_no_fixed)
)

Expand Down
8 changes: 7 additions & 1 deletion diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def fit(
optimizer_kwargs: Optional[dict[str, Any]] = None,
mean_centering: bool = True,
similarity_gradient_bypass: bool = False,
show_pbar: bool = True,
) -> DiffPASSResults:
# Validate inputs
self.validate_inputs(x, y)
Expand Down Expand Up @@ -400,13 +401,18 @@ def fit(
)
self.optimizer_ = optimizer_cls(self.parameters(), **optimizer_kwargs)

# Progress bar
pbar = range(epochs + 1)
if show_pbar:
pbar = tqdm(pbar)

# ------------------------------------------------------------------------------------------
## Gradient descent
# ------------------------------------------------------------------------------------------
x_perm_hard = None
with torch.set_grad_enabled(True):
self.optimizer_.zero_grad()
for i in tqdm(range(epochs + 1)):
for i in pbar:
# Hard pass
self.hard_()
with torch.no_grad():
Expand Down
33 changes: 21 additions & 12 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,12 @@
" # as \"batch\" dimensions.\n",
" self.nonfixed_group_sizes_ = (\n",
" tuple(\n",
" s - len(fm) if fm is not None else s\n",
" for s, fm in zip(self.group_sizes, self.fixed_matchings)\n",
" s - num_efm\n",
" for s, num_efm in zip(\n",
" self.group_sizes, self._effective_number_fixed_matchings\n",
" )\n",
" )\n",
" if self.fixed_matchings is not None\n",
" if self.fixed_matchings\n",
" else self.group_sizes\n",
" )\n",
" self.log_alphas = ParameterList(\n",
Expand All @@ -160,14 +162,14 @@
" def _validate_fixed_matchings(\n",
" self, fixed_matchings: Optional[Sequence[Sequence[Sequence[int]]]] = None\n",
" ) -> None:\n",
" if fixed_matchings is not None:\n",
" if fixed_matchings:\n",
" if len(fixed_matchings) != len(self.group_sizes):\n",
" raise ValueError(\n",
" \"If `fixed_matchings` is provided, it must have the same length as \"\n",
" \"`group_sizes`.\"\n",
" )\n",
" for s, fm in zip(self.group_sizes, fixed_matchings):\n",
" if fm is None:\n",
" if not fm:\n",
" continue\n",
" if any([len(p) != 2 for p in fm]):\n",
" raise ValueError(\n",
Expand All @@ -178,16 +180,23 @@
" \"All fixed matchings must be within the range of the corresponding \"\n",
" \"group size.\"\n",
" )\n",
" self._effective_number_fixed_matchings = []\n",
" self._fixed_matchings_zip = []\n",
" for idx, (s, fm) in enumerate(zip(self.group_sizes, fixed_matchings)):\n",
" mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)\n",
" if fm:\n",
" for i, j in fm:\n",
" _fm = [] if fm is None else fm\n",
" num_fm = len(_fm)\n",
" is_not_fully_fixed = s - num_fm > 1\n",
" num_efm = s - (s - num_fm) * is_not_fully_fixed\n",
" self._effective_number_fixed_matchings.append(num_efm)\n",
" if not is_not_fully_fixed:\n",
" mask = torch.zeros(*self.ensemble_shape, s, s, dtype=torch.bool)\n",
" else:\n",
" mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)\n",
" for i, j in _fm:\n",
" mask[..., j, :] = False\n",
" mask[..., :, i] = False\n",
" self.register_buffer(f\"_not_fixed_masks_{idx}\", mask)\n",
" self._fixed_matchings_zip = [\n",
" tuple(zip(*fm)) if fm else ((), ()) for fm in fixed_matchings\n",
" ]\n",
" self._fixed_matchings_zip.append(tuple(zip(*_fm)) if _fm else ((), ()))\n",
"\n",
" @property\n",
" def _not_fixed_masks(self) -> list[torch.Tensor]:\n",
Expand All @@ -209,7 +218,7 @@
" _mats_fn_no_fixed = getattr(self, f\"_{self._mode}_mats\")\n",
" self._mats_fn = (\n",
" _mats_fn_no_fixed\n",
" if self.fixed_matchings is None\n",
" if not self.fixed_matchings\n",
" else self._impl_fixed_matchings(_mats_fn_no_fixed)\n",
" )\n",
"\n",
Expand Down
8 changes: 7 additions & 1 deletion nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@
" optimizer_kwargs: Optional[dict[str, Any]] = None,\n",
" mean_centering: bool = True,\n",
" similarity_gradient_bypass: bool = False,\n",
" show_pbar: bool = True,\n",
" ) -> DiffPASSResults:\n",
" # Validate inputs\n",
" self.validate_inputs(x, y)\n",
Expand Down Expand Up @@ -472,13 +473,18 @@
" )\n",
" self.optimizer_ = optimizer_cls(self.parameters(), **optimizer_kwargs)\n",
"\n",
" # Progress bar\n",
" pbar = range(epochs + 1)\n",
" if show_pbar:\n",
" pbar = tqdm(pbar)\n",
"\n",
" # ------------------------------------------------------------------------------------------\n",
" ## Gradient descent\n",
" # ------------------------------------------------------------------------------------------\n",
" x_perm_hard = None\n",
" with torch.set_grad_enabled(True):\n",
" self.optimizer_.zero_grad()\n",
" for i in tqdm(range(epochs + 1)):\n",
" for i in pbar:\n",
" # Hard pass\n",
" self.hard_()\n",
" with torch.no_grad():\n",
Expand Down

0 comments on commit 4b550be

Please sign in to comment.