Skip to content

Commit

Permalink
integrate TSDataProcessor, revise sample method
Browse files Browse the repository at this point in the history
Auto regressive timeseries sampling method

revert TS data processor integration
  • Loading branch information
Francisco Santos committed Dec 15, 2021
1 parent 2c2b720 commit a3c9b34
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
6 changes: 3 additions & 3 deletions examples/timeseries/tscwgan_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from numpy import squeeze
from numpy import reshape

from ydata_synthetic.preprocessing.timeseries import processed_stock
from ydata_synthetic.synthesizers.timeseries import TSCWGAN
Expand Down Expand Up @@ -51,9 +51,9 @@
#Sampling the data
#Note that the data returned is not inverse processed.
cond_index = 100 # Arbitrary sequence for conditioning
cond_array = squeeze(processed_data[cond_index][:cond_dim], axis=1)
cond_array = reshape(processed_data[cond_index][:cond_dim], (1,-1))

data_sample = synth.sample(cond_array, 1000)
data_sample = synth.sample(cond_array, 1000, 100)

# Inverting the scaling of the synthetic samples
data_sample = inverse_transform(data_sample, scaler)
41 changes: 27 additions & 14 deletions src/ydata_synthetic/synthesizers/timeseries/tscwgan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
And on: https://github.com/CasperHogenboom/WGAN_financial_time-series
"""
from tqdm import trange
from numpy import array, vstack
from numpy import array, vstack, hstack
from numpy.random import normal

from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, make_ndarray, make_tensor_proto, tile, expand_dims
from tensorflow import concat, float32, convert_to_tensor, reshape, GradientTape, reduce_mean, tile
from tensorflow import data as tfdata
from tensorflow.keras import Model, Sequential
from tensorflow.keras.optimizers import Adam
Expand Down Expand Up @@ -142,20 +142,33 @@ def get_batch_data(self, data, n_windows= None):
.shuffle(buffer_size=n_windows)
.batch(self.batch_size).repeat())

def sample(self, cond_array, n_samples):
"""Provided that cond_array is passed, produce n_samples for each condition vector in cond_array."""
assert len(cond_array.shape) == 1, "Condition array should be one-dimensional."
assert cond_array.shape[0] == self.cond_dim, \
f"The condition sequence should have a {self.cond_dim} length."
steps = n_samples // self.batch_size + 1
def sample(self, condition: array, n_samples: int = 100, seq_len: int = 24):
"""For a given condition, produce n_samples of length seq_len.
Args:
condition (numpy.array): Condition for the generated samples, must have the same length.
n_samples (int): Minimum number of generated samples (returns always a multiple of batch_size).
seq_len (int): Length of the generated samples.
Returns:
data (numpy.array): An array of data of shape [n_samples, seq_len]"""
assert len(condition.shape) == 2, "Condition array should be two-dimensional."
assert condition.shape[1] == self.cond_dim, \
f"The condition sequence should have {self.cond_dim} length."
batches = n_samples // self.batch_size + 1
ar_steps = seq_len // self.data_dim + 1
data = []
z_dist = self.get_batch_noise()
cond_seq = expand_dims(convert_to_tensor(cond_array, float32), axis=0)
cond_seq = tile(cond_seq, multiples=[self.batch_size, 1])
for step in trange(steps, desc=f'Synthetic data generation'):
gen_input = concat([cond_seq, next(z_dist)], axis=1)
records = make_ndarray(make_tensor_proto(self.generator(gen_input, training=False)))
data.append(records)
for batch in trange(batches, desc=f'Synthetic data generation'):
data_ = []
cond_seq = convert_to_tensor(condition, float32)
gen_input = concat([tile(cond_seq, multiples=[self.batch_size, 1]), next(z_dist)], axis=1)
for step in range(ar_steps):
records = self.generator(gen_input, training=False)
gen_input = concat([records[:, -self.cond_dim:], next(z_dist)], axis=1)
data_.append(records)
data_ = hstack(data_)[:, :seq_len]
data.append(data_)
return array(vstack(data))


Expand Down

0 comments on commit a3c9b34

Please sign in to comment.