diff --git a/UniTrain/dataset/DCGAN.py b/UniTrain/dataset/DCGAN.py index 65312c6..2f60eaf 100644 --- a/UniTrain/dataset/DCGAN.py +++ b/UniTrain/dataset/DCGAN.py @@ -3,7 +3,6 @@ from PIL import Image import torch from torch.utils.data import Dataset -import wandb class DCGANdataset: def __init__(self, data_dir, transform=None): @@ -35,19 +34,4 @@ def __getitem__(self, idx): if self.transform is not None: image = self.transform(image) - return image, target - -# Initialize wandb -wandb.init(project="your_project_name", entity="your_entity_name") - -# Your training code here -# Log the model architecture -wandb.watch(model) - -for epoch in range(num_epochs): - # Your training code here - wandb.log({"model_weights": model.state_dict()}) - # Log images or other metrics as needed - -# Finish the wandb run -wandb.finish() + return image, target \ No newline at end of file diff --git a/UniTrain/dataset/segmentation.py b/UniTrain/dataset/segmentation.py index d932161..af23fc5 100644 --- a/UniTrain/dataset/segmentation.py +++ b/UniTrain/dataset/segmentation.py @@ -4,13 +4,26 @@ from torch.utils.data import Dataset import cv2 import torchvision.transforms as transforms +import torchvision.models as models class SegmentationDataset(Dataset): - def __init__(self, image_paths: list, mask_paths: list, transform=None): + def __init__(self, image_paths: list, mask_paths: list, transform=None, base_model_weights_path=None): self.image_paths = image_paths self.mask_paths = mask_paths self.transform = transform + # Load base model for transfer learning + self.base_model = self.load_base_model(base_model_weights_path) + + def load_base_model(self, weights_path): + if weights_path is not None: + base_model = models.segmentation.deeplabv3_resnet50(pretrained=False) + base_model.load_state_dict(torch.load(weights_path)) + else: + # Load a default model if weights_path is not provided + base_model = models.segmentation.deeplabv3_resnet50(pretrained=True) + return base_model + def __len__(self): return len(self.image_paths) @@ -28,11 +41,10 @@ def __getitem__(self, index): if self.transform is not None: image = self.transform(image) - + mask_transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) mask = mask_transform(mask) mask = mask.to(torch.long) - - # You may need to further preprocess the mask if required - # Example: Convert mask to tensor and perform class mapping + + return image, mask