-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathembedding_layers.py
61 lines (45 loc) · 2.39 KB
/
embedding_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) 2021 Graphcore Ltd. All rights reserved.
import tensorflow as tf
from absl import flags
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
from tensorflow.keras import backend
flags.DEFINE_boolean("one_hot_embedding", False, "Use a one-hot formulation of the embedding lookup")
FLAGS = flags.FLAGS
class OneHotEmbedding(tf.keras.layers.Embedding):
def call(self, inputs):
dtype = backend.dtype(inputs)
if dtype != "int32" and dtype != "int64":
inputs = tf.cast(inputs, "int32")
one_hot = tf.one_hot(inputs, self.input_dim, dtype=self.compute_dtype)
out = one_hot @ self.embeddings
if self._dtype_policy.compute_dtype != self._dtype_policy.variable_dtype:
# Instead of casting the variable as in most layers, cast the output, as
# this is mathematically equivalent but is faster.
out = tf.cast(out, self._dtype_policy.compute_dtype)
return out
class MultiFeatureEncoder(tf.keras.layers.Layer):
def __init__(self, emb_dim, n_feature_dims, name=""):
"""
For something like an atom, which has several categorical features,
we build a learnable embedding by summing a looked-up embedding from each
of these several features.
:param emb_dim: number of output dimensions from the embedding
:param n_feature_dims: list of feature dimensions
:param name: name for the keras layer
"""
super().__init__(name=name)
self.emb_dim = emb_dim
embedding_fn = OneHotEmbedding if FLAGS.one_hot_embedding else tf.keras.layers.Embedding
# one embedding table for each of the categorical feature dimensions
self.embeddings = [embedding_fn(n, emb_dim) for n in n_feature_dims]
def call(self, inputs, training=True):
output = tf.zeros([*inputs.shape[:-1], self.emb_dim], dtype=self.compute_dtype)
for i, embedding in enumerate(self.embeddings):
output += embedding(inputs[..., i])
return output
class AtomEncoder(MultiFeatureEncoder):
def __init__(self, emb_dim, name="AtomEncoder"):
super().__init__(emb_dim=emb_dim, n_feature_dims=get_atom_feature_dims(), name=name)
class BondEncoder(MultiFeatureEncoder):
def __init__(self, emb_dim, name="BondEncoder"):
super().__init__(emb_dim=emb_dim, n_feature_dims=get_bond_feature_dims(), name=name)