From 0f5649cc076083584595000a03c4a8694cab8bfa Mon Sep 17 00:00:00 2001 From: Dusan Varis Date: Wed, 9 Jan 2019 16:29:51 +0100 Subject: [PATCH] workaround for train_set batching during inference time --- neuralmonkey/learning_utils.py | 5 ++++- tests/hier-multiattention.ini | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/neuralmonkey/learning_utils.py b/neuralmonkey/learning_utils.py index 50e0e0711..102eafb42 100644 --- a/neuralmonkey/learning_utils.py +++ b/neuralmonkey/learning_utils.py @@ -13,7 +13,7 @@ from termcolor import colored from neuralmonkey.logging import log, log_print, warn -from neuralmonkey.dataset import Dataset +from neuralmonkey.dataset import Dataset, BatchingScheme from neuralmonkey.tf_manager import TensorFlowManager from neuralmonkey.runners.base_runner import ( BaseRunner, ExecutionResult, GraphExecutor, OutputSeries) @@ -85,6 +85,9 @@ def training_loop(cfg: Namespace) -> None: trainer_result = cfg.tf_manager.execute( batch, feedables, cfg.trainers, train=True, summaries=True) + # workaround: we need to use validation batching scheme + # during evaluation + batch.batching = BatchingScheme(batch_size=cfg.batch_size) train_results, train_outputs, f_batch = run_on_dataset( cfg.tf_manager, cfg.runners, cfg.dataset_runner, batch, cfg.postprocess, write_out=False) diff --git a/tests/hier-multiattention.ini b/tests/hier-multiattention.ini index f4a4b5c68..205ba281e 100644 --- a/tests/hier-multiattention.ini +++ b/tests/hier-multiattention.ini @@ -10,6 +10,7 @@ trainer= runners=[, , , ] postprocess=None evaluation=[("target_hier_noshare_nosentinel", "target", evaluators.BLEU)] +batch_size=1 logging_period=1 validation_period=5 test_datasets=[]