From d819cc82e6383491c01fbe23e55e1b08ce59cd78 Mon Sep 17 00:00:00 2001 From: Dingel321 Date: Fri, 21 Jun 2024 16:21:52 -0400 Subject: [PATCH] updated loss --- src/cryo_sbi/inference/losses.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/cryo_sbi/inference/losses.py b/src/cryo_sbi/inference/losses.py index b09f1ce..f31ea09 100644 --- a/src/cryo_sbi/inference/losses.py +++ b/src/cryo_sbi/inference/losses.py @@ -37,20 +37,19 @@ class NPERobustStatsLoss(nn.Module): def __init__(self, estimator: nn.Module, gamma: float): super().__init__() - + print("distance loss") self.estimator = estimator - self.gamma = gamma + self.distance = nn.PairwiseDistance(p=10) def forward(self, theta: torch.Tensor, x: torch.Tensor, x_obs: torch.Tensor) -> torch.Tensor: latent_vecs_x = self.estimator.embedding(x) + self.estimator.eval() latent_vecs_x_obs = self.estimator.embedding(x_obs) + summary_stats_regularization = self.distance(latent_vecs_x, latent_vecs_x_obs).mean() + self.estimator.train() - summary_stats_regularization = self.gamma * mmd_unweighted( - latent_vecs_x, - latent_vecs_x_obs, - median_heuristic(x) - ) log_p = self.estimator.npe(self.estimator.standardize(theta), latent_vecs_x) + print(-log_p.mean().item(), summary_stats_regularization.mean().item()) - return -log_p.mean() + summary_stats_regularization \ No newline at end of file + return -log_p.mean() + 0.0 * summary_stats_regularization \ No newline at end of file