Skip to content

Commit

Permalink
refactor: test.py is to be replaced by test_test.py, where result…
Browse files Browse the repository at this point in the history
… is bad
  • Loading branch information
HanX committed Jun 10, 2023
1 parent 3697d0e commit 8fbc388
Show file tree
Hide file tree
Showing 156 changed files with 3,016 additions and 68 deletions.
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,10 @@
EXP
# Ignore logs and data files
EXP/
data/
segment_anything/sam_vit_h_4b8939.pth

# Ignore compiled binaries and cache
.ipynb_checkpoints
__pycache__

# Ignore editor files
24 changes: 15 additions & 9 deletions .ipynb_checkpoints/model-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,15 @@ def weights_init(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
print("111")

if isinstance(m, nn.BatchNorm2d):
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)
print("222")

else:
print("333")


def forward(self, input, sem):

Expand Down Expand Up @@ -160,12 +166,12 @@ def __init__(self, weights):
self.enhance = EnhanceNetwork(layers=1, channels=3)
self._criterion = LossFunction()

base_weights = torch.load(weights)
pretrained_dict = base_weights
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
# base_weights = torch.load(weights)
# pretrained_dict = base_weights
# model_dict = self.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# self.load_state_dict(model_dict)

def weights_init(self, m):
if isinstance(m, nn.Conv2d):
Expand All @@ -175,8 +181,8 @@ def weights_init(self, m):
if isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)

def forward(self, input):
i = self.enhance(input)
def forward(self, input, sem):
i = self.enhance(input, sem)
r = input / i
r = torch.clamp(r, 0, 1)
return i, r
Expand Down
2 changes: 1 addition & 1 deletion .ipynb_checkpoints/test-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
parser.add_argument('--data_path', type=str, default='./data/medium',
help='location of the data corpus')
parser.add_argument('--save_path', type=str, default='./results/medium', help='location of the data corpus')
parser.add_argument('--model', type=str, default='./weights/medium.pt', help='location of the data corpus')
parser.add_argument('--model', type=str, default='./EXP/Train-20230610-19:32:31/model_epochs/weights_24.pt', help='location of the data corpus')
parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
parser.add_argument('--seed', type=int, default=2, help='random seed')

Expand Down
33 changes: 22 additions & 11 deletions .ipynb_checkpoints/test_train-checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from multi_read_data import MemoryFriendlyLoader

# 该脚本命令行参数 可选项
parser = argparse.ArgumentParser("SCI")
parser = argparse.ArgumentParser("enlighten-anything")
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--cuda', type=bool, default=True, help='Use CUDA to train model')
parser.add_argument('--gpu', type=str, default='0', help='gpu device id')
Expand All @@ -27,6 +27,8 @@
parser.add_argument('--stage', type=int, default=3, help='epochs')
parser.add_argument('--save', type=str, default='EXP/', help='location of the data corpus')
parser.add_argument('--pretrain', type=str, default=None, help='pretrained weights directory')
parser.add_argument('--train_dir', type=str, default='data/LOL/train480/low', help='training data directory')
parser.add_argument('--val_dir', type=str, default='data/LOL/val5/low', help='training data directory')
args = parser.parse_args()

# 根据命令行参数进行设置
Expand Down Expand Up @@ -54,12 +56,18 @@ def save_images(tensor, path):

def model_init(model):
if(args.pretrain==None):
model.enhance.in_conv.apply(model.weights_init)
model.enhance.conv.apply(model.weights_init)
model.enhance.out_conv.apply(model.weights_init)
model.calibrate.in_conv.apply(model.weights_init)
model.calibrate.convs.apply(model.weights_init)
model.calibrate.out_conv.apply(model.weights_init)
# model.enhance.in_conv.apply(model.weights_init)
# model.enhance.conv.apply(model.weights_init)
# model.enhance.out_conv.apply(model.weights_init)
# model.calibrate.in_conv.apply(model.weights_init)
# model.calibrate.convs.apply(model.weights_init)
# model.calibrate.out_conv.apply(model.weights_init)

# model.enhance.apply(model.weights_init)
# model.calibrate.apply(model.weights_init)

