forked from YJiangcm/SST-2-sentiment-analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
177 lines (162 loc) · 7.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 25 00:24:07 2020
@author: Jiang Yuxin
"""
import torch
import torch.nn as nn
import time
from tqdm import tqdm
from sklearn.metrics import (
roc_auc_score,
accuracy_score,
precision_score,
recall_score,
f1_score,
classification_report
)
def Metric(y_true, y_pred):
"""
compute and show the classification result
"""
accuracy = accuracy_score(y_true, y_pred)
macro_precision = precision_score(y_true, y_pred, average='macro')
macro_recall = recall_score(y_true, y_pred, average='macro')
weighted_f1 = f1_score(y_true, y_pred, average='macro')
target_names = ['class_0', 'class_1']
report = classification_report(y_true, y_pred, target_names=target_names, digits=3)
print('Accuracy: {:.1%}\nPrecision: {:.1%}\nRecall: {:.1%}\nF1: {:.1%}'.format(accuracy, macro_precision,
macro_recall, weighted_f1))
print("classification_report:\n")
print(report)
def correct_predictions(output_probabilities, targets):
"""
Compute the number of predictions that match some target classes in the
output of a model.
Args:
output_probabilities: A tensor of probabilities for different output
classes.
targets: The indices of the actual target classes.
Returns:
The number of correct predictions in 'output_probabilities'.
"""
_, out_classes = output_probabilities.max(dim=1)
correct = (out_classes == targets).sum()
return correct.item()
def train(model, dataloader, optimizer, epoch_number, max_gradient_norm):
"""
Train a model for one epoch on some input data with a given optimizer and
criterion.
Args:
model: A torch module that must be trained on some input data.
dataloader: A DataLoader object to iterate over the training data.
optimizer: A torch optimizer to use for training on the input model.
epoch_number: The number of the epoch for which training is performed.
max_gradient_norm: Max. norm for gradient norm clipping.
Returns:
epoch_time: The total time necessary to train the epoch.
epoch_loss: The training loss computed for the epoch.
epoch_accuracy: The accuracy computed for the epoch.
"""
# Switch the model to train mode.
model.train()
device = model.device
epoch_start = time.time()
batch_time_avg = 0.0
running_loss = 0.0
correct_preds = 0
tqdm_batch_iterator = tqdm(dataloader)
for batch_index, (batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels) in enumerate(tqdm_batch_iterator):
batch_start = time.time()
# Move input and output data to the GPU if it is used.
seqs, masks, segments, labels = batch_seqs.to(device), batch_seq_masks.to(device), batch_seq_segments.to(device), batch_labels.to(device)
optimizer.zero_grad()
loss, logits, probabilities = model(seqs, masks, segments, labels)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_gradient_norm)
optimizer.step()
batch_time_avg += time.time() - batch_start
running_loss += loss.item()
correct_preds += correct_predictions(probabilities, labels)
description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}"\
.format(batch_time_avg/(batch_index+1), running_loss/(batch_index+1))
tqdm_batch_iterator.set_description(description)
epoch_time = time.time() - epoch_start
epoch_loss = running_loss / len(dataloader)
epoch_accuracy = correct_preds / len(dataloader.dataset)
return epoch_time, epoch_loss, epoch_accuracy
def validate(model, dataloader):
"""
Compute the loss and accuracy of a model on some validation dataset.
Args:
model: A torch module for which the loss and accuracy must be
computed.
dataloader: A DataLoader object to iterate over the validation data.
Returns:
epoch_time: The total time to compute the loss and accuracy on the
entire validation set.
epoch_loss: The loss computed on the entire validation set.
epoch_accuracy: The accuracy computed on the entire validation set.
roc_auc_score(all_labels, all_prob): The auc computed on the entire validation set.
all_prob: The probability of classification as label 1 on the entire validation set.
"""
# Switch to evaluate mode.
model.eval()
device = model.device
epoch_start = time.time()
running_loss = 0.0
running_accuracy = 0.0
all_prob = []
all_labels = []
# Deactivate autograd for evaluation.
with torch.no_grad():
for (batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels) in dataloader:
# Move input and output data to the GPU if one is used.
seqs = batch_seqs.to(device)
masks = batch_seq_masks.to(device)
segments = batch_seq_segments.to(device)
labels = batch_labels.to(device)
loss, logits, probabilities = model(seqs, masks, segments, labels)
running_loss += loss.item()
running_accuracy += correct_predictions(probabilities, labels)
all_prob.extend(probabilities[:,1].cpu().numpy())
all_labels.extend(batch_labels)
epoch_time = time.time() - epoch_start
epoch_loss = running_loss / len(dataloader)
epoch_accuracy = running_accuracy / (len(dataloader.dataset))
return epoch_time, epoch_loss, epoch_accuracy, roc_auc_score(all_labels, all_prob), all_prob
def test(model, dataloader):
"""
Test the accuracy of a model on some labelled test dataset.
Args:
model: The torch module on which testing must be performed.
dataloader: A DataLoader object to iterate over some dataset.
Returns:
batch_time: The average time to predict the classes of a batch.
total_time: The total time to process the whole dataset.
accuracy: The accuracy of the model on the input data.
all_prob: The probability of classification as label 1 on the entire validation set.
"""
# Switch the model to eval mode.
model.eval()
device = model.device
time_start = time.time()
batch_time = 0.0
accuracy = 0.0
all_prob = []
all_labels = []
# Deactivate autograd for evaluation.
with torch.no_grad():
for (batch_seqs, batch_seq_masks, batch_seq_segments, batch_labels) in dataloader:
batch_start = time.time()
# Move input and output data to the GPU if one is used.
seqs, masks, segments, labels = batch_seqs.to(device), batch_seq_masks.to(device), batch_seq_segments.to(device), batch_labels.to(device)
_, _, probabilities = model(seqs, masks, segments, labels)
accuracy += correct_predictions(probabilities, labels)
batch_time += time.time() - batch_start
all_prob.extend(probabilities[:,1].cpu().numpy())
all_labels.extend(batch_labels)
batch_time /= len(dataloader)
total_time = time.time() - time_start
accuracy /= (len(dataloader.dataset))
return batch_time, total_time, accuracy, all_prob