Skip to content

Commit

Permalink
formatted code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Jan 30, 2024
1 parent 7caf5bf commit 6739b78
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 25 deletions.
5 changes: 2 additions & 3 deletions src/cryo_sbi/inference/models/embedding_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,13 @@ def forward(self, x):
x = self.avgpool(x).flatten(start_dim=1)
x = self.feedforward(x)
return x


@add_embedding("ConvEncoder_Tutorial")
class ConvEncoder(nn.Module):
def __init__(self, output_dimension: int):
super(ConvEncoder, self).__init__()
ndf = 16 # fixed for the tutorial
ndf = 16 # fixed for the tutorial
self.main = nn.Sequential(
# input is 1 x 64 x 64
nn.Conv2d(1, ndf, 4, 2, 1, bias=False),
Expand Down Expand Up @@ -429,6 +429,5 @@ def forward(self, x):
return x.view(x.size(0), -1) # flatten



if __name__ == "__main__":
pass
4 changes: 2 additions & 2 deletions src/cryo_sbi/inference/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def gen_quat() -> torch.Tensor:
count = 0
while count < 1:
quat = 2 * torch.rand(size=(4,)) - 1
norm = torch.sqrt(torch.sum(quat ** 2))
norm = torch.sqrt(torch.sum(quat**2))
if 0.2 <= norm <= 1.0:
quat /= norm
count += 1
Expand Down Expand Up @@ -207,7 +207,7 @@ class PriorLoader(DataLoader):
def __init__(
self,
prior: Distribution,
batch_size: int = 2 ** 8, # 256
batch_size: int = 2**8, # 256
**kwargs,
):
super().__init__(
Expand Down
6 changes: 3 additions & 3 deletions src/cryo_sbi/utils/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def circular_mask(n_pixels: int, radius: int, inside: bool = True) -> torch.Tens
r_2d = grid[None, :] ** 2 + grid[:, None] ** 2

if inside is True:
mask = r_2d < radius ** 2
mask = r_2d < radius**2
else:
mask = r_2d > radius ** 2
mask = r_2d > radius**2

return mask

Expand Down Expand Up @@ -183,7 +183,7 @@ def __init__(self, image_size: int, sigma: int):
-0.5 * (image_size - 1), 0.5 * (image_size - 1), image_size
)
self._r_2d = self._grid[None, :] ** 2 + self._grid[:, None] ** 2
self._mask = torch.exp(-self._r_2d / (2 * sigma ** 2))
self._mask = torch.exp(-self._r_2d / (2 * sigma**2))

def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Expand Down
4 changes: 2 additions & 2 deletions src/cryo_sbi/wpa_simulator/ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch.
freq_pix_1d = torch.fft.fftfreq(num_pixels, d=pixel_size, device=image.device)
x, y = torch.meshgrid(freq_pix_1d, freq_pix_1d, indexing="ij")

freq2_2d = x ** 2 + y ** 2
freq2_2d = x**2 + y**2
freq2_2d = freq2_2d.expand(num_batch, -1, -1)
imag = torch.zeros_like(freq2_2d, device=image.device) * 1j

Expand All @@ -30,7 +30,7 @@ def apply_ctf(image: torch.Tensor, defocus, b_factor, amp, pixel_size) -> torch.

ctf = (
-amp * torch.cos(phase * freq2_2d * 0.5)
- torch.sqrt(1 - amp ** 2) * torch.sin(phase * freq2_2d * 0.5)
- torch.sqrt(1 - amp**2) * torch.sin(phase * freq2_2d * 0.5)
+ imag
)
ctf = ctf * env / amp
Expand Down
4 changes: 2 additions & 2 deletions src/cryo_sbi/wpa_simulator/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def gen_quat() -> torch.Tensor:
count = 0
while count < 1:
quat = 2 * torch.rand(size=(4,)) - 1
norm = torch.sqrt(torch.sum(quat ** 2))
norm = torch.sqrt(torch.sum(quat**2))
if 0.2 <= norm <= 1.0:
quat /= norm
count += 1
Expand Down Expand Up @@ -72,7 +72,7 @@ def project_density(
"""

num_batch, _, num_atoms = coords.shape
norm = 1 / (2 * torch.pi * sigma ** 2 * num_atoms)
norm = 1 / (2 * torch.pi * sigma**2 * num_atoms)

grid_min = -pixel_size * num_pixels * 0.5
grid_max = pixel_size * num_pixels * 0.5
Expand Down
6 changes: 4 additions & 2 deletions src/cryo_sbi/wpa_simulator/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def circular_mask(n_pixels: int, radius: int, device: str = "cpu") -> torch.Tens
-0.5 * (n_pixels - 1), 0.5 * (n_pixels - 1), n_pixels, device=device
)
r_2d = grid[None, :] ** 2 + grid[:, None] ** 2
mask = r_2d < radius ** 2
mask = r_2d < radius**2

return mask

Expand All @@ -37,7 +37,9 @@ def get_snr(images, snr):
images[:, mask], dim=[-1]
) # images are not centered at 0, so std is not the same as power
assert signal_power.shape[0] == images.shape[0]
noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt(torch.pow(torch.tensor(10), snr))
noise_power = signal_power.reshape(-1, 1, 1) / torch.sqrt(
torch.pow(torch.tensor(10), snr)
)

return noise_power

Expand Down
6 changes: 3 additions & 3 deletions tests/test_image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def test_circular_mask():

assert inside_mask.shape == (n_pixels, n_pixels)
assert outside_mask.shape == (n_pixels, n_pixels)
assert inside_mask.sum().item() == pytest.approx(radius ** 2 * 3.14159, abs=10)
assert inside_mask.sum().item() == pytest.approx(radius**2 * 3.14159, abs=10)
assert outside_mask.sum().item() == pytest.approx(
n_pixels ** 2 - radius ** 2 * 3.14159, abs=10
n_pixels**2 - radius**2 * 3.14159, abs=10
)


Expand All @@ -27,7 +27,7 @@ def test_mask_class():
masked_image = mask(image)
assert masked_image.shape == (image_size, image_size)
assert masked_image[inside].sum().item() == pytest.approx(
image_size ** 2 - radius ** 2 * 3.14159, abs=10
image_size**2 - radius**2 * 3.14159, abs=10
)


Expand Down
4 changes: 3 additions & 1 deletion tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,6 @@ def test_get_snr(noise_std, num_images):

assert snr.shape == torch.Size([images.shape[0], 1, 1]), "SNR has wrong shape"
assert isinstance(snr, torch.Tensor)
assert torch.allclose(snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01), "SNR is not correct"
assert torch.allclose(
snr.flatten(), noise_std * torch.ones(images.shape[0]), atol=1e-01
), "SNR is not correct"
16 changes: 9 additions & 7 deletions tutorials/tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@
"metadata": {},
"outputs": [],
"source": [
"models = torch.tensor([[[0., 0.],[-dist, dist],[0., 0.]] for dist in distance_centers])\n",
"torch.save(models, 'models.pt')"
"models = torch.tensor(\n",
" [[[0.0, 0.0], [-dist, dist], [0.0, 0.0]] for dist in distance_centers]\n",
")\n",
"torch.save(models, \"models.pt\")"
]
},
{
Expand Down Expand Up @@ -131,7 +133,7 @@
"for idx, ax in enumerate(axes.flatten()):\n",
" ax.imshow(images[idx], vmin=-3, vmax=3, cmap=\"gray\")\n",
" ax.set_title(f\"Distance: {distance_centers[dist[idx].round().long()].item():.2f}\")\n",
" ax.axis('off')"
" ax.axis(\"off\")"
]
},
{
Expand Down Expand Up @@ -254,7 +256,7 @@
" \"train_config.json\",\n",
" \"tutorial_estimator.pt\",\n",
" device=\"cuda\",\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -370,12 +372,12 @@
" c=dist,\n",
" s=0.5,\n",
" cmap=\"rainbow\",\n",
" #vmin=0,\n",
" #vmax=70,\n",
" # vmin=0,\n",
" # vmax=70,\n",
")\n",
"\n",
"cbar = plt.colorbar(mappable=scatter)\n",
"cbar.set_label('Posterior mean', rotation=270, fontsize=15, labelpad=20)\n",
"cbar.set_label(\"Posterior mean\", rotation=270, fontsize=15, labelpad=20)\n",
"\n",
"\n",
"axes.set_xlabel(\"UMAP 1\", fontsize=15)\n",
Expand Down

0 comments on commit 6739b78

Please sign in to comment.