Skip to content

Commit

Permalink
add foj loss
Browse files Browse the repository at this point in the history
  • Loading branch information
jgrss committed May 2, 2024
1 parent 9e640e6 commit dd08d4c
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions src/cultionet/models/field_of_junctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,11 @@ def __init__(
# self.x0_y0_range = torch.linspace(-3.0, 3.0, self.nvals)

# Create pytorch variables for angles and vertex position for each patch
angles = torch.ones(
1, 3, self.h_patches, self.w_patches, dtype=torch.float32
)
x0y0 = torch.ones(
1, 2, self.h_patches, self.w_patches, dtype=torch.float32
self.params = nn.Parameter(
torch.ones(
1, 5, self.h_patches, self.w_patches, dtype=torch.float32
)
)
self.params = nn.Parameter(torch.cat([angles, x0y0], dim=1))

def forward(self, x: torch.Tensor) -> T.Dict[str, torch.Tensor]:
batch_size, in_channels, in_height, in_width = x.shape
Expand Down

0 comments on commit dd08d4c

Please sign in to comment.