Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document bimodal / full uncertainty regression outputs #120

Open
noahho opened this issue Jan 12, 2025 · 0 comments
Open

Document bimodal / full uncertainty regression outputs #120

noahho opened this issue Jan 12, 2025 · 0 comments
Labels
documentation Improvements or additions to documentation

Comments

@noahho
Copy link
Collaborator

noahho commented Jan 12, 2025

Example usage and visualization and of full uncertainty mode:

reg = TabPFNRegressor()
reg.fit(x, y_noisy)
preds = reg.predict(x_test, output_type="full")

fig, ax = plt.subplots(1, figsize=(12,6))

N = 10 #number of samples to visualize

plot_bar_distribution(ax, torch.tensor(x)[0:N], preds["criterion"].borders, preds["logits"][0:N])
ax.set_ylim(-1, 10)
import matplotlib.patches as patches
import seaborn as sns
import torch
import warnings
from matplotlib.collections import PatchCollection


def get_rect(coord, height, width):
    rect = patches.Rectangle(coord, height, width)

    return rect


def heatmap_with_box_sizes(
    ax,
    data: torch.Tensor,
    x_starts,
    x_ends,
    y_starts,
    y_ends,
    palette=None,
    set_lims=True,
    threshold_i=0.0,  # Threshold intensity (not probability)
    y_min=None,
    y_max=None,
    transpose=False,
    per_col_normalize=False,
):
    """
    Beware all x and y arrays should be sorted from small to large and the data will appear in that same order: Small indexes map to lower x/y-axis values.
    """
    if palette is None:
        palette = sns.cubehelix_palette(
            start=2.9,
            rot=0.0,
            dark=0.6,
            light=1,
            gamma=4.0,
            hue=9.0,
            as_cmap=True
            # use gamma to control how much of the spectrum is saturated, more gamma -> bigger part that is saturated
            # use dark to control how dark the darkest part is, a higher value will make the darkest part lighter
        )

    if set_lims:
        ax.set_xlim(x_starts[0], x_ends[-1])
        if not y_min or not y_max:
            assert (
                len(y_starts.shape) == 1
            ), "If y_min and y_max are not provided, y_starts should be 1D. Please set y_min and y_max manually."
            ax.set_ylim(y_starts[0], y_ends[-1])
        else:
            ax.set_ylim(y_min, y_max)

    if per_col_normalize:
        data = (data - data.min(0, keepdim=True).values) / (
            data.max(0, keepdim=True).values - data.min(0, keepdim=True).values
        )
    else:
       data = (data - data.min()) / (data.max() - data.min())
    rects, colors = [], []

    assert y_ends.shape == y_starts.shape
    if len(y_starts.shape) == 1:
        y_starts = y_starts.unsqueeze(0).expand(len(x_starts), -1)
        y_ends = y_ends.unsqueeze(0).expand(len(x_starts), -1)

    for col_i, (col_start, col_end) in enumerate(zip(x_starts, x_ends)):
        for row_i, (row_start, row_end) in enumerate(
            zip(y_starts[col_i], y_ends[col_i])
        ):
            intensity = data[row_i, col_i].item()
            intensity = max(0.0, (intensity - threshold_i)) / (
                1 - threshold_i
            )  # Start with intensity at the threshold value (smoother visualization)

            if intensity <= 0:
                continue

            if y_max and y_min and (row_start > y_max or row_end < y_min):
                continue

            if row_start >= row_end or col_start >= col_end:
                continue

            if palette(intensity) == (1.0, 1.0, 1.0, 1.0):
                continue

            # print(row_start, row_end, col_start, col_end, intensity, palette(intensity))

            # print(intensity, palette(intensity), row_start, row_end)

            # e.g. data[row_i, col_i].item() / col_end - col_start (or row_end - row_start)
            if transpose:
                rects += [
                    get_rect(
                        (row_start, col_start), row_end - row_start, col_end - col_start
                    )
                ]
            else:
                rects += [
                    get_rect(
                        (col_start, row_start), col_end - col_start, row_end - row_start
                    )
                ]
            colors += [palette(intensity)]
    rect_collection = PatchCollection(
        rects, facecolors=colors, edgecolor="none", linewidth=1
    )
    ax.add_collection(rect_collection)
    ax.set_rasterized(True)


