Skip to content

Commit

Permalink
eval tmp fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Mar 5, 2024
2 parents 8f4d969 + b8bfb32 commit eb737b6
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=Fal
0, test_steps_per_epoch, desc="Batches", ncols=100, disable=False, leave=True
)
for batch_idx in range(test_steps_per_epoch):
inputs, labels = next(batch_test_ds)
batch = next(batch_test_ds)

batch_loss, test_metrics = test_step_fn(params, inputs, labels, test_metrics)
batch_loss, test_metrics = test_step_fn(params, batch, test_metrics)

epoch_loss["test_loss"] += batch_loss
batch_pbar.set_postfix(test_loss=epoch_loss["test_loss"] / batch_idx)
Expand Down Expand Up @@ -123,7 +123,8 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"):

raw_ds = load_test_data(config, model_version_path, eval_path, n_test)

test_ds = initialize_dataset(config, raw_ds, read_labels=False, calc_stats=False)
test_ds = initialize_dataset(config, raw_ds, read_labels=True, calc_stats=False)
test_ds.set_batch_size(1) # TODO temporary

_, init_box = test_ds.init_input()

Expand Down

0 comments on commit eb737b6

Please sign in to comment.