Skip to content

Commit

Permalink
Merge pull request #13 from MentenAI/dtype
Browse files Browse the repository at this point in the history
Dtype and bump to 0.2.1
  • Loading branch information
JackMaguire authored Feb 5, 2021
2 parents 3536dd9 + 10a932e commit 998d30f
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.2.0
current_version = 0.2.1
commit = True
tag = True

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
year = '2020'
author = 'Menten AI, Inc.'
copyright = '{0}, {1}'.format(year, author)
version = release = '0.2.0'
version = release = '0.2.1'

pygments_style = 'trac'
templates_path = ['_templates']
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def read(*names, **kwargs):

setup(
name='menten-gcn',
version='0.2.0',
version='0.2.1',
license='MIT',
description='This package decorates graph tensors with data from protein models',
long_description='%s' % (
Expand Down
27 changes: 17 additions & 10 deletions src/menten_gcn/data_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ class DataMaker:
nbr_distance_cutoff_A: float
A node will be included in the graph if it is within this distance (Angstroms) of any focus node.
A value of None will set this equal to edge_distance_cutoff_A
dtype: np.dtype
What numpy data type should we use to represent your data?
"""

def __init__(self, decorators: List[Decorator], edge_distance_cutoff_A: float, max_residues: int,
exclude_bbdec: bool = False, nbr_distance_cutoff_A: float = None):
exclude_bbdec: bool = False, nbr_distance_cutoff_A: float = None,
dtype: np.dtype = np.float32):

self.bare_bones_decorator = BareBonesDecorator()
self.exclude_bbdec = exclude_bbdec
Expand All @@ -53,6 +56,8 @@ def __init__(self, decorators: List[Decorator], edge_distance_cutoff_A: float, m
else:
self.nbr_distance_cutoff_A = nbr_distance_cutoff_A

self.dtype = dtype

def get_N_F_S(self) -> Tuple[int, int, int]:
"""
Returns
Expand Down Expand Up @@ -179,17 +184,17 @@ def _get_edge_data_for_pair(self, wrapped_pose: WrappedPose, resid_i: int, resid
assert len(f_ij) == self.all_decs.n_edge_features()
assert len(f_ji) == self.all_decs.n_edge_features()

f_ij = np.asarray(f_ij)
f_ji = np.asarray(f_ji)
f_ij = np.asarray(f_ij, dtype=self.dtype)
f_ji = np.asarray(f_ji, dtype=self.dtype)
if data_cache.edge_cache is not None:
data_cache.edge_cache[resid_i][resid_j] = f_ij
data_cache.edge_cache[resid_j][resid_i] = f_ji
return f_ij, f_ji

def _calc_adjacency_matrix_and_edge_data(self, wrapped_pose: WrappedPose, all_resids: List[int], data_cache):
N, F, S = self.get_N_F_S()
A_dense = np.zeros(shape=[N, N])
E_dense = np.zeros(shape=[N, N, S])
A_dense = np.zeros(shape=[N, N], dtype=self.dtype)
E_dense = np.zeros(shape=[N, N, S], dtype=self.dtype)

for i in range(0, len(all_resids) - 1):
resid_i = all_resids[i]
Expand All @@ -210,7 +215,7 @@ def _calc_adjacency_matrix_and_edge_data(self, wrapped_pose: WrappedPose, all_re

def _get_node_data(self, wrapped_pose: WrappedPose, resids: List[int], data_cache):
N, F, S = self.get_N_F_S()
X = np.zeros(shape=[N, F])
X = np.zeros(shape=[N, F], dtype=self.dtype)
index = -1
for resid in resids:
index += 1
Expand All @@ -226,7 +231,7 @@ def _get_node_data(self, wrapped_pose: WrappedPose, resids: List[int], data_cach

n = self.all_decs.calc_node_features(wrapped_pose, resid)

n = np.asarray(n)
n = np.asarray(n, dtype=self.dtype)
if data_cache.node_cache is not None:
data_cache.node_cache[resid] = n
X[index] = n
Expand All @@ -253,10 +258,12 @@ def generate_XAE_input_tensors(self) -> Tuple[Layer, Layer, Layer]:
Edge Feature Input
"""

dtype_str = str(self.dtype).split('.')[-1].split('\'')[0]

N, F, S = self.get_N_F_S()
X_in = Input(shape=(N, F), name='X_in')
A_in = Input(shape=(N, N), sparse=False, name='A_in')
E_in = Input(shape=(N, N, S), name='E_in')
X_in = Input(shape=(N, F), name='X_in', dtype=dtype_str)
A_in = Input(shape=(N, N), sparse=False, name='A_in', dtype=dtype_str)
E_in = Input(shape=(N, N, S), name='E_in', dtype=dtype_str)
return X_in, A_in, E_in

def generate_input(self, wrapped_pose: WrappedPose, focus_resids: List[int], data_cache: DecoratorDataCache = None,
Expand Down
28 changes: 15 additions & 13 deletions src/menten_gcn/data_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,19 @@ class DataHolder:
There are descriptions for each method below but perhaps the best way to grasp
the DataHolder's usage is to see the example at the bottom.
Parameters
----------
dtype: np.dtype
What NumPy dtype should we use to represent your data?
"""

def __init__(self):
def __init__(self, dtype: np.dtype = np.float32):
self.Xs = []
self.As = []
self.Es = []
self.outs = []
self.dtype = dtype

def assert_mode(self, mode=spektral.layers.ops.modes.BATCH):
"""
Expand Down Expand Up @@ -85,10 +91,10 @@ def append(self, X: np.ndarray, A: np.ndarray, E: np.ndarray, out: np.ndarray):
"""

# TODO assert shape
self.Xs.append(np.asarray(X))
self.As.append(np.asarray(A))
self.Es.append(np.asarray(E))
self.outs.append(np.asarray(out))
self.Xs.append(np.asarray(X, dtype=self.dtype))
self.As.append(np.asarray(A, dtype=self.dtype))
self.Es.append(np.asarray(E, dtype=self.dtype))
self.outs.append(np.asarray(out, dtype=self.dtype))

def size(self) -> int:
return len(self.Xs)
Expand Down Expand Up @@ -139,14 +145,10 @@ def save_to_file(self, fileprefix: str):
"""
np.savez_compressed(
fileprefix + '.npz',
x=np.asarray(
self.Xs),
a=np.asarray(
self.As),
e=np.asarray(
self.Es),
o=np.asarray(
self.outs))
x=np.asarray(self.Xs, dtype=self.dtype),
a=np.asarray(self.As, dtype=self.dtype),
e=np.asarray(self.Es, dtype=self.dtype),
o=np.asarray(self.outs, dtype=self.dtype))

def load_from_file(self, fileprefix: str = None, filename: str = None):
"""
Expand Down
14 changes: 7 additions & 7 deletions tests/test_menten_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def observe_NEE3(Xval, Aval, Eval, N, F, S):
foo(all_outs[11], [300., -12., 12., 200., -13., 13., 100., -23., 23., ])


def test_NEE3():
def extra_test_NEE3():
'''
2
/ \
Expand Down Expand Up @@ -387,7 +387,7 @@ def Xval(generator, i, j):
log = np.zeros(10)
for i in range(0, 6):
for j in range(0, 3):
log[Xval(generator, i, j)] += 1
log[int(Xval(generator, i, j))] += 1
assert(max(log) == 2)
assert(min(log) <= 1)
assert(sum(log) == 18)
Expand Down Expand Up @@ -758,7 +758,7 @@ def test_expected_md_traj_results():
assert_equal(resids, [20, 21, 19, 8, 61], 2)


def test_model_sizes():
def extra_test_model_sizes():
N = 5
F = 4
S = 3
Expand Down Expand Up @@ -931,7 +931,7 @@ def test_clustering():
121, 122]]


def test_sanity_check_flat_nbody():
def extra_test_sanity_check_flat_nbody():
N = 3
#F = 3
#S = 2
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def test_sanity_check_flat_nbody():
equal(stitch1_np, target, decimal=3)


def test_sanity_check_flat_nbody2():
def extra_test_sanity_check_flat_nbody2():
N = 3
#F = 3
#S = 2
Expand Down Expand Up @@ -1175,7 +1175,7 @@ def test_sanity_check_flat_nbody2():
equal(stitch1_np, target, decimal=3)


def test_flat_nbody_layer():
def extra_test_flat_nbody_layer():

class TestFlat(tf.keras.layers.Layer):
def __init__(self):
Expand Down Expand Up @@ -1348,7 +1348,7 @@ def call(self, inputs):
assert_almost_equal(output, target, decimal=3)


def test_flat_2body_feed():
def extra_test_flat_2body_feed():
testX = [[[0., 1., 0.],
[1., 2., 0.],
[0., 3., 3.]]]
Expand Down

0 comments on commit 998d30f

Please sign in to comment.