diff --git a/src/cryo_sbi/inference/models/embedding_nets.py b/src/cryo_sbi/inference/models/embedding_nets.py index b8b5f41..6293819 100644 --- a/src/cryo_sbi/inference/models/embedding_nets.py +++ b/src/cryo_sbi/inference/models/embedding_nets.py @@ -395,13 +395,13 @@ def forward(self, x): x = self.avgpool(x).flatten(start_dim=1) x = self.feedforward(x) return x - + @add_embedding("ConvEncoder_Tutorial") class ConvEncoder(nn.Module): def __init__(self, output_dimension: int): super(ConvEncoder, self).__init__() - ndf = 16 # fixed for the tutorial + ndf = 16 # fixed for the tutorial self.main = nn.Sequential( # input is 1 x 64 x 64 nn.Conv2d(1, ndf, 4, 2, 1, bias=False), @@ -429,6 +429,5 @@ def forward(self, x): return x.view(x.size(0), -1) # flatten - if __name__ == "__main__": pass diff --git a/src/cryo_sbi/inference/priors.py b/src/cryo_sbi/inference/priors.py index 6342b1f..f8fbb29 100644 --- a/src/cryo_sbi/inference/priors.py +++ b/src/cryo_sbi/inference/priors.py @@ -15,7 +15,7 @@ def gen_quat() -> torch.Tensor: count = 0 while count < 1: quat = 2 * torch.rand(size=(4,)) - 1 - norm = torch.sqrt(torch.sum(quat ** 2)) + norm = torch.sqrt(torch.sum(quat**2)) if 0.2 <= norm <= 1.0: quat /= norm count += 1 @@ -207,7 +207,7 @@ class PriorLoader(DataLoader): def __init__( self, prior: Distribution, - batch_size: int = 2 ** 8, # 256 + batch_size: int = 2**8, # 256 **kwargs, ): super().__init__( diff --git a/src/cryo_sbi/utils/image_utils.py b/src/cryo_sbi/utils/image_utils.py index 6080fde..0699b59 100644 --- a/src/cryo_sbi/utils/image_utils.py +++ b/src/cryo_sbi/utils/image_utils.py @@ -25,9 +25,9 @@ def circular_mask(n_pixels: int, radius: int, inside: bool = True) -> torch.Tens r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 if inside is True: - mask = r_2d < radius ** 2 + mask = r_2d < radius**2 else: - mask = r_2d > radius ** 2 + mask = r_2d > radius**2 return mask @@ -183,7 +183,7 @@ def __init__(self, image_size: int, sigma: int): -0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size ) self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2 - self._mask = torch.exp(-self._r_2d / (2 * sigma ** 2)) + self._mask = torch.exp(-self._r_2d / (2 * sigma**2)) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ diff --git a/src/cryo_sbi/wpa_simulator/ctf.py b/src/cryo_sbi/wpa_simulator/ctf.py index 9717e67..35a4a81 100644 --- a/src/cryo_sbi/wpa_simulator/ctf.py +++ b/src/cryo_sbi/wpa_simulator/ctf.py @@ -21,7 +21,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch. freq_pix_1d = torch.fft.fftfreq(num_pixels, d=pixel_size, device=image.device) x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij") - freq2_2d = x ** 2 + y ** 2 + freq2_2d = x**2 + y**2 freq2_2d = freq2_2d.expand(num_batch, -1, -1) imag = torch.zeros_like(freq2_2d, device=image.device) * 1j @@ -30,7 +30,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch. ctf = ( -amp * torch.cos(phase * freq2_2d * 0.5) - - torch.sqrt(1 - amp ** 2) * torch.sin(phase * freq2_2d * 0.5) + - torch.sqrt(1 - amp**2) * torch.sin(phase * freq2_2d * 0.5) + imag ) ctf = ctf * env / amp diff --git a/src/cryo_sbi/wpa_simulator/image_generation.py b/src/cryo_sbi/wpa_simulator/image_generation.py index c19a18b..3e9bc00 100644 --- a/src/cryo_sbi/wpa_simulator/image_generation.py +++ b/src/cryo_sbi/wpa_simulator/image_generation.py @@ -13,7 +13,7 @@ def gen_quat() -> torch.Tensor: count = 0 while count < 1: quat = 2 * torch.rand(size=(4,)) - 1 - norm = torch.sqrt(torch.sum(quat ** 2)) + norm = torch.sqrt(torch.sum(quat**2)) if 0.2 <= norm <= 1.0: quat /= norm count += 1 @@ -72,7 +72,7 @@ def project_density( """ num_batch, _, num_atoms = coords.shape - norm = 1 / (2 * torch.pi * sigma ** 2 * num_atoms) + norm = 1 / (2 * torch.pi * sigma**2 * num_atoms) grid_min = -pixel_size * num_pixels * 0.5 grid_max = pixel_size * num_pixels * 0.5 diff --git a/src/cryo_sbi/wpa_simulator/noise.py b/src/cryo_sbi/wpa_simulator/noise.py index 717d320..a763bb2 100644 --- a/src/cryo_sbi/wpa_simulator/noise.py +++ b/src/cryo_sbi/wpa_simulator/noise.py @@ -19,7 +19,7 @@ def circular_mask(n_pixels: int, radius: int, device: str = "cpu") -> torch.Tens -0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels, device=device ) r_2d = grid[None, :] ** 2 + grid[:, None] ** 2 - mask = r_2d < radius ** 2 + mask = r_2d < radius**2 return mask @@ -37,7 +37,9 @@ def get_snr(images, snr): images[:, mask], dim=[-1] ) # images are not centered at 0, so std is not the same as power assert signal_power.shape[0] == images.shape[0] - noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt(torch.pow(torch.tensor(10), snr)) + noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt( + torch.pow(torch.tensor(10), snr) + ) return noise_power diff --git a/tests/test_image_utils.py b/tests/test_image_utils.py index dadef69..ff21436 100644 --- a/tests/test_image_utils.py +++ b/tests/test_image_utils.py @@ -11,9 +11,9 @@ def test_circular_mask(): assert inside_mask.shape == (n_pixels, n_pixels) assert outside_mask.shape == (n_pixels, n_pixels) - assert inside_mask.sum().item() == pytest.approx(radius ** 2 * 3.14159, abs=10) + assert inside_mask.sum().item() == pytest.approx(radius**2 * 3.14159, abs=10) assert outside_mask.sum().item() == pytest.approx( - n_pixels ** 2 - radius ** 2 * 3.14159, abs=10 + n_pixels**2 - radius**2 * 3.14159, abs=10 ) @@ -27,7 +27,7 @@ def test_mask_class(): masked_image = mask(image) assert masked_image.shape == (image_size, image_size) assert masked_image[inside].sum().item() == pytest.approx( - image_size ** 2 - radius ** 2 * 3.14159, abs=10 + image_size**2 - radius**2 * 3.14159, abs=10 ) diff --git a/tests/test_simulator.py b/tests/test_simulator.py index 0d23e96..0c78e63 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -76,4 +76,6 @@ def test_get_snr(noise_std, num_images): assert snr.shape == torch.Size([images.shape[0], 1, 1]), "SNR has wrong shape" assert isinstance(snr, torch.Tensor) - assert torch.allclose(snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01), "SNR is not correct" + assert torch.allclose( + snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01 + ), "SNR is not correct" diff --git a/tutorials/tutorial.ipynb b/tutorials/tutorial.ipynb index 6d82d3f..0d0a828 100644 --- a/tutorials/tutorial.ipynb +++ b/tutorials/tutorial.ipynb @@ -41,8 +41,10 @@ "metadata": {}, "outputs": [], "source": [ - "models = torch.tensor([[[0., 0.],[-dist, dist],[0., 0.]] for dist in distance_centers])\n", - "torch.save(models, 'models.pt')" + "models = torch.tensor(\n", + " [[[0.0, 0.0], [-dist, dist], [0.0, 0.0]] for dist in distance_centers]\n", + ")\n", + "torch.save(models, \"models.pt\")" ] }, { @@ -131,7 +133,7 @@ "for idx, ax in enumerate(axes.flatten()):\n", " ax.imshow(images[idx], vmin=-3, vmax=3, cmap=\"gray\")\n", " ax.set_title(f\"Distance: {distance_centers[dist[idx].round().long()].item():.2f}\")\n", - " ax.axis('off')" + " ax.axis(\"off\")" ] }, { @@ -254,7 +256,7 @@ " \"train_config.json\",\n", " \"tutorial_estimator.pt\",\n", " device=\"cuda\",\n", - ")\n" + ")" ] }, { @@ -370,12 +372,12 @@ " c=dist,\n", " s=0.5,\n", " cmap=\"rainbow\",\n", - " #vmin=0,\n", - " #vmax=70,\n", + " # vmin=0,\n", + " # vmax=70,\n", ")\n", "\n", "cbar = plt.colorbar(mappable=scatter)\n", - "cbar.set_label('Posterior mean', rotation=270, fontsize=15, labelpad=20)\n", + "cbar.set_label(\"Posterior mean\", rotation=270, fontsize=15, labelpad=20)\n", "\n", "\n", "axes.set_xlabel(\"UMAP 1\", fontsize=15)\n",