From 34e0d854a80912a058a3872146cb98597c9cff57 Mon Sep 17 00:00:00 2001 From: Francisco Santos Date: Tue, 2 Nov 2021 18:43:04 +0000 Subject: [PATCH 1/6] tscwgan mvp --- .../preprocessing/timeseries/__init__.py | 2 + .../preprocessing/timeseries/stock.py | 4 +- .../timeseries/stock_univariate.py | 18 ++ .../preprocessing/timeseries/utils.py | 8 +- src/ydata_synthetic/synthesizers/gan.py | 4 +- .../synthesizers/regular/cramergan/model.py | 4 +- .../synthesizers/timeseries/__init__.py | 2 + .../timeseries/tscwgan/__init__.py | 0 .../synthesizers/timeseries/tscwgan/model.py | 260 ++++++++++++++++++ 9 files changed, 292 insertions(+), 10 deletions(-) create mode 100644 src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py create mode 100644 src/ydata_synthetic/synthesizers/timeseries/tscwgan/__init__.py create mode 100644 src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py diff --git a/src/ydata_synthetic/preprocessing/timeseries/__init__.py b/src/ydata_synthetic/preprocessing/timeseries/__init__.py index e8eff6c2..f79d23b4 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/__init__.py +++ b/src/ydata_synthetic/preprocessing/timeseries/__init__.py @@ -1,5 +1,7 @@ from ydata_synthetic.preprocessing.timeseries.stock import transformations as processed_stock +from ydata_synthetic.preprocessing.timeseries.stock_univariate import transformations as processed_stock_univariate __all__ = [ "processed_stock", + "processed_stock_univariate" ] diff --git a/src/ydata_synthetic/preprocessing/timeseries/stock.py b/src/ydata_synthetic/preprocessing/timeseries/stock.py index f10367cc..21b4f5a6 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/stock.py +++ b/src/ydata_synthetic/preprocessing/timeseries/stock.py @@ -13,6 +13,6 @@ def transformations(path, seq_len: int): except: stock_df=stock_df #Data transformations to be applied prior to be used with the synthesizer model - processed_data = real_data_loading(stock_df.values, seq_len=seq_len) + data, processed_data, scaler = real_data_loading(stock_df.values, seq_len=seq_len) - return processed_data + return data, processed_data, scaler diff --git a/src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py b/src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py new file mode 100644 index 00000000..6485b532 --- /dev/null +++ b/src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py @@ -0,0 +1,18 @@ +""" + Get the stock data from Yahoo finance data + Data from the period 01 January 2017 - 24 January 2021 +""" +import pandas as pd + +from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading + +def transformations(path, seq_len: int, col='Open'): + stock_df = pd.DataFrame(pd.read_csv(path)[col]) + try: + stock_df = stock_df.set_index('Date').sort_index() + except: + stock_df=stock_df + #Data transformations to be applied prior to be used with the synthesizer model + data, processed_data, scaler = real_data_loading(stock_df.values, seq_len=seq_len) + + return data, processed_data, scaler diff --git a/src/ydata_synthetic/preprocessing/timeseries/utils.py b/src/ydata_synthetic/preprocessing/timeseries/utils.py index c77c67b2..c8404899 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/utils.py +++ b/src/ydata_synthetic/preprocessing/timeseries/utils.py @@ -4,7 +4,7 @@ import numpy as np from sklearn.preprocessing import MinMaxScaler -# Method implemented here: https://github.com/jsyoon0823/TimeGAN/blob/master/data_loading.py +# Method adapted from here: https://github.com/jsyoon0823/TimeGAN/blob/master/data_loading.py # Originally used in TimeGAN research def real_data_loading(data: np.array, seq_len): """Load and preprocess real-world datasets. @@ -30,7 +30,7 @@ def real_data_loading(data: np.array, seq_len): # Mix the datasets (to make it similar to i.i.d) idx = np.random.permutation(len(temp_data)) - data = [] + processed_data = [] for i in range(len(temp_data)): - data.append(temp_data[idx[i]]) - return data + processed_data.append(temp_data[idx[i]]) + return data, processed_data, scaler diff --git a/src/ydata_synthetic/synthesizers/gan.py b/src/ydata_synthetic/synthesizers/gan.py index 6f7d5684..d1f658cc 100644 --- a/src/ydata_synthetic/synthesizers/gan.py +++ b/src/ydata_synthetic/synthesizers/gan.py @@ -21,10 +21,10 @@ _model_parameters_df = [128, 1e-4, (None, None), 128, 264, None, None, None, 1, None] -_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval', 'labels'] +_train_parameters = ['cache_prefix', 'label_dim', 'epochs', 'sample_interval', 'labels', 'critic_iter'] ModelParameters = namedtuple('ModelParameters', _model_parameters, defaults=_model_parameters_df) -TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None)) +TrainParameters = namedtuple('TrainParameters', _train_parameters, defaults=('', None, 300, 50, None, None)) # pylint: disable=R0902 diff --git a/src/ydata_synthetic/synthesizers/regular/cramergan/model.py b/src/ydata_synthetic/synthesizers/regular/cramergan/model.py index ca953a6f..f6e4d4f1 100644 --- a/src/ydata_synthetic/synthesizers/regular/cramergan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/cramergan/model.py @@ -187,7 +187,7 @@ def save(self, path): super().save(path) -class Generator(tf.keras.Model): +class Generator(Model): def __init__(self, batch_size): """Simple generator with dense feedforward layers.""" self.batch_size = batch_size @@ -202,7 +202,7 @@ def build_model(self, input_shape, dim, data_dim, activation_info: Optional[Name x = GumbelSoftmaxActivation(activation_info)(x) return Model(inputs=input_, outputs=x) -class Critic(tf.keras.Model): +class Critic(Model): def __init__(self, batch_size): """Simple critic with dense feedforward and dropout layers.""" self.batch_size = batch_size diff --git a/src/ydata_synthetic/synthesizers/timeseries/__init__.py b/src/ydata_synthetic/synthesizers/timeseries/__init__.py index a3523536..3984a68b 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/__init__.py +++ b/src/ydata_synthetic/synthesizers/timeseries/__init__.py @@ -1,5 +1,7 @@ from ydata_synthetic.synthesizers.timeseries.timegan.model import TimeGAN +from ydata_synthetic.synthesizers.timeseries.tscwgan.model import TSCWGAN __all__ = [ 'TimeGAN', + 'TSCWGAN', ] diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/__init__.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py new file mode 100644 index 00000000..14351c6c --- /dev/null +++ b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py @@ -0,0 +1,260 @@ +""" +Conditional time-series Wasserstein GAN. +Based on: https://www.naun.org/main/NAUN/neural/2020/a082016-004(2020).pdf +And on: https://github.com/CasperHogenboom/WGAN_financial_time-series +""" +from tqdm import trange +from numpy.random import normal +from pandas import DataFrame + +from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims +from tensorflow import data as tfdata +from tensorflow.keras import Model, Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.layers import Input, Conv1D, Dense, LeakyReLU, Flatten, Add + + +from ydata_synthetic.synthesizers.gan import BaseModel +from ydata_synthetic.synthesizers import TrainParameters +from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty + +class TSCWGAN(BaseModel): + + __MODEL__='TSCWGAN' + + def __init__(self, model_parameters, gradient_penalty_weight=10): + """Create a base TSCWGAN.""" + self.gradient_penalty_weight = gradient_penalty_weight + super().__init__(model_parameters) + + def define_gan(self): + self.generator = Generator(self.batch_size). \ + build_model(input_shape=(self.noise_dim + self.cond_dim, 1), dim=self.layers_dim, data_dim=self.data_dim) + self.critic = Critic(self.batch_size). \ + build_model(input_shape=(self.data_dim + self.cond_dim, 1), dim=self.layers_dim) + + self.g_optimizer = Adam(self.g_lr, beta_1=self.beta_1, beta_2=self.beta_2) + self.c_optimizer = Adam(self.d_lr, beta_1=self.beta_1, beta_2=self.beta_2) + + # The generator takes noise as input and generates records + noise = Input(shape=self.noise_dim, batch_size=self.batch_size) + cond = Input(shape=self.cond_dim, batch_size=self.batch_size) + gen = concat([cond, noise], axis=1) + gen = self.generator(gen) + score = concat([cond, gen], axis=1) + score = self.critic(score) + + def train(self, data, train_arguments: TrainParameters): + real_batches = self.get_batch_data(data) + noise_batches = self.get_batch_noise() + + for epoch in trange(train_arguments.epochs): + for i in range(train_arguments.critic_iter): + real_batch = next(real_batches) + noise_batch = next(noise_batches)[:len(real_batch)] # Truncate the noise tensor in the shape of the real data tensor + + c_loss = self.update_critic(real_batch, noise_batch) + + real_batch = next(real_batches) + noise_batch = next(noise_batches)[:len(real_batch)] + + g_loss = self.update_generator(real_batch, noise_batch) + + print( + "Epoch: {} | critic_loss: {} | gen_loss: {}".format( + epoch, c_loss, g_loss + )) + + self.g_optimizer = self.g_optimizer.get_config() + self.c_optimizer = self.c_optimizer.get_config() + + def update_critic(self, real_batch, noise_batch): + with GradientTape() as c_tape: + fake_batch, cond_batch = self._make_fake_batch(real_batch, noise_batch) + + # Real and fake records with conditions + real_batch_ = concat([cond_batch, real_batch], axis=1) + fake_batch_ = concat([cond_batch, fake_batch], axis=1) + + c_loss = self.c_lossfn(real_batch_, fake_batch_) + + c_gradient = c_tape.gradient(c_loss, self.critic.trainable_variables) + + # Update the weights of the critic using the optimizer + self.c_optimizer.apply_gradients( + zip(c_gradient, self.critic.trainable_variables) + ) + return c_loss + + def update_generator(self, real_batch, noise_batch): + with GradientTape() as g_tape: + fake_batch, cond_batch = self._make_fake_batch(real_batch, noise_batch) + + # Fake records with conditions + fake_batch_ = concat([cond_batch, fake_batch], axis=1) + + g_loss = self.g_lossfn(fake_batch_) + + g_gradient = g_tape.gradient(g_loss, self.generator.trainable_variables) + + # Update the weights of the generator using the optimizer + self.g_optimizer.apply_gradients( + zip(g_gradient, self.generator.trainable_variables) + ) + return g_loss + + def c_lossfn(self, real_batch_, fake_batch_): + score_fake = self.critic(fake_batch_) + score_real = self.critic(real_batch_) + grad_penalty = self.gradient_penalty(real_batch_, fake_batch_) + c_loss = reduce_mean(score_fake) - reduce_mean(score_real) + grad_penalty + return c_loss + + def g_lossfn(self, fake_batch_): + score_fake = self.critic(fake_batch_) + g_loss = - reduce_mean(score_fake) + return g_loss + + def _make_fake_batch(self, real_batch, noise_batch): + """Generate a batch of fake records and return it with the batch of used conditions. + Conditions are the first elements of records in the real batch.""" + cond_batch = real_batch[:, :self.cond_dim] + gen_input = concat([cond_batch, noise_batch], axis=1) + return self.generator(gen_input, training=True), cond_batch + + def gradient_penalty(self, real, fake): + gp = gradient_penalty(self.critic, real, fake, mode=Mode.DRAGAN) + return gp + + def _generate_noise(self): + "Gaussian noise for the generator input." + while True: + yield normal(size=self.noise_dim) + + def get_batch_noise(self): + "Create a batch iterator for the generator gaussian noise input." + return iter(tfdata.Dataset.from_generator(self._generate_noise, output_types=float32) + .batch(self.batch_size) + .repeat()) + + def get_batch_data(self, data, n_windows= None): + if not n_windows: + n_windows = len(data) + data = reshape(convert_to_tensor(data, dtype=float32), shape=(-1, self.data_dim)) + return iter(tfdata.Dataset.from_tensor_slices(data) + .shuffle(buffer_size=n_windows) + .batch(self.batch_size).repeat()) + + def sample(self, cond_array, n_samples): + """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.""" + assert len(cond_array.shape) == 2, "Condition array should have 2 dimensions." + assert cond_array.shape[1] == self.cond_dim, \ + f"Each sequence in the condition array should have a {self.cond_dim} length." + n_conds = cond_array.shape[0] + steps = n_samples // self.batch_size + 1 + data = [] + z_dist = self.get_batch_noise() + for seq in range(n_conds): + cond_seq = expand_dims(convert_to_tensor(cond_array.iloc[seq], float32), axis=0) + cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) + for step in trange(steps, desc=f'Synthetic data generation - Condition {seq+1}/{n_conds}'): + gen_input = concat([cond_seq, next(z_dist)], axis=1) + records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) + data.append(records) + return DataFrame(concat(data, axis=0)) + + +class Generator(Model): + """Conditional generator with skip connections.""" + def __init__(self, batch_size): + self.batch_size = batch_size + + def build_model(self, input_shape, dim, data_dim): + # Define blocks + input_to_latent = Sequential(layers=[ + Conv1D(filters=dim, kernel_size=1, input_shape = input_shape), + LeakyReLU(), + Conv1D(dim, kernel_size=5, dilation_rate=2, padding="same"), + LeakyReLU() + ], name='input_to_latent') + block_cnn = Sequential(layers=[ + Conv1D(filters=dim, kernel_size=3, dilation_rate=2, padding="same"), + LeakyReLU() + ], name='block_cnn') + block_shift = Sequential(layers=[ + Conv1D(filters=10, kernel_size=3, dilation_rate=2, padding="same"), + LeakyReLU(), + Flatten(), + Dense(dim*2), + LeakyReLU() + ], name='block_shift') + block = Sequential(layers=[ + Dense(dim*2), + LeakyReLU() + ], name='block') + latent_to_output = Sequential([ + Dense(data_dim) + ], name='latent_to_ouput') + + # Define input - Expected input shape is (batch_size, seq_len, noise_dim). noise_dim = Z + cond + noise_input = Input(shape = input_shape, batch_size = self.batch_size) + + # Compose model + x = input_to_latent(noise_input) + x_block = block_cnn(x) + x = Add()([x_block, x]) + x_block = block_cnn(x) + x = Add()([x_block, x]) + x_block = block_cnn(x) + x = Add()([x_block, x]) + x = block_shift(x) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x = latent_to_output(x) + # Output - Expected shape is (batch_size, seq_len, data_dim). data_dim does not include conditions + return Model(inputs=noise_input, outputs=x, name='SkipConnectionGenerator') + +class Critic(Model): + """Conditional Wasserstein Critic with skip connections.""" + def __init__(self, batch_size): + self.batch_size = batch_size + + def build_model(self, input_shape, dim): + # Define blocks + ts_to_latent = Sequential(layers=[ + Dense(dim*2,), + LeakyReLU() + ], name='ts_to_latent') + block = Sequential(layers=[ + Dense(dim*2), + LeakyReLU() + ], name='block') + latent_to_score = Sequential(layers=[ + Dense(1) + ], name='latent_to_score') + + # Define input - Expected input shape is X + condition + record_input = Input(shape = input_shape, batch_size = self.batch_size) + + # Compose model + x = ts_to_latent(record_input) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x_block = block(x) + x = Add()([x_block, x]) + x = latent_to_score(x) + return Model(inputs=record_input, outputs=x, name='SkipConnectionCritic') From 84545595edc7d3e589d7077db038c53b0408ce91 Mon Sep 17 00:00:00 2001 From: Francisco Santos Date: Thu, 4 Nov 2021 12:42:19 +0000 Subject: [PATCH 2/6] pr review --- .gitignore | 4 +- .../regular/inverse_preprocesser.py | 25 ++--- .../timeseries/inverse_preprocesser.py | 17 +++ .../preprocessing/timeseries/__init__.py | 2 - .../preprocessing/timeseries/stock.py | 17 ++- .../timeseries/stock_univariate.py | 18 ---- .../synthesizers/timeseries/tscwgan/model.py | 102 ++++++++---------- 7 files changed, 93 insertions(+), 92 deletions(-) create mode 100644 src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py delete mode 100644 src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py diff --git a/.gitignore b/.gitignore index 166fdb12..86a2da71 100644 --- a/.gitignore +++ b/.gitignore @@ -373,4 +373,6 @@ DerivedData/ # User created VERSION -version.py \ No newline at end of file +version.py +local_test_*.py +local_test_*.ipynb diff --git a/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py b/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py index 9b9a0b50..b99f4bc3 100644 --- a/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py +++ b/src/ydata_synthetic/postprocessing/regular/inverse_preprocesser.py @@ -1,45 +1,46 @@ # Inverts all preprocessing pipelines provided in the preprocessing examples from typing import Union -import pandas as pd +from pandas import DataFrame, concat from sklearn.pipeline import Pipeline from sklearn.compose import ColumnTransformer -from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler +from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler -def inverse_transform(data: pd.DataFrame, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, OneHotEncoder, StandardScaler]) -> pd.DataFrame: +def inverse_transform(data: DataFrame, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, + OneHotEncoder, StandardScaler, MinMaxScaler]) -> DataFrame: """Inverts data transformations taking place in a standard sklearn processor. Supported processes are sklearn pipelines, column transformers or base estimators like standard scalers. Args: - data (pd.DataFrame): The data object that needs inversion of preprocessing + data (DataFrame): The data object that needs inversion of preprocessing processor (Union[Pipeline, ColumnTransformer, BaseEstimator]): The processor applied on the original data Returns: - inv_data (pd.DataFrame): The data object after inverting preprocessing""" + inv_data (DataFrame): The data object after inverting preprocessing""" inv_data = data.copy() - if isinstance(processor, (PowerTransformer, OneHotEncoder, StandardScaler, Pipeline)): - inv_data = pd.DataFrame(processor.inverse_transform(data), columns=processor.feature_names_in_) + if isinstance(processor, (PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler, Pipeline)): + inv_data = DataFrame(processor.inverse_transform(data), columns=processor.feature_names_in_ if hasattr(processor, "feature_names_in") else None) elif isinstance(processor, ColumnTransformer): output_indices = processor.output_indices_ - assert isinstance(data, pd.DataFrame), "The data to be inverted from a ColumnTransformer has to be a Pandas DataFrame." + assert isinstance(data, DataFrame), "The data to be inverted from a ColumnTransformer has to be a Pandas DataFrame." for t_name, t, t_cols in processor.transformers_[::-1]: slice_ = output_indices[t_name] t_indices = list(range(slice_.start, slice_.stop, 1 if slice_.step is None else slice_.step)) if t == 'drop': continue elif t == 'passthrough': - inv_cols = pd.DataFrame(data.iloc[:,t_indices].values, columns = t_cols, index = data.index) + inv_cols = DataFrame(data.iloc[:,t_indices].values, columns = t_cols, index = data.index) inv_col_names = inv_cols.columns else: - inv_cols = pd.DataFrame(t.inverse_transform(data.iloc[:,t_indices].values), columns = t_cols, index = data.index) + inv_cols = DataFrame(t.inverse_transform(data.iloc[:,t_indices].values), columns = t_cols, index = data.index) inv_col_names = inv_cols.columns if set(inv_col_names).issubset(set(inv_data.columns)): inv_data[inv_col_names] = inv_cols[inv_col_names] else: - inv_data = pd.concat([inv_data, inv_cols], axis=1) + inv_data = concat([inv_data, inv_cols], axis=1) else: print('The provided data processor is not supported and cannot be inverted with this method.') return None - return inv_data[processor.feature_names_in_] + return inv_data[processor.feature_names_in_] if hasattr(processor, "feature_names_in") else inv_data diff --git a/src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py b/src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py new file mode 100644 index 00000000..36be1dc1 --- /dev/null +++ b/src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py @@ -0,0 +1,17 @@ +from typing import Union, List + +from ydata_synthetic.postprocessing.regular import inverse_preprocesser + +from sklearn.pipeline import Pipeline +from sklearn.compose import ColumnTransformer +from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler + +from pandas import DataFrame + +def inverse_transform(data: List, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, OneHotEncoder, + StandardScaler, MinMaxScaler]): + if isinstance(data, list): + data = DataFrame(data) + return inverse_preprocesser.inverse_transform(data, processor).tolist() + else: + return inverse_preprocesser.inverse_transform(data, processor) diff --git a/src/ydata_synthetic/preprocessing/timeseries/__init__.py b/src/ydata_synthetic/preprocessing/timeseries/__init__.py index f79d23b4..e8eff6c2 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/__init__.py +++ b/src/ydata_synthetic/preprocessing/timeseries/__init__.py @@ -1,7 +1,5 @@ from ydata_synthetic.preprocessing.timeseries.stock import transformations as processed_stock -from ydata_synthetic.preprocessing.timeseries.stock_univariate import transformations as processed_stock_univariate __all__ = [ "processed_stock", - "processed_stock_univariate" ] diff --git a/src/ydata_synthetic/preprocessing/timeseries/stock.py b/src/ydata_synthetic/preprocessing/timeseries/stock.py index 21b4f5a6..a6aba6f3 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/stock.py +++ b/src/ydata_synthetic/preprocessing/timeseries/stock.py @@ -2,12 +2,25 @@ Get the stock data from Yahoo finance data Data from the period 01 January 2017 - 24 January 2021 """ +from typing import Union, List + import pandas as pd from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading -def transformations(path, seq_len: int): - stock_df = pd.read_csv(path) +def transformations(path, seq_len: int, cols: Union[str, List] = None): + """Apply min max scaling and roll windows of a temporal dataset. + + Args: + path(str): path to a csv temporal dataframe + seq_len(int): length of the rolled sequences + cols (Union[str, List]): Column or list of columns to be used""" + if isinstance(cols, str): + cols = [cols] + if isinstance(cols, list): + stock_df = pd.read_csv(path)[cols] + else: + stock_df = pd.read_csv(path) try: stock_df = stock_df.set_index('Date').sort_index() except: diff --git a/src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py b/src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py deleted file mode 100644 index 6485b532..00000000 --- a/src/ydata_synthetic/preprocessing/timeseries/stock_univariate.py +++ /dev/null @@ -1,18 +0,0 @@ -""" - Get the stock data from Yahoo finance data - Data from the period 01 January 2017 - 24 January 2021 -""" -import pandas as pd - -from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading - -def transformations(path, seq_len: int, col='Open'): - stock_df = pd.DataFrame(pd.read_csv(path)[col]) - try: - stock_df = stock_df.set_index('Date').sort_index() - except: - stock_df=stock_df - #Data transformations to be applied prior to be used with the synthesizer model - data, processed_data, scaler = real_data_loading(stock_df.values, seq_len=seq_len) - - return data, processed_data, scaler diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py index 14351c6c..4390862b 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py +++ b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py @@ -25,6 +25,7 @@ class TSCWGAN(BaseModel): def __init__(self, model_parameters, gradient_penalty_weight=10): """Create a base TSCWGAN.""" self.gradient_penalty_weight = gradient_penalty_weight + self.cond_dim = model_parameters.condition super().__init__(model_parameters) def define_gan(self): @@ -170,53 +171,51 @@ def __init__(self, batch_size): self.batch_size = batch_size def build_model(self, input_shape, dim, data_dim): - # Define blocks - input_to_latent = Sequential(layers=[ + # Define input - Expected input shape is (batch_size, seq_len, noise_dim). noise_dim = Z + cond + noise_input = Input(shape = input_shape, batch_size = self.batch_size) + + # Compose model + proc_input = Sequential(layers=[ Conv1D(filters=dim, kernel_size=1, input_shape = input_shape), LeakyReLU(), Conv1D(dim, kernel_size=5, dilation_rate=2, padding="same"), LeakyReLU() - ], name='input_to_latent') + ], name='input_to_latent')(noise_input) + block_cnn = Sequential(layers=[ Conv1D(filters=dim, kernel_size=3, dilation_rate=2, padding="same"), LeakyReLU() ], name='block_cnn') - block_shift = Sequential(layers=[ + for i in range(3): + if i == 0: + cnn_block_i = proc_input + cnn_block_o = block_cnn(proc_input) + else: + cnn_block_o = block_cnn(cnn_block_i) + cnn_block_i = Add()([cnn_block_i, cnn_block_o]) + + shift = Sequential(layers=[ Conv1D(filters=10, kernel_size=3, dilation_rate=2, padding="same"), LeakyReLU(), Flatten(), Dense(dim*2), LeakyReLU() - ], name='block_shift') + ], name='block_shift')(cnn_block_i) + block = Sequential(layers=[ Dense(dim*2), LeakyReLU() ], name='block') - latent_to_output = Sequential([ - Dense(data_dim) - ], name='latent_to_ouput') + for i in range(3): + if i == 0: + block_i = shift + block_o = block(shift) + else: + block_o = block(block_i) + block_i = Add()([block_i, block_o]) - # Define input - Expected input shape is (batch_size, seq_len, noise_dim). noise_dim = Z + cond - noise_input = Input(shape = input_shape, batch_size = self.batch_size) - - # Compose model - x = input_to_latent(noise_input) - x_block = block_cnn(x) - x = Add()([x_block, x]) - x_block = block_cnn(x) - x = Add()([x_block, x]) - x_block = block_cnn(x) - x = Add()([x_block, x]) - x = block_shift(x) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x = latent_to_output(x) - # Output - Expected shape is (batch_size, seq_len, data_dim). data_dim does not include conditions - return Model(inputs=noise_input, outputs=x, name='SkipConnectionGenerator') + output = Dense(data_dim, name='latent_to_ouput')(block_i) + return Model(inputs = noise_input, outputs = output, name='SkipConnectionGenerator') class Critic(Model): """Conditional Wasserstein Critic with skip connections.""" @@ -224,37 +223,26 @@ def __init__(self, batch_size): self.batch_size = batch_size def build_model(self, input_shape, dim): - # Define blocks - ts_to_latent = Sequential(layers=[ + # Define input - Expected input shape is X + condition + record_input = Input(shape = input_shape, batch_size = self.batch_size) + + # Compose model + proc_record = Sequential(layers=[ Dense(dim*2,), LeakyReLU() - ], name='ts_to_latent') + ], name='ts_to_latent')(record_input) + block = Sequential(layers=[ Dense(dim*2), LeakyReLU() ], name='block') - latent_to_score = Sequential(layers=[ - Dense(1) - ], name='latent_to_score') - - # Define input - Expected input shape is X + condition - record_input = Input(shape = input_shape, batch_size = self.batch_size) - - # Compose model - x = ts_to_latent(record_input) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x_block = block(x) - x = Add()([x_block, x]) - x = latent_to_score(x) - return Model(inputs=record_input, outputs=x, name='SkipConnectionCritic') + for i in range(7): + if i == 0: + block_i = proc_record + block_o = block(proc_record) + else: + block_o = block(block_i) + block_i = Add()([block_i, block_o]) + + output = Dense(1, name = 'latent_to_score')(block_i) + return Model(inputs=record_input, outputs=output, name='SkipConnectionCritic') From bf8656fce675fe72fa1bbe3783cadacf14bb15b7 Mon Sep 17 00:00:00 2001 From: Francisco Santos Date: Thu, 4 Nov 2021 12:55:02 +0000 Subject: [PATCH 3/6] add example, remove added attribute of basemodel remove changes on gitignore removed unused n_feats argument --- .gitignore | 2 - examples/timeseries/tscwgan_example.py | 60 ++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) create mode 100644 examples/timeseries/tscwgan_example.py diff --git a/.gitignore b/.gitignore index 86a2da71..c6a2d897 100644 --- a/.gitignore +++ b/.gitignore @@ -374,5 +374,3 @@ DerivedData/ # User created VERSION version.py -local_test_*.py -local_test_*.ipynb diff --git a/examples/timeseries/tscwgan_example.py b/examples/timeseries/tscwgan_example.py new file mode 100644 index 00000000..b686ab82 --- /dev/null +++ b/examples/timeseries/tscwgan_example.py @@ -0,0 +1,60 @@ +from pandas import DataFrame +from numpy import squeeze + +from ydata_synthetic.postprocessing.timeseries.inverse_preprocesser import inverse_transform +from ydata_synthetic.preprocessing.timeseries import processed_stock +from ydata_synthetic.synthesizers.timeseries import TSCWGAN +from ydata_synthetic.synthesizers import ModelParameters, TrainParameters + +model = TSCWGAN + +#Define the GAN and training parameters +noise_dim = 32 +dim = 128 +seq_len = 48 +cond_dim = 24 +batch_size = 128 + +log_step = 100 +epochs = 300+1 +learning_rate = 5e-4 +beta_1 = 0.5 +beta_2 = 0.9 +models_dir = './cache' +critic_iter = 5 + +# Get transformed data stock - Univariate +data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = 'Open') +data_sample = processed_data[0] + +model_parameters = ModelParameters(batch_size=batch_size, + lr=learning_rate, + betas=(beta_1, beta_2), + noise_dim=noise_dim, + n_cols=seq_len, + layers_dim=dim, + condition = cond_dim) + +train_args = TrainParameters(epochs=epochs, + sample_interval=log_step, + critic_iter=critic_iter) + +#Training the TSCWGAN model +synthesizer = model(model_parameters, gradient_penalty_weight=10) +synthesizer.train(processed_data, train_args) + +#Saving the synthesizer to later generate new events +synthesizer.save(path='./tscwgan_stock.pkl') + +#Loading the synthesizer +synth = model.load(path='./tscwgan_stock.pkl') + +#Sampling the data +#Note that the data returned is not inverse processed. +step = int(len(processed_data)/(5-1)) +cond_array = DataFrame(data=[squeeze(processed_data[i][:cond_dim], axis=1) for i in range(0, len(processed_data), step)]) + +data_sample = synth.sample(cond_array, 200) + +# Inverting the scaling of the synthetic samples +data_sample = inverse_transform(data_sample, scaler) From 2c2b7200410cb71c0d8ae9b9db8ab7393b8fac19 Mon Sep 17 00:00:00 2001 From: Francisco Santos Date: Fri, 5 Nov 2021 17:19:21 +0000 Subject: [PATCH 4/6] apply revisions + add typeguard apply revisions --- examples/timeseries/tscwgan_example.py | 11 ++++--- .../timeseries/inverse_preprocesser.py | 17 ----------- .../preprocessing/timeseries/stock.py | 8 ++--- .../synthesizers/regular/cramergan/model.py | 4 +-- .../synthesizers/timeseries/tscwgan/model.py | 30 ++++++++----------- 5 files changed, 23 insertions(+), 47 deletions(-) delete mode 100644 src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py diff --git a/examples/timeseries/tscwgan_example.py b/examples/timeseries/tscwgan_example.py index b686ab82..68a54087 100644 --- a/examples/timeseries/tscwgan_example.py +++ b/examples/timeseries/tscwgan_example.py @@ -1,10 +1,9 @@ -from pandas import DataFrame from numpy import squeeze -from ydata_synthetic.postprocessing.timeseries.inverse_preprocesser import inverse_transform from ydata_synthetic.preprocessing.timeseries import processed_stock from ydata_synthetic.synthesizers.timeseries import TSCWGAN from ydata_synthetic.synthesizers import ModelParameters, TrainParameters +from ydata_synthetic.postprocessing.regular.inverse_preprocesser import inverse_transform model = TSCWGAN @@ -24,7 +23,7 @@ critic_iter = 5 # Get transformed data stock - Univariate -data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = 'Open') +data, processed_data, scaler = processed_stock(path='./data/stock_data.csv', seq_len=seq_len, cols = ['Open']) data_sample = processed_data[0] model_parameters = ModelParameters(batch_size=batch_size, @@ -51,10 +50,10 @@ #Sampling the data #Note that the data returned is not inverse processed. -step = int(len(processed_data)/(5-1)) -cond_array = DataFrame(data=[squeeze(processed_data[i][:cond_dim], axis=1) for i in range(0, len(processed_data), step)]) +cond_index = 100 # Arbitrary sequence for conditioning +cond_array = squeeze(processed_data[cond_index][:cond_dim], axis=1) -data_sample = synth.sample(cond_array, 200) +data_sample = synth.sample(cond_array, 1000) # Inverting the scaling of the synthetic samples data_sample = inverse_transform(data_sample, scaler) diff --git a/src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py b/src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py deleted file mode 100644 index 36be1dc1..00000000 --- a/src/ydata_synthetic/postprocessing/timeseries/inverse_preprocesser.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Union, List - -from ydata_synthetic.postprocessing.regular import inverse_preprocesser - -from sklearn.pipeline import Pipeline -from sklearn.compose import ColumnTransformer -from sklearn.preprocessing import PowerTransformer, OneHotEncoder, StandardScaler, MinMaxScaler - -from pandas import DataFrame - -def inverse_transform(data: List, processor: Union[Pipeline, ColumnTransformer, PowerTransformer, OneHotEncoder, - StandardScaler, MinMaxScaler]): - if isinstance(data, list): - data = DataFrame(data) - return inverse_preprocesser.inverse_transform(data, processor).tolist() - else: - return inverse_preprocesser.inverse_transform(data, processor) diff --git a/src/ydata_synthetic/preprocessing/timeseries/stock.py b/src/ydata_synthetic/preprocessing/timeseries/stock.py index a6aba6f3..26866adb 100644 --- a/src/ydata_synthetic/preprocessing/timeseries/stock.py +++ b/src/ydata_synthetic/preprocessing/timeseries/stock.py @@ -2,21 +2,21 @@ Get the stock data from Yahoo finance data Data from the period 01 January 2017 - 24 January 2021 """ -from typing import Union, List +from typing import Optional, List import pandas as pd +from typeguard import typechecked from ydata_synthetic.preprocessing.timeseries.utils import real_data_loading -def transformations(path, seq_len: int, cols: Union[str, List] = None): +@typechecked +def transformations(path, seq_len: int, cols: Optional[List] = None): """Apply min max scaling and roll windows of a temporal dataset. Args: path(str): path to a csv temporal dataframe seq_len(int): length of the rolled sequences cols (Union[str, List]): Column or list of columns to be used""" - if isinstance(cols, str): - cols = [cols] if isinstance(cols, list): stock_df = pd.read_csv(path)[cols] else: diff --git a/src/ydata_synthetic/synthesizers/regular/cramergan/model.py b/src/ydata_synthetic/synthesizers/regular/cramergan/model.py index f6e4d4f1..ca953a6f 100644 --- a/src/ydata_synthetic/synthesizers/regular/cramergan/model.py +++ b/src/ydata_synthetic/synthesizers/regular/cramergan/model.py @@ -187,7 +187,7 @@ def save(self, path): super().save(path) -class Generator(Model): +class Generator(tf.keras.Model): def __init__(self, batch_size): """Simple generator with dense feedforward layers.""" self.batch_size = batch_size @@ -202,7 +202,7 @@ def build_model(self, input_shape, dim, data_dim, activation_info: Optional[Name x = GumbelSoftmaxActivation(activation_info)(x) return Model(inputs=input_, outputs=x) -class Critic(Model): +class Critic(tf.keras.Model): def __init__(self, batch_size): """Simple critic with dense feedforward and dropout layers.""" self.batch_size = batch_size diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py index 4390862b..4b571965 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py +++ b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py @@ -4,8 +4,8 @@ And on: https://github.com/CasperHogenboom/WGAN_financial_time-series """ from tqdm import trange +from numpy import array, vstack from numpy.random import normal -from pandas import DataFrame from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims from tensorflow import data as tfdata @@ -13,7 +13,6 @@ from tensorflow.keras.optimizers import Adam from tensorflow.keras.layers import Input, Conv1D, Dense, LeakyReLU, Flatten, Add - from ydata_synthetic.synthesizers.gan import BaseModel from ydata_synthetic.synthesizers import TrainParameters from ydata_synthetic.synthesizers.loss import Mode, gradient_penalty @@ -61,10 +60,7 @@ def train(self, data, train_arguments: TrainParameters): g_loss = self.update_generator(real_batch, noise_batch) - print( - "Epoch: {} | critic_loss: {} | gen_loss: {}".format( - epoch, c_loss, g_loss - )) + print(f"Epoch: {epoch} | critic_loss: {c_loss} | gen_loss: {g_loss}") self.g_optimizer = self.g_optimizer.get_config() self.c_optimizer = self.c_optimizer.get_config() @@ -148,21 +144,19 @@ def get_batch_data(self, data, n_windows= None): def sample(self, cond_array, n_samples): """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.""" - assert len(cond_array.shape) == 2, "Condition array should have 2 dimensions." - assert cond_array.shape[1] == self.cond_dim, \ - f"Each sequence in the condition array should have a {self.cond_dim} length." - n_conds = cond_array.shape[0] + assert len(cond_array.shape) == 1, "Condition array should be one-dimensional." + assert cond_array.shape[0] == self.cond_dim, \ + f"The condition sequence should have a {self.cond_dim} length." steps = n_samples // self.batch_size + 1 data = [] z_dist = self.get_batch_noise() - for seq in range(n_conds): - cond_seq = expand_dims(convert_to_tensor(cond_array.iloc[seq], float32), axis=0) - cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) - for step in trange(steps, desc=f'Synthetic data generation - Condition {seq+1}/{n_conds}'): - gen_input = concat([cond_seq, next(z_dist)], axis=1) - records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) - data.append(records) - return DataFrame(concat(data, axis=0)) + cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0) + cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) + for step in trange(steps, desc=f'Synthetic data generation'): + gen_input = concat([cond_seq, next(z_dist)], axis=1) + records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) + data.append(records) + return array(vstack(data)) class Generator(Model): From a3c9b3443599cd4ce11cd6fb5426460cacae03d6 Mon Sep 17 00:00:00 2001 From: Francisco Santos Date: Tue, 30 Nov 2021 12:56:43 +0000 Subject: [PATCH 5/6] integrate TSDataProcessor, revise sample method Auto regressive timeseries sampling method revert TS data processor integration --- examples/timeseries/tscwgan_example.py | 6 +-- .../synthesizers/timeseries/tscwgan/model.py | 41 ++++++++++++------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/examples/timeseries/tscwgan_example.py b/examples/timeseries/tscwgan_example.py index 68a54087..8863e750 100644 --- a/examples/timeseries/tscwgan_example.py +++ b/examples/timeseries/tscwgan_example.py @@ -1,4 +1,4 @@ -from numpy import squeeze +from numpy import reshape from ydata_synthetic.preprocessing.timeseries import processed_stock from ydata_synthetic.synthesizers.timeseries import TSCWGAN @@ -51,9 +51,9 @@ #Sampling the data #Note that the data returned is not inverse processed. cond_index = 100 # Arbitrary sequence for conditioning -cond_array = squeeze(processed_data[cond_index][:cond_dim], axis=1) +cond_array = reshape(processed_data[cond_index][:cond_dim], (1,-1)) -data_sample = synth.sample(cond_array, 1000) +data_sample = synth.sample(cond_array, 1000, 100) # Inverting the scaling of the synthetic samples data_sample = inverse_transform(data_sample, scaler) diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py index 4b571965..0a9ac319 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py +++ b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py @@ -4,10 +4,10 @@ And on: https://github.com/CasperHogenboom/WGAN_financial_time-series """ from tqdm import trange -from numpy import array, vstack +from numpy import array, vstack, hstack from numpy.random import normal -from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims +from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile from tensorflow import data as tfdata from tensorflow.keras import Model, Sequential from tensorflow.keras.optimizers import Adam @@ -142,20 +142,33 @@ def get_batch_data(self, data, n_windows= None): .shuffle(buffer_size=n_windows) .batch(self.batch_size).repeat()) - def sample(self, cond_array, n_samples): - """Provided that cond_array is passed, produce n_samples for each condition vector in cond_array.""" - assert len(cond_array.shape) == 1, "Condition array should be one-dimensional." - assert cond_array.shape[0] == self.cond_dim, \ - f"The condition sequence should have a {self.cond_dim} length." - steps = n_samples // self.batch_size + 1 + def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24): + """For a given condition, produce n_samples of length seq_len. + + Args: + condition (numpy.array): Condition for the generated samples, must have the same length. + n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size). + seq_len (int): Length of the generated samples. + + Returns: + data (numpy.array): An array of data of shape [n_samples, seq_len]""" + assert len(condition.shape) == 2, "Condition array should be two-dimensional." + assert condition.shape[1] == self.cond_dim, \ + f"The condition sequence should have {self.cond_dim} length." + batches = n_samples // self.batch_size + 1 + ar_steps = seq_len // self.data_dim + 1 data = [] z_dist = self.get_batch_noise() - cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0) - cond_seq = tile(cond_seq, multiples=[self.batch_size, 1]) - for step in trange(steps, desc=f'Synthetic data generation'): - gen_input = concat([cond_seq, next(z_dist)], axis=1) - records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False))) - data.append(records) + for batch in trange(batches, desc=f'Synthetic data generation'): + data_ = [] + cond_seq = convert_to_tensor(condition, float32) + gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1) + for step in range(ar_steps): + records = self.generator(gen_input, training=False) + gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1) + data_.append(records) + data_ = hstack(data_)[:, :seq_len] + data.append(data_) return array(vstack(data)) From 10c101ae7dfd67ff1fec7d7a2695274e5580642e Mon Sep 17 00:00:00 2001 From: Francisco Santos Date: Tue, 18 Jan 2022 17:30:58 +0000 Subject: [PATCH 6/6] flow fixes --- examples/timeseries/tscwgan_example.py | 2 +- .../synthesizers/timeseries/tscwgan/model.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/timeseries/tscwgan_example.py b/examples/timeseries/tscwgan_example.py index 8863e750..cb1bf1a7 100644 --- a/examples/timeseries/tscwgan_example.py +++ b/examples/timeseries/tscwgan_example.py @@ -56,4 +56,4 @@ data_sample = synth.sample(cond_array, 1000, 100) # Inverting the scaling of the synthetic samples -data_sample = inverse_transform(data_sample, scaler) +inv_data_sample = inverse_transform(data_sample, scaler) diff --git a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py index 0a9ac319..ee90c927 100644 --- a/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py +++ b/src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py @@ -7,7 +7,7 @@ from numpy import array, vstack, hstack from numpy.random import normal -from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile +from tensorflow import concat, float32, convert_to_tensor, GradientTape, reduce_mean, tile, squeeze from tensorflow import data as tfdata from tensorflow.keras import Model, Sequential from tensorflow.keras.optimizers import Adam @@ -26,6 +26,7 @@ def __init__(self, model_parameters, gradient_penalty_weight=10): self.gradient_penalty_weight = gradient_penalty_weight self.cond_dim = model_parameters.condition super().__init__(model_parameters) + self.data_dim = model_parameters.n_cols def define_gan(self): self.generator = Generator(self.batch_size). \ @@ -45,6 +46,7 @@ def define_gan(self): score = self.critic(score) def train(self, data, train_arguments: TrainParameters): + self.define_gan() real_batches = self.get_batch_data(data) noise_batches = self.get_batch_noise() @@ -137,7 +139,7 @@ def get_batch_noise(self): def get_batch_data(self, data, n_windows= None): if not n_windows: n_windows = len(data) - data = reshape(convert_to_tensor(data, dtype=float32), shape=(-1, self.data_dim)) + data = squeeze(convert_to_tensor(data, dtype=float32)) return iter(tfdata.Dataset.from_tensor_slices(data) .shuffle(buffer_size=n_windows) .batch(self.batch_size).repeat())