From fa47ca07bc87c41e15bd847be8938fde6a4aaa2b Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Wed, 21 Mar 2018 10:29:27 -0700 Subject: [PATCH 1/2] removed source_reverse from wmt16_gnmt_8_layer.json, since this hparam has been removed from NMT code, and it caused training and inference failures when use this .json file --- nmt/standard_hparams/wmt16_gnmt_8_layer.json | 1 - 1 file changed, 1 deletion(-) diff --git a/nmt/standard_hparams/wmt16_gnmt_8_layer.json b/nmt/standard_hparams/wmt16_gnmt_8_layer.json index 438ddcf5..da2034ca 100644 --- a/nmt/standard_hparams/wmt16_gnmt_8_layer.json +++ b/nmt/standard_hparams/wmt16_gnmt_8_layer.json @@ -22,7 +22,6 @@ "share_vocab": false, "subword_option": "bpe", "sos": "", - "source_reverse": false, "src_max_len": 50, "src_max_len_infer": null, "steps_per_external_eval": null, From 8c3a240f3ff3ef707637d026dc52d63a2f4ae744 Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Tue, 3 Apr 2018 00:36:36 -0700 Subject: [PATCH 2/2] add command line option to control the num_inter_threads and num_intra_threads for inference session --- nmt/inference.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nmt/inference.py b/nmt/inference.py index 6f589337..cf7924b5 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -131,7 +131,10 @@ def single_worker_inference(infer_model, infer_data = load_data(inference_input_file, hparams) with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads + )) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") sess.run( @@ -190,7 +193,10 @@ def multi_worker_inference(infer_model, infer_data = infer_data[start_position:end_position] with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads + )) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer,