MasterThesis/UNet/data.py
2024-04-15 13:46:41 +02:00

31 lines
1.0 KiB
Python

import os
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor()
])
#use VOC2007 Dataset
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.name = os.listdir(os.path.join(path, 'SegmentationClass'))
def __len__(self):
return len(self.name)
def __getitem__(self, index):
segment_name = self.name[index] #xx.png
segment_path = os.path.join(self.path, 'SegmentationClass',segment_name)
image_path = os.path.join(self.path,'JPEGImages', segment_name.replace('png','jpg'))
segment_image = keep_image_size_open(segment_path)
image = keep_image_size_open(image_path)
return transform(image), transform(segment_image)
if __name__ == '__main__':
data = MyDataset('/Users/hanzhangma/Document/DataSet/VOC2007')
print(data[0][0].shape) # print the size of image(0,0)
print(data[0][1].shape) # print the size of image(0,1)