-
Notifications
You must be signed in to change notification settings - Fork 103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize sequence labeler and allow re-use embeddings for labeling #798
Changes from 2 commits
2a4aea2
76c6f11
86bca58
e51a98e
1bdbfaf
1a9d54a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,29 +1,34 @@ | ||
from typing import Dict, Union | ||
from typing import List, Dict, Callable | ||
|
||
import tensorflow as tf | ||
from typeguard import check_argument_types | ||
|
||
from neuralmonkey.dataset import Dataset | ||
from neuralmonkey.decorators import tensor | ||
from neuralmonkey.encoders.recurrent import RecurrentEncoder | ||
from neuralmonkey.encoders.facebook_conv import SentenceEncoder | ||
from neuralmonkey.model.stateful import TemporalStateful | ||
from neuralmonkey.model.feedable import FeedDict | ||
from neuralmonkey.model.parameterized import InitializerSpecs | ||
from neuralmonkey.model.model_part import ModelPart | ||
from neuralmonkey.tf_utils import get_variable | ||
from neuralmonkey.model.sequence import EmbeddedSequence | ||
from neuralmonkey.nn.utils import dropout | ||
from neuralmonkey.vocabulary import Vocabulary, pad_batch, sentence_mask | ||
|
||
|
||
class SequenceLabeler(ModelPart): | ||
"""Classifier assing a label to each encoder's state.""" | ||
|
||
# pylint: disable=too-many-arguments | ||
# pylint: disable=too-many-arguments,too-many-locals | ||
def __init__(self, | ||
name: str, | ||
encoder: Union[RecurrentEncoder, SentenceEncoder], | ||
encoders: List[TemporalStateful], | ||
vocabulary: Vocabulary, | ||
data_id: str, | ||
max_output_len: int = None, | ||
hidden_dim: int = None, | ||
activation: Callable = tf.nn.relu, | ||
dropout_keep_prob: float = 1.0, | ||
add_start_symbol: bool = False, | ||
add_end_symbol: bool = False, | ||
reuse: ModelPart = None, | ||
save_checkpoint: str = None, | ||
load_checkpoint: str = None, | ||
|
@@ -32,11 +37,16 @@ def __init__(self, | |
ModelPart.__init__(self, name, reuse, save_checkpoint, load_checkpoint, | ||
initializers) | ||
|
||
self.encoder = encoder | ||
self.encoders = encoders | ||
self.vocabulary = vocabulary | ||
self.data_id = data_id | ||
self.max_output_len = max_output_len | ||
self.hidden_dim = hidden_dim | ||
self.activation = activation | ||
self.dropout_keep_prob = dropout_keep_prob | ||
# pylint: enable=too-many-arguments | ||
self.add_start_symbol = add_start_symbol | ||
self.add_end_symbol = add_end_symbol | ||
# pylint: enable=too-many-arguments,too-many-locals | ||
|
||
@property | ||
def input_types(self) -> Dict[str, tf.DType]: | ||
|
@@ -46,70 +56,58 @@ def input_types(self) -> Dict[str, tf.DType]: | |
def input_shapes(self) -> Dict[str, tf.TensorShape]: | ||
return {self.data_id: tf.TensorShape([None, None])} | ||
|
||
@tensor | ||
def input_mask(self) -> tf.Tensor: | ||
mask_main = self.encoders[0].temporal_mask | ||
|
||
asserts = [ | ||
tf.assert_equal( | ||
mask_main, enc.temporal_mask, | ||
message=("Encoders '{}' and '{}' does not have equal temporal " | ||
"masks.".format(self.encoders[0].name, enc.name))) | ||
for enc in self.encoders[1:]] | ||
|
||
with tf.control_dependencies(asserts): | ||
return mask_main | ||
|
||
@tensor | ||
def target_tokens(self) -> tf.Tensor: | ||
return self.dataset[self.data_id] | ||
|
||
@tensor | ||
def train_targets(self) -> tf.Tensor: | ||
return self.vocabulary.strings_to_indices(self.target_tokens) | ||
return self.vocabulary.strings_to_indices( | ||
self.dataset[self.data_id]) | ||
|
||
@tensor | ||
def train_mask(self) -> tf.Tensor: | ||
return sentence_mask(self.train_targets) | ||
|
||
@property | ||
def rnn_size(self) -> int: | ||
return int(self.encoder.temporal_states.get_shape()[-1]) | ||
|
||
@tensor | ||
def decoding_w(self) -> tf.Variable: | ||
return get_variable( | ||
name="state_to_word_W", | ||
shape=[self.rnn_size, len(self.vocabulary)]) | ||
|
||
@tensor | ||
def decoding_b(self) -> tf.Variable: | ||
return get_variable( | ||
name="state_to_word_b", | ||
shape=[len(self.vocabulary)], | ||
initializer=tf.zeros_initializer()) | ||
def concatenated_inputs(self) -> tf.Tensor: | ||
# Validate shapes first | ||
with tf.control_dependencies(self.input_mask): | ||
return tf.concat( | ||
[inp.temporal_states for inp in self.encoders], axis=2) | ||
|
||
@tensor | ||
def decoding_residual_w(self) -> tf.Variable: | ||
input_dim = self.encoder.input_sequence.dimension | ||
return get_variable( | ||
name="emb_to_word_W", | ||
shape=[input_dim, len(self.vocabulary)]) | ||
def states(self) -> tf.Tensor: | ||
states = dropout( | ||
self.concatenated_inputs, self.dropout_keep_prob, self.train_mode) | ||
|
||
jindrahelcl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self.hidden_dim is not None: | ||
hidden = tf.layers.dense( | ||
states, self.hidden_dim, self.activation, | ||
name="hidden_layer") | ||
# pylint: disable=redefined-variable-type | ||
states = dropout(hidden, self.dropout_keep_prob, self.train_mode) | ||
# pylint: enable=redefined-variable-type | ||
return states | ||
|
||
@tensor | ||
def logits(self) -> tf.Tensor: | ||
# To multiply 3-D matrix (encoder hidden states) by a 2-D matrix | ||
# (weights), we use 1-by-1 convolution (similar trick can be found in | ||
# attention computation) | ||
|
||
# TODO dropout needs to be revisited | ||
|
||
encoder_states = tf.expand_dims(self.encoder.temporal_states, 2) | ||
weights_4d = tf.expand_dims(tf.expand_dims(self.decoding_w, 0), 0) | ||
|
||
multiplication = tf.nn.conv2d( | ||
encoder_states, weights_4d, [1, 1, 1, 1], "SAME") | ||
multiplication_3d = tf.squeeze(multiplication, axis=[2]) | ||
|
||
biases_3d = tf.expand_dims(tf.expand_dims(self.decoding_b, 0), 0) | ||
|
||
embedded_inputs = tf.expand_dims( | ||
self.encoder.input_sequence.temporal_states, 2) | ||
dweights_4d = tf.expand_dims( | ||
tf.expand_dims(self.decoding_residual_w, 0), 0) | ||
|
||
dmultiplication = tf.nn.conv2d( | ||
embedded_inputs, dweights_4d, [1, 1, 1, 1], "SAME") | ||
dmultiplication_3d = tf.squeeze(dmultiplication, axis=[2]) | ||
|
||
logits = multiplication_3d + dmultiplication_3d + biases_3d | ||
return logits | ||
return tf.layers.dense( | ||
jindrahelcl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.states, len(self.vocabulary), name="logits") | ||
|
||
@tensor | ||
def logprobs(self) -> tf.Tensor: | ||
|
@@ -120,14 +118,19 @@ def decoded(self) -> tf.Tensor: | |
return tf.argmax(self.logits, 2) | ||
|
||
@tensor | ||
def cost(self) -> tf.Tensor: | ||
def train_xents(self) -> tf.Tensor: | ||
loss = tf.nn.sparse_softmax_cross_entropy_with_logits( | ||
labels=self.train_targets, logits=self.logits) | ||
|
||
# loss is now of shape [batch, time]. Need to mask it now by | ||
# element-wise multiplication with weights placeholder | ||
weighted_loss = loss * self.train_mask | ||
return tf.reduce_sum(weighted_loss) | ||
return loss * self.train_mask | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Jo, jeste se mi tady nezda format train_xents. Bylo by teda fajn se dohodnout, co se bude vracet (bud tady vracet prumer-per-sequence; nebo v autoregressive by melo stacit vypnout average_across_timesteps v self.train_xents) In the long run by samozrejme mel byt jeden spolecny predek "dekoder" s abstraktnima metodama, jako loss, xents apod., ale to bych klidne nechal do jineho PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. V mariánovi to jde nastavovat ("ce-mean" vs "ce-mean-words"), s tím, že v nových modelech se používá loss zprůměrovanej ze všech slov v batchi, ale default má průměr po větách. U labeleru dává víc smysl mít ten loss pro každej label zvlášť, kdežto u dekodéru asi spíš po větách, ale souhlasim, že se to má sjednotit. Je tim pádem asi lepší vracet (batch, time) kterej si pak zprůměruješ jak chceš. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Taky jsem za druhou moznost, to ale znamena tedka jeste vypnout switch v AutoregressiveDecoder. Nevim, jak to ovlivni zabehle trenovani (predpokladam, ze minimalne; ale nemeril jsem to). Pokud udelas porovnani, tak to klidne muzes zamergovat timhle zpusobem. Na druhou stranu uz mam vyzkousene ze udelat prumer pres vety a pak pres batch v seq. labeleru funguje, takze bych v tomto PR radej udelal tohle (hodil issue na poradne doreseni). Kazdopadne by bylo fajn to uz ted mit v masteru sjednocene, nez to nekam zapadne. Klicove je, ze to vyrazne snizi uroven odrbavani v jinych komponentach, kde pak musis mit divne workaroundy tipu kontroly shapu, supported decoder apod. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. autoregressive dekodér má nějakej switch? Vidim jen, že to dělá There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mam na mysli toto: seq2seq.sequence_loss ma prepinac average_across_timesteps, ktery je tedka True. Kdyby se vypnul, tak by to odpovidalo formatu train_xents v labeleru There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. už to vidim... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hotovo. musel se ještě změnit perplexity runner, kterej počítá s průměrama přes čas. |
||
|
||
@tensor | ||
def cost(self) -> tf.Tensor: | ||
# Cross entropy mean over all words in the batch | ||
# (could also be done as a mean over sentences) | ||
return tf.reduce_sum(self.train_xents) / tf.reduce_sum(self.train_mask) | ||
|
||
@property | ||
def train_loss(self) -> tf.Tensor: | ||
|
@@ -142,6 +145,67 @@ def feed_dict(self, dataset: Dataset, train: bool = False) -> FeedDict: | |
|
||
sentences = dataset.maybe_get_series(self.data_id) | ||
if sentences is not None: | ||
fd[self.target_tokens] = pad_batch(list(sentences)) | ||
fd[self.target_tokens] = pad_batch( | ||
list(sentences), self.max_output_len, self.add_start_symbol, | ||
self.add_end_symbol) | ||
|
||
return fd | ||
|
||
|
||
class EmbeddingsLabeler(SequenceLabeler): | ||
"""SequenceLabeler that uses an embedding matrix for output projection.""" | ||
|
||
# pylint: disable=too-many-arguments,too-many-locals | ||
def __init__(self, | ||
name: str, | ||
encoders: List[TemporalStateful], | ||
embedded_sequence: EmbeddedSequence, | ||
data_id: str, | ||
max_output_len: int = None, | ||
hidden_dim: int = None, | ||
activation: Callable = tf.nn.relu, | ||
train_embeddings: bool = True, | ||
dropout_keep_prob: float = 1.0, | ||
add_start_symbol: bool = False, | ||
add_end_symbol: bool = False, | ||
reuse: ModelPart = None, | ||
save_checkpoint: str = None, | ||
load_checkpoint: str = None, | ||
initializers: InitializerSpecs = None) -> None: | ||
|
||
check_argument_types() | ||
SequenceLabeler.__init__( | ||
self, name, encoders, embedded_sequence.vocabulary, data_id, | ||
max_output_len, hidden_dim=hidden_dim, activation=activation, | ||
dropout_keep_prob=dropout_keep_prob, | ||
add_start_symbol=add_start_symbol, add_end_symbol=add_end_symbol, | ||
reuse=reuse, save_checkpoint=save_checkpoint, | ||
load_checkpoint=load_checkpoint, initializers=initializers) | ||
|
||
self.embedded_sequence = embedded_sequence | ||
self.train_embeddings = train_embeddings | ||
# pylint: enable=too-many-arguments,too-many-locals | ||
|
||
@tensor | ||
def logits(self) -> tf.Tensor: | ||
embeddings = self.embedded_sequence.embedding_matrix | ||
if not self.train_embeddings: | ||
embeddings = tf.stop_gradient(embeddings) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fakt muzes tohle udelat? Vzhledem k tomu, ze se jedna o fakticky posledni vrstvu, nestane se to, ze ti pri backpropu neprotece zadna informace do zbytku site (a tedy se nic nenaucis)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. neproteče tam přes tuhle lokální proměnnou There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Jasne, uz to vidim |
||
|
||
states = self.states | ||
# pylint: disable=no-member | ||
states_dim = self.states.get_shape()[-1].value | ||
# pylint: enable=no-member | ||
embedding_dim = self.embedded_sequence.embedding_sizes[0] | ||
# pylint: disable=redefined-variable-type | ||
if states_dim != embedding_dim: | ||
states = tf.layers.dense( | ||
jindrahelcl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
states, embedding_dim, name="project_for_embeddings") | ||
states = dropout(states, self.dropout_keep_prob, self.train_mode) | ||
# pylint: enable=redefined-variable-type | ||
|
||
reshaped_states = tf.reshape(states, [-1, embedding_dim]) | ||
reshaped_logits = tf.matmul( | ||
reshaped_states, embeddings, transpose_b=True, name="logits") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Biasy necheme? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. neděláme je ani jinde při There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Jo, nevsiml jsem si, ze EmbeddingsLabeler vzdycky vaze embeddingy. |
||
return tf.reshape( | ||
reshaped_logits, [self.batch_size, -1, len(self.vocabulary)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kdyz budu mit vice enkoderu, co pracuji nad ruzne dlouhymi sekvencemi, tak to na tom concatu spadne, ne?
Nemela by se takova situace resit spise pres FactoredSequence/FactoredEncoder?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to je pravda