Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch predictions when test set is large #125

Open
LeoGrin opened this issue Jan 13, 2025 · 1 comment
Open

Batch predictions when test set is large #125

LeoGrin opened this issue Jan 13, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@LeoGrin
Copy link
Collaborator

LeoGrin commented Jan 13, 2025

Copying a discussion from discord:

David Holzmüller
For me, on some (larger?) datasets (but also <= 10K train samples), I get the attached error (TabPFNClassifier), tested on torch 2.4 and torch 2.5.
 File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/classifier.py", line 533, in predict_proba
    for output, config in self.executor_.iter_outputs(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/inference.py", line 192, in iter_outputs
    output = self.model(
             ^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 413, in forward
    return self._forward(x, y, style=style, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 625, in _forward
    encoder_out = self.transformer_encoder(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 74, in forward
    x = layer(x, **kwargs)
        ^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 449, in forward
    state = sublayer(state)
            ^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 334, in attn_between_features
    return self.self_attn_between_features(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 355, in forward
    output: torch.Tensor = self._compute(
                           ^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/memory.py", line 100, in method_
    return x + method(self, x, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 504, in _compute
    attention_head_outputs = MultiHeadAttention.compute_attention_heads(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 710, in compute_attention_heads
    attention_head_outputs = torch.nn.functional.scaled_dot_product_attention(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid configuration argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Jingang
Hi David, the CUDA error "invalid configuration argument" may occur if batch size is too large with flash attention and memory-efficient SDPA enabled.
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

David Holzmüller
I was hoping that TabPFN's interface would automatically batch the predict() step if there are too many samples, but apparently that is not the case

Seems that it makes sense to batch when the test set is too big, and maybe what too big means depend on different things including whether we're using flash attention and memory-efficient SDPA.

@LeoGrin LeoGrin added enhancement New feature or request good first issue Good for newcomers labels Jan 13, 2025
@LennartPurucker
Copy link
Collaborator

I agree it would be nice to support batching the predictions, but I removed this feature during the refactoring because the user should handle it.

Batching the test predictions highly depends on the deployment hardware, input data, and inference requirements. In other words, finding a catch-all solution to determining the size of each batch that perfectly trades-off time with memory is tough.

IMO, we could instead add an example for this to tabpfn-extensions and avoid supporting it in the interface here. Moreover, I am not aware of a standard from scikit-learn about this.

Here is the very unsophisticated code I used before. The problem are how to set test_batch_threshold and how to split/batch X.

test_batch_threshold = 20_000
batch_predict_call =  X.shape[0] > test_batch_threshold

if batch_predict_call:
    warnings.warn(
        f"Using batch_predict_call! This is a workaround to avoid VRAM memory issues that splits predictions "
        f"with more than {test_batch_threshold} test instances into three predict-wise batches. "
        stacklevel=2,
    )
    split_point_a = len(X) // 3
    split_point_b = split_point_a * 2
    input_batches = [
        (
            X[:split_point_a],
            None if additional_y is None else additional_y[:split_point_a],
        ),
        (
            X[split_point_a:split_point_b],
            None
            if additional_y is None
            else additional_y[split_point_a:split_point_b],
        ),
        (
            X[split_point_b:],
            None if additional_y is None else additional_y[split_point_b:],
        ),
    ]

# Handle predictions per batch
for batch_X, batch_additional_y in input_batches:
    # Predict here...

    output_per_batch.append(
        (
            batch_prediction,
            batch_additional_outputs,
        )
    )

if not batch_predict_call:
    prediction, additional_outputs = output_per_batch[0]
else:
    prediction = np.vstack([batch[0] for batch in output_per_batch])
    additional_outputs = {}
    for _, _batch_additional_outputs in output_per_batch:
        for k in _batch_additional_outputs:
            if k not in additional_outputs:
                additional_outputs[k] = _batch_additional_outputs[k]
            else:
                additional_outputs[k] = torch.cat(
                    [additional_outputs[k], _batch_additional_outputs[k]], dim=1
                )

@LeoGrin LeoGrin removed the good first issue Good for newcomers label Jan 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants