Skip to content

Commit

Permalink
removed deprecated tests and added config files and models for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Oct 20, 2023
1 parent 1688086 commit c0cffdf
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 129 deletions.
26 changes: 11 additions & 15 deletions tests/config_files/image_params_testing.json
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
{"N_PIXELS": 128,
"PIXEL_SIZE": 1.5,
"SIGMA": 4.0,
"MODEL_FILE": "data/protein_models/hsp90_models.npy",
"ROTATIONS": false,
"SHIFT": true,
"CTF": true,
"NOISE": true,
"DEFOCUS": 1.5,
"SNR": 1.0,
"RADIUS_MASK": 16,
"AMP": 0.1,
"B_FACTOR": 1.0,
"ELECWAVE": 0.019866
}
{
"N_PIXELS": 128,
"PIXEL_SIZE": 2.06,
"SIGMA": [0.5, 5.0],
"MODEL_FILE": "../models/hemagglutinin_models.pt",
"SHIFT": 20.0,
"DEFOCUS": [1.5, 3.5],
"SNR": [0.05, 0.05],
"AMP": 0.1,
"B_FACTOR": [1.0, 100.0]
}
25 changes: 13 additions & 12 deletions tests/config_files/training_params_npe_testing.json
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
{"EMBEDDING": "RESNET18",
"OUT_DIM": 256,
"NUM_TRANSFORM": 5,
"NUM_HIDDEN_FLOW": 10,
"HIDDEN_DIM_FLOW": 256,
"MODEL": "NSF",
"LEARNING_RATE": 0.0003,
"CLIP_GRADIENT": 5.0,
"THETA_SHIFT": 9.5,
"THETA_SCALE": 9.5,
"BATCH_SIZE": 512
}
{
"EMBEDDING": "RESNET18",
"OUT_DIM": 256,
"NUM_TRANSFORM": 5,
"NUM_HIDDEN_FLOW": 10,
"HIDDEN_DIM_FLOW": 256,
"MODEL": "NSF",
"LEARNING_RATE": 0.0003,
"CLIP_GRADIENT": 5.0,
"THETA_SHIFT": 25,
"THETA_SCALE": 25,
"BATCH_SIZE": 256
}
Binary file added tests/models/hsp90_models.pt
Binary file not shown.
140 changes: 38 additions & 102 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,119 +3,55 @@
import numpy as np
import json

from cryo_sbi.wpa_simulator.ctf import calc_ctf, apply_ctf
from cryo_sbi.wpa_simulator.image_generation import gen_img, gen_quat
from cryo_sbi.wpa_simulator.noise import add_noise
from cryo_sbi.wpa_simulator.cryo_em_simulator import cryo_em_simulator
from cryo_sbi.wpa_simulator.ctf import apply_ctf
from cryo_sbi.wpa_simulator.image_generation import project_density, gen_quat, gen_rot_matrix
from cryo_sbi.wpa_simulator.noise import add_noise, circular_mask, get_snr
from cryo_sbi.wpa_simulator.normalization import gaussian_normalize_image
from cryo_sbi.wpa_simulator.padding import pad_image
from cryo_sbi.wpa_simulator.shift import apply_no_shift, apply_random_shift
from cryo_sbi.wpa_simulator.validate_image_config import check_params
from cryo_sbi import CryoEmSimulator
from cryo_sbi.inference.priors import get_image_priors


@pytest.fixture
def image_params():
config = json.load(open("tests/config_files/image_params_testing.json"))
check_params(config)
return config
def test_apply_ctf():
# Create a test image
image = torch.randn(1, 64, 64)

# Set test parameters
defocus = torch.tensor([1.0])
b_factor = torch.tensor([100.0])
amp = torch.tensor([0.5])
pixel_size = torch.tensor(1.0)

def test_padding(image_params):
pad_width = int(np.ceil(image_params["N_PIXELS"] * 0.1)) + 1
image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
padded_image = pad_image(image, image_params)
# Apply CTF to the test image
image_ctf = apply_ctf(image, defocus, b_factor, amp, pixel_size)

