You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Example usage and visualization and of full uncertainty mode:
reg = TabPFNRegressor()
reg.fit(x, y_noisy)
preds = reg.predict(x_test, output_type="full")
fig, ax = plt.subplots(1, figsize=(12,6))
N = 10 #number of samples to visualize
plot_bar_distribution(ax, torch.tensor(x)[0:N], preds["criterion"].borders, preds["logits"][0:N])
ax.set_ylim(-1, 10)
importmatplotlib.patchesaspatchesimportseabornassnsimporttorchimportwarningsfrommatplotlib.collectionsimportPatchCollectiondefget_rect(coord, height, width):
rect=patches.Rectangle(coord, height, width)
returnrectdefheatmap_with_box_sizes(
ax,
data: torch.Tensor,
x_starts,
x_ends,
y_starts,
y_ends,
palette=None,
set_lims=True,
threshold_i=0.0, # Threshold intensity (not probability)y_min=None,
y_max=None,
transpose=False,
per_col_normalize=False,
):
""" Beware all x and y arrays should be sorted from small to large and the data will appear in that same order: Small indexes map to lower x/y-axis values. """ifpaletteisNone:
palette=sns.cubehelix_palette(
start=2.9,
rot=0.0,
dark=0.6,
light=1,
gamma=4.0,
hue=9.0,
as_cmap=True# use gamma to control how much of the spectrum is saturated, more gamma -> bigger part that is saturated# use dark to control how dark the darkest part is, a higher value will make the darkest part lighter
)
ifset_lims:
ax.set_xlim(x_starts[0], x_ends[-1])
ifnoty_minornoty_max:
assert (
len(y_starts.shape) ==1
), "If y_min and y_max are not provided, y_starts should be 1D. Please set y_min and y_max manually."ax.set_ylim(y_starts[0], y_ends[-1])
else:
ax.set_ylim(y_min, y_max)
ifper_col_normalize:
data= (data-data.min(0, keepdim=True).values) / (
data.max(0, keepdim=True).values-data.min(0, keepdim=True).values
)
else:
data= (data-data.min()) / (data.max() -data.min())
rects, colors= [], []
asserty_ends.shape==y_starts.shapeiflen(y_starts.shape) ==1:
y_starts=y_starts.unsqueeze(0).expand(len(x_starts), -1)
y_ends=y_ends.unsqueeze(0).expand(len(x_starts), -1)
forcol_i, (col_start, col_end) inenumerate(zip(x_starts, x_ends)):
forrow_i, (row_start, row_end) inenumerate(
zip(y_starts[col_i], y_ends[col_i])
):
intensity=data[row_i, col_i].item()
intensity=max(0.0, (intensity-threshold_i)) / (
1-threshold_i
) # Start with intensity at the threshold value (smoother visualization)ifintensity<=0:
continueify_maxandy_minand (row_start>y_maxorrow_end<y_min):
continueifrow_start>=row_endorcol_start>=col_end:
continueifpalette(intensity) == (1.0, 1.0, 1.0, 1.0):
continue# print(row_start, row_end, col_start, col_end, intensity, palette(intensity))# print(intensity, palette(intensity), row_start, row_end)# e.g. data[row_i, col_i].item() / col_end - col_start (or row_end - row_start)iftranspose:
rects+= [
get_rect(
(row_start, col_start), row_end-row_start, col_end-col_start
)
]
else:
rects+= [
get_rect(
(col_start, row_start), col_end-col_start, row_end-row_start
)
]
colors+= [palette(intensity)]
rect_collection=PatchCollection(
rects, facecolors=colors, edgecolor="none", linewidth=1
)
ax.add_collection(rect_collection)
ax.set_rasterized(True)
defplot_bar_distribution(
ax,
x: torch.Tensor,
bar_borders: torch.Tensor,
logits: torch.Tensor,
merge_bars=None,
restrict_to_range=None,
plot_log_probs=False,
**kwargs,
):
""" :param ax: A matplotlib axis, you can get one with: `fig, ax = pyplot.subplots()` :param x: The positions to plot on the x-axis, this is your x, but it has to be 1d with shape (num_examples,) :param bar_borders: The borders of your bar distritbuion, they can be obtained at transformer_model.criterion.borders :param logits: A tensor of shape (num_examples, len(bar_borders)-1) that comes straight out of the model :param merge_bars: Number of bars to merge into one. If None, no merging is done. This speeds up the plotting. :param restrict_to_range: A tuple of (min_y, max_y) that restricts the y-axis to this range. If None, no restriction is done. :param plot_log_probs: If True, the log probabilities are plotted instead of the probabilities. This is useful if some probabilities are really high. :param kwargs: :return: """x=x.squeeze()
predictions=logits.squeeze().softmax(-1)
assertlen(x.shape) ==1assertlen(predictions.shape) ==2assertlen(predictions) ==len(x)
assertlen(bar_borders.shape) ==1assertlen(bar_borders) -1==predictions.shape[1]
assertisinstance(x, torch.Tensor)
ifmerge_barsandmerge_bars>1:
new_borders_inds=torch.arange(0, len(bar_borders), merge_bars)
ifnew_borders_inds[-1] !=len(bar_borders) -1:
new_borders_inds=torch.cat(
[new_borders_inds, torch.tensor([len(bar_borders) -1])]
)
bar_borders=bar_borders[new_borders_inds]
pred_cumsum=torch.cat(
[torch.zeros(len(predictions), 1), predictions.cumsum(-1)], dim=-1
)
predictions= (
pred_cumsum[:, new_borders_inds[1:]] -pred_cumsum[:, new_borders_inds[:-1]]
)
assertlen(bar_borders) -1==predictions.shape[-1]
ifrestrict_to_rangeisnotNone:
min_y, max_y=restrict_to_rangeborder_mask= (min_y<=bar_borders) & (bar_borders<=max_y)
# make the mask itself one border broaderborder_mask[:-1] =border_mask[1:] |border_mask[:-1]
border_mask[1:] =border_mask[1:] |border_mask[:-1]
logit_mask=border_mask[:-1] &border_mask[1:]
bar_borders=bar_borders[border_mask]
predictions=predictions[:, logit_mask]
y_starts=bar_borders[:-1]
y_ends=bar_borders[1:]
x, order=x.sort(0)
predictions=predictions[order] / (bar_borders[1:] -bar_borders[:-1])
predictions[torch.isinf(predictions)] =0.0predictions[:, (bar_borders[1:] -bar_borders[:-1]) <1e-10] =0.0ifplot_log_probs:
predictions=predictions.log()
predictions[predictions.isinf()] =torch.min(predictions[~predictions.isinf()])
# assume x is sortedx_starts=torch.cat([x[0].unsqueeze(0), (x[1:] +x[:-1]) /2])
x_ends=torch.cat(
[
(x[1:] +x[:-1]) /2,
x[-1].unsqueeze(0),
]
)
heatmap_with_box_sizes(
ax, predictions.T, x_starts, x_ends, y_starts, y_ends, **kwargs
)
We need to document usage and add the visualization code to our repository (tabpfn-extensions?)
The text was updated successfully, but these errors were encountered:
Example usage and visualization and of full uncertainty mode:
We need to document usage and add the visualization code to our repository (tabpfn-extensions?)
The text was updated successfully, but these errors were encountered: