diff --git a/neuralmonkey/decoders/autoregressive.py b/neuralmonkey/decoders/autoregressive.py index 5b39e4f43..3842e3eab 100644 --- a/neuralmonkey/decoders/autoregressive.py +++ b/neuralmonkey/decoders/autoregressive.py @@ -176,7 +176,7 @@ def embedding_size(self) -> int: "size of the reused embeddings from the " "`embeddings_source`.") - return self.embeddings_source.dimension + return self.embeddings_source.embedding_matrix.get_shape()[1].value @tensor def go_symbols(self) -> tf.Tensor: diff --git a/neuralmonkey/decoders/sequence_labeler.py b/neuralmonkey/decoders/sequence_labeler.py index 9d9931a1d..90dc16261 100644 --- a/neuralmonkey/decoders/sequence_labeler.py +++ b/neuralmonkey/decoders/sequence_labeler.py @@ -23,7 +23,7 @@ def __init__(self, name: str, encoder: TemporalStateful, data_id: str, - vocabulary: Vocabulary, + vocabulary: Vocabulary = None, embeddings_source: EmbeddedSequence = None, dropout_keep_prob: float = 1.0, reuse: ModelPart = None, @@ -40,12 +40,11 @@ def __init__(self, self.data_id = data_id self.dropout_keep_prob = dropout_keep_prob - # We provide only embedding_source when we want to input and output + # We provide only embedding_source when we want to tie input and output # projections - if self.embeddings_source is not None and self.vocabulary is not None: - warn("Both `vocabulary` and `embedding_source` was provided. " - "using `embedding_source.vocabulary` instead of provided " - "`vocabulary`") + if (self.embeddings_source is None) == (self.vocabulary is None): + raise ValueError("You must specify either `vocabulary or` or " + "`embeddings_source`, not both") self.vocabulary = self.embeddings_source.vocabulary # pylint: enable=too-many-arguments diff --git a/neuralmonkey/readers/string_vector_reader.py b/neuralmonkey/readers/string_vector_reader.py index d6545b2a3..439a23838 100644 --- a/neuralmonkey/readers/string_vector_reader.py +++ b/neuralmonkey/readers/string_vector_reader.py @@ -13,7 +13,7 @@ def process_line(line: str, lineno: int, path: str) -> np.ndarray: return np.array(numbers, dtype=dtype) - def reader(files: List[str])-> Iterable[List[np.ndarray]]: + def reader(files: List[str]) -> Iterable[List[np.ndarray]]: for path in files: current_line = 0 diff --git a/tests/bert.ini b/tests/bert.ini index 2b3e0b7cc..7c3a20c46 100644 --- a/tests/bert.ini +++ b/tests/bert.ini @@ -66,7 +66,6 @@ dropout_keep_prob=0.9 class=decoders.sequence_labeler.SequenceLabeler name="labeler_bert" encoder= -vocabulary= data_id="source_masked" dropout_keep_prob=0.5 embeddings_source=