-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathargs_test.py
198 lines (184 loc) · 8.21 KB
/
args_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import torch
import torch.nn.functional as F
import numpy as np
from transformers import GPT2Config
from transformers_quantized import TransformerQuantized
from cichy_data import CichyData, CichyContData, CichyQuantized, CichyQuantizedGauss, CichyQuantizedAR
class Args:
gpu = '1' # cuda gpu index
func = {'train': True} # dict of functions to run from training.py
def __init__(self):
n = 1 # can be used to do multiple runs, e.g. over subjects
# experiment arguments
self.name = 'args.py' # name of this file, don't change
self.fix_seed = True
self.common_dataset = False
self.load_dataset = True # whether to load self.dataset
self.learning_rate = 0.0001 # learning rate for Adam
self.max_trials = 1.0 # ratio of training data (1=max)
self.val_max_trials = False
self.batch_size = 2 # batch size for training and validation data
self.epochs = 1000 # number of loops over training data
self.val_freq = 10 # how often to validate (in epochs)
self.print_freq = 2 # how often to print metrics (in epochs)
self.anneal_lr = False # whether to anneal learning rate
self.save_curves = True # whether to save loss curves to file
self.load_model = False
self.result_dir = [os.path.join(
'/',
'well',
'woolrich',
'users',
'yaq921',
'MEG-transfer-decoding', # path(s) to save model and others
'results',
'cichy_epoched',
'subj1',
'cont_quantized',
'gpt2_50hz100hz',
'concat_output')]
self.model = TransformerQuantized # class of model to use
self.dataset = CichyQuantized # dataset class for loading and handling data
# wavenet arguments
self.activation = torch.nn.Identity() # activation function for models
self.subjects = 0 # number of subjects used for training
self.embedding_dim = 0 # subject embedding size
self.p_drop = 0.0 # dropout probability
self.ch_mult = 2 # channel multiplier for hidden channels in wavenet
self.groups = 306
self.kernel_size = 2 # convolutional kernel size
self.timesteps = 1 # how many timesteps in the future to forecast
self.sample_rate = [0, 256] # start and end of timesteps within trials
self.rf = 128 # receptive field of wavenet, 2*rf - 1
self.example_shift = 128
rf = 128
ks = self.kernel_size
nl = int(np.log(rf) / np.log(ks))
dilations = [ks**i for i in range(nl)]
self.dilations = dilations + dilations # dilation: 2^num_layers
#self.dilations = [1] + [2] + [4] * 7 # costum dilations
# classifier arguments
self.wavenet_class = None # class of wavenet model
self.load_conv = False # where to load neural nerwork
# dimensionality reduction from
self.pred = False # whether to use wavenet in prediction mode
self.init_model = True # whether to reinitialize classifier
self.reg_semb = True # whether to regularize subject embedding
self.fixed_wavenet = False # whether to fix weights of wavenet
self.alpha_norm = 0.0 # regularization multiplier on weights
self.num_classes = 119 # number of classes for classification
self.units = [800, 300] # hidden layer sizes of fully-connected block
self.dim_red = 16 # number of pca components for channel reduction
self.stft_freq = 0 # STFT frequency index for LDA_wavelet_freq model
self.decode_peak = 0.1
# GPT2 arguments
n_embd = 768
self.gpt2_config = GPT2Config(
vocab_size=50257,
n_positions=1024,
n_embd=n_embd,
n_layer=12,
n_head=12,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
use_cache=False
)
# quantized wavenet arguments
self.skips_shift = 1
self.mu = 255
self.residual_channels = 128
self.dilation_channels = 128
self.skip_channels = 512
self.channel_emb = n_embd
self.class_emb = n_embd
self.quant_emb = n_embd
self.pos_emb = n_embd
self.cond_channels = self.class_emb + self.embedding_dim
self.head_channels = 256
self.conv_bias = False
# dataset arguments
data_path = os.path.join('/', 'gpfs2', 'well', 'woolrich', 'projects',
'cichy118_cont', 'preproc_data_osl', 'subj1')
self.data_path = [[os.path.join(data_path, 'subj1_50hz.npy')]] # path(s) to data directory
self.num_channels = list(range(614)) # channel indices
self.numpy = True # whether data is saved in numpy format
self.crop = 1 # cropping ratio for trials
self.whiten = False # pca components used in whitening
self.group_whiten = False # whether to perform whitening at the GL
self.split = np.array([0, 0.1]) # validation split (start, end)
self.sr_data = 100 # sampling rate used for downsampling
self.original_sr = 1000
self.save_data = True # whether to save the created data
self.bypass = False
self.subjects_data = False # list of subject inds to use in group data
self.save_whiten = False
self.num_clip = 4
self.dump_data = [os.path.join(data_path, '50hz100hz_quantized_clamp4')] # path(s) for dumping data
self.load_data = self.dump_data # path(s) for loading data files
# analysis arguments
self.closest_chs = 20 # channel neighbourhood size for spatial PFI
self.PFI_inverse = False # invert which channels/timesteps to shuffle
self.pfich_timesteps = [0, 256] # time window for spatiotemporal PFI
self.PFI_perms = 20 # number of PFI permutations
self.halfwin = 5 # half window size for temporal PFI
self.halfwin_uneven = False # whether to use even or uneven window
self.generate_noise = 1 # noise used for wavenet generation
self.generate_length = self.sr_data * 1000 # generated timeseries len
self.generate_mode = 'recursive' # IIR or FIR mode for wavenet generation
self.generate_input = 'data' # input type for generation
self.generate_sampling = 'top-p'
self.generate_shift = 1
self.top_p = 0.8
self.individual = True # whether to analyse individual kernels
self.anal_lr = 0.001 # learning rate for input backpropagation
self.anal_epochs = 200 # number of epochs for input backpropagation
self.norm_coeff = 0.0001 # L2 of input for input backpropagation
self.kernel_limit = 300 # max number of kernels to analyse
# simulation arguments
self.nonlinear_prenoise = True
self.nonlinear_data = True
self.seconds = 3000
self.events = 8
self.sim_num_channels = 1
self.sim_ar_order = 2
self.gamma_shape = 14
self.gamma_scale = 14
self.noise_std = 2.5
self.lambda_exp = 0.005
self.ar_shrink = 1.0
self.freqs = []
self.ar_noise_std = np.random.rand(self.events) / 5 + 0.8
self.max_len = 1000
# AR model arguments
self.order = 20
self.uni = True
self.save_AR = True
self.do_anal = False
self.AR_load_path = [os.path.join( # path(s) to save model and others
'results',
'cichy_epoched',
'subj1',
'cont_quantized',
'AR_uni')]
# unused
self.num_plot = 1
self.plot_ch = 1
self.linear = False
self.num_samples_CPC = 20
self.dropout2d_bad = False
self.k_CPC = 1
self.conv1x1_groups = 1
self.pos_enc_type = 'cat'
self.pos_enc_d = 128
self.l1_loss = False
self.norm_alpha = self.alpha_norm
self.num_components = 0
self.resample = 7
self.save_norm = True
self.norm_path = os.path.join(data_path, 'norm_coeff')
self.pca_path = os.path.join(data_path, 'pca128_model')
self.load_pca = False
self.compare_model = False
self.channel_idx = 0