From 1a54fb7f71aff7f0f89bbf0afa08b6bd1edbbac6 Mon Sep 17 00:00:00 2001 From: kwotsin Date: Tue, 11 Aug 2020 01:29:51 +0800 Subject: [PATCH] formatting --- examples/sngan_example.py | 75 ++++++++++++++++++-------------------- examples/ssgan_tutorial.py | 48 ++++++++++++------------ 2 files changed, 59 insertions(+), 64 deletions(-) diff --git a/examples/sngan_example.py b/examples/sngan_example.py index d07ad8c..6f38f07 100644 --- a/examples/sngan_example.py +++ b/examples/sngan_example.py @@ -3,13 +3,14 @@ import torch_mimicry as mmc from torch_mimicry.nets import sngan - if __name__ == "__main__": # Data handling objects device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10') - dataloader = torch.utils.data.DataLoader( - dataset, batch_size=64, shuffle=True, num_workers=4) + dataloader = torch.utils.data.DataLoader(dataset, + batch_size=64, + shuffle=True, + num_workers=4) # Define models and optimizers netG = sngan.SNGANGenerator32().to(device) @@ -18,46 +19,42 @@ optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9)) # Start training - trainer = mmc.training.Trainer( - netD=netD, - netG=netG, - optD=optD, - optG=optG, - n_dis=5, - num_steps=100000, - lr_decay='linear', - dataloader=dataloader, - log_dir='./log/example', - device=device) + trainer = mmc.training.Trainer(netD=netD, + netG=netG, + optD=optD, + optG=optG, + n_dis=5, + num_steps=100000, + lr_decay='linear', + dataloader=dataloader, + log_dir='./log/example', + device=device) trainer.train() # Evaluate fid - mmc.metrics.evaluate( - metric='fid', - log_dir='./log/example', - netG=netG, - dataset_name='cifar10', - num_real_samples=50000, - num_fake_samples=50000, - evaluate_step=100000, - device=device) + mmc.metrics.evaluate(metric='fid', + log_dir='./log/example', + netG=netG, + dataset_name='cifar10', + num_real_samples=50000, + num_fake_samples=50000, + evaluate_step=100000, + device=device) # Evaluate kid - mmc.metrics.evaluate( - metric='kid', - log_dir='./log/example', - netG=netG, - dataset_name='cifar10', - num_subsets=50, - subset_size=1000, - evaluate_step=100000, - device=device) + mmc.metrics.evaluate(metric='kid', + log_dir='./log/example', + netG=netG, + dataset_name='cifar10', + num_subsets=50, + subset_size=1000, + evaluate_step=100000, + device=device) # Evaluate inception score - mmc.metrics.evaluate( - metric='inception_score', - log_dir='./log/example', - netG=netG, - num_samples=50000, - evaluate_step=100000, - device=device) \ No newline at end of file + mmc.metrics.evaluate(metric='inception_score', + log_dir='./log/example', + netG=netG, + num_samples=50000, + evaluate_step=100000, + device=device) diff --git a/examples/ssgan_tutorial.py b/examples/ssgan_tutorial.py index 3cf75ed..ff2dae7 100644 --- a/examples/ssgan_tutorial.py +++ b/examples/ssgan_tutorial.py @@ -13,6 +13,7 @@ from torch_mimicry.modules import SNLinear from torch_mimicry.modules import GBlock, DBlock, DBlockOptimized + ####################### # Models ####################### @@ -118,8 +119,6 @@ def __init__(self, ndf=128, loss_type='hinge', **kwargs): self.l_y = SNLinear(self.ndf, self.num_classes) nn.init.xavier_uniform_(self.l_y.weight.data, 1.0) - - def forward(self, x): """ Feedforwards a batch of real/fake images and produces a batch of GAN logits, @@ -141,7 +140,6 @@ def forward(self, x): return output, output_classes - def _rot_tensor(self, image, deg): """ Rotation for pytorch tensors using rotation matrix. Takes in a tensor of (C, H, W shape). @@ -216,7 +214,7 @@ def train_step(self, netG, optD, log_data, - device=None, + device=None, global_step=None, **kwargs): """ @@ -272,8 +270,10 @@ def train_step(self, # Data handling objects device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu") dataset = mmc.datasets.load_dataset(root='./datasets', name='cifar10') -dataloader = torch.utils.data.DataLoader( - dataset, batch_size=64, shuffle=True, num_workers=4) +dataloader = torch.utils.data.DataLoader(dataset, + batch_size=64, + shuffle=True, + num_workers=4) # Define models and optimizers netG = SSGANGenerator().to(device) @@ -282,28 +282,26 @@ def train_step(self, optG = optim.Adam(netG.parameters(), 2e-4, betas=(0.0, 0.9)) # Start training -trainer = mmc.training.Trainer( - netD=netD, - netG=netG, - optD=optD, - optG=optG, - n_dis=2, - num_steps=100000, - dataloader=dataloader, - log_dir=log_dir, - device=device) +trainer = mmc.training.Trainer(netD=netD, + netG=netG, + optD=optD, + optG=optG, + n_dis=2, + num_steps=100000, + dataloader=dataloader, + log_dir=log_dir, + device=device) trainer.train() ########################## # Evaluation ########################## # Evaluate fid -mmc.metrics.evaluate( - metric='fid', - log_dir=log_dir, - netG=netG, - dataset_name='cifar10', - num_real_samples=10000, - num_fake_samples=10000, - evaluate_step=100000, - device=device) \ No newline at end of file +mmc.metrics.evaluate(metric='fid', + log_dir=log_dir, + netG=netG, + dataset_name='cifar10', + num_real_samples=10000, + num_fake_samples=10000, + evaluate_step=100000, + device=device)