-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
104 lines (84 loc) · 2.51 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import cv2
import datetime
import pyfirmata
import time
port = 'COM3'
board = pyfirmata.Arduino(port)
def get_default_device():
"""Pick GPU if available, else CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
device = get_default_device()
# Initialize optimizer
model = torch.jit.load('model_new.pt', map_location=device)
model.eval()
data_dir = './dataset-resized'
classes = os.listdir(data_dir)
transformations = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
dataset = ImageFolder(data_dir, transform=transformations)
allow_photo = False
def to_device(data, device):
"""Move tensor(s) to chosen device"""
if isinstance(data, (list, tuple)):
return [to_device(x, device) for x in data]
return data.to(device, non_blocking=True)
def predict_image(img, model):
# Convert to a batch of 1
xb = to_device(img.unsqueeze(0), device)
# Get predictions from model
yb = model(xb)
# Pick index with highest probability
prob, preds = torch.max(yb, dim=1)
# Retrieve the class label
return dataset.classes[preds[0].item()]
def predict_external_image(image_name):
try:
image = Image.fromarray(image_name)
except AttributeError:
image = Image.open(image_name)
example_image = transformations(image)
answer = predict_image(example_image, model)
print("The image resembles", answer + ".")
if answer == "glass":
pin = 3
elif answer == "metal":
pin = 4
elif answer == "paper":
pin = 2
elif answer == "plastic":
pin = 5
else:
pin = None
if pin is not None:
board.digital[pin].mode = pyfirmata.SERVO
for i in range(0, 90):
board.digital[pin].write(i)
time.sleep(0.015)
return answer
cap = cv2.VideoCapture(0)
wait = 0
t_end = time.time() + 5
while True:
ret, frame = cap.read()
frame = cv2.resize(frame, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_AREA)
cv2.imshow('Result', frame)
c = cv2.waitKey(1)
if c == ord("q"):
break
if c == ord("e"):
allow_photo = not allow_photo
if allow_photo:
wait = wait + 100
if wait == 6000:
dd = datetime.datetime.now()
answer = predict_external_image(frame)
wait = 0
cap.release()
cv2.destroyAllWindows()