2024-04-15 13:46:41 +02:00
|
|
|
import os
|
|
|
|
|
|
|
|
from torch.utils.data import Dataset
|
|
|
|
from utils import *
|
|
|
|
from torchvision import transforms
|
|
|
|
transform = transforms.Compose([
|
|
|
|
transforms.ToTensor()
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
#use VOC2007 Dataset
|
|
|
|
class MyDataset(Dataset):
|
|
|
|
def __init__(self, path):
|
|
|
|
self.path = path
|
|
|
|
self.name = os.listdir(os.path.join(path, 'SegmentationClass'))
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.name)
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
segment_name = self.name[index] #xx.png
|
|
|
|
segment_path = os.path.join(self.path, 'SegmentationClass',segment_name)
|
|
|
|
image_path = os.path.join(self.path,'JPEGImages', segment_name.replace('png','jpg'))
|
|
|
|
segment_image = keep_image_size_open(segment_path)
|
|
|
|
image = keep_image_size_open(image_path)
|
|
|
|
return transform(image), transform(segment_image)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
data = MyDataset('/Users/hanzhangma/Document/DataSet/VOC2007')
|
|
|
|
print(data[0][0].shape) # print the size of image(0,0)
|
2024-04-15 10:30:04 +02:00
|
|
|
print(data[0][1].shape) # print the size of image(0,1)
|