You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I adapted your time_series_classification example for market data prediction. It seems to be working but training is exceptionally slow on a P100 GPU which normally finishes similar tasks in 30m. After 4 hours it completed the first 2 epochs. Is this normal with CDEs or did I do something wrong? Training loss is also diverging, but that might be due to learning rate I haven't checked that yet.
Here is the dataprep function I added as well as some minor adaptations to the model.
def get_data():
btc_df = pd.read_csv('example/btc_data.csv', parse_dates=['open_time'])
btc_df_n_t = normalize(btc_df)
# Split training/testing
train_size = int(len(btc_df_n_t) * .8)
train_df, test_df = btc_df_n_t[:train_size], btc_df_n_t[train_size + 1:]
# Create sequences
SEQUENCE_LENGTH = 120
train_X, train_y = create_sequences(train_df, 'close', SEQUENCE_LENGTH)
test_X, test_y = create_sequences(test_df, 'close', SEQUENCE_LENGTH)
# Create tensor arrays
train_X, train_y = arr_to_tensor(train_X), arr_to_tensor(train_y)
test_X, test_y = arr_to_tensor(test_X), arr_to_tensor(test_y)
return train_X, train_y, test_X, test_y
def main(num_epochs=30):
train_X, train_y, test_X, test_y = get_data()
######################
# input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
# hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
# output_channels=1 because we're doing binary classification.
######################
model = NeuralCDE(input_channels=8, hidden_channels=8, output_channels=1)
optimizer = torch.optim.Adam(model.parameters())
######################
# Now we turn our dataset into a continuous path. We do this here via Hermite cubic spline interpolation.
# The resulting `train_coeffs` is a tensor describing the path.
# For most problems, it's probably easiest to save this tensor and treat it as the dataset.
######################
train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)
train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
for epoch in range(num_epochs):
for batch in train_dataloader:
batch_coeffs, batch_y = batch
pred_y = model(batch_coeffs).squeeze(-1)
loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y.unsqueeze(1), batch_y)
loss.backward()
optimizer.step()
optimizer.zero_grad()
print('Epoch: {} Training loss: {}'.format(epoch, loss.item()))
test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
pred_y = model(test_coeffs).squeeze(-1)
# TODO: Modify evaluation for non-binary prediction
binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)
prediction_matches = (binary_prediction == test_y).to(test_y.dtype)
proportion_correct = prediction_matches.sum() / test_y.size(0)
print('Test Accuracy: {}'.format(proportion_correct))
The text was updated successfully, but these errors were encountered:
It doesn't look like you're specifying the solver, or the tolerances/step-sizes at all. The defaults may not be suitable for your data.
You may like to try Diffrax, which is built for JAX instead. This builds on a lot of the lessons we learnt building torchcde/torchsde/torchdiffeq, and in particular has much better default behaviour, that demands that you make an explicit choice about this kind of thing.
I adapted your time_series_classification example for market data prediction. It seems to be working but training is exceptionally slow on a P100 GPU which normally finishes similar tasks in 30m. After 4 hours it completed the first 2 epochs. Is this normal with CDEs or did I do something wrong? Training loss is also diverging, but that might be due to learning rate I haven't checked that yet.
Here is the dataprep function I added as well as some minor adaptations to the model.
The complete code with corresponding data CSV: time_series_prediction example
The text was updated successfully, but these errors were encountered: