You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
David Holzmüller
For me, on some (larger?) datasets (but also <= 10K train samples), I get the attached error (TabPFNClassifier), tested on torch 2.4 and torch 2.5.
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/classifier.py", line 533, in predict_proba
for output, config in self.executor_.iter_outputs(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/inference.py", line 192, in iter_outputs
output = self.model(
^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 413, in forward
return self._forward(x, y, style=style, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 625, in _forward
encoder_out = self.transformer_encoder(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/transformer.py", line 74, in forward
x = layer(x, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 449, in forward
state = sublayer(state)
^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/layer.py", line 334, in attn_between_features
return self.self_attn_between_features(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 355, in forward
output: torch.Tensor = self._compute(
^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/memory.py", line 100, in method_
return x + method(self, x, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 504, in _compute
attention_head_outputs = MultiHeadAttention.compute_attention_heads(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/storage/holzmudd/mamba_envs/envs/tabpfn_test/lib/python3.12/site-packages/tabpfn/model/multi_head_attention.py", line 710, in compute_attention_heads
attention_head_outputs = torch.nn.functional.scaled_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid configuration argument
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Jingang
Hi David, the CUDA error "invalid configuration argument" may occur if batch size is too large with flash attention and memory-efficient SDPA enabled.
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
David Holzmüller
I was hoping that TabPFN's interface would automatically batch the predict() step if there are too many samples, but apparently that is not the case
Seems that it makes sense to batch when the test set is too big, and maybe what too big means depend on different things including whether we're using flash attention and memory-efficient SDPA.
The text was updated successfully, but these errors were encountered:
I agree it would be nice to support batching the predictions, but I removed this feature during the refactoring because the user should handle it.
Batching the test predictions highly depends on the deployment hardware, input data, and inference requirements. In other words, finding a catch-all solution to determining the size of each batch that perfectly trades-off time with memory is tough.
IMO, we could instead add an example for this to tabpfn-extensions and avoid supporting it in the interface here. Moreover, I am not aware of a standard from scikit-learn about this.
Here is the very unsophisticated code I used before. The problem are how to set test_batch_threshold and how to split/batch X.
test_batch_threshold=20_000batch_predict_call=X.shape[0] >test_batch_thresholdifbatch_predict_call:
warnings.warn(
f"Using batch_predict_call! This is a workaround to avoid VRAM memory issues that splits predictions "f"with more than {test_batch_threshold} test instances into three predict-wise batches. "stacklevel=2,
)
split_point_a=len(X) //3split_point_b=split_point_a*2input_batches= [
(
X[:split_point_a],
Noneifadditional_yisNoneelseadditional_y[:split_point_a],
),
(
X[split_point_a:split_point_b],
Noneifadditional_yisNoneelseadditional_y[split_point_a:split_point_b],
),
(
X[split_point_b:],
Noneifadditional_yisNoneelseadditional_y[split_point_b:],
),
]
# Handle predictions per batchforbatch_X, batch_additional_yininput_batches:
# Predict here...output_per_batch.append(
(
batch_prediction,
batch_additional_outputs,
)
)
ifnotbatch_predict_call:
prediction, additional_outputs=output_per_batch[0]
else:
prediction=np.vstack([batch[0] forbatchinoutput_per_batch])
additional_outputs= {}
for_, _batch_additional_outputsinoutput_per_batch:
forkin_batch_additional_outputs:
ifknotinadditional_outputs:
additional_outputs[k] =_batch_additional_outputs[k]
else:
additional_outputs[k] =torch.cat(
[additional_outputs[k], _batch_additional_outputs[k]], dim=1
)
Copying a discussion from discord:
Seems that it makes sense to batch when the test set is too big, and maybe what too big means depend on different things including whether we're using flash attention and memory-efficient SDPA.
The text was updated successfully, but these errors were encountered: