-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmasking.py
41 lines (35 loc) · 1.15 KB
/
masking.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
def create_mask(path, color_threshold):
"""
create a binary mask of an image using a color threshold
args:
- path [str]: path to image file
- color_threshold [array]: 1x3 array of RGB value
returns:
- mask [array]: binary array
"""
img = np.array(Image.open(path).convert('RGB'))
R, G, B = img[..., 0], img[..., 1], img[..., 2]
rt, gt, bt = color_threshold
mask = (R > rt) & (G > gt) & (B > bt)
return img, mask
def mask_and_display(img, mask):
"""
display 3 plots next to each other: image, mask and masked image
args:
- img [array]: HxWxC image array
- mask [array]: HxW mask array
"""
masked_image = img * np.stack([mask]*3, axis=2)
f, ax = plt.subplots(1, 3, figsize=(15, 10))
ax[0].imshow(img)
ax[1].imshow(mask)
ax[2].imshow(masked_image)
plt.show()
if __name__ == '__main__':
path = 'data/images/segment-1231623110026745648_480_000_500_000_with_camera_labels_38.png'
color_threshold = [128, 128, 128]
img, mask = create_mask(path, color_threshold)
mask_and_display(img, mask)