Skip to content

Commit

Permalink
Merge pull request #73 from Mohityadav797693/main
Browse files Browse the repository at this point in the history
segmentation: add a function for loading base model weights for transfer learning #21
  • Loading branch information
ahiliitb authored Oct 31, 2023
2 parents e20b255 + 3e5f9a7 commit 5c86b74
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion UniTrain/dataset/DCGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ def __getitem__(self, idx):
if self.transform is not None:
image = self.transform(image)

return image, target
return image, target
22 changes: 17 additions & 5 deletions UniTrain/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,26 @@
from torch.utils.data import Dataset
import cv2
import torchvision.transforms as transforms
import torchvision.models as models

class SegmentationDataset(Dataset):
def __init__(self, image_paths: list, mask_paths: list, transform=None):
def __init__(self, image_paths: list, mask_paths: list, transform=None, base_model_weights_path=None):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform

# Load base model for transfer learning
self.base_model = self.load_base_model(base_model_weights_path)

def load_base_model(self, weights_path):
if weights_path is not None:
base_model = models.segmentation.deeplabv3_resnet50(pretrained=False)
base_model.load_state_dict(torch.load(weights_path))
else:
# Load a default model if weights_path is not provided
base_model = models.segmentation.deeplabv3_resnet50(pretrained=True)
return base_model

def __len__(self):
return len(self.image_paths)

Expand All @@ -28,11 +41,10 @@ def __getitem__(self, index):

if self.transform is not None:
image = self.transform(image)

mask_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
mask = mask_transform(mask)
mask = mask.to(torch.long)

# You may need to further preprocess the mask if required
# Example: Convert mask to tensor and perform class mapping


return image, mask

0 comments on commit 5c86b74

Please sign in to comment.