Skip to content

Commit

Permalink
Migrate docs to the new two-step benchmark idiom
Browse files Browse the repository at this point in the history
This replaces the previous zero-state BenchmarkRunner with the new "collect-then-run" idiom
in the examples and documentation.

In most of the docs, this is only a syntax change, but some formulations had to be tweaked,
since there is no more single runner resource.
  • Loading branch information
nicholasjng committed Dec 21, 2024
1 parent 596dc81 commit c58978c
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 46 deletions.
8 changes: 4 additions & 4 deletions docs/guides/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ To supply context to your benchmarks, you can give a sequence of context provide
import nnbench

# uses the `platinfo` context provider from above to log platform metadata.
runner = nnbench.BenchmarkRunner()
result = runner.run(__name__, params={}, context=[platinfo])
benchmarks = nnbench.collect(__name__)
result = nnbench.run(benchmarks, params={}, context=[platinfo])
```

## Being type safe by using `nnbench.Parameters`
Expand All @@ -104,8 +104,8 @@ def prod(a: int, b: int) -> int:


params = MyParams(a=1, b=2)
runner = nnbench.BenchmarkRunner()
result = runner.run(__name__, params=params)
benchmarks = nnbench.collect(__name__)
result = nnbench.run(benchmarks, params=params)
```

While this does not have a concrete advantage in terms of type safety over a raw dictionary, it guards against accidental modification of parameters breaking reproducibility.
4 changes: 2 additions & 2 deletions docs/guides/organization.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ Now, to only run data quality benchmarks marked "foo", pass the corresponding ta
```python
import nnbench

runner = nnbench.BenchmarkRunner()
foo_data_metrics = runner.run("benchmarks/data_quality.py", params=..., tags=("foo",))
benchmarks = nnbench.collect("benchmarks/data_quality.py", tags=("foo",))
foo_data_metrics = nnbench.run(benchmarks, params=..., )
```

!!!tip
Expand Down
32 changes: 13 additions & 19 deletions docs/guides/runners.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,8 @@
# Collecting and running benchmarks

nnbench provides the `BenchmarkRunner` as a compact interface to collect and run benchmarks selectively.
nnbench provides the `nnbench.collect` and `nnbench.run` APIs as a compact interface to collect and run benchmarks selectively.

## The abstract `BenchmarkRunner` class
Let's first instantiate and then walk through the base class.

```python
from nnbench import BenchmarkRunner

runner = BenchmarkRunner()
```

Use the `BenchmarkRunner.collect()` method to collect benchmarks from files or directories.
Use the `nnbench.collect()` method to collect benchmarks from files or directories.
Assume we have the following benchmark setup:
```python
# dir_a/bm1.py
Expand Down Expand Up @@ -46,26 +37,29 @@ def the_last_benchmark(d: int) -> int:
Now we can collect benchmarks from files:

```python
runner.collect('dir_a/bm1.py')
import nnbench


benchmarks = nnbench.collect('dir_a/bm1.py')
```
Or directories:

```python
runner.collect('dir_b')
benchmarks = nnbench.collect('dir_b')
```

This collection can happen iteratively. So, after executing the two collections our runner has all four benchmarks ready for execution.

To remove the collected benchmarks again, use the `BenchmarkRunner.clear()` method.
You can also supply tags to the runner to selectively collect only benchmarks with the appropriate tag.
For example, after clearing the runner again, you can collect all benchmarks with the `"tag"` tag as such:

```python
runner.collect('dir_b', tags=("tag",))
import nnbench


