Skip to content

Commit

Permalink
segmentation_updated
Browse files Browse the repository at this point in the history
  • Loading branch information
mohit committed Oct 30, 2023
1 parent 262213c commit 3e5f9a7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
18 changes: 1 addition & 17 deletions UniTrain/dataset/DCGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from PIL import Image
import torch
from torch.utils.data import Dataset
import wandb

class DCGANdataset:
def __init__(self, data_dir, transform=None):
Expand Down Expand Up @@ -35,19 +34,4 @@ def __getitem__(self, idx):
if self.transform is not None:
image = self.transform(image)

return image, target

# Initialize wandb
wandb.init(project="your_project_name", entity="your_entity_name")

# Your training code here
# Log the model architecture
wandb.watch(model)

for epoch in range(num_epochs):
# Your training code here
wandb.log({"model_weights": model.state_dict()})
# Log images or other metrics as needed

# Finish the wandb run
wandb.finish()
return image, target
22 changes: 17 additions & 5 deletions UniTrain/dataset/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,26 @@
from torch.utils.data import Dataset
import cv2
import torchvision.transforms as transforms
import torchvision.models as models

class SegmentationDataset(Dataset):
def __init__(self, image_paths: list, mask_paths: list, transform=None):
def __init__(self, image_paths: list, mask_paths: list, transform=None, base_model_weights_path=None):
self.image_paths = image_paths
self.mask_paths = mask_paths
self.transform = transform

# Load base model for transfer learning
self.base_model = self.load_base_model(base_model_weights_path)

def load_base_model(self, weights_path):
if weights_path is not None:
base_model = models.segmentation.deeplabv3_resnet50(pretrained=False)
base_model.load_state_dict(torch.load(weights_path))
else:
# Load a default model if weights_path is not provided
base_model = models.segmentation.deeplabv3_resnet50(pretrained=True)
return base_model

def __len__(self):
return len(self.image_paths)

Expand All @@ -28,11 +41,10 @@ def __getitem__(self, index):

if self.transform is not None:
image = self.transform(image)

mask_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
mask = mask_transform(mask)
mask = mask.to(torch.long)

# You may need to further preprocess the mask if required
# Example: Convert mask to tensor and perform class mapping


return image, mask

0 comments on commit 3e5f9a7

Please sign in to comment.