Skip to content

Commit

Permalink
Fix dtype conversion bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ulupo committed Dec 2, 2023
1 parent 73e6239 commit b66dd59
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions diffpass/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def _fit(
)
results.hard_perms.append(
[
_dccn(perms_this_group).argmax(-1).to(torch.int16)
_dccn(perms_this_group.argmax(-1).to(torch.int16))
for perms_this_group in perms
]
)
Expand Down Expand Up @@ -653,7 +653,7 @@ def _fit(
)
results.hard_perms.append(
[
_dccn(perms_this_group).argmax(-1).to(torch.int16)
_dccn(perms_this_group.argmax(-1).to(torch.int16))
for perms_this_group in perms
]
)
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def _fit(
)
results.hard_perms.append(
[
_dccn(perms_this_group).argmax(-1).to(torch.int16)
_dccn(perms_this_group.argmax(-1).to(torch.int16))
for perms_this_group in perms
]
)
Expand Down
6 changes: 3 additions & 3 deletions nbs/train.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@
" )\n",
" results.hard_perms.append(\n",
" [\n",
" _dccn(perms_this_group).argmax(-1).to(torch.int16)\n",
" _dccn(perms_this_group.argmax(-1).to(torch.int16))\n",
" for perms_this_group in perms\n",
" ]\n",
" )\n",
Expand Down Expand Up @@ -733,7 +733,7 @@
" )\n",
" results.hard_perms.append(\n",
" [\n",
" _dccn(perms_this_group).argmax(-1).to(torch.int16)\n",
" _dccn(perms_this_group.argmax(-1).to(torch.int16))\n",
" for perms_this_group in perms\n",
" ]\n",
" )\n",
Expand Down Expand Up @@ -1112,7 +1112,7 @@
" )\n",
" results.hard_perms.append(\n",
" [\n",
" _dccn(perms_this_group).argmax(-1).to(torch.int16)\n",
" _dccn(perms_this_group.argmax(-1).to(torch.int16))\n",
" for perms_this_group in perms\n",
" ]\n",
" )\n",
Expand Down

0 comments on commit b66dd59

Please sign in to comment.