Skip to content

Commit

Permalink
Fix #3 (#4)
Browse files Browse the repository at this point in the history
Use .item() instead of dccn for losses. Change type annotations accordingly.
  • Loading branch information
ulupo authored May 14, 2024
1 parent dcab707 commit c26b106
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
15 changes: 6 additions & 9 deletions diffpass/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ class DiffPaSSResults:
]
# Hard losses
hard_losses: Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
GradientDescentList[GroupByGroupList[float]],
BootstrapList[GradientDescentList[GroupByGroupList[float]]],
]
# Soft losses
soft_losses: Optional[
Union[
GradientDescentList[GroupByGroupList[np.ndarray]],
BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],
GradientDescentList[GroupByGroupList[float]],
BootstrapList[GradientDescentList[GroupByGroupList[float]]],
]
]

Expand Down Expand Up @@ -318,7 +318,7 @@ def _hard_pass(
for perms_this_group in perms
]
)
results.hard_losses.append(dccn(loss))
results.hard_losses.append(loss.item())

def _soft_pass(
self,
Expand All @@ -338,10 +338,7 @@ def _soft_pass(
[dccn(perms_this_group) for perms_this_group in perms]
)
if record_soft_losses:
results.soft_losses.append(dccn(loss))

# Compute total loss
loss = loss.sum()
results.soft_losses.append(loss.item())

return loss

Expand Down
15 changes: 6 additions & 9 deletions nbs/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@
" ]\n",
" # Hard losses\n",
" hard_losses: Union[\n",
" GradientDescentList[GroupByGroupList[np.ndarray]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],\n",
" GradientDescentList[GroupByGroupList[float]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[float]]],\n",
" ]\n",
" # Soft losses\n",
" soft_losses: Optional[\n",
" Union[\n",
" GradientDescentList[GroupByGroupList[np.ndarray]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[np.ndarray]]],\n",
" GradientDescentList[GroupByGroupList[float]],\n",
" BootstrapList[GradientDescentList[GroupByGroupList[float]]],\n",
" ]\n",
" ]\n",
"\n",
Expand Down Expand Up @@ -386,7 +386,7 @@
" for perms_this_group in perms\n",
" ]\n",
" )\n",
" results.hard_losses.append(dccn(loss))\n",
" results.hard_losses.append(loss.item())\n",
"\n",
" def _soft_pass(\n",
" self,\n",
Expand All @@ -406,10 +406,7 @@
" [dccn(perms_this_group) for perms_this_group in perms]\n",
" )\n",
" if record_soft_losses:\n",
" results.soft_losses.append(dccn(loss))\n",
"\n",
" # Compute total loss\n",
" loss = loss.sum()\n",
" results.soft_losses.append(loss.item())\n",
"\n",
" return loss\n",
"\n",
Expand Down

0 comments on commit c26b106

Please sign in to comment.