Skip to content

Commit

Permalink
cleaned up repo added plot function for models
Browse files Browse the repository at this point in the history
  • Loading branch information
Dingel321 committed Mar 7, 2024
1 parent b96a553 commit 0f8be74
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 4 deletions.
3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,3 @@ dependencies = [

[project.scripts]
train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving"
train_npe_model_vram = "cryo_sbi.inference.command_line_tools:cl_npe_train_from_vram"
train_npe_model_disk = "cryo_sbi.inference.command_line_tools:cl_npe_train_from_disk"
generate_training_data = "cryo_sbi.inference.command_line_tools:cl_generate_training_data"
15 changes: 14 additions & 1 deletion src/cryo_sbi/utils/estimator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,20 @@ def evaluate_log_prob(
batch_size: int = 0,
device: str = "cpu",
) -> torch.Tensor:
"""
Evaluates the log probability of a given set of images under a given estimator.
Args:
estimator (torch.nn.Module): The posterior model to use for evaluation.
images (torch.Tensor): The input images used to condition the posterior.
theta (torch.Tensor): The parameter values at which to evaluate the log probability.
batch_size (int, optional): The batch size for batching the images. Defaults to 0.
device (str, optional): The device to use for computation. Defaults to "cpu".
Returns:
torch.Tensor: The log probabilities of the images under the estimator.
"""

# batching images if necessary
if images.shape[0] > batch_size and batch_size > 0:
images = torch.split(images, split_size_or_sections=batch_size, dim=0)
Expand Down Expand Up @@ -53,7 +66,7 @@ def sample_posterior(
Args:
estimator (torch.nn.Module): The posterior to use for sampling.
images (torch.Tensor): The images used to condition the posterio.
images (torch.Tensor): The images used to condition the posterior.
num_samples (int): The number of samples to draw
batch_size (int, optional): The batch size for sampling. Defaults to 100.
device (str, optional): The device to use. Defaults to "cpu".
Expand Down
73 changes: 73 additions & 0 deletions src/cryo_sbi/utils/visualize_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import matplotlib.pyplot as plt
import numpy as np
import torch


def _scatter_plot_models(model: torch.Tensor, view_angles : tuple = (30, 45), **plot_kwargs: dict) -> None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.view_init(*view_angles)

ax.scatter(*model, **plot_kwargs)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')


def _sphere_plot_models(model: torch.Tensor, radius: float = 4, view_angles : tuple = (30, 45), **plot_kwargs: dict,) -> None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.view_init(30, 45)

spheres = []
for x, y, z in zip(model[0], model[1], model[2]):
spheres.append((x.item(), y.item(), z.item(), radius))

for idx, sphere in enumerate(spheres):
x, y, z, r = sphere

u = np.linspace(0, 2 * np.pi, 100)
v = np.linspace(0, np.pi, 100)
x = r * np.outer(np.cos(u), np.sin(v)) + x
y = r * np.outer(np.sin(u), np.sin(v)) + y
z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + z

ax.plot_surface(x, y, z, **plot_kwargs)

ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')


def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None:
"""
Plot a model from the tensor.
Args:
model (torch.Tensor): Model to plot, should be a 2D tensor with shape (3, num_atoms)
method (str, optional): Method to use for plotting. Defaults to "scatter". Can be "scatter" or "sphere".
"scatter" is fast and simple, "sphere" is a proper 3D representation (Take long to render).
**kwargs: Additional keyword arguments to pass to the plotting function.
Returns:
None
Raises:
AssertionError: If the model is not a 2D tensor with shape (3, num_atoms).
ValueError: If the method is not "scatter" or "sphere".
"""

assert model.ndim == 2, "Model should be 2D tensor"
assert model.shape[0] == 3, "Model should have 3 rows"

if method == "scatter":
_scatter_plot_models(model, **kwargs)

elif method == "sphere":
_sphere_plot_models(model, **kwargs)

else:
raise ValueError(f"Unknown method {method}. Use 'scatter' or 'sphere'.")

25 changes: 25 additions & 0 deletions tests/test_visualize_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import torch
import pytest
from cryo_sbi.utils.visualize_models import plot_model


def test_plot_model_scatter():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
plot_model(model, method="scatter") # No assertion, just checking if it runs without errors


def test_plot_model_sphere():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
plot_model(model, method="sphere") # No assertion, just checking if it runs without errors


def test_plot_model_invalid_model():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # Invalid shape, should have 3 rows
with pytest.raises(AssertionError):
plot_model(model, method="scatter")


def test_plot_model_invalid_method():
model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
with pytest.raises(ValueError):
plot_model(model, method="invalid_method")

0 comments on commit 0f8be74

Please sign in to comment.