diff --git a/src/cultionet/models/field_of_junctions.py b/src/cultionet/models/field_of_junctions.py index ab5d87ea..42058167 100644 --- a/src/cultionet/models/field_of_junctions.py +++ b/src/cultionet/models/field_of_junctions.py @@ -69,13 +69,11 @@ def __init__( # self.x0_y0_range = torch.linspace(-3.0, 3.0, self.nvals) # Create pytorch variables for angles and vertex position for each patch - angles = torch.ones( - 1, 3, self.h_patches, self.w_patches, dtype=torch.float32 - ) - x0y0 = torch.ones( - 1, 2, self.h_patches, self.w_patches, dtype=torch.float32 + self.params = nn.Parameter( + torch.ones( + 1, 5, self.h_patches, self.w_patches, dtype=torch.float32 + ) ) - self.params = nn.Parameter(torch.cat([angles, x0y0], dim=1)) def forward(self, x: torch.Tensor) -> T.Dict[str, torch.Tensor]: batch_size, in_channels, in_height, in_width = x.shape