Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Very slow training with market data, normal? #46

Open
Karlheinzniebuhr opened this issue Apr 30, 2022 · 1 comment
Open

Very slow training with market data, normal? #46

Karlheinzniebuhr opened this issue Apr 30, 2022 · 1 comment

Comments

@Karlheinzniebuhr
Copy link

Karlheinzniebuhr commented Apr 30, 2022

I adapted your time_series_classification example for market data prediction. It seems to be working but training is exceptionally slow on a P100 GPU which normally finishes similar tasks in 30m. After 4 hours it completed the first 2 epochs. Is this normal with CDEs or did I do something wrong? Training loss is also diverging, but that might be due to learning rate I haven't checked that yet.
Here is the dataprep function I added as well as some minor adaptations to the model.

The complete code with corresponding data CSV: time_series_prediction example

def get_data():
    btc_df = pd.read_csv('example/btc_data.csv', parse_dates=['open_time'])
    btc_df_n_t = normalize(btc_df)
    
    # Split training/testing
    train_size = int(len(btc_df_n_t) * .8)
    train_df, test_df = btc_df_n_t[:train_size], btc_df_n_t[train_size + 1:]
    
    # Create sequences
    SEQUENCE_LENGTH = 120
    train_X, train_y = create_sequences(train_df, 'close', SEQUENCE_LENGTH)
    test_X, test_y = create_sequences(test_df, 'close', SEQUENCE_LENGTH)
    
    # Create tensor arrays
    train_X, train_y  = arr_to_tensor(train_X), arr_to_tensor(train_y)
    test_X, test_y  = arr_to_tensor(test_X), arr_to_tensor(test_y)

    return train_X, train_y, test_X, test_y

def main(num_epochs=30):
    train_X, train_y, test_X, test_y = get_data()

    ######################
    # input_channels=3 because we have both the horizontal and vertical position of a point in the spiral, and time.
    # hidden_channels=8 is the number of hidden channels for the evolving z_t, which we get to choose.
    # output_channels=1 because we're doing binary classification.
    ######################
    model = NeuralCDE(input_channels=8, hidden_channels=8, output_channels=1)
    optimizer = torch.optim.Adam(model.parameters())

    ######################
    # Now we turn our dataset into a continuous path. We do this here via Hermite cubic spline interpolation.
    # The resulting `train_coeffs` is a tensor describing the path.
    # For most problems, it's probably easiest to save this tensor and treat it as the dataset.
    ######################
    train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)

    train_dataset = torch.utils.data.TensorDataset(train_coeffs, train_y)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=32)
    for epoch in range(num_epochs):
        for batch in train_dataloader:
            batch_coeffs, batch_y = batch
            pred_y = model(batch_coeffs).squeeze(-1)
            loss = torch.nn.functional.binary_cross_entropy_with_logits(pred_y.unsqueeze(1), batch_y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print('Epoch: {}   Training loss: {}'.format(epoch, loss.item()))

    test_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
    pred_y = model(test_coeffs).squeeze(-1)
    
    # TODO: Modify evaluation for non-binary prediction
    binary_prediction = (torch.sigmoid(pred_y) > 0.5).to(test_y.dtype)
    prediction_matches = (binary_prediction == test_y).to(test_y.dtype)
    proportion_correct = prediction_matches.sum() / test_y.size(0)
    print('Test Accuracy: {}'.format(proportion_correct))
@patrick-kidger
Copy link
Owner

It doesn't look like you're specifying the solver, or the tolerances/step-sizes at all. The defaults may not be suitable for your data.

You may like to try Diffrax, which is built for JAX instead. This builds on a lot of the lessons we learnt building torchcde/torchsde/torchdiffeq, and in particular has much better default behaviour, that demands that you make an explicit choice about this kind of thing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants