Skip to content

Commit

Permalink
feat: Network_woCalibrate for finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
HanX committed Jun 12, 2023
1 parent 665dbcc commit da5b5ab
Show file tree
Hide file tree
Showing 87 changed files with 405 additions and 24,953 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,6 @@ segment_anything/sam_vit_h_4b8939.pth
__pycache__

# Ignore editor files

# Ignore temporary files
tmp/
32 changes: 15 additions & 17 deletions .ipynb_checkpoints/model-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, channels):
super(SemanticFusionUnit, self).__init__()

self.conv = nn.Sequential(
nn.Conv2d(in_channels=2*channels, out_channels=channels, kernel_size=3, stride=1, padding=1),
nn.Conv2d(in_channels=channels+3, out_channels=channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(channels),
nn.ReLU()
)
Expand Down Expand Up @@ -58,6 +58,7 @@ def __init__(self, layers, channels):

def forward(self, input, sem):
fea = self.in_conv(input)

fea = fea + self.fusion(fea, sem)
for conv in self.blocks:
fea = fea + conv(fea)
Expand All @@ -82,6 +83,8 @@ def __init__(self, layers, channels):
nn.BatchNorm2d(channels),
nn.ReLU()
)

# self.fusion = SemanticFusionUnit(channels)

self.convs = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, stride=1, padding=padding),
Expand All @@ -100,8 +103,10 @@ def __init__(self, layers, channels):
nn.Sigmoid()
)

def forward(self, input):
def forward(self, input, sem):
fea = self.in_conv(input)

# fea = fea + self.fusion(fea, sem)
for conv in self.blocks:
fea = fea + conv(fea)

Expand Down Expand Up @@ -144,7 +149,7 @@ def forward(self, input, sem):
i = self.enhance(input_op, sem)
r = input / i
r = torch.clamp(r, 0, 1)
att = self.calibrate(r)
att = self.calibrate(r, sem)
input_op = input + att
ilist.append(i)
rlist.append(r)
Expand All @@ -160,27 +165,19 @@ def _loss(self, input, sem):
return loss


class Network_woCalibrate(nn.Module):

class Finetunemodel(nn.Module):

def __init__(self, weights):
super(Finetunemodel, self).__init__()
def __init__(self):
super().__init__()
self.enhance = EnhanceNetwork(layers=1, channels=3)
self._criterion = LossFunction()

# base_weights = torch.load(weights)
# pretrained_dict = base_weights
# model_dict = self.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# self.load_state_dict(model_dict)

def weights_init(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()

if isinstance(m, nn.BatchNorm2d):
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)

def forward(self, input, sem):
Expand All @@ -190,8 +187,9 @@ def forward(self, input, sem):
return i, r


def _loss(self, input):
i, r = self(input)
def _loss(self, input, sem):
i, r = self(input, sem)
loss = self._criterion(input, i)
return loss


61 changes: 38 additions & 23 deletions .ipynb_checkpoints/test-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,56 +7,71 @@
import torch.backends.cudnn as cudnn
from PIL import Image
from torch.autograd import Variable
from model import Finetunemodel
from model import Network_woCalibrate

from multi_read_data import MemoryFriendlyLoader

parser = argparse.ArgumentParser("SCI")
parser.add_argument('--data_path', type=str, default='./data/medium',
help='location of the data corpus')
parser.add_argument('--save_path', type=str, default='./results/medium', help='location of the data corpus')
parser.add_argument('--model', type=str, default='./EXP/Train-20230610-19:32:31/model_epochs/weights_24.pt', help='location of the data corpus')
parser = argparse.ArgumentParser("enlighten-anything")
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')

parser.add_argument('--weights', type=str, default="weights/weights_finetune.pt", help='weights after training with semantic')
parser.add_argument('--test_dir', type=str, default='data/LOL/test15/low', help='testing data directory')
parser.add_argument('--test_output_dir', type=str, default='test_output', help='testing output directory')
args = parser.parse_args()
save_path = args.save_path
os.makedirs(save_path, exist_ok=True)

TestDataset = MemoryFriendlyLoader(img_dir=args.data_path, task='test')

test_queue = torch.utils.data.DataLoader(
TestDataset, batch_size=1,
pin_memory=True, num_workers=0)
save_path = args.test_output_dir
os.makedirs(save_path, exist_ok=True)

import subprocess
print("sam is working...")
# subprocess.call(['python', 'sam.py', '--source_dir', args.test_dir])
print("sam is done...")

def save_images(tensor, path):
image_numpy = tensor[0].cpu().float().numpy() # test_queue's batch_size==1
image_numpy = tensor[0].cpu().float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
im = Image.fromarray(np.clip(image_numpy * 255.0, 0, 255.0).astype('uint8'))
im.save(path, 'png')

def model_init(model):
weights_dict = torch.load(args.weights)
model_dict = model.state_dict()
weights_dict = {k: v for k, v in weights_dict.items() if k in model_dict}
model_dict.update(weights_dict)
model.load_state_dict(model_dict)

def main():
if not torch.cuda.is_available():
print('no gpu device available')
sys.exit(1)

model = Finetunemodel(args.model)

#TODO:SAM预处理
#weights的保存
# 推理网络

TestDataset = MemoryFriendlyLoader(img_dir = args.test_dir,
sem_dir = os.path.join(os.path.split(args.test_dir)[0], 'low_semantic'))
test_queue = torch.utils.data.DataLoader(
TestDataset, batch_size=1, shuffle = False,
pin_memory=True
)

model = Network_woCalibrate(args.weights)
model_init(model)
model = model.cuda()

model.eval()
with torch.no_grad():
for _, (input, image_name) in enumerate(test_queue):
input = Variable(input, volatile=True).cuda()
image_name = image_name[0].split('\\')[-1].split('.')[0]
i, r = model(input)
for batch_idx, (in_, sem_, imgname_, semname_ ) in enumerate(test_queue):
in_ = in_.cuda()
sem_ = sem_.cuda()
image_name = os.path.splitext(imgname_[0])[0]
i, r = model(in_, sem_)
u_name = '%s.png' % (image_name)
print('processing {}'.format(u_name))
print('test processing {}'.format(u_name))
u_path = save_path + '/' + u_name
save_images(r, u_path)



if __name__ == '__main__':
main()
Loading

0 comments on commit da5b5ab

Please sign in to comment.