for size in padded_image.shape:
assert size == pad_width * 2 + image_params["N_PIXELS"]
return
assert image_ctf.shape == image.shape
assert isinstance(image_ctf, torch.Tensor)
assert not torch.allclose(image_ctf, image)


def test_shift_size(image_params):
image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
padded_image = pad_image(image, image_params)
shifted_image = apply_random_shift(padded_image, image_params)
def test_gen_rot_matrix():
# Create a test quaternion
quat = torch.tensor([[1.0, 0.0, 0.0, 0.0]])

for size in shifted_image.shape:
assert size == image_params["N_PIXELS"]
return
# Generate a rotation matrix from the quaternion
rot_matrix = gen_rot_matrix(quat)

assert rot_matrix.shape == torch.Size([1, 3, 3])
assert isinstance(rot_matrix, torch.Tensor)
assert torch.allclose(rot_matrix, torch.eye(3).unsqueeze(0))

def test_shift_bias(image_params):
x_0 = image_params["N_PIXELS"] // 2
y_0 = image_params["N_PIXELS"] // 2

image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
image[x_0, y_0] = 1
image[x_0 - 1, y_0] = 1
image[x_0, y_0 - 1] = 1
image[x_0 - 1, y_0 - 1] = 1
def test_gen_rot_matrix_batched():
# Create a test quaternions with batche size 3
quat = torch.tensor([
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0]
])

padded_image = pad_image(image, image_params)
shifted_image = torch.zeros_like(image)
# Generate a rotation matrix from the quaternion
rot_matrix = gen_rot_matrix(quat)

for _ in range(10000):
shifted_image = shifted_image + apply_random_shift(padded_image, image_params)

indices_x, indices_y = np.where(shifted_image >= 1)

assert np.mean(indices_x) == image_params["N_PIXELS"] / 2 - 0.5
assert np.mean(indices_y) == image_params["N_PIXELS"] / 2 - 0.5

return


def test_no_shift(image_params):
image = torch.zeros((image_params["N_PIXELS"], image_params["N_PIXELS"]))
padded_image = pad_image(image, image_params)
shifted_image = apply_no_shift(padded_image, image_params)

for size in shifted_image.shape:
assert size == image_params["N_PIXELS"]

assert torch.allclose(image, shifted_image)
return


def test_normalization(image_params):
img_shape = (image_params["N_PIXELS"], image_params["N_PIXELS"])
image = torch.distributions.normal.Normal(23, 1.30432).sample(img_shape)
gnormed_image = gaussian_normalize_image(image)

assert torch.allclose(torch.mean(gnormed_image), torch.tensor(0.0), atol=1e-3)
assert torch.allclose(torch.std(gnormed_image), torch.tensor(1.0), atol=1e-3)
return


def test_noise(image_params):
N = 10000
stds = torch.zeros(N)

for i in range(N):
image = torch.ones((image_params["N_PIXELS"], image_params["N_PIXELS"]))
image_noise = add_noise(image, image_params)

stds[i] = torch.std(image_noise)

assert torch.allclose(torch.mean(stds), torch.tensor(1.0), atol=1e-3)
return


# def test_ctf():


def test_simulation(image_params):
simul = CryoEmSimulator("tests/config_files/image_params_testing.json")
image_sim = simul.simulator(index=torch.tensor(0.0), seed=0)

model = np.load(image_params["MODEL_FILE"])[0, 0]
image = gen_img(model, image_params)
image = pad_image(image, image_params)
ctf = calc_ctf(image_params)
image = apply_ctf(image, ctf)
image = add_noise(image, image_params, seed=0)
image = apply_random_shift(image, image_params, seed=0)
image = gaussian_normalize_image(image)
image = image.to(dtype=torch.float32)

assert torch.allclose(image, image_sim)
return
assert rot_matrix.shape == torch.Size([3, 3, 3])
assert isinstance(rot_matrix, torch.Tensor)
assert torch.allclose(rot_matrix, torch.eye(3).repeat(3, 1, 1))

0 comments on commit c0cffdf

Please sign in to comment.