model.apply(model.weights_init)

else:
pretrained_dict = torch.load(args.pretrain)
model_dict = model.state_dict()
Expand Down Expand Up @@ -92,6 +100,7 @@ def main():
# 模型
model = Network(stage=args.stage)
model_init(model)
return
# GPU训练的准备2: 模型放到GPU
model = model.cuda()
# 打一个日志记录模型大小
Expand All @@ -102,9 +111,11 @@ def main():
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr*100, betas=(0.9, 0.999), weight_decay=3e-4)

# 数据集
TrainDataset = MemoryFriendlyLoader(img_dir='../LOL/train480/low', sem_dir = '../LOL/train480/high_semantic')
ValDataset = MemoryFriendlyLoader(img_dir='../LOL/val5/low', sem_dir = '../LOL/val5/high_semantic')
# 数据集
TrainDataset = MemoryFriendlyLoader(img_dir=args.train_dir, #'../LOL/train480/semantic'
sem_dir = os.path.join(os.path.split(args.train_dir)[0], 'high_semantic'))
ValDataset = MemoryFriendlyLoader(img_dir=args.val_dir,
sem_dir = os.path.join(os.path.split(args.val_dir)[0], 'high_semantic'))
# from torch.utils.data import RandomSampler
train_queue = torch.utils.data.DataLoader(
TrainDataset, batch_size=args.batch_size,
Expand Down Expand Up @@ -143,7 +154,7 @@ def main():

logging.info('train: epoch %3d: average_loss %f', epoch, np.average(losses))
logging.info('----------')
utils.save(model, os.path.join(model_path, 'weights_%d.pt' % epoch))
utils.save(model, os.path.join(model_path, f'weights_{epoch}.pt'))


model.eval()
Expand Down
18 changes: 9 additions & 9 deletions .ipynb_checkpoints/test_val-checkpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 21,
"id": "1d2c42f3-aeaf-4340-ab92-4916c90625c6",
"metadata": {},
"outputs": [],
"source": [
"ValDataset = MemoryFriendlyLoader(img_dir='../LOL/val5/low', sem_dir = '../LOL/val5/high_semantic')\n",
"val_queue = torch.utils.data.DataLoader(\n",
" ValDataset, batch_size=1,\n",
" ValDataset, batch_size=2,\n",
" pin_memory=True, num_workers=0, shuffle=True\n",
")"
]
Expand All @@ -48,26 +48,26 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 24,
"id": "41bba849-5c16-4b41-8758-730df5017151",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 102\n",
"1 101\n",
"2 100\n",
"3 104\n",
"4 103\n"
"torch.Size([2, 3, 400, 600])\n"
]
}
],
"source": [
"epoch = 10\n",
"with torch.no_grad():\n",
" for batch_idx, (in_, sem_, imgname_, semname_ ) in enumerate(val_queue):\n",
" print(batch_idx, os.path.splitext(imgname_[0])[0])"
" # iaa = os.path.splitext(imgname_[0])[0]\n",
" # print(batch_idx, f'{iaa}_{epoch}')\n",
" print(model(in_, sem_)[0][0].shape)\n",
" break"
]
}
],
Expand Down
Binary file removed Figs/.ipynb_checkpoints/Det_1-checkpoint.png
Binary file not shown.
Binary file removed Figs/Det_1.png
Binary file not shown.
Binary file removed Figs/Firstfig.png
Binary file not shown.
Binary file removed Figs/Flowchart.png
Binary file not shown.
Binary file removed Figs/LLIE_1.png
Binary file not shown.
Binary file removed Figs/LLIE_2.png
Binary file not shown.
Binary file removed Figs/LLIE_3.png
Binary file not shown.
Binary file removed Figs/LLIE_4.png
Binary file not shown.
Binary file removed Figs/Seg_1.png
Binary file not shown.
Binary file removed Figs/Seg_2.png
Binary file not shown.
1 change: 0 additions & 1 deletion Figs/hello.md

This file was deleted.

