85 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			85 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import print_function
 | |
| import numpy as np
 | |
| from PIL import Image
 | |
| import pickle as pkl
 | |
| import os, cv2, csv, glob
 | |
| import torch
 | |
| import torch.utils.data as data
 | |
| 
 | |
| 
 | |
| class TieredImageNet(data.Dataset):
 | |
| 
 | |
|   def __init__(self, root_dir, split, transform=None):
 | |
|     self.split = split
 | |
|     self.root_dir = root_dir
 | |
|     self.transform = transform
 | |
|     splits = split.split('-')
 | |
| 
 | |
|     images, labels, last = [], [], 0
 | |
|     for split in splits:
 | |
|       labels_name = '{:}/{:}_labels.pkl'.format(self.root_dir, split)
 | |
|       images_name = '{:}/{:}_images.npz'.format(self.root_dir, split)
 | |
|       # decompress images if npz not exits
 | |
|       if not os.path.exists(images_name):
 | |
|         png_pkl = images_name[:-4] + '_png.pkl'
 | |
|         if os.path.exists(png_pkl):
 | |
|           decompress(images_name, png_pkl)
 | |
|         else:
 | |
|           raise ValueError('png_pkl {:} not exits'.format( png_pkl ))
 | |
|       assert os.path.exists(images_name) and os.path.exists(labels_name), '{:} & {:}'.format(images_name, labels_name)
 | |
|       print ("Prepare {:} done".format(images_name))
 | |
|       try:
 | |
|         with open(labels_name) as f:
 | |
|           data = pkl.load(f)
 | |
|           label_specific = data["label_specific"]
 | |
|       except:
 | |
|         with open(labels_name, 'rb') as f:
 | |
|           data = pkl.load(f, encoding='bytes')
 | |
|           label_specific = data[b'label_specific']
 | |
|       with np.load(images_name, mmap_mode="r", encoding='latin1') as data:
 | |
|         image_data = data["images"]
 | |
|       images.append( image_data )
 | |
|       label_specific = label_specific + last
 | |
|       labels.append( label_specific )
 | |
|       last = np.max(label_specific) + 1
 | |
|       print ("Load {:} done, with image shape = {:}, label shape = {:}, [{:} ~ {:}]".format(images_name, image_data.shape, label_specific.shape, np.min(label_specific), np.max(label_specific)))
 | |
|     images, labels = np.concatenate(images), np.concatenate(labels)
 | |
| 
 | |
|     self.images = images
 | |
|     self.labels = labels
 | |
|     self.n_classes = int( np.max(labels) + 1 )
 | |
|     self.dict_index_label = {}
 | |
|     for cls in range(self.n_classes):
 | |
|       idxs = np.where(labels==cls)[0]
 | |
|       self.dict_index_label[cls] = idxs
 | |
|     self.length = len(labels)
 | |
|     print ("There are {:} images, {:} labels [{:} ~ {:}]".format(images.shape, labels.shape, np.min(labels), np.max(labels)))
 | |
|   
 | |
| 
 | |
|   def __repr__(self):
 | |
|     return ('{name}(length={length}, classes={n_classes})'.format(name=self.__class__.__name__, **self.__dict__))
 | |
| 
 | |
|   def __len__(self):
 | |
|     return self.length
 | |
| 
 | |
|   def __getitem__(self, index):
 | |
|     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
 | |
|     image = self.images[index].copy()
 | |
|     label = int(self.labels[index])
 | |
|     image = Image.fromarray(image[:,:,::-1].astype('uint8'), 'RGB')
 | |
|     if self.transform is not None:
 | |
|       image = self.transform( image )
 | |
|     return image, label
 | |
| 
 | |
| 
 | |
| 
 | |
| 
 | |
| def decompress(path, output):
 | |
|   with open(output, 'rb') as f:
 | |
|     array = pkl.load(f, encoding='bytes')
 | |
|   images = np.zeros([len(array), 84, 84, 3], dtype=np.uint8)
 | |
|   for ii, item in enumerate(array):
 | |
|     im = cv2.imdecode(item, 1)
 | |
|     images[ii] = im
 | |
|   np.savez(path, images=images)
 |