-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_adj.py
75 lines (55 loc) · 2.41 KB
/
test_adj.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Script for test loop using Hierarchical CADNet (Adj)."""
import time
import tensorflow as tf
import numpy as np
from src.network_adj import HierarchicalGCNN as HierGCNN
from src.helper import dataloader_adj as dataloader
from src.analysis import analysis_report_mfcadplus
def test_step(x, y):
test_logits = model(x, training=False)
loss_value = loss_fn(y, test_logits)
y_true = np.argmax(y.numpy(), axis=1)
y_pred = np.argmax(test_logits.numpy(), axis=1)
test_loss_metric.update_state(loss_value)
test_acc_metric.update_state(y, test_logits)
test_precision_metric.update_state(y, test_logits)
test_recall_metric.update_state(y, test_logits)
return y_true, y_pred
if __name__ == '__main__':
# User defined parameters.
num_classes = 25
num_layers = 6
units = 512
num_epochs = 100
learning_rate = 1e-2
dropout_rate = 0.3
date_str="2023-11-17"
checkpoint_path = f'checkpoint/adj_lvl_{num_layers}_units_{units}_epochs_{num_epochs}_date_{date_str}.ckpt'
test_set_path = "data/test_MFCAD++.h5"
model = HierGCNN(units=units, rate=dropout_rate, num_classes=num_classes, num_layers=num_layers)
loss_fn = tf.keras.losses.CategoricalCrossentropy()
test_loss_metric = tf.keras.metrics.Mean()
test_acc_metric = tf.keras.metrics.CategoricalAccuracy()
test_precision_metric = tf.keras.metrics.Precision()
test_recall_metric = tf.keras.metrics.Recall()
model.load_weights(checkpoint_path)
test_dataloader = dataloader(test_set_path)
y_true_total = []
y_pred_total = []
start_time = time.time()
for x_batch_test, y_batch_test in test_dataloader:
one_hot_y = tf.one_hot(y_batch_test, depth=num_classes)
y_true, y_pred = test_step(x_batch_test, one_hot_y)
y_true_total = np.append(y_true_total, y_true)
y_pred_total = np.append(y_pred_total, y_pred)
print("Time taken: %.2fs" % (time.time() - start_time))
analysis_report_mfcadplus(y_true_total, y_pred_total)
test_loss = test_loss_metric.result()
test_acc = test_acc_metric.result()
test_precision = test_precision_metric.result()
test_recall = test_recall_metric.result()
test_loss_metric.reset_states()
test_acc_metric.reset_states()
test_precision_metric.reset_states()
test_recall_metric.reset_states()
print(f"Test loss={test_loss}, Test acc={test_acc}, Precision={test_precision}, Recall={test_recall}")