Binary file modified __pycache__/model.cpython-38.pyc
Binary file not shown.
Binary file modified __pycache__/multi_read_data.cpython-38.pyc
Binary file not shown.
Binary file removed data/difficult/1482.png
Binary file not shown.
Binary file removed data/difficult/2015_02057.jpg
Binary file not shown.
Binary file removed data/difficult/2015_02154.jpg
Binary file not shown.
Binary file removed data/difficult/2015_02243.jpg
Binary file not shown.
Binary file removed data/difficult/2015_05885.jpg
Binary file not shown.
Binary file removed data/difficult/2015_05886.jpg
Binary file not shown.
Binary file removed data/difficult/234.png
Binary file not shown.
Binary file removed data/difficult/252.png
Binary file not shown.
Binary file removed data/difficult/2850.png
Binary file not shown.
Binary file removed data/difficult/372.png
Binary file not shown.
Binary file removed data/difficult/644.png
Binary file not shown.
Binary file removed data/difficult/645.png
Binary file not shown.
Binary file removed data/difficult/NPE_89.png
Binary file not shown.
Binary file removed data/easy/0016.png
Binary file not shown.
Binary file removed data/easy/0018.png
Binary file not shown.
Binary file removed data/easy/0224.png
Diff not rendered.
Binary file removed data/easy/0245.png
Diff not rendered.
Binary file removed data/easy/0528.png
Diff not rendered.
Binary file removed data/easy/0534.png
Diff not rendered.
Binary file removed data/easy/0908.png
Diff not rendered.
Binary file removed data/easy/0934.png
Diff not rendered.
Binary file removed data/easy/1191.png
Diff not rendered.
Binary file removed data/easy/1217.png
Diff not rendered.
Binary file removed data/finetune/LIME_9.bmp
Binary file not shown.
Binary file removed data/medium/.ipynb_checkpoints/00001-checkpoint.png
Diff not rendered.
Binary file removed data/medium/.ipynb_checkpoints/00051-checkpoint.png
Diff not rendered.
Binary file removed data/medium/00001.png
Diff not rendered.
Binary file removed data/medium/00051.png
Diff not rendered.
Binary file removed data/medium/00079.png
Diff not rendered.
Binary file removed data/medium/00091.png
Diff not rendered.
Binary file removed data/medium/2062.jpg
Diff not rendered.
Binary file removed data/medium/2064.jpg
Diff not rendered.
Binary file removed data/medium/3008.jpg
Diff not rendered.
Binary file removed data/medium/3018.jpg
Diff not rendered.
Binary file removed data/medium/3020.jpg
Diff not rendered.
Binary file removed data/medium/NPE_71.png
Diff not rendered.
24 changes: 15 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,15 @@ def weights_init(self, m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.02)
m.bias.data.zero_()
print("111")

if isinstance(m, nn.BatchNorm2d):
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)
print("222")

else:
print("333")


def forward(self, input, sem):

Expand Down Expand Up @@ -160,12 +166,12 @@ def __init__(self, weights):
self.enhance = EnhanceNetwork(layers=1, channels=3)
self._criterion = LossFunction()

base_weights = torch.load(weights)
pretrained_dict = base_weights
model_dict = self.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
self.load_state_dict(model_dict)
# base_weights = torch.load(weights)
# pretrained_dict = base_weights
# model_dict = self.state_dict()
# pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# model_dict.update(pretrained_dict)
# self.load_state_dict(model_dict)

def weights_init(self, m):
if isinstance(m, nn.Conv2d):
Expand All @@ -175,8 +181,8 @@ def weights_init(self, m):
if isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1., 0.02)

