From 0f8be74e597abc181f3e5ee73688c2ae0d70942e Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Thu, 7 Mar 2024 14:27:03 +0100 Subject: [PATCH] cleaned up repo added plot function for models --- pyproject.toml | 3 -- src/cryo_sbi/utils/estimator_utils.py | 15 +++++- src/cryo_sbi/utils/visualize_models.py | 73 ++++++++++++++++++++++++++ tests/test_visualize_models.py | 25 +++++++++ 4 files changed, 112 insertions(+), 4 deletions(-) create mode 100644 src/cryo_sbi/utils/visualize_models.py create mode 100644 tests/test_visualize_models.py diff --git a/pyproject.toml b/pyproject.toml index f068847..bfae890 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,3 @@ dependencies = [ [project.scripts] train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving" -train_npe_model_vram = "cryo_sbi.inference.command_line_tools:cl_npe_train_from_vram" -train_npe_model_disk = "cryo_sbi.inference.command_line_tools:cl_npe_train_from_disk" -generate_training_data = "cryo_sbi.inference.command_line_tools:cl_generate_training_data" diff --git a/src/cryo_sbi/utils/estimator_utils.py b/src/cryo_sbi/utils/estimator_utils.py index 6828fb8..acefe40 100644 --- a/src/cryo_sbi/utils/estimator_utils.py +++ b/src/cryo_sbi/utils/estimator_utils.py @@ -11,7 +11,20 @@ def evaluate_log_prob( batch_size: int = 0, device: str = "cpu", ) -> torch.Tensor: + """ + Evaluates the log probability of a given set of images under a given estimator. + + Args: + estimator (torch.nn.Module): The posterior model to use for evaluation. + images (torch.Tensor): The input images used to condition the posterior. + theta (torch.Tensor): The parameter values at which to evaluate the log probability. + batch_size (int, optional): The batch size for batching the images. Defaults to 0. + device (str, optional): The device to use for computation. Defaults to "cpu". + Returns: + torch.Tensor: The log probabilities of the images under the estimator. + """ + # batching images if necessary if images.shape[0] > batch_size and batch_size > 0: images = torch.split(images, split_size_or_sections=batch_size, dim=0) @@ -53,7 +66,7 @@ def sample_posterior( Args: estimator (torch.nn.Module): The posterior to use for sampling. - images (torch.Tensor): The images used to condition the posterio. + images (torch.Tensor): The images used to condition the posterior. num_samples (int): The number of samples to draw batch_size (int, optional): The batch size for sampling. Defaults to 100. device (str, optional): The device to use. Defaults to "cpu". diff --git a/src/cryo_sbi/utils/visualize_models.py b/src/cryo_sbi/utils/visualize_models.py new file mode 100644 index 0000000..c4144c6 --- /dev/null +++ b/src/cryo_sbi/utils/visualize_models.py @@ -0,0 +1,73 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + + +def _scatter_plot_models(model: torch.Tensor, view_angles : tuple = (30, 45), **plot_kwargs: dict) -> None: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.view_init(*view_angles) + + ax.scatter(*model, **plot_kwargs) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + + +def _sphere_plot_models(model: torch.Tensor, radius: float = 4, view_angles : tuple = (30, 45), **plot_kwargs: dict,) -> None: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.view_init(30, 45) + + spheres = [] + for x, y, z in zip(model[0], model[1], model[2]): + spheres.append((x.item(), y.item(), z.item(), radius)) + + for idx, sphere in enumerate(spheres): + x, y, z, r = sphere + + u = np.linspace(0, 2 * np.pi, 100) + v = np.linspace(0, np.pi, 100) + x = r * np.outer(np.cos(u), np.sin(v)) + x + y = r * np.outer(np.sin(u), np.sin(v)) + y + z = r * np.outer(np.ones(np.size(u)), np.cos(v)) + z + + ax.plot_surface(x, y, z, **plot_kwargs) + + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_zlabel('Z') + + +def plot_model(model: torch.Tensor, method: str = "scatter", **kwargs) -> None: + """ + Plot a model from the tensor. + + Args: + model (torch.Tensor): Model to plot, should be a 2D tensor with shape (3, num_atoms) + method (str, optional): Method to use for plotting. Defaults to "scatter". Can be "scatter" or "sphere". + "scatter" is fast and simple, "sphere" is a proper 3D representation (Take long to render). + **kwargs: Additional keyword arguments to pass to the plotting function. + + Returns: + None + + Raises: + AssertionError: If the model is not a 2D tensor with shape (3, num_atoms). + ValueError: If the method is not "scatter" or "sphere". + + """ + + assert model.ndim == 2, "Model should be 2D tensor" + assert model.shape[0] == 3, "Model should have 3 rows" + + if method == "scatter": + _scatter_plot_models(model, **kwargs) + + elif method == "sphere": + _sphere_plot_models(model, **kwargs) + + else: + raise ValueError(f"Unknown method {method}. Use 'scatter' or 'sphere'.") + diff --git a/tests/test_visualize_models.py b/tests/test_visualize_models.py new file mode 100644 index 0000000..108d433 --- /dev/null +++ b/tests/test_visualize_models.py @@ -0,0 +1,25 @@ +import torch +import pytest +from cryo_sbi.utils.visualize_models import plot_model + + +def test_plot_model_scatter(): + model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + plot_model(model, method="scatter") # No assertion, just checking if it runs without errors + + +def test_plot_model_sphere(): + model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + plot_model(model, method="sphere") # No assertion, just checking if it runs without errors + + +def test_plot_model_invalid_model(): + model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) # Invalid shape, should have 3 rows + with pytest.raises(AssertionError): + plot_model(model, method="scatter") + + +def test_plot_model_invalid_method(): + model = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]) + with pytest.raises(ValueError): + plot_model(model, method="invalid_method") \ No newline at end of file