diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 6d86640..f660b34 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -18,9 +18,9 @@ Steps to reproduce the behavior: 4. See error **Operating environment(运行环境):** - - python version [e.g. 3.4, 3.6] - - tensorflow version [e.g. 1.4.0, 1.12.0] - - deepmatch version [e.g. 0.1.1,] + - python version [e.g. 3.6, 3.7] + - tensorflow version [e.g. 1.4.0, 1.14.0, 2.3.0] + - deepmatch version [e.g. 0.2.0,] **Additional context** Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/question.md b/.github/ISSUE_TEMPLATE/question.md index 53e8451..f6d65c0 100644 --- a/.github/ISSUE_TEMPLATE/question.md +++ b/.github/ISSUE_TEMPLATE/question.md @@ -17,4 +17,4 @@ Add any other context about the problem here. **Operating environment(运行环境):** - python version [e.g. 3.6] - tensorflow version [e.g. 1.4.0,] - - deepmatch version [e.g. 0.1.1,] + - deepmatch version [e.g. 0.2.0,] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ee1b5bc..275b40d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: strategy: matrix: python-version: [3.5,3.6,3.7] - tf-version: [1.4.0,1.14.0,2.1.0,2.2.0] + tf-version: [1.4.0,1.14.0,2.1.0,2.2.0,2.3.0] exclude: - python-version: 3.7 diff --git a/README.md b/README.md index 74a346b..fc6d721 100644 --- a/README.md +++ b/README.md @@ -67,4 +67,4 @@ Please follow our wechat to join group: - 公众号:**浅梦的学习笔记** - wechat ID: **deepctrbot** - ![wechat](./docs/pics/weichennote.png) + ![wechat](./docs/pics/code.png) diff --git a/deepmatch/__init__.py b/deepmatch/__init__.py index b1af8c4..0fa3e17 100644 --- a/deepmatch/__init__.py +++ b/deepmatch/__init__.py @@ -1,4 +1,4 @@ from .utils import check_version -__version__ = '0.1.3' +__version__ = '0.2.0' check_version(__version__) diff --git a/deepmatch/inputs.py b/deepmatch/inputs.py index b37357a..fbb2d79 100644 --- a/deepmatch/inputs.py +++ b/deepmatch/inputs.py @@ -1,14 +1,17 @@ from itertools import chain -from deepctr.inputs import SparseFeat,VarLenSparseFeat,create_embedding_matrix,embedding_lookup,get_dense_input,varlen_embedding_lookup,get_varlen_pooling_list,mergeDict -def input_from_feature_columns(features, feature_columns, l2_reg, init_std, seed, prefix='', seq_mask_zero=True, - support_dense=True, support_group=False,embedding_matrix_dict=None): +from deepctr.feature_column import SparseFeat, VarLenSparseFeat, create_embedding_matrix, embedding_lookup, \ + get_dense_input, varlen_embedding_lookup, get_varlen_pooling_list, mergeDict + + +def input_from_feature_columns(features, feature_columns, l2_reg, seed, prefix='', seq_mask_zero=True, + support_dense=True, support_group=False, embedding_matrix_dict=None): sparse_feature_columns = list( filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if feature_columns else [] varlen_sparse_feature_columns = list( filter(lambda x: isinstance(x, VarLenSparseFeat), feature_columns)) if feature_columns else [] - if embedding_matrix_dict is None: - embedding_matrix_dict = create_embedding_matrix(feature_columns, l2_reg, init_std, seed, prefix=prefix, + if embedding_matrix_dict is None: + embedding_matrix_dict = create_embedding_matrix(feature_columns, l2_reg, seed, prefix=prefix, seq_mask_zero=seq_mask_zero) group_sparse_embedding_dict = embedding_lookup(embedding_matrix_dict, features, sparse_feature_columns) @@ -22,4 +25,4 @@ def input_from_feature_columns(features, feature_columns, l2_reg, init_std, seed group_embedding_dict = mergeDict(group_sparse_embedding_dict, group_varlen_sparse_embedding_dict) if not support_group: group_embedding_dict = list(chain.from_iterable(group_embedding_dict.values())) - return group_embedding_dict, dense_value_list \ No newline at end of file + return group_embedding_dict, dense_value_list diff --git a/deepmatch/layers/__init__.py b/deepmatch/layers/__init__.py index bd94afe..854c596 100644 --- a/deepmatch/layers/__init__.py +++ b/deepmatch/layers/__init__.py @@ -1,28 +1,29 @@ from deepctr.layers import custom_objects from deepctr.layers.utils import reduce_sum -from .core import PoolingLayer, Similarity, LabelAwareAttention, CapsuleLayer,SampledSoftmaxLayer,EmbeddingIndex -from ..utils import sampledsoftmaxloss -from .interaction import DotAttention, ConcatAttention, SoftmaxWeightedSum, AttentionSequencePoolingLayer, SelfAttention,\ +from .core import PoolingLayer, Similarity, LabelAwareAttention, CapsuleLayer, SampledSoftmaxLayer, EmbeddingIndex +from .interaction import DotAttention, ConcatAttention, SoftmaxWeightedSum, AttentionSequencePoolingLayer, \ + SelfAttention, \ SelfMultiHeadAttention, UserAttention from .sequence import DynamicMultiRNN +from ..utils import sampledsoftmaxloss _custom_objects = {'PoolingLayer': PoolingLayer, 'Similarity': Similarity, 'LabelAwareAttention': LabelAwareAttention, 'CapsuleLayer': CapsuleLayer, - 'reduce_sum':reduce_sum, - 'SampledSoftmaxLayer':SampledSoftmaxLayer, - 'sampledsoftmaxloss':sampledsoftmaxloss, - 'EmbeddingIndex':EmbeddingIndex, - 'DotAttention':DotAttention, - 'ConcatAttention':ConcatAttention, - 'SoftmaxWeightedSum':SoftmaxWeightedSum, - 'AttentionSequencePoolingLayer':AttentionSequencePoolingLayer, - 'SelfAttention':SelfAttention, - 'SelfMultiHeadAttention':SelfMultiHeadAttention, - 'UserAttention':UserAttention, - 'DynamicMultiRNN':DynamicMultiRNN + 'reduce_sum': reduce_sum, + 'SampledSoftmaxLayer': SampledSoftmaxLayer, + 'sampledsoftmaxloss': sampledsoftmaxloss, + 'EmbeddingIndex': EmbeddingIndex, + 'DotAttention': DotAttention, + 'ConcatAttention': ConcatAttention, + 'SoftmaxWeightedSum': SoftmaxWeightedSum, + 'AttentionSequencePoolingLayer': AttentionSequencePoolingLayer, + 'SelfAttention': SelfAttention, + 'SelfMultiHeadAttention': SelfMultiHeadAttention, + 'UserAttention': UserAttention, + 'DynamicMultiRNN': DynamicMultiRNN } custom_objects = dict(custom_objects, **_custom_objects) diff --git a/deepmatch/layers/core.py b/deepmatch/layers/core.py index 293b885..2bdd7c2 100644 --- a/deepmatch/layers/core.py +++ b/deepmatch/layers/core.py @@ -1,7 +1,9 @@ import tensorflow as tf +from deepctr.layers.activation import activation_layer from deepctr.layers.utils import reduce_max, reduce_mean, reduce_sum, concat_func, div, softmax -from tensorflow.python.keras.initializers import RandomNormal, Zeros +from tensorflow.python.keras.initializers import RandomNormal, Zeros, glorot_normal from tensorflow.python.keras.layers import Layer +from tensorflow.python.keras.regularizers import l2 class PoolingLayer(Layer): @@ -216,21 +218,20 @@ def squash(inputs): return vec_squashed - - class EmbeddingIndex(Layer): - def __init__(self, index,**kwargs): - self.index =index + def __init__(self, index, **kwargs): + self.index = index super(EmbeddingIndex, self).__init__(**kwargs) def build(self, input_shape): - super(EmbeddingIndex, self).build( input_shape) # Be sure to call this somewhere! + def call(self, x, **kwargs): - return tf.constant(self.index) + return tf.constant(self.index) + def get_config(self, ): config = {'index': self.index, } base_config = super(EmbeddingIndex, self).get_config() - return dict(list(base_config.items()) + list(config.items())) \ No newline at end of file + return dict(list(base_config.items()) + list(config.items())) diff --git a/deepmatch/layers/sequence.py b/deepmatch/layers/sequence.py index a3bffc7..10e12b9 100644 --- a/deepmatch/layers/sequence.py +++ b/deepmatch/layers/sequence.py @@ -24,19 +24,19 @@ def build(self, input_shape): if self.rnn_type == "LSTM": try: single_cell = tf.nn.rnn_cell.BasicLSTMCell(self.num_units, forget_bias=self.forget_bias) - except: + except AttributeError: single_cell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell(self.num_units, forget_bias=self.forget_bias) elif self.rnn_type == "GRU": try: single_cell = tf.nn.rnn_cell.GRUCell(self.num_units, forget_bias=self.forget_bias) - except: + except AttributeError: single_cell = tf.compat.v1.nn.rnn_cell.GRUCell(self.num_units, forget_bias=self.forget_bias) else: raise ValueError("Unknown unit type %s!" % self.rnn_type) dropout = self.dropout if tf.keras.backend.learning_phase() == 1 else 0 try: single_cell = tf.nn.rnn_cell.DropoutWrapper(cell=single_cell, input_keep_prob=(1.0 - dropout)) - except: + except AttributeError: single_cell = tf.compat.v1.nn.rnn_cell.DropoutWrapper(cell=single_cell, input_keep_prob=(1.0 - dropout)) cell_list = [] for i in range(self.num_layers): @@ -44,7 +44,7 @@ def build(self, input_shape): if residual: try: single_cell_residual = tf.nn.rnn_cell.ResidualWrapper(single_cell) - except: + except AttributeError: single_cell_residual = tf.compat.v1.nn.rnn_cell.ResidualWrapper(single_cell) cell_list.append(single_cell_residual) else: @@ -54,7 +54,7 @@ def build(self, input_shape): else: try: self.final_cell = tf.nn.rnn_cell.MultiRNNCell(cell_list) - except: + except AttributeError: self.final_cell = tf.compat.v1.nn.rnn_cell.MultiRNNCell(cell_list) super(DynamicMultiRNN, self).build(input_shape) @@ -66,7 +66,7 @@ def call(self, input_list, mask=None, training=None): rnn_output, hidden_state = tf.nn.dynamic_rnn(self.final_cell, inputs=rnn_input, sequence_length=tf.squeeze(sequence_length), dtype=tf.float32, scope=self.name) - except: + except AttributeError: with tf.name_scope("rnn"), tf.compat.v1.variable_scope("rnn", reuse=tf.compat.v1.AUTO_REUSE): rnn_output, hidden_state = tf.compat.v1.nn.dynamic_rnn(self.final_cell, inputs=rnn_input, sequence_length=tf.squeeze(sequence_length), @@ -86,6 +86,6 @@ def compute_output_shape(self, input_shape): def get_config(self, ): config = {'num_units': self.num_units, 'rnn_type': self.rnn_type, 'return_sequence': self.return_sequence, 'num_layers': self.num_layers, - 'num_residual_layers': self.num_residual_layers, 'dropout_rate': self.dropout} + 'num_residual_layers': self.num_residual_layers, 'dropout_rate': self.dropout, 'forget_bias':self.forget_bias} base_config = super(DynamicMultiRNN, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/deepmatch/models/dssm.py b/deepmatch/models/dssm.py index 74ac3b1..d1bd455 100644 --- a/deepmatch/models/dssm.py +++ b/deepmatch/models/dssm.py @@ -5,8 +5,8 @@ Huang P S , He X , Gao J , et al. Learning deep structured semantic models for web search using clickthrough data[C]// Acm International Conference on Conference on Information & Knowledge Management. ACM, 2013. """ -from deepctr.inputs import build_input_features, combined_dnn_input, create_embedding_matrix -from deepctr.layers.core import PredictionLayer, DNN +from deepctr.feature_column import build_input_features, create_embedding_matrix +from deepctr.layers import PredictionLayer, DNN, combined_dnn_input from tensorflow.python.keras.models import Model from ..inputs import input_from_feature_columns @@ -16,7 +16,7 @@ def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64, 32), item_dnn_hidden_units=(64, 32), dnn_activation='tanh', dnn_use_bn=False, - l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, seed=1024, metric='cos'): + l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, seed=1024, metric='cos'): """Instantiates the Deep Structured Semantic Model architecture. :param user_feature_columns: An iterable containing user's features used by the model. @@ -28,7 +28,6 @@ def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64, :param l2_reg_dnn: float. L2 regularizer strength applied to DNN :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. - :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. :param metric: str, ``"cos"`` for cosine or ``"ip"`` for inner product :return: A Keras model instance. @@ -36,14 +35,14 @@ def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64, """ embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding, - init_std, seed, + seed=seed, seq_mask_zero=True) user_features = build_input_features(user_feature_columns) user_inputs_list = list(user_features.values()) user_sparse_embedding_list, user_dense_value_list = input_from_feature_columns(user_features, user_feature_columns, - l2_reg_embedding, init_std, seed, + l2_reg_embedding, seed=seed, embedding_matrix_dict=embedding_matrix_dict) user_dnn_input = combined_dnn_input(user_sparse_embedding_list, user_dense_value_list) @@ -51,15 +50,15 @@ def DSSM(user_feature_columns, item_feature_columns, user_dnn_hidden_units=(64, item_inputs_list = list(item_features.values()) item_sparse_embedding_list, item_dense_value_list = input_from_feature_columns(item_features, item_feature_columns, - l2_reg_embedding, init_std, seed, + l2_reg_embedding, seed=seed, embedding_matrix_dict=embedding_matrix_dict) item_dnn_input = combined_dnn_input(item_sparse_embedding_list, item_dense_value_list) user_dnn_out = DNN(user_dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, - dnn_use_bn, seed, )(user_dnn_input) + dnn_use_bn, seed=seed)(user_dnn_input) item_dnn_out = DNN(item_dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, - dnn_use_bn, seed)(item_dnn_input) + dnn_use_bn, seed=seed)(item_dnn_input) score = Similarity(type=metric)([user_dnn_out, item_dnn_out]) diff --git a/deepmatch/models/fm.py b/deepmatch/models/fm.py index 3b5fa91..97bf339 100644 --- a/deepmatch/models/fm.py +++ b/deepmatch/models/fm.py @@ -1,4 +1,4 @@ -from deepctr.inputs import build_input_features +from deepctr.feature_column import build_input_features from deepctr.layers.core import PredictionLayer from deepctr.layers.utils import concat_func, reduce_sum from tensorflow.python.keras.layers import Lambda @@ -8,13 +8,12 @@ from ..layers.core import Similarity -def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, init_std=0.0001, seed=1024, metric='cos'): +def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, seed=1024, metric='cos'): """Instantiates the FM architecture. :param user_feature_columns: An iterable containing user's features used by the model. :param item_feature_columns: An iterable containing item's features used by the model. :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector - :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. :param metric: str, ``"cos"`` for cosine or ``"ip"`` for inner product :return: A Keras model instance. @@ -22,14 +21,14 @@ def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, init_s """ embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding, - init_std, seed, + seed=seed, seq_mask_zero=True) user_features = build_input_features(user_feature_columns) user_inputs_list = list(user_features.values()) user_sparse_embedding_list, user_dense_value_list = input_from_feature_columns(user_features, user_feature_columns, - l2_reg_embedding, init_std, seed, + l2_reg_embedding, seed=seed, support_dense=False, embedding_matrix_dict=embedding_matrix_dict) @@ -37,7 +36,7 @@ def FM(user_feature_columns, item_feature_columns, l2_reg_embedding=1e-6, init_s item_inputs_list = list(item_features.values()) item_sparse_embedding_list, item_dense_value_list = input_from_feature_columns(item_features, item_feature_columns, - l2_reg_embedding, init_std, seed, + l2_reg_embedding, seed=seed, support_dense=False, embedding_matrix_dict=embedding_matrix_dict) diff --git a/deepmatch/models/mind.py b/deepmatch/models/mind.py index ddd1c2c..3053dd1 100755 --- a/deepmatch/models/mind.py +++ b/deepmatch/models/mind.py @@ -7,11 +7,10 @@ """ import tensorflow as tf -from deepctr.inputs import SparseFeat, VarLenSparseFeat, DenseFeat, \ - embedding_lookup, varlen_embedding_lookup, get_varlen_pooling_list, get_dense_input, build_input_features, \ - combined_dnn_input -from deepctr.layers.core import DNN -from deepctr.layers.utils import NoMask +from deepctr.feature_column import SparseFeat, VarLenSparseFeat, DenseFeat, \ + embedding_lookup, varlen_embedding_lookup, get_varlen_pooling_list, get_dense_input, build_input_features +from deepctr.layers import DNN +from deepctr.layers.utils import NoMask, combined_dnn_input from tensorflow.python.keras.layers import Concatenate from tensorflow.python.keras.models import Model @@ -30,8 +29,7 @@ def tile_user_otherfeat(user_other_feature, k_max): def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1.0, dynamic_k=False, user_dnn_hidden_units=(64, 32), dnn_activation='relu', dnn_use_bn=False, l2_reg_dnn=0, l2_reg_embedding=1e-6, - dnn_dropout=0, - init_std=0.0001, seed=1024): + dnn_dropout=0, output_activation='linear', seed=1024): """Instantiates the MIND Model architecture. :param user_feature_columns: An iterable containing user's features used by the model. @@ -47,8 +45,8 @@ def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1 :param l2_reg_dnn: L2 regularizer strength applied to DNN :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. - :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. + :param output_activation: Activation function to use in output layer :return: A Keras model instance. """ @@ -83,8 +81,7 @@ def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1 inputs_list = list(features.values()) embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding, - init_std, - seed, prefix="") + seed=seed, prefix="") item_features = build_input_features(item_feature_columns) @@ -127,7 +124,9 @@ def MIND(user_feature_columns, item_feature_columns, num_sampled=5, k_max=2, p=1 user_deep_input = high_capsule user_embeddings = DNN(user_dnn_hidden_units, dnn_activation, l2_reg_dnn, - dnn_dropout, dnn_use_bn, seed, name="user_embedding")(user_deep_input) + dnn_dropout, dnn_use_bn, output_activation=output_activation, seed=seed, + name="user_embedding")( + user_deep_input) item_inputs_list = list(item_features.values()) item_embedding_matrix = embedding_matrix_dict[item_feature_name] diff --git a/deepmatch/models/ncf.py b/deepmatch/models/ncf.py index 2de4c5f..95be769 100644 --- a/deepmatch/models/ncf.py +++ b/deepmatch/models/ncf.py @@ -8,8 +8,8 @@ import math -from deepctr.inputs import input_from_feature_columns, build_input_features, combined_dnn_input, SparseFeat -from deepctr.layers.core import DNN +from deepctr.feature_column import input_from_feature_columns, build_input_features, SparseFeat +from deepctr.layers import DNN, combined_dnn_input from tensorflow.python.keras.layers import Lambda, Concatenate, Multiply from tensorflow.python.keras.models import Model @@ -17,7 +17,7 @@ def NCF(user_feature_columns, item_feature_columns, user_gmf_embedding_dim=20, item_gmf_embedding_dim=20, user_mlp_embedding_dim=20, item_mlp_embedding_dim=20, dnn_use_bn=False, dnn_hidden_units=(64, 32), dnn_activation='relu', l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, - init_std=0.0001, seed=1024): + seed=1024): """Instantiates the NCF Model architecture. :param user_feature_columns: A dict containing user's features and features'dim. @@ -32,7 +32,6 @@ def NCF(user_feature_columns, item_feature_columns, user_gmf_embedding_dim=20, i :param l2_reg_dnn: float. L2 regularizer strength applied to DNN :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. - :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. :return: A Keras model instance. @@ -51,8 +50,8 @@ def NCF(user_feature_columns, item_feature_columns, user_gmf_embedding_dim=20, i user_inputs_list = list(user_features.values()) user_gmf_sparse_embedding_list, user_gmf_dense_value_list = input_from_feature_columns(user_features, user_gmf_feature_columns, - l2_reg_embedding, init_std, - seed, prefix='gmf_') + l2_reg_embedding, seed=seed, + prefix='gmf_') user_gmf_input = combined_dnn_input(user_gmf_sparse_embedding_list, []) user_gmf_out = Lambda(lambda x: x, name="user_gmf_embedding")(user_gmf_input) @@ -62,8 +61,8 @@ def NCF(user_feature_columns, item_feature_columns, user_gmf_embedding_dim=20, i item_inputs_list = list(item_features.values()) item_gmf_sparse_embedding_list, item_gmf_dense_value_list = input_from_feature_columns(item_features, item_gmf_feature_columns, - l2_reg_embedding, init_std, - seed, prefix='gmf_') + l2_reg_embedding, seed=seed, + prefix='gmf_') item_gmf_input = combined_dnn_input(item_gmf_sparse_embedding_list, []) item_gmf_out = Lambda(lambda x: x, name="item_gmf_embedding")(item_gmf_input) @@ -74,8 +73,8 @@ def NCF(user_feature_columns, item_feature_columns, user_gmf_embedding_dim=20, i for feat, size in user_feature_columns.items()] user_mlp_sparse_embedding_list, user_mlp_dense_value_list = input_from_feature_columns(user_features, user_mlp_feature_columns, - l2_reg_embedding, init_std, - seed, prefix='mlp_') + l2_reg_embedding, seed=seed, + prefix='mlp_') user_mlp_input = combined_dnn_input( user_mlp_sparse_embedding_list, user_mlp_dense_value_list) user_mlp_out = Lambda(lambda x: x, name="user_mlp_embedding")(user_mlp_input) @@ -85,19 +84,19 @@ def NCF(user_feature_columns, item_feature_columns, user_gmf_embedding_dim=20, i item_mlp_sparse_embedding_list, item_mlp_dense_value_list = input_from_feature_columns(item_features, item_mlp_feature_columns, - l2_reg_embedding, init_std, - seed, prefix='mlp_') + l2_reg_embedding, seed=seed, + prefix='mlp_') item_mlp_input = combined_dnn_input( item_mlp_sparse_embedding_list, item_mlp_dense_value_list) item_mlp_out = Lambda(lambda x: x, name="item_mlp_embedding")(item_mlp_input) mlp_input = Concatenate(axis=1)([user_mlp_out, item_mlp_out]) mlp_out = DNN(dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, - dnn_use_bn, seed, name="mlp_embedding")(mlp_input) + dnn_use_bn, seed = seed, name="mlp_embedding")(mlp_input) # Fusion of GMF and MLP neumf_input = Concatenate(axis=1)([gmf_out, mlp_out]) - neumf_out = DNN(hidden_units=[1], activation='sigmoid')(neumf_input) + neumf_out = DNN(hidden_units=[1], activation='sigmoid',seed=seed)(neumf_input) output = Lambda(lambda x: x, name='neumf_out')(neumf_out) # output = PredictionLayer(task, False)(neumf_out) diff --git a/deepmatch/models/sdm.py b/deepmatch/models/sdm.py index 8119fcf..6d6ca96 100644 --- a/deepmatch/models/sdm.py +++ b/deepmatch/models/sdm.py @@ -8,10 +8,11 @@ """ import tensorflow as tf -from deepctr.inputs import build_input_features, SparseFeat, DenseFeat, get_varlen_pooling_list, VarLenSparseFeat, \ +from deepctr.feature_column import build_input_features, SparseFeat, DenseFeat, get_varlen_pooling_list, \ + VarLenSparseFeat, \ create_embedding_matrix, embedding_lookup, varlen_embedding_lookup, concat_func from deepctr.layers.utils import NoMask -from tensorflow.python.keras.layers import Dense, Input, Lambda +from tensorflow.python.keras.layers import Dense, Lambda from tensorflow.python.keras.models import Model from deepmatch.utils import get_item_embedding @@ -20,9 +21,10 @@ from ..layers.sequence import DynamicMultiRNN -def SDM(user_feature_columns, item_feature_columns, history_feature_list, num_sampled=5, units=64, rnn_layers=2, dropout_rate=0.2, +def SDM(user_feature_columns, item_feature_columns, history_feature_list, num_sampled=5, units=64, rnn_layers=2, + dropout_rate=0.2, rnn_num_res=1, - num_head=4, l2_reg_embedding=1e-6, dnn_activation='tanh', init_std=0.0001, seed=1024): + num_head=4, l2_reg_embedding=1e-6, dnn_activation='tanh', seed=1024): """Instantiates the Sequential Deep Matching Model architecture. :param user_feature_columns: An iterable containing user's features used by the model. @@ -36,14 +38,13 @@ def SDM(user_feature_columns, item_feature_columns, history_feature_list, num_sa :param num_head: int int, the number of attention head :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector :param dnn_activation: Activation function to use in deep net - :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. :return: A Keras model instance. """ if len(item_feature_columns) > 1: - raise ValueError("Now MIND only support 1 item feature like item_id") + raise ValueError("Now SDM only support 1 item feature like item_id") item_feature_column = item_feature_columns[0] item_feature_name = item_feature_column.name item_vocabulary_size = item_feature_columns[0].vocabulary_size @@ -78,8 +79,7 @@ def SDM(user_feature_columns, item_feature_columns, history_feature_list, num_sa sparse_varlen_feature_columns.append(fc) embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding, - init_std, seed, - prefix="") + seed=seed) item_features = build_input_features(item_feature_columns) item_inputs_list = list(item_features.values()) @@ -142,11 +142,22 @@ def SDM(user_feature_columns, item_feature_columns, history_feature_list, num_sa pooling_item_embedding_weight, gate_output_reshape, item_features[item_feature_name]]) model = Model(inputs=user_inputs_list + item_inputs_list, outputs=output) + # model.user_input = user_inputs_list + # model.user_embedding = gate_output_reshape + model.__setattr__("user_input", user_inputs_list) model.__setattr__("user_embedding", gate_output_reshape) + # model.item_input = item_inputs_list + # model.item_embedding = get_item_embedding(pooling_item_embedding_weight, item_features[item_feature_name]) + model.__setattr__("item_input", item_inputs_list) model.__setattr__("item_embedding", get_item_embedding(pooling_item_embedding_weight, item_features[item_feature_name])) return model + # , Model(inputs=user_inputs_list, outputs=gate_output_reshape), Model(inputs=item_inputs_list, + # outputs=get_item_embedding( + # pooling_item_embedding_weight, + # item_features[ + # item_feature_name])) diff --git a/deepmatch/models/youtubednn.py b/deepmatch/models/youtubednn.py index f17087f..3b96a70 100644 --- a/deepmatch/models/youtubednn.py +++ b/deepmatch/models/youtubednn.py @@ -4,21 +4,21 @@ Reference: Covington P, Adams J, Sargin E. Deep neural networks for youtube recommendations[C]//Proceedings of the 10th ACM conference on recommender systems. 2016: 191-198. """ -from deepctr.inputs import input_from_feature_columns, build_input_features, combined_dnn_input, create_embedding_matrix -from deepctr.layers.core import DNN -from deepctr.layers.utils import NoMask +from deepctr.feature_column import build_input_features +from deepctr.layers import DNN +from deepctr.layers.utils import NoMask, combined_dnn_input from tensorflow.python.keras.models import Model from deepmatch.layers import PoolingLayer from deepmatch.utils import get_item_embedding -from ..inputs import input_from_feature_columns +from ..inputs import input_from_feature_columns, create_embedding_matrix from ..layers.core import SampledSoftmaxLayer, EmbeddingIndex def YoutubeDNN(user_feature_columns, item_feature_columns, num_sampled=5, user_dnn_hidden_units=(64, 32), dnn_activation='relu', dnn_use_bn=False, - l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, init_std=0.0001, seed=1024, ): + l2_reg_dnn=0, l2_reg_embedding=1e-6, dnn_dropout=0, output_activation='linear', seed=1024, ): """Instantiates the YoutubeDNN Model architecture. :param user_feature_columns: An iterable containing user's features used by the model. @@ -30,8 +30,8 @@ def YoutubeDNN(user_feature_columns, item_feature_columns, num_sampled=5, :param l2_reg_dnn: float. L2 regularizer strength applied to DNN :param l2_reg_embedding: float. L2 regularizer strength applied to embedding vector :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. - :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. + :param output_activation: Activation function to use in output layer :return: A Keras model instance. """ @@ -42,20 +42,19 @@ def YoutubeDNN(user_feature_columns, item_feature_columns, num_sampled=5, item_vocabulary_size = item_feature_columns[0].vocabulary_size embedding_matrix_dict = create_embedding_matrix(user_feature_columns + item_feature_columns, l2_reg_embedding, - init_std, seed, prefix="") + seed=seed) user_features = build_input_features(user_feature_columns) user_inputs_list = list(user_features.values()) - user_sparse_embedding_list, user_dense_value_list = input_from_feature_columns(user_features, - user_feature_columns, - l2_reg_embedding, init_std, seed, + user_sparse_embedding_list, user_dense_value_list = input_from_feature_columns(user_features, user_feature_columns, + l2_reg_embedding, seed=seed, embedding_matrix_dict=embedding_matrix_dict) user_dnn_input = combined_dnn_input(user_sparse_embedding_list, user_dense_value_list) item_features = build_input_features(item_feature_columns) item_inputs_list = list(item_features.values()) user_dnn_out = DNN(user_dnn_hidden_units, dnn_activation, l2_reg_dnn, dnn_dropout, - dnn_use_bn, seed, )(user_dnn_input) + dnn_use_bn, output_activation=output_activation, seed=seed)(user_dnn_input) item_index = EmbeddingIndex(list(range(item_vocabulary_size)))(item_features[item_feature_name]) diff --git a/docs/pics/code.png b/docs/pics/code.png new file mode 100644 index 0000000..aa53cbf Binary files /dev/null and b/docs/pics/code.png differ diff --git a/docs/pics/weichennote.png b/docs/pics/weichennote.png index fec7b11..0b60a2f 100644 Binary files a/docs/pics/weichennote.png and b/docs/pics/weichennote.png differ diff --git a/docs/source/Examples.md b/docs/source/Examples.md index dc2a7f4..2e652e2 100644 --- a/docs/source/Examples.md +++ b/docs/source/Examples.md @@ -20,7 +20,7 @@ This example shows how to use ``YoutubeDNN`` to solve a matching task. You can g ```python import pandas as pd -from deepctr.inputs import SparseFeat, VarLenSparseFeat +from deepctr.feature_column import SparseFeat, VarLenSparseFeat from preprocess import gen_data_set, gen_model_input from sklearn.preprocessing import LabelEncoder from tensorflow.python.keras import backend as K @@ -292,7 +292,7 @@ This example shows how to use ``DSSM`` to solve a matching task. You can get the ```python import pandas as pd -from deepctr.inputs import SparseFeat, VarLenSparseFeat +from deepctr.feature_column import SparseFeat, VarLenSparseFeat from preprocess import gen_data_set, gen_model_input from sklearn.preprocessing import LabelEncoder from tensorflow.python.keras.models import Model diff --git a/docs/source/Features.md b/docs/source/Features.md index d7acce9..ed3831a 100644 --- a/docs/source/Features.md +++ b/docs/source/Features.md @@ -2,15 +2,17 @@ ## Feature Columns ### SparseFeat -``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, dtype,embedding_name, group_name)`` +``SparseFeat`` is a namedtuple with signature ``SparseFeat(name, vocabulary_size, embedding_dim, use_hash, dtype, embeddings_initializer, embedding_name, group_name, trainable)`` - name : feature name - vocabulary_size : number of unique feature values for sprase feature or hashing space when `use_hash=True` - embedding_dim : embedding dimension - use_hash : defualt `False`.If `True` the input will be hashed to space of size `vocabulary_size`. -- dtype : default `float32`.dtype of input tensor. +- dtype : default `int32`.dtype of input tensor. +- embeddings_initializer : initializer for the `embeddings` matrix. - embedding_name : default `None`. If None, the embedding_name will be same as `name`. - group_name : feature group of this feature. +- trainable: default `True`.Whether or not the embedding is trainable. ### DenseFeat ``DenseFeat`` is a namedtuple with signature ``DenseFeat(name, dimension, dtype)`` @@ -30,6 +32,7 @@ - weight_name : default `None`. If not None, the sequence feature will be multiplyed by the feature whose name is `weight_name`. - weight_norm : default `True`. Whether normalize the weight score or not. + ## Models diff --git a/docs/source/History.md b/docs/source/History.md index 1645b66..d0c9f61 100644 --- a/docs/source/History.md +++ b/docs/source/History.md @@ -1,4 +1,5 @@ # History +- 10/12/2020 : [v0.2.0](https://github.com/shenweichen/DeepMatch/releases/tag/v0.2.0) released.Support different initializers for different embedding weights and loading pretrained embeddings. - 05/17/2020 : [v0.1.3](https://github.com/shenweichen/DeepMatch/releases/tag/v0.1.3) released.Add `SDM` model . - 04/10/2020 : [v0.1.2](https://github.com/shenweichen/DeepMatch/releases/tag/v0.1.2) released.Support [saving and loading model](./FAQ.html#save-or-load-weights-models). - 04/06/2020 : DeepMatch first version is released on [PyPi](https://pypi.org/project/deepmatch/) \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 0a25c61..4e882a8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,7 +26,7 @@ # The short X.Y version version = '' # The full version, including alpha/beta/rc tags -release = '0.1.3' +release = '0.2.0' # -- General configuration --------------------------------------------------- diff --git a/docs/source/index.rst b/docs/source/index.rst index a774fe7..f5b25cc 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -18,18 +18,18 @@ You can read the latest code at https://github.com/shenweichen/DeepMatch News ----- +10/12/2020 : Support different initializers for different embedding weights and loading pretrained embeddings. `Changelog `_ + 05/17/2020 : Add ``SDM`` model. `Changelog `_ 04/10/2020 : Support `saving and loading model <./FAQ.html#save-or-load-weights-models>`_ . `Changelog `_ -04/06/2020 : DeepMatch first version . - DisscussionGroup ----------------------- 公众号:**浅梦的学习笔记** wechat ID: **deepctrbot** -.. image:: ../pics/weichennote.png +.. image:: ../pics/code.png .. toctree:: :maxdepth: 2 diff --git a/examples/colab_MovieLen1M_YoutubeDNN.ipynb b/examples/colab_MovieLen1M_YoutubeDNN.ipynb index 6f78370..a5bd90e 100644 --- a/examples/colab_MovieLen1M_YoutubeDNN.ipynb +++ b/examples/colab_MovieLen1M_YoutubeDNN.ipynb @@ -159,7 +159,7 @@ ], "source": [ "import pandas as pd\n", - "from deepctr.inputs import SparseFeat, VarLenSparseFeat\n", + "from deepctr.feature_column import SparseFeat, VarLenSparseFeat\n", "from preprocess import gen_data_set, gen_model_input\n", "from sklearn.preprocessing import LabelEncoder\n", "from tensorflow.python.keras import backend as K\n", @@ -386,7 +386,7 @@ " tf.compat.v1.disable_eager_execution()\n", "\n", "model = YoutubeDNN(user_feature_columns, item_feature_columns, num_sampled=100, user_dnn_hidden_units=(128,64, embedding_dim))\n", - "# model = MIND(user_feature_columns,item_feature_columns,dynamic_k=False,p=1,k_max=2,num_sampled=100,user_dnn_hidden_units=(128,64, embedding_dim),init_std=0.001)\n", + "# model = MIND(user_feature_columns,item_feature_columns,dynamic_k=False,p=1,k_max=2,num_sampled=100,user_dnn_hidden_units=(128,64, embedding_dim))\n", "\n", "model.compile(optimizer=\"adam\", loss=sampledsoftmaxloss) # \"binary_crossentropy\")\n", "\n", diff --git a/examples/preprocess.py b/examples/preprocess.py index 876bc12..4803749 100644 --- a/examples/preprocess.py +++ b/examples/preprocess.py @@ -104,8 +104,8 @@ def gen_model_input_sdm(train_set, user_profile, seq_short_len, seq_prefer_len): value=0) train_model_input = {"user_id": train_uid, "movie_id": train_iid, "short_movie_id": train_short_item_pad, - "prefer_movie_id": train_prefer_item_pad, "prefer_sess_length": train_short_len, "short_sess_length": - train_prefer_len, 'short_genres': train_short_genres_pad, 'prefer_genres': train_prefer_genres_pad} + "prefer_movie_id": train_prefer_item_pad, "prefer_sess_length": train_prefer_len, "short_sess_length": + train_short_len, 'short_genres': train_short_genres_pad, 'prefer_genres': train_prefer_genres_pad} for key in ["gender", "age", "occupation", "zip"]: train_model_input[key] = user_profile.loc[train_model_input['user_id']][key].values diff --git a/examples/run_dssm_negsampling.py b/examples/run_dssm_negsampling.py index e40a0c1..2ed6d83 100644 --- a/examples/run_dssm_negsampling.py +++ b/examples/run_dssm_negsampling.py @@ -1,5 +1,5 @@ import pandas as pd -from deepctr.inputs import SparseFeat, VarLenSparseFeat +from deepctr.feature_column import SparseFeat, VarLenSparseFeat from preprocess import gen_data_set, gen_model_input from sklearn.preprocessing import LabelEncoder from tensorflow.python.keras.models import Model diff --git a/examples/run_ncf.py b/examples/run_ncf.py new file mode 100644 index 0000000..29d0358 --- /dev/null +++ b/examples/run_ncf.py @@ -0,0 +1,56 @@ +import pandas as pd +from preprocess import gen_data_set, gen_model_input +from sklearn.preprocessing import LabelEncoder + +from deepmatch.models import NCF + +if __name__ == "__main__": + data = pd.read_csvdata = pd.read_csv("./movielens_sample.txt") + sparse_features = ["movie_id", "user_id", + "gender", "age", "occupation", "zip", ] + SEQ_LEN = 50 + negsample = 3 + + # 1.Label Encoding for sparse features,and process sequence features with `gen_date_set` and `gen_model_input` + + features = ['user_id', 'movie_id', 'gender', 'age', 'occupation', 'zip'] + feature_max_idx = {} + for feature in features: + lbe = LabelEncoder() + data[feature] = lbe.fit_transform(data[feature]) + 1 + feature_max_idx[feature] = data[feature].max() + 1 + + user_profile = data[["user_id", "gender", "age", "occupation", "zip"]].drop_duplicates('user_id') + + item_profile = data[["movie_id"]].drop_duplicates('movie_id') + + user_profile.set_index("user_id", inplace=True) + + user_item_list = data.groupby("user_id")['movie_id'].apply(list) + + train_set, test_set = gen_data_set(data, negsample) + + train_model_input, train_label = gen_model_input(train_set, user_profile, SEQ_LEN) + test_model_input, test_label = gen_model_input(test_set, user_profile, SEQ_LEN) + + # 2.count #unique features for each sparse field and generate feature config for sequence feature + + user_feature_columns = {"user_id": feature_max_idx['user_id'], 'gender': feature_max_idx['gender'], + "age": feature_max_idx['age'], + "occupation": feature_max_idx["occupation"], "zip": feature_max_idx["zip"]} + + item_feature_columns = {"movie_id": feature_max_idx['movie_id']} + + # 3.Define Model,train,predict and evaluate + model = NCF(user_feature_columns, item_feature_columns, user_gmf_embedding_dim=20, + item_gmf_embedding_dim=20, user_mlp_embedding_dim=32, item_mlp_embedding_dim=32, + dnn_hidden_units=[128, 64, 32], ) + model.summary() + model.compile("adam", "binary_crossentropy", + metrics=['binary_crossentropy'], ) + + history = model.fit(train_model_input, train_label, + batch_size=64, epochs=20, verbose=2, validation_split=0.2, ) + pred_ans = model.predict(test_model_input, batch_size=64) + # print("test LogLoss", round(log_loss(test_label, pred_ans), 4)) + # print("test AUC", round(roc_auc_score(test_label, pred_ans), 4)) diff --git a/examples/run_sdm.py b/examples/run_sdm.py index e409414..fdf1def 100644 --- a/examples/run_sdm.py +++ b/examples/run_sdm.py @@ -1,5 +1,5 @@ import pandas as pd -from deepctr.inputs import SparseFeat, VarLenSparseFeat +from deepctr.feature_column import SparseFeat, VarLenSparseFeat from preprocess import gen_data_set_sdm, gen_model_input_sdm from sklearn.preprocessing import LabelEncoder from tensorflow.python.keras import backend as K @@ -74,16 +74,13 @@ model = SDM(user_feature_columns, item_feature_columns, history_feature_list=['movie_id', 'genres'], units=embedding_dim, num_sampled=100, ) - optimizer = optimizers.Adam(lr=0.001, clipnorm=5.0) - - model.compile(optimizer=optimizer, loss=sampledsoftmaxloss) # "binary_crossentropy") + model.compile(optimizer='adam', loss=sampledsoftmaxloss) # "binary_crossentropy") history = model.fit(train_model_input, train_label, # train_label, batch_size=512, epochs=1, verbose=1, validation_split=0.0, ) - # model.save_weights('SDM_weights.h5') K.set_learning_phase(False) - # 4. Generate user features for testing and full item features for retrieval + # 3.Define Model,train,predict and evaluate test_user_model_input = test_model_input all_item_model_input = {"movie_id": item_profile['movie_id'].values, } diff --git a/examples/run_youtubednn.py b/examples/run_youtubednn.py index aeb36c0..d98a472 100644 --- a/examples/run_youtubednn.py +++ b/examples/run_youtubednn.py @@ -1,5 +1,5 @@ import pandas as pd -from deepctr.inputs import SparseFeat, VarLenSparseFeat +from deepctr.feature_column import SparseFeat, VarLenSparseFeat from preprocess import gen_data_set, gen_model_input from sklearn.preprocessing import LabelEncoder from tensorflow.python.keras import backend as K @@ -60,7 +60,7 @@ tf.compat.v1.disable_eager_execution() model = YoutubeDNN(user_feature_columns, item_feature_columns, num_sampled=5, user_dnn_hidden_units=(64, embedding_dim)) - # model = MIND(user_feature_columns,item_feature_columns,dynamic_k=False,p=1,k_max=2,num_sampled=5,user_dnn_hidden_units=(64, embedding_dim),init_std=0.001) + #model = MIND(user_feature_columns,item_feature_columns,dynamic_k=False,p=1,k_max=2,num_sampled=5,user_dnn_hidden_units=(64, embedding_dim)) model.compile(optimizer="adam", loss=sampledsoftmaxloss) # "binary_crossentropy") diff --git a/setup.py b/setup.py index fa915ee..8ab2c5c 100644 --- a/setup.py +++ b/setup.py @@ -4,12 +4,12 @@ long_description = fh.read() REQUIRED_PACKAGES = [ - 'h5py','requests',"deepctr==0.7.5" + 'h5py', 'requests', "deepctr==0.8.2" ] setuptools.setup( name="deepmatch", - version="0.1.3", + version="0.2.0", author="Weichen Shen", author_email="wcshen1994@163.com", description="Deep matching model library for recommendations, advertising. It's easy to train models and to **export representation vectors** for user and item which can be used for **ANN search**.", @@ -45,6 +45,6 @@ 'Topic :: Software Development :: Libraries :: Python Modules', ), license="Apache-2.0", - keywords=['match', 'matching','recommendation' - 'deep learning', 'tensorflow', 'tensor', 'keras'], + keywords=['match', 'matching', 'recommendation' + 'deep learning', 'tensorflow', 'tensor', 'keras'], ) diff --git a/tests/models/DSSM_test.py b/tests/models/DSSM_test.py index 90c5e24..c4210e8 100644 --- a/tests/models/DSSM_test.py +++ b/tests/models/DSSM_test.py @@ -2,9 +2,6 @@ from ..utils import check_model, get_xy_fd -# @pytest.mark.xfail(reason="There is a bug when save model use Dice") -# @pytest.mark.skip(reason="misunderstood the API") - def test_DSSM(): model_name = "DSSM" diff --git a/tests/models/FM_test.py b/tests/models/FM_test.py index 0ae1222..baa224e 100644 --- a/tests/models/FM_test.py +++ b/tests/models/FM_test.py @@ -2,10 +2,6 @@ from ..utils import check_model, get_xy_fd -# @pytest.mark.xfail(reason="There is a bug when save model use Dice") -# @pytest.mark.skip(reason="misunderstood the API") - - def test_FM(): model_name = "FM" diff --git a/tests/models/MIND_test.py b/tests/models/MIND_test.py index 9451630..0b9509b 100644 --- a/tests/models/MIND_test.py +++ b/tests/models/MIND_test.py @@ -1,12 +1,9 @@ -from deepmatch.models import MIND -from deepmatch.utils import sampledsoftmaxloss -from tensorflow.python.keras import backend as K -from ..utils import check_model,get_xy_fd import tensorflow as tf +from tensorflow.python.keras import backend as K - -#@pytest.mark.xfail(reason="There is a bug when save model use Dice") -#@pytest.mark.skip(reason="misunderstood the API") +from deepmatch.models import MIND +from deepmatch.utils import sampledsoftmaxloss +from ..utils import check_model, get_xy_fd def test_MIND(): @@ -16,12 +13,12 @@ def test_MIND(): K.set_learning_phase(True) if tf.__version__ >= '2.0.0': - tf.compat.v1.disable_eager_execution() + tf.compat.v1.disable_eager_execution() model = MIND(user_feature_columns, item_feature_columns, num_sampled=2, user_dnn_hidden_units=(16, 4)) model.compile('adam', sampledsoftmaxloss) - check_model(model,model_name,x,y) + check_model(model, model_name, x, y) if __name__ == "__main__": diff --git a/tests/models/NCF_test.py b/tests/models/NCF_test.py new file mode 100644 index 0000000..da609c0 --- /dev/null +++ b/tests/models/NCF_test.py @@ -0,0 +1,16 @@ +from deepmatch.models import NCF +from ..utils import get_xy_fd_ncf + + +def test_NCF(): + model_name = "NCF" + + x, y, user_feature_columns, item_feature_columns = get_xy_fd_ncf(False) + model = NCF(user_feature_columns, item_feature_columns, ) + + model.compile('adam', "binary_crossentropy") + model.fit(x, y, batch_size=10, epochs=2, validation_split=0.5) + + +if __name__ == "__main__": + pass diff --git a/tests/models/SDM_test.py b/tests/models/SDM_test.py index b4caf9b..60e3e5a 100644 --- a/tests/models/SDM_test.py +++ b/tests/models/SDM_test.py @@ -1,12 +1,10 @@ -from deepmatch.models import SDM -from deepmatch.utils import sampledsoftmaxloss -from tensorflow.python.keras import backend as K -from tests.utils import check_model, get_xy_fd_sdm import tensorflow as tf +from tensorflow.python.keras import backend as K +from deepmatch.models import SDM +from deepmatch.utils import sampledsoftmaxloss +from ..utils import check_model, get_xy_fd_sdm -#@pytest.mark.xfail(reason="There is a bug when save model use Dice") -#@pytest.mark.skip(reason="misunderstood the API") def test_SDM(): @@ -16,14 +14,14 @@ def test_SDM(): K.set_learning_phase(True) if tf.__version__ >= '2.0.0': - tf.compat.v1.disable_eager_execution() + tf.compat.v1.disable_eager_execution() model = SDM(user_feature_columns, item_feature_columns, history_feature_list, units=8) - #model.summary() + # model.summary() model.compile('adam', sampledsoftmaxloss) check_model(model, model_name, x, y) if __name__ == "__main__": - pass \ No newline at end of file + pass diff --git a/tests/models/YoutubeDNN_test.py b/tests/models/YoutubeDNN_test.py index ea7560c..46afb03 100644 --- a/tests/models/YoutubeDNN_test.py +++ b/tests/models/YoutubeDNN_test.py @@ -1,13 +1,9 @@ -from deepmatch.models import YoutubeDNN -from deepmatch.utils import sampledsoftmaxloss -from tensorflow.python.keras import backend as K import tensorflow as tf +from tensorflow.python.keras import backend as K -from ..utils import check_model,get_xy_fd - - -#@pytest.mark.xfail(reason="There is a bug when save model use Dice") -#@pytest.mark.skip(reason="misunderstood the API") +from deepmatch.models import YoutubeDNN +from deepmatch.utils import sampledsoftmaxloss +from ..utils import check_model, get_xy_fd def test_YoutubeDNN(): @@ -17,12 +13,12 @@ def test_YoutubeDNN(): K.set_learning_phase(True) if tf.__version__ >= '2.0.0': - tf.compat.v1.disable_eager_execution() + tf.compat.v1.disable_eager_execution() model = YoutubeDNN(user_feature_columns, item_feature_columns, num_sampled=2, user_dnn_hidden_units=(16, 4)) model.compile('adam', sampledsoftmaxloss) - check_model(model,model_name,x,y,check_model_io=True) + check_model(model, model_name, x, y, check_model_io=True) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index 08be5fb..b58e9db 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -6,12 +6,12 @@ import numpy as np import tensorflow as tf +from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat, DEFAULT_GROUP_NAME from numpy.testing import assert_allclose from tensorflow.python.keras import backend as K from tensorflow.python.keras.layers import Input, Masking from tensorflow.python.keras.models import Model, load_model, save_model -from deepctr.inputs import SparseFeat, DenseFeat, VarLenSparseFeat,DEFAULT_GROUP_NAME from deepmatch.layers import custom_objects SAMPLE_SIZE = 8 @@ -44,12 +44,13 @@ def get_test_data(sample_size=1000, embedding_size=4, sparse_feature_num=1, dens for i in range(sparse_feature_num): if use_group: - group_name = str(i%3) + group_name = str(i % 3) else: group_name = DEFAULT_GROUP_NAME dim = np.random.randint(1, 10) feature_columns.append( - SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, use_hash=hash_flag, dtype=tf.int32,group_name=group_name)) + SparseFeat(prefix + 'sparse_feature_' + str(i), dim, embedding_size, use_hash=hash_flag, dtype=tf.int32, + group_name=group_name)) for i in range(dense_feature_num): feature_columns.append(DenseFeat(prefix + 'dense_feature_' + str(i), 1, dtype=tf.float32)) @@ -340,11 +341,8 @@ def check_model(model, model_name, x, y, check_model_io=True): model.fit(x, y, batch_size=10, epochs=2, validation_split=0.5) - - print(model_name + " test train valid pass!") - user_embedding_model = Model(inputs=model.user_input, outputs=model.user_embedding) item_embedding_model = Model(inputs=model.item_input, outputs=model.item_embedding) @@ -366,47 +364,79 @@ def check_model(model, model_name, x, y, check_model_io=True): print(model_name + " test save load model pass!") print(model_name + " test pass!") + # print(1) + # + # save_model(item_embedding_model, model_name + '.user.h5') + # print(2) + # + # item_embedding_model = load_model(model_name + '.user.h5', custom_objects) + # print(3) + # + # item_embs = item_embedding_model.predict(x, batch_size=2 ** 12) + # print(item_embs) + # print("go") def get_xy_fd(hash_flag=False): + user_feature_columns = [SparseFeat('user', 3), SparseFeat( + 'gender', 2), VarLenSparseFeat( + SparseFeat('hist_item', vocabulary_size=3 + 1, embedding_dim=4, embedding_name='item'), maxlen=4, + length_name="hist_len")] + item_feature_columns = [SparseFeat('item', 3 + 1, embedding_dim=4, )] + + uid = np.array([0, 1, 2, 1]) + ugender = np.array([0, 1, 0, 1]) + iid = np.array([1, 2, 3, 1]) # 0 is mask value + + hist_iid = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [3, 0, 0, 0]]) + hist_len = np.array([3, 3, 2, 1]) + feature_dict = {'user': uid, 'gender': ugender, 'item': iid, + 'hist_item': hist_iid, "hist_len": hist_len} + + # feature_names = get_feature_names(feature_columns) + x = feature_dict + y = np.array([1, 1, 1, 1]) + return x, y, user_feature_columns, item_feature_columns - user_feature_columns = [SparseFeat('user',3),SparseFeat( - 'gender', 2),VarLenSparseFeat(SparseFeat('hist_item', vocabulary_size=3 + 1,embedding_dim=4,embedding_name='item'), maxlen=4,length_name="hist_len") ] - item_feature_columns = [SparseFeat('item', 3 + 1,embedding_dim=4,)] +def get_xy_fd_ncf(hash_flag=False): + user_feature_columns = {"user": 3, "gender": 2, } + item_feature_columns = {"item": 4} - uid = np.array([0, 1, 2,1]) - ugender = np.array([0, 1, 0,1]) - iid = np.array([1, 2, 3,1]) # 0 is mask value + uid = np.array([0, 1, 2, 1]) + ugender = np.array([0, 1, 0, 1]) + iid = np.array([1, 2, 3, 1]) # 0 is mask value - hist_iid = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0],[3, 0, 0, 0]]) - hist_len = np.array([3,3,2,1]) + hist_iid = np.array([[1, 2, 3, 0], [1, 2, 3, 0], [1, 2, 0, 0], [3, 0, 0, 0]]) + hist_len = np.array([3, 3, 2, 1]) feature_dict = {'user': uid, 'gender': ugender, 'item': iid, - 'hist_item': hist_iid, "hist_len":hist_len} + 'hist_item': hist_iid, "hist_len": hist_len} - #feature_names = get_feature_names(feature_columns) + # feature_names = get_feature_names(feature_columns) x = feature_dict - y = np.array([1, 1, 1,1]) - return x, y, user_feature_columns,item_feature_columns + y = np.array([1, 1, 1, 1]) + return x, y, user_feature_columns, item_feature_columns def get_xy_fd_sdm(hash_flag=False): - - user_feature_columns = [SparseFeat('user',3), + user_feature_columns = [SparseFeat('user', 3), SparseFeat('gender', 2), - VarLenSparseFeat(SparseFeat('prefer_item', vocabulary_size=100,embedding_dim=8, - embedding_name='item'), maxlen=6, length_name="prefer_sess_length"), + VarLenSparseFeat(SparseFeat('prefer_item', vocabulary_size=100, embedding_dim=8, + embedding_name='item'), maxlen=6, + length_name="prefer_sess_length"), VarLenSparseFeat(SparseFeat('prefer_cate', vocabulary_size=100, embedding_dim=8, - embedding_name='cate'), maxlen=6, length_name="prefer_sess_length"), - VarLenSparseFeat(SparseFeat('short_item', vocabulary_size=100,embedding_dim=8, - embedding_name='item'), maxlen=4, length_name="short_sess_length"), + embedding_name='cate'), maxlen=6, + length_name="prefer_sess_length"), + VarLenSparseFeat(SparseFeat('short_item', vocabulary_size=100, embedding_dim=8, + embedding_name='item'), maxlen=4, + length_name="short_sess_length"), VarLenSparseFeat(SparseFeat('short_cate', vocabulary_size=100, embedding_dim=8, - embedding_name='cate'), maxlen=4, length_name="short_sess_length"), + embedding_name='cate'), maxlen=4, + length_name="short_sess_length"), ] - item_feature_columns = [SparseFeat('item', 100, embedding_dim=8,)] - + item_feature_columns = [SparseFeat('item', 100, embedding_dim=8, )] uid = np.array([0, 1, 2, 1]) ugender = np.array([0, 1, 0, 1]) @@ -419,12 +449,13 @@ def get_xy_fd_sdm(hash_flag=False): prefer_len = np.array([6, 5, 4, 3]) short_len = np.array([3, 3, 2, 1]) - feature_dict = {'user': uid, 'gender': ugender, 'item': iid, 'prefer_item': prefer_iid, "prefer_cate":prefer_cate, - 'short_item': short_iid, 'short_cate': short_cate, 'prefer_sess_length': prefer_len, 'short_sess_length':short_len} + feature_dict = {'user': uid, 'gender': ugender, 'item': iid, 'prefer_item': prefer_iid, "prefer_cate": prefer_cate, + 'short_item': short_iid, 'short_cate': short_cate, 'prefer_sess_length': prefer_len, + 'short_sess_length': short_len} - #feature_names = get_feature_names(feature_columns) + # feature_names = get_feature_names(feature_columns) x = feature_dict y = np.array([1, 1, 1, 0]) history_feature_list = ['item', 'cate'] - return x, y, user_feature_columns, item_feature_columns, history_feature_list \ No newline at end of file + return x, y, user_feature_columns, item_feature_columns, history_feature_list