Skip to content

Commit

Permalink
improved docs to show defaults. took some tweaking of giffe
Browse files Browse the repository at this point in the history
  • Loading branch information
quattro committed Apr 12, 2024
1 parent 01cdd48 commit 3c41585
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 13 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import jax.numpy as jnp
import jax.random as rdm
import lineax as lx

import traceax as tr
import traceax as tx

# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
Expand All @@ -68,15 +68,15 @@ k = 10
key, key1, key2, key3 = rdm.split(key, 4)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tr.HutchinsonEstimator()
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k)) # (Array(3.7297516, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tr.HutchPlusPlusEstimator()
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k)) # (Array(3.9572973, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tr.XTraceEstimator()
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k)) # (Array(3.1775048, dtype=float32), {'std.err': Array(0.24185811, dtype=float32)})
```

Expand Down
8 changes: 4 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import jax.numpy as jnp
import jax.random as rdm
import lineax as lx

import traceax as tr
import traceax as tx

# simulate simple symmetric matrix with exponential eigenvalue decay
seed = 0
Expand All @@ -68,15 +68,15 @@ k = 10
key, key1, key2, key3 = rdm.split(key, 4)

# Hutchinson estimator; default samples Rademacher {-1,+1}
hutch = tr.HutchinsonEstimator()
hutch = tx.HutchinsonEstimator()
print(hutch.estimate(key1, operator, k)) # (Array(3.7297516, dtype=float32), {})

# Hutch++ estimator; default samples Rademacher {-1,+1}
hpp = tr.HutchPlusPlusEstimator()
hpp = tx.HutchPlusPlusEstimator()
print(hpp.estimate(key2, operator, k)) # (Array(3.9572973, dtype=float32), {})

# XTrace estimator; default samples uniformly on n-Sphere
xt = tr.XTraceEstimator()
xt = tx.XTraceEstimator()
print(xt.estimate(key3, operator, k)) # (Array(3.1775048, dtype=float32), {'std.err': Array(0.24185811, dtype=float32)})
```

Expand Down
25 changes: 20 additions & 5 deletions docs/scripts/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from griffe import Class, Docstring, dynamic_import, Extension, Function, get_logger, Object, ObjectNode
from griffe.dataclasses import Parameter
from griffe.expressions import ExprCall


logger = get_logger(__name__)
Expand Down Expand Up @@ -41,11 +42,25 @@ def on_class_members(self, *, node: ast.AST | ObjectNode, cls: Class) -> None:
return # skip objects that were not selected

# pull class attributes as parameters for the __init__ function...
parameters = [
Parameter(name=attr.name, annotation=attr.annotation, kind=attr.kind)
for attr in cls.members.values()
if attr.is_attribute
]
parameters = []
for attr in cls.members.values():
if attr.is_attribute:
if attr.value is not None:
# import pdb; pdb.set_trace()
if type(attr.value) is ExprCall and len(attr.value.arguments) > 0:
for arg in attr.value.arguments:
if arg.name == "default":
param = Parameter(
name=attr.name, default=arg.value.name, annotation=attr.annotation, kind=attr.kind
)
else:
param = Parameter(
name=attr.name, default=attr.value, annotation=attr.annotation, kind=attr.kind
)
else:
param = Parameter(name=attr.name, annotation=attr.annotation, kind=attr.kind)
parameters.append(param)

# such a huge hack to pull in inherited attributes
cls.members["__init__"] = Function(
name="__init__", parameters=parameters, docstring=_get_dynamic_docstring(cls, "__init__")
Expand Down
1 change: 1 addition & 0 deletions src/traceax/_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int)
```python
key = jax.random.PRNGKey(...)
operator = lx.MatrixLinearOperator(...)
hutch = tx.HutchinsonEstimator()
result = hutch.compute(key, operator, k=10)
# or
result = hutch(key, operator, k=10)
Expand Down
12 changes: 12 additions & 0 deletions src/traceax/_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def __call__(self, key: PRNGKeyArray, n: int, k: int) -> Num[Array, "n k"]:
sampler = tr.RademacherSampler()
samples = sampler(key, n, k)
```
Each sampler accepts a `dtype` (i.e. `float`, `complex`, `int`) argument upon initialization,
with sensible default values. This makes it possible to sample from more general spaces (e.g.,
complex Normal test-vectors).
!!! Example
```python
sampler = tr.NormalSampler(complex)
samples = sampler(key, n, k)
```
**Arguments:**
- `key`: a jax PRNG key used as the random key.
Expand Down

0 comments on commit 3c41585

Please sign in to comment.