forked from pytorch/ignite
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvis.py
100 lines (73 loc) · 3.11 KB
/
vis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from typing import Callable, Optional
import numpy as np
import torch
try:
from image_dataset_viz import render_datapoint
except ImportError:
raise ModuleNotFoundError(
"Please install image-dataset-viz via pip install --upgrade git+https://github.com/vfdev-5/ImageDatasetViz.git"
)
def tensor_to_numpy(t: torch.Tensor) -> np.ndarray:
img = t.cpu().numpy().transpose((1, 2, 0))
return img.astype(np.uint8)
def make_grid(
batch_img: torch.Tensor,
batch_preds: torch.Tensor,
img_denormalize_fn: Callable,
batch_gt: Optional[torch.Tensor] = None,
):
"""Create a grid from batch image and mask as
i+l1+gt1 | i+l2+gt2 | i+l3+gt3 | i+l4+gt4 | ...
where i+l+gt = image + predicted label + ground truth
Args:
batch_img (torch.Tensor) batch of images of any type
batch_preds (torch.Tensor) batch of masks
img_denormalize_fn (Callable): function to denormalize batch of images
batch_gt (torch.Tensor, optional): batch of ground truth masks.
"""
assert isinstance(batch_img, torch.Tensor) and isinstance(batch_preds, torch.Tensor)
assert len(batch_img) == len(batch_preds), f"{len(batch_img)} vs {len(batch_preds)}"
assert batch_preds.ndim == 1, f"{batch_preds.ndim}"
if batch_gt is not None:
assert isinstance(batch_gt, torch.Tensor)
assert len(batch_preds) == len(batch_gt)
assert batch_gt.ndim == 1, f"{batch_gt.ndim}"
b = batch_img.shape[0]
h, w = batch_img.shape[2:]
le = 1
out_image = np.zeros((h * le, w * b, 3), dtype="uint8")
for i in range(b):
img = batch_img[i]
y_preds = batch_preds[i]
img = img_denormalize_fn(img)
img = tensor_to_numpy(img)
pred_label = y_preds.cpu().item()
target = f"p={pred_label}"
if batch_gt is not None:
gt_label = batch_gt[i]
gt_label = gt_label.cpu().item()
target += f" | gt={gt_label}"
out_image[0:h, i * w : (i + 1) * w, :] = render_datapoint(img, target, text_size=12)
return out_image
def predictions_gt_images_handler(img_denormalize_fn, n_images=None, another_engine=None, prefix_tag=None):
def wrapper(engine, logger, event_name):
batch = engine.state.batch
output = engine.state.output
x, y = batch
y_pred = output[0]
if y.shape == y_pred.shape and y.ndim == 4:
# Case of y of shape (B, C, H, W)
y = torch.argmax(y, dim=1)
y_pred = torch.argmax(y_pred, dim=1).byte()
if n_images is not None:
x = x[:n_images, ...]
y = y[:n_images, ...]
y_pred = y_pred[:n_images, ...]
grid_pred_gt = make_grid(x, y_pred, img_denormalize_fn, batch_gt=y)
state = engine.state if another_engine is None else another_engine.state
global_step = state.get_event_attrib_value(event_name)
tag = "predictions_with_gt"
if prefix_tag is not None:
tag = f"{prefix_tag}: {tag}"
logger.writer.add_image(tag=tag, img_tensor=grid_pred_gt, global_step=global_step, dataformats="HWC")
return wrapper