-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathmain.py
101 lines (81 loc) · 3.84 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import argparse, os, sys, datetime, glob, importlib
from torch.utils.data import random_split, DataLoader, Dataset
import lightning as L
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from lightning import seed_everything
from torch.utils.data.dataloader import default_collate as custom_collate
import torch
torch.set_float32_matmul_precision("high")
torch.backends.cudnn.deterministic = True #True
torch.backends.cudnn.benchmark = False #False
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(L.LightningDataModule):
def __init__(self, batch_size, train=None, validation=None, test=None,
wrap=False, num_workers=None):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size*2
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = self._val_dataloader
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = self._test_dataloader
self.wrap = wrap
def prepare_data(self):
for data_cfg in self.dataset_configs.values():
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict()
for k in self.dataset_configs:
if "pretrain" not in self.dataset_configs[k]["target"]: ##laion should use webdataset
self.datasets[k] = instantiate_from_config(self.dataset_configs[k])
else:
self.datasets[k] = instantiate_from_config(self.dataset_configs[k]).create_dataset()
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
"""
laion serves as the train loader
"""
if "pretrain" in self.dataset_configs["train"]["target"]: ## webdataset no need for shuffle=True
return DataLoader(self.datasets["train"], batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True)
else:
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
num_workers=self.num_workers, shuffle=True, collate_fn=custom_collate, pin_memory=True)
def _val_dataloader(self):
return DataLoader(self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=custom_collate, shuffle=False, pin_memory=True)
def _test_dataloader(self):
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
num_workers=self.num_workers, collate_fn=custom_collate, shuffle=False, pin_memory=True)
def main():
cli = LightningCLI(
save_config_kwargs={"overwrite": True},
)
if __name__ == "__main__":
main()