Skip to content

Commit

Permalink
updated docs for random sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Mar 28, 2024
1 parent 14ea227 commit 618d8a5
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
14 changes: 8 additions & 6 deletions docs/api/samplers.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Stochastic Samplers

TBD
`traceax` uses a flexible approach to define how random samples are generated within
[`traceax.AbstractTraceEstimator`][] instances. While this typically wraps a single
jax random call, the varied interfaces for each randomization procedure may differ,
which makes uniformly interfacing with it a bit annoying. As such, we provide a
simple abstract class definition, [`traceax.AbstractSampler`][] using that subclasses
[`Equinox`](https://docs.kidger.site/equinox/) modules.

??? abstract "`traceax.AbstractSampler`"
::: traceax.AbstractSampler
Expand All @@ -9,7 +14,7 @@ TBD
members:
- __call__

# Floating-point Samplers
## Floating-point Samplers

::: traceax.NormalSampler

Expand All @@ -21,13 +26,10 @@ TBD

::: traceax.RademacherSampler

---

# Complex-value Samplers
## Complex-value Samplers
::: traceax.ComplexNormalSampler

---

::: traceax.ComplexSphereSampler

---
53 changes: 53 additions & 0 deletions src/traceax/_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,87 @@


class AbstractSampler(eqx.Module, strict=True):
"""Abstract base class for all samplers."""

@abstractmethod
def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
r"""Sample random variates from the underlying distribution as an $n \times k$
matrix.
!!! Example
```python
sampler = tr.RademacherSampler()
samples = sampler(key, n, k)
```
**Arguments:**
- `key`: a jax PRNG key used as the random key.
- `n`: the size of the leading dimension.
- `k`: the size of the trailing dimension.
**Returns**:
An Array of random samples.
"""
...


class NormalSampler(AbstractSampler, strict=True):
r"""Standard normal distribution sampler.
Generates samples $X_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
return rdm.normal(key, (n, k))


class SphereSampler(AbstractSampler, strict=True):
r"""Sphere distribution sampler.
Generates samples $X_1, \dotsc, X_n$ uniformly distributed on the surface of a
$k$ dimensional sphere (i.e. $k-1$-sphere) with radius $\sqrt{n}$. Internally,
this operates by sampling standard normal variates, and then rescaling such that
each $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
samples = rdm.normal(key, (n, k))
return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0))


class ComplexNormalSampler(AbstractSampler, strict=True):
r"""Standard complex normal distribution sampler.
Generates complex-valued samples $X_{ij} = A_{ij} + i B_{ij}$ where
$A_{ij} \sim N(0, 1)$ and $B_{ij} \sim N(0, 1)$ for $i \in [n]$ and $j \in [k]$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k))
return samples / jnp.sqrt(2)


class ComplexSphereSampler(AbstractSampler, strict=True):
r"""Complex sphere distribution sampler.
Generates complex-valued samples $X_1, \dotsc, X_n$ uniformly distributed on the
surface of a $k$ dimensional complex-valued sphere with radius $\sqrt{n}$. Internally,
this operates by sampling standard complex normal variates, and then rescaling such
that each complex-valued $k$-vector $X_i$ has $\lVert X_i \rVert = \sqrt{n}$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
samples = rdm.normal(key, (n, k)) + 1j * rdm.normal(key, (n, k))
return jnp.sqrt(n) * (samples / jnp.linalg.norm(samples, axis=0))


class RademacherSampler(AbstractSampler, strict=True):
r"""Rademacher distribution sampler.
Generates samples $X_{ij} \sim \mathcal{U}(-1, +1)$ for $i \in [n]$ and $j \in [k]$.
"""

def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Inexact[Array, "n k"]:
return rdm.rademacher(key, (n, k))

0 comments on commit 618d8a5

Please sign in to comment.