-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathEnsemble_wat.py
71 lines (37 loc) · 1.17 KB
/
Ensemble_wat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from parsing import parse_train_args
from cv import cv
from utils import create_logger
"""
source ~/conda_py.sh
source activate torch
"""
if __name__ == '__main__':
args = parse_train_args()
args.num_folds=1
args.epochs=50
args.ensemble_size=100
args.batch_size=128
args.activation='ReLU'
args.depth=3
args.dropout=0
args.ffn_num_layers=2
args.hidden_size=300
args.sumstyle=True
args.seed=0
args.gpuUSE=True
args.gpu=3
args.data_path,args.cols_to_read ='data_RE2/water_solubilityOCD.csv',[x for x in range(2)]
args.save_dir='save_test'
args.tmp_data_dir='./data_RE2/tmp/'
args.scale='normalization'
args.split_type='random'
args.diff_depth_weights=True
args.layers_per_message=1
args.attention=True
args.message_attention=False
args.global_attention=False
args.message_attention_heads=1
args.log_dir=None
print(args)
logger = create_logger(name='train_crossValidate', save_dir=args.save_dir, quiet=args.quiet)
cv(args, logger)