Skip to content

Commit

Permalink
Fix handling of fixed matchings
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Oct 31, 2023
1 parent a31e4a3 commit b50bc55
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions diffpass/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def _validate_fixed_matchings(
)
for idx, (s, fm) in enumerate(zip(self.group_sizes, fixed_matchings)):
mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)
for i, j in fm:
mask[..., j, :] = False
mask[..., :, i] = False
if fm:
for i, j in fm:
mask[..., j, :] = False
mask[..., :, i] = False
self.register_buffer(f"_not_fixed_masks_{idx}", mask)
self._fixed_matchings_zip = [
tuple(zip(*fm)) if fm is not None else ((), ())
for fm in fixed_matchings
tuple(zip(*fm)) if fm else ((), ()) for fm in fixed_matchings
]

@property
Expand Down
10 changes: 5 additions & 5 deletions nbs/model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@
" )\n",
" for idx, (s, fm) in enumerate(zip(self.group_sizes, fixed_matchings)):\n",
" mask = torch.ones(*self.ensemble_shape, s, s, dtype=torch.bool)\n",
" for i, j in fm:\n",
" mask[..., j, :] = False\n",
" mask[..., :, i] = False\n",
" if fm:\n",
" for i, j in fm:\n",
" mask[..., j, :] = False\n",
" mask[..., :, i] = False\n",
" self.register_buffer(f\"_not_fixed_masks_{idx}\", mask)\n",
" self._fixed_matchings_zip = [\n",
" tuple(zip(*fm)) if fm is not None else ((), ())\n",
" for fm in fixed_matchings\n",
" tuple(zip(*fm)) if fm else ((), ()) for fm in fixed_matchings\n",
" ]\n",
"\n",
" @property\n",
Expand Down

0 comments on commit b50bc55

Please sign in to comment.