def plot_bar_distribution(
    ax,
    x: torch.Tensor,
    bar_borders: torch.Tensor,
    logits: torch.Tensor,
    merge_bars=None,
    restrict_to_range=None,
    plot_log_probs=False,
    **kwargs,
):
    """
    :param ax: A matplotlib axis, you can get one with: `fig, ax = pyplot.subplots()`
    :param x: The positions to plot on the x-axis, this is your x, but it has to be 1d with shape (num_examples,)
    :param bar_borders: The borders of your bar distritbuion, they can be obtained at transformer_model.criterion.borders
    :param logits: A tensor of shape (num_examples, len(bar_borders)-1) that comes straight out of the model
    :param merge_bars: Number of bars to merge into one. If None, no merging is done. This speeds up the plotting.
    :param restrict_to_range: A tuple of (min_y, max_y) that restricts the y-axis to this range. If None, no restriction is done.
    :param plot_log_probs: If True, the log probabilities are plotted instead of the probabilities. This is useful if some probabilities are really high.
    :param kwargs:
    :return:
    """
    x = x.squeeze()
    predictions = logits.squeeze().softmax(-1)
    assert len(x.shape) == 1
    assert len(predictions.shape) == 2
    assert len(predictions) == len(x)
    assert len(bar_borders.shape) == 1
    assert len(bar_borders) - 1 == predictions.shape[1]
    assert isinstance(x, torch.Tensor)

    if merge_bars and merge_bars > 1:
        new_borders_inds = torch.arange(0, len(bar_borders), merge_bars)
        if new_borders_inds[-1] != len(bar_borders) - 1:
            new_borders_inds = torch.cat(
                [new_borders_inds, torch.tensor([len(bar_borders) - 1])]
            )
        bar_borders = bar_borders[new_borders_inds]
        pred_cumsum = torch.cat(
            [torch.zeros(len(predictions), 1), predictions.cumsum(-1)], dim=-1
        )

        predictions = (
            pred_cumsum[:, new_borders_inds[1:]] - pred_cumsum[:, new_borders_inds[:-1]]
        )
        assert len(bar_borders) - 1 == predictions.shape[-1]

    if restrict_to_range is not None:
        min_y, max_y = restrict_to_range
        border_mask = (min_y <= bar_borders) & (bar_borders <= max_y)
        # make the mask itself one border broader
        border_mask[:-1] = border_mask[1:] | border_mask[:-1]
        border_mask[1:] = border_mask[1:] | border_mask[:-1]
        logit_mask = border_mask[:-1] & border_mask[1:]
        bar_borders = bar_borders[border_mask]
        predictions = predictions[:, logit_mask]

    y_starts = bar_borders[:-1]
    y_ends = bar_borders[1:]

    x, order = x.sort(0)

    predictions = predictions[order] / (bar_borders[1:] - bar_borders[:-1])
    predictions[torch.isinf(predictions)] = 0.0
    predictions[:, (bar_borders[1:] - bar_borders[:-1]) < 1e-10] = 0.0

    if plot_log_probs:
        predictions = predictions.log()
        predictions[predictions.isinf()] = torch.min(predictions[~predictions.isinf()])

    # assume x is sorted
    x_starts = torch.cat([x[0].unsqueeze(0), (x[1:] + x[:-1]) / 2])
    x_ends = torch.cat(
        [
            (x[1:] + x[:-1]) / 2,
            x[-1].unsqueeze(0),
        ]
    )

    heatmap_with_box_sizes(
        ax, predictions.T, x_starts, x_ends, y_starts, y_ends, **kwargs
    )

We need to document usage and add the visualization code to our repository (tabpfn-extensions?)

@noahho noahho added the documentation Improvements or additions to documentation label Jan 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

1 participant