def forward(self, input):
i = self.enhance(input)
def forward(self, input, sem):
i = self.enhance(input, sem)
r = input / i
r = torch.clamp(r, 0, 1)
return i, r
Expand Down
Binary file removed results/difficult/1482.png
Diff not rendered.
Binary file removed results/difficult/2015_02057.png
Diff not rendered.
Binary file removed results/difficult/2015_02154.png
Diff not rendered.
Binary file removed results/difficult/2015_02243.png
Diff not rendered.
Binary file removed results/difficult/2015_05885.png
Diff not rendered.
Binary file removed results/difficult/2015_05886.png
Diff not rendered.
Binary file removed results/difficult/234.png
Diff not rendered.
Binary file removed results/difficult/252.png
Diff not rendered.
Binary file removed results/difficult/2850.png
Diff not rendered.
Binary file removed results/difficult/372.png
Diff not rendered.
Binary file removed results/difficult/644.png
Diff not rendered.
Binary file removed results/difficult/645.png
Diff not rendered.
Binary file removed results/difficult/NPE_89.png
Diff not rendered.
Binary file removed results/easy/0016.png
Diff not rendered.
Binary file removed results/easy/0018.png
Diff not rendered.
Binary file removed results/easy/0224.png
Diff not rendered.
Binary file removed results/easy/0245.png
Diff not rendered.
Binary file removed results/easy/0528.png
Diff not rendered.
Binary file removed results/easy/0534.png
Diff not rendered.
Binary file removed results/easy/0908.png
Diff not rendered.
Binary file removed results/easy/0934.png
Diff not rendered.
Binary file removed results/easy/1191.png
Diff not rendered.
Binary file removed results/easy/1217.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_100_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_10_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_20_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_30_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_40_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_50_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_60_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_70_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_80_ref_.png
Diff not rendered.
Binary file removed results/finetune/LIME_9_90_ref_.png
Diff not rendered.
Binary file removed results/medium/00001.png
Diff not rendered.
Binary file removed results/medium/00051.png
Diff not rendered.
Binary file removed results/medium/00079.png
Diff not rendered.
Binary file removed results/medium/00091.png
Diff not rendered.
Binary file removed results/medium/2062.png
Diff not rendered.
Binary file removed results/medium/2064.png
Diff not rendered.
Binary file removed results/medium/3008.png
Diff not rendered.
Binary file removed results/medium/3018.png
Diff not rendered.
Binary file removed results/medium/3020.png
Diff not rendered.
Binary file removed results/medium/NPE_71.png
Diff not rendered.
67 changes: 67 additions & 0 deletions sam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# This script batch segments pictures in <source> directory using SAM(Segment Anything).
# And save the semantic pictures in <source_semantic> for colored versions,
# and <source_semanticB> for binary versions

# Author: Qihan Zhao

# Prerequisite:
# 1. download this script in directory `SAM`
# 2. download SAM weights in directory `SAM`
# 3. add <source> file in directory `SAM`

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

import argparse
parser = argparse.ArgumentParser("enlighten-anything")
parser.add_argument('--source_dir', type=str, default='data/LOL/test15/low', help='directory of data to be segmented')
args = parser.parse_args()




sam_checkpoint = "./segment_anything/sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)


import os
sourcedir = args.source_dir
for i, filename in enumerate(os.listdir(sourcedir)):
if filename.endswith('.jpg') or filename.endswith('.png'):
print(f'{i}th pic: {filename}')
# read image
img_path = os.path.join(sourcedir, filename)
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# segment
masks = mask_generator.generate(image)

# save semantic
img = np.ones((masks[0]['segmentation'].shape[0], masks[0]['segmentation'].shape[1], 4))
img[:,:,3] = 0

# Binary
os.makedirs(f'{sourcedir}_semanticB', exist_ok=True)
for i, mask in enumerate(masks):
save_path = os.path.join(f'{sourcedir}_semanticB', f'{os.path.splitext(filename)[0]}_semanticB_{i}.png')
cv2.imwrite(save_path, np.uint8(mask["segmentation"]) * 255)

#Color
os.makedirs(f'{sourcedir}_semantic', exist_ok=True)
save_path = os.path.join(f'{sourcedir}_semantic', f'{os.path.splitext(filename)[0]}_semantic.png')
print(save_path)
for i, mask in enumerate(masks):
mask_bool = mask['segmentation']
color_mask = np.concatenate([np.random.random(3), [0.35]])
img[mask_bool] = color_mask
cv2.imwrite(save_path, np.uint8(img * 255))

15 changes: 15 additions & 0 deletions segment_anything/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .build_sam import (
build_sam,
build_sam_vit_h,
build_sam_vit_l,
build_sam_vit_b,
sam_model_registry,
)
from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator
Loading

0 comments on commit 8fbc388

Please sign in to comment.