tagged_benchmarks = nnbench.collect('dir_b', tags=("tag",))
```

To run the benchmarks, call the `BenchmarkRunner.run()` method and supply the necessary parameters required by the collected benchmarks.
To run the benchmarks, call the `nnbench.run()` method and supply the necessary parameters required by the collected benchmarks.

```python
runner.run("dir_b", params={"b": 1, "c": 2, "d": 3})
result = nnbench.run(benchmarks, params={"b": 1, "c": 2, "d": 3})
```
4 changes: 2 additions & 2 deletions docs/tutorials/duckdb.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import nnbench
from nnbench.context import GitEnvironmentInfo
from nnbench.reporter.file import FileReporter

runner = nnbench.BenchmarkRunner()
record = runner.run("benchmarks.py", params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))
benchmarks = nnbench.collect("benchmarks.py")
record = nnbench.run(benchmarks, params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))

file_reporter = FileReporter()
file_reporter.write(record, "record.json", driver="ndjson")
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/huggingface.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ In the following `IndexLabelMapMemo` class, we store a dictionary mapping the la

!!! Info
There is no need to type-hint `TokenClassificationModelMemo`s in the corresponding benchmarks -
the benchmark runner takes care of filling in the memoized values for the memos themselves.
the benchmark running method takes care of filling in the memoized values for the memos themselves.

Because we implemented our memoized values as four different memo class types, this modularizes the benchmark input parameters -
we only need to reference memos when they are actually used. Considering the recall benchmarks:
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ To properly structure our project, we avoid mixing training pipeline code and be

This definition is short and sweet, and contains a few important details:

* Both functions are given the `@nnbench.benchmark` decorator - this enables our runner to find and collect them before starting the benchmark run.
* Both functions are given the `@nnbench.benchmark` decorator - this allows us to find and collect them before starting the benchmark run.
* The `modelsize` benchmark is given a custom name (`"Model size (MB)"`), indicating that the resulting number is the combined size of the model weights in megabytes.
This is done for display purposes, to improve interpretability when reporting results.
* The `params` argument is the same in both benchmarks, both in name and type. This is important, since it ensures that both benchmarks will be run with the same model weights.

That's all - now we can shift over to our main pipeline code and see what is necessary to execute the benchmarks and visualize the results.

## Setting up a benchmark runner and parameters
## Setting up a benchmark run and parameters

After finishing the benchmark setup, we only need a few more lines to augment our pipeline.

Expand Down
4 changes: 2 additions & 2 deletions examples/bq/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def main():
autodetect=True, source_format=bigquery.SourceFormat.NEWLINE_DELIMITED_JSON
)

runner = nnbench.BenchmarkRunner()
res = runner.run("benchmarks.py", params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))
benchmarks = nnbench.collect("benchmarks.py")
res = nnbench.run(benchmarks, params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))

load_job = client.load_table_from_json(res.to_json(), table_id, job_config=job_config)
load_job.result()
Expand Down
4 changes: 2 additions & 2 deletions examples/huggingface/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@


def main() -> None:
runner = nnbench.BenchmarkRunner()
benchmarks = nnbench.collect("benchmark.py", tags=("per-class",))
reporter = nnbench.ConsoleReporter()
result = runner.run("benchmark.py", tags=("per-class",))
result = nnbench.run(benchmarks)
reporter.display(result)


Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,10 @@ def mnist_jax():
state, data = train(mnist)

# the nnbench portion.
runner = nnbench.BenchmarkRunner()
benchmarks = nnbench.collect(HERE)
reporter = nnbench.FileReporter()
params = MNISTTestParameters(params=state.params, data=data)
result = runner.run(HERE, params=params)
result = nnbench.run(benchmarks, params=params)
reporter.write(result, "result.json")


Expand Down
18 changes: 8 additions & 10 deletions examples/prefect/src/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,9 @@ async def write(
def run_metric_benchmarks(
model: base.BaseEstimator, X_test: np.ndarray, y_test: np.ndarray
) -> nnbench.types.BenchmarkRecord:
runner = nnbench.BenchmarkRunner()
results = runner.run(
os.path.join(dir_path, "benchmark.py"),
tags=("metric",),
benchmarks = nnbench.collect(os.path.join(dir_path, "benchmark.py"), tags=("metric",))
results = nnbench.run(
benchmarks,
params={"model": model, "X_test": X_test, "y_test": y_test},
)
return results
Expand All @@ -44,10 +43,9 @@ def run_metric_benchmarks(
def run_metadata_benchmarks(
model: base.BaseEstimator, X: np.ndarray
) -> nnbench.types.BenchmarkRecord:
runner = nnbench.BenchmarkRunner()
result = runner.run(
os.path.join(dir_path, "benchmark.py"),
tags=("model-meta",),
benchmarks = nnbench.collect(os.path.join(dir_path, "benchmark.py"), tags=("model-meta",))
result = nnbench.run(
benchmarks,
params={"model": model, "X": X},
)
return result
Expand All @@ -73,7 +71,7 @@ async def train_and_benchmark(
metadata_results: types.BenchmarkRecord = run_metadata_benchmarks(model=model, X=X_test)

metadata_results.context.update(data_params)
metadata_results.context.update(context.PythonInfo())
metadata_results.context.update(context.PythonInfo()())

await reporter.write(
record=metadata_results, key="model-attributes", description="Model Attributes"
Expand All @@ -84,7 +82,7 @@ async def train_and_benchmark(
)

metric_results.context.update(data_params)
metric_results.context.update(context.PythonInfo())
metric_results.context.update(context.PythonInfo()())
await reporter.write(metric_results, key="model-performance", description="Model Performance")
return metadata_results, metric_results

Expand Down

0 comments on commit c58978c

Please sign in to comment.