Skip to content

Commit

Permalink
renamed to for
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Mar 27, 2024
1 parent a6e9f17 commit 14ea227
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/traceax/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ class HutchinsonEstimator(AbstractTraceEstimator):
$\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})$,
where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$.
!!! info
"""

sampler: AbstractSampler = RademacherSampler()
Expand Down Expand Up @@ -94,8 +91,6 @@ class HutchPlusPlusEstimator(AbstractTraceEstimator):
As with the Girard-Hutchinson estimator, it requires
$\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$.
!!! info
"""

sampler: AbstractSampler = RademacherSampler()
Expand Down Expand Up @@ -133,10 +128,14 @@ def compute(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -


class XTraceEstimator(AbstractTraceEstimator):
r""" """
r"""XTrace Trace Estimator:
TBD.
"""

sampler: AbstractSampler = SphereSampler()
rescale: bool = True
improved: bool = True

def compute(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]:
n = _get_shape(operator)
Expand Down Expand Up @@ -168,7 +167,7 @@ def compute(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -
term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0)

re_vals = n - jnp.linalg.norm(W, axis=0) ** 2 + jnp.abs(SW_d * jnp.linalg.norm(S, axis=0)) ** 2
scale = jnp.where(self.rescale, (n - m + 1) / re_vals, 1.0)
scale = jnp.where(self.improved, (n - m + 1) / re_vals, 1.0)

estimates = jnp.trace(H) * jnp.ones(m) - SHS_d + (WHW_d - TW_d + term1 + term2 + term3) * scale
trace_est = jnp.mean(estimates)
Expand All @@ -180,5 +179,6 @@ def compute(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -
XTraceEstimator.__init__.__doc__ = r"""**Arguments:**
- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][].
- `rescale`: Whether to rescale samples for _improved_ XTrace estimator (see Notes).
- `improved`: Whether to use the _improved_ XTrace estimator, which rescales predicted samples.
Default is `True` (see Notes).
"""

0 comments on commit 14ea227

Please sign in to comment.