forked from LucaLaFisca/Human-Centered-xAI
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
129 lines (102 loc) · 4.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
from fastai.vision.all import *
from fastai.data.all import *
from model import AAE
from utils import label_func, FreezeDiscriminator, GetLatentSpace, LossAttrMetric, distrib_regul_regression, compute_main_direction
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
### Define the Dataloader
data_path = untar_data(URLs.PETS) #checker les autres databases dispo
print(data_path.ls())
catblock = MultiCategoryBlock(encoded=True, vocab=['cat', 'dog'])
dblock = DataBlock(
blocks=(ImageBlock(), catblock),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=label_func,
item_tfms=Resize(128),
batch_tfms=[Normalize.from_stats(*imagenet_stats)],
)
# Créez un DataLoader
dls = dblock.dataloaders(data_path/"images", bs=16, drop_last=True)
# Define the model
model = AAE(
input_size=128,
input_channels=3,
encoding_dims=128,
classes=2,
)
### Train Autoencoder ###
metrics = [LossAttrMetric("recons_loss"), accuracy_multi]
learn = Learner(dls, model, loss_func=model.ae_loss_func, metrics=metrics)
model_file = 'cat_dog_ae_test'
learning_rate = learn.lr_find()
learn.fit(100, lr=learning_rate.valley,
cbs=[TrackerCallback(),
SaveModelCallback(fname=model_file),
EarlyStoppingCallback(min_delta=1e-4,patience=10)])
state_dict = torch.load(f'models/{model_file}.pth')
model.load_state_dict(state_dict, strict=False)
### Train Adversarial ###
metrics = [LossAttrMetric("adv_loss"), LossAttrMetric("recons_loss"), LossAttrMetric("crit_loss"),
accuracy_multi]
learn = Learner(dls, model, loss_func=model.aae_loss_func, metrics=metrics)
model_file = 'cat_dog_aae_test'
learn.fit(100, lr=5e-3,
cbs=[GradientAccumulation(n_acc=16*4),
TrackerCallback(),
SaveModelCallback(fname=model_file),
EarlyStoppingCallback(min_delta=1e-4,patience=10),
FreezeDiscriminator()])
state_dict = torch.load(f'models/{model_file}.pth')
model.load_state_dict(state_dict, strict=False)
### Train Classifier ###
metrics = [LossAttrMetric("adv_loss"), LossAttrMetric("recons_loss"),
LossAttrMetric("classif_loss"), LossAttrMetric("crit_loss"),
accuracy_multi]
monitor_loss = 'valid_loss'
learn = Learner(dls, model, loss_func=model.classif_loss_func, metrics=metrics)
model_file = 'cat_dog_aae_classif_test'
learn.fit(100, lr=1e-2,
cbs=[GradientAccumulation(n_acc=16*4),
TrackerCallback(monitor=monitor_loss),
SaveModelCallback(fname=model_file,monitor=monitor_loss),
EarlyStoppingCallback(min_delta=1e-4,patience=10,monitor=monitor_loss),
FreezeDiscriminator()])
### Display the latent space ###
learn.load(f'models/{model_file}', strict=False)
# compute and display the latent space
learn.zi_valid = torch.tensor([]).to(dev)
learn.get_preds(ds_idx=0,cbs=[GetLatentSpace()])
new_zi = learn.zi_valid
learn.zi_valid = torch.tensor([]).to(dev)
learn.get_preds(ds_idx=1,cbs=[GetLatentSpace()])
new_zi = torch.vstack((new_zi,learn.zi_valid))
torch.save(new_zi,'z_aae.pt')
print(new_zi.shape)
tsne = TSNE(random_state=42)
# z = new_zi.view(-1, 128)
z = new_zi.view(-1, 512)
predictions_embedded = tsne.fit_transform(z.cpu().detach().numpy())
#Compute linear regression from 2D space
y_pred_embed = distrib_regul_regression(predictions_embedded, lab_gather)
diverging_norm = mcolors.TwoSlopeNorm(vmin=lab_gather.min(),vcenter=0.5,vmax=lab_gather.max())
mapper = plt.cm.ScalarMappable(norm=diverging_norm)#, cmap='YlOrBr_r')
colors = mapper.to_rgba(lab_gather.numpy())
fig, ax = plt.subplots()
sns.scatterplot(x=predictions_embedded[:,0], y=predictions_embedded[:,1], hue=category, s=55)
# Plot the line along the first principal component
start, end = compute_main_direction(predictions_embedded, y_pred_embed)
ax.arrow(start[0], start[1], end[0]-start[0], end[1]-start[1], linewidth=3,
head_width=10, head_length=10, fc='#8B0000', ec='#8B0000', length_includes_head=True)
# Define x,y limits
maxabs = np.max(np.abs(predictions_embedded)) + 5
plt.xlim([-maxabs, maxabs])
plt.ylim([-maxabs, maxabs])
# Remove xticks and yticks
ax.set_xticks([])
ax.set_yticks([])
# Remove the legend
ax.get_legend().remove()