Updates
This commit is contained in:
		
							
								
								
									
										8
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -48,5 +48,13 @@ jobs: | ||||
|           ls | ||||
|           python --version | ||||
|           python -m pytest ./tests/test_basic_space.py -s | ||||
|         shell: bash | ||||
|  | ||||
|       - name: Test Synthetic Data | ||||
|         run: | | ||||
|           python -m pip install pytest numpy | ||||
|           python -m pip install parameterized | ||||
|           python -m pip install torch | ||||
|           python --version | ||||
|           python -m pytest ./tests/test_synthetic.py -s | ||||
|         shell: bash | ||||
|   | ||||
 Submodule .latent-data/NATS-Bench updated: 3a8794322f...33bfb2eb13
									
								
							| @@ -5,118 +5,133 @@ import os, sys, hashlib, torch | ||||
| import numpy as np | ||||
| from PIL import Image | ||||
| import torch.utils.data as data | ||||
|  | ||||
| if sys.version_info[0] == 2: | ||||
|   import cPickle as pickle | ||||
|     import cPickle as pickle | ||||
| else: | ||||
|   import pickle | ||||
|     import pickle | ||||
|  | ||||
|  | ||||
| def calculate_md5(fpath, chunk_size=1024 * 1024): | ||||
|   md5 = hashlib.md5() | ||||
|   with open(fpath, 'rb') as f: | ||||
|     for chunk in iter(lambda: f.read(chunk_size), b''): | ||||
|       md5.update(chunk) | ||||
|   return md5.hexdigest() | ||||
|     md5 = hashlib.md5() | ||||
|     with open(fpath, "rb") as f: | ||||
|         for chunk in iter(lambda: f.read(chunk_size), b""): | ||||
|             md5.update(chunk) | ||||
|     return md5.hexdigest() | ||||
|  | ||||
|  | ||||
| def check_md5(fpath, md5, **kwargs): | ||||
|   return md5 == calculate_md5(fpath, **kwargs) | ||||
|     return md5 == calculate_md5(fpath, **kwargs) | ||||
|  | ||||
|  | ||||
| def check_integrity(fpath, md5=None): | ||||
|   if not os.path.isfile(fpath): return False | ||||
|   if md5 is None: return True | ||||
|   else          : return check_md5(fpath, md5) | ||||
|     if not os.path.isfile(fpath): | ||||
|         return False | ||||
|     if md5 is None: | ||||
|         return True | ||||
|     else: | ||||
|         return check_md5(fpath, md5) | ||||
|  | ||||
|  | ||||
| class ImageNet16(data.Dataset): | ||||
|   # http://image-net.org/download-images | ||||
|   # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||
|   # https://arxiv.org/pdf/1707.08819.pdf | ||||
|     # http://image-net.org/download-images | ||||
|     # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||
|     # https://arxiv.org/pdf/1707.08819.pdf | ||||
|  | ||||
|   train_list = [ | ||||
|         ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], | ||||
|         ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], | ||||
|         ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], | ||||
|         ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], | ||||
|         ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], | ||||
|         ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], | ||||
|         ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], | ||||
|         ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], | ||||
|         ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], | ||||
|         ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], | ||||
|     train_list = [ | ||||
|         ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"], | ||||
|         ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"], | ||||
|         ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"], | ||||
|         ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"], | ||||
|         ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"], | ||||
|         ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"], | ||||
|         ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"], | ||||
|         ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"], | ||||
|         ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"], | ||||
|         ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"], | ||||
|     ] | ||||
|   valid_list = [ | ||||
|         ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], | ||||
|     valid_list = [ | ||||
|         ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"], | ||||
|     ] | ||||
|  | ||||
|   def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||
|     self.root      = root | ||||
|     self.transform = transform | ||||
|     self.train     = train  # training set or valid set | ||||
|     if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') | ||||
|     def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||
|         self.root = root | ||||
|         self.transform = transform | ||||
|         self.train = train  # training set or valid set | ||||
|         if not self._check_integrity(): | ||||
|             raise RuntimeError("Dataset not found or corrupted.") | ||||
|  | ||||
|     if self.train: downloaded_list = self.train_list | ||||
|     else         : downloaded_list = self.valid_list | ||||
|     self.data    = [] | ||||
|     self.targets = [] | ||||
|    | ||||
|     # now load the picked numpy arrays | ||||
|     for i, (file_name, checksum) in enumerate(downloaded_list): | ||||
|       file_path = os.path.join(self.root, file_name) | ||||
|       #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||
|       with open(file_path, 'rb') as f: | ||||
|         if sys.version_info[0] == 2: | ||||
|           entry = pickle.load(f) | ||||
|         if self.train: | ||||
|             downloaded_list = self.train_list | ||||
|         else: | ||||
|           entry = pickle.load(f, encoding='latin1') | ||||
|         self.data.append(entry['data']) | ||||
|         self.targets.extend(entry['labels']) | ||||
|     self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||
|     self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||
|     if use_num_of_class_only is not None: | ||||
|       assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) | ||||
|       new_data, new_targets = [], [] | ||||
|       for I, L in zip(self.data, self.targets): | ||||
|         if 1 <= L <= use_num_of_class_only: | ||||
|           new_data.append( I ) | ||||
|           new_targets.append( L ) | ||||
|       self.data    = new_data | ||||
|       self.targets = new_targets | ||||
|     #    self.mean.append(entry['mean']) | ||||
|     #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) | ||||
|     #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) | ||||
|     #print ('Mean : {:}'.format(self.mean)) | ||||
|     #temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) | ||||
|     #std_data  = np.std(temp, axis=0) | ||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|     #print ('Std  : {:}'.format(std_data)) | ||||
|             downloaded_list = self.valid_list | ||||
|         self.data = [] | ||||
|         self.targets = [] | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets)))) | ||||
|         # now load the picked numpy arrays | ||||
|         for i, (file_name, checksum) in enumerate(downloaded_list): | ||||
|             file_path = os.path.join(self.root, file_name) | ||||
|             # print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||
|             with open(file_path, "rb") as f: | ||||
|                 if sys.version_info[0] == 2: | ||||
|                     entry = pickle.load(f) | ||||
|                 else: | ||||
|                     entry = pickle.load(f, encoding="latin1") | ||||
|                 self.data.append(entry["data"]) | ||||
|                 self.targets.extend(entry["labels"]) | ||||
|         self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||
|         self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||
|         if use_num_of_class_only is not None: | ||||
|             assert ( | ||||
|                 isinstance(use_num_of_class_only, int) | ||||
|                 and use_num_of_class_only > 0 | ||||
|                 and use_num_of_class_only < 1000 | ||||
|             ), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only) | ||||
|             new_data, new_targets = [], [] | ||||
|             for I, L in zip(self.data, self.targets): | ||||
|                 if 1 <= L <= use_num_of_class_only: | ||||
|                     new_data.append(I) | ||||
|                     new_targets.append(L) | ||||
|             self.data = new_data | ||||
|             self.targets = new_targets | ||||
|         #    self.mean.append(entry['mean']) | ||||
|         # self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) | ||||
|         # self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) | ||||
|         # print ('Mean : {:}'.format(self.mean)) | ||||
|         # temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) | ||||
|         # std_data  = np.std(temp, axis=0) | ||||
|         # std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||
|         # print ('Std  : {:}'.format(std_data)) | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}({num} images, {classes} classes)".format( | ||||
|             name=self.__class__.__name__, | ||||
|             num=len(self.data), | ||||
|             classes=len(set(self.targets)), | ||||
|         ) | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     img, target = self.data[index], self.targets[index] - 1 | ||||
|     def __getitem__(self, index): | ||||
|         img, target = self.data[index], self.targets[index] - 1 | ||||
|  | ||||
|     img = Image.fromarray(img) | ||||
|         img = Image.fromarray(img) | ||||
|  | ||||
|     if self.transform is not None: | ||||
|       img = self.transform(img) | ||||
|         if self.transform is not None: | ||||
|             img = self.transform(img) | ||||
|  | ||||
|     return img, target | ||||
|         return img, target | ||||
|  | ||||
|   def __len__(self): | ||||
|     return len(self.data) | ||||
|     def __len__(self): | ||||
|         return len(self.data) | ||||
|  | ||||
|     def _check_integrity(self): | ||||
|         root = self.root | ||||
|         for fentry in self.train_list + self.valid_list: | ||||
|             filename, md5 = fentry[0], fentry[1] | ||||
|             fpath = os.path.join(root, filename) | ||||
|             if not check_integrity(fpath, md5): | ||||
|                 return False | ||||
|         return True | ||||
|  | ||||
|   def _check_integrity(self): | ||||
|     root = self.root | ||||
|     for fentry in (self.train_list + self.valid_list): | ||||
|       filename, md5 = fentry[0], fentry[1] | ||||
|       fpath = os.path.join(root, filename) | ||||
|       if not check_integrity(fpath, md5): | ||||
|         return False | ||||
|     return True | ||||
|  | ||||
| """ | ||||
| if __name__ == '__main__': | ||||
|   | ||||
| @@ -20,172 +20,282 @@ import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class LandmarkDataset(data.Dataset): | ||||
|     def __init__( | ||||
|         self, | ||||
|         transform, | ||||
|         sigma, | ||||
|         downsample, | ||||
|         heatmap_type, | ||||
|         shape, | ||||
|         use_gray, | ||||
|         mean_file, | ||||
|         data_indicator, | ||||
|         cache_images=None, | ||||
|     ): | ||||
|  | ||||
|   def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None): | ||||
|  | ||||
|     self.transform    = transform | ||||
|     self.sigma        = sigma | ||||
|     self.downsample   = downsample | ||||
|     self.heatmap_type = heatmap_type | ||||
|     self.dataset_name = data_indicator | ||||
|     self.shape        = shape # [H,W] | ||||
|     self.use_gray     = use_gray | ||||
|     assert transform is not None, 'transform : {:}'.format(transform) | ||||
|     self.mean_file    = mean_file | ||||
|     if mean_file is None: | ||||
|       self.mean_data  = None | ||||
|       warnings.warn('LandmarkDataset initialized with mean_data = None') | ||||
|     else: | ||||
|       assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file) | ||||
|       self.mean_data  = torch.load(mean_file) | ||||
|     self.reset() | ||||
|     self.cutout       = None | ||||
|     self.cache_images = cache_images | ||||
|     print ('The general dataset initialization done : {:}'.format(self)) | ||||
|     warnings.simplefilter( 'once' ) | ||||
|  | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|  | ||||
|  | ||||
|   def set_cutout(self, length): | ||||
|     if length is not None and length >= 1: | ||||
|       self.cutout = CutOut( int(length) ) | ||||
|     else: self.cutout = None | ||||
|  | ||||
|  | ||||
|   def reset(self, num_pts=-1, boxid='default', only_pts=False): | ||||
|     self.NUM_PTS = num_pts | ||||
|     if only_pts: return | ||||
|     self.length  = 0 | ||||
|     self.datas   = [] | ||||
|     self.labels  = [] | ||||
|     self.NormDistances = [] | ||||
|     self.BOXID = boxid | ||||
|     if self.mean_data is None: | ||||
|       self.mean_face = None | ||||
|     else: | ||||
|       self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) | ||||
|       assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face) | ||||
|     #assert self.dataset_name is not None, 'The dataset name is None' | ||||
|  | ||||
|  | ||||
|   def __len__(self): | ||||
|     assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length) | ||||
|     return self.length | ||||
|  | ||||
|  | ||||
|   def append(self, data, label, distance): | ||||
|     assert osp.isfile(data), 'The image path is not a file : {:}'.format(data) | ||||
|     self.datas.append( data )             ;  self.labels.append( label ) | ||||
|     self.NormDistances.append( distance ) | ||||
|     self.length = self.length + 1 | ||||
|  | ||||
|  | ||||
|   def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): | ||||
|     if reset: self.reset(num_pts, boxindicator) | ||||
|     else    : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts) | ||||
|     if isinstance(file_lists, str): file_lists = [file_lists] | ||||
|     samples = [] | ||||
|     for idx, file_path in enumerate(file_lists): | ||||
|       print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path)) | ||||
|       xdata = torch.load(file_path) | ||||
|       if isinstance(xdata, list)  : data = xdata          # image or video dataset list | ||||
|       elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list | ||||
|       else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) )) | ||||
|       samples = samples + data | ||||
|     # samples is a dict, where the key is the image-path and the value is the annotation | ||||
|     # each annotation is a dict, contains 'points' (3,num_pts), and various box | ||||
|     print ('GeneralDataset-V2 : {:} samples'.format(len(samples))) | ||||
|  | ||||
|     #for index, annotation in enumerate(samples): | ||||
|     for index in tqdm( range( len(samples) ) ): | ||||
|       annotation = samples[index] | ||||
|       image_path  = annotation['current_frame'] | ||||
|       points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)] | ||||
|       label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name) | ||||
|       if normalizeL is None: normDistance = None | ||||
|       else                 : normDistance = annotation['normalizeL-{:}'.format(normalizeL)] | ||||
|       self.append(image_path, label, normDistance) | ||||
|  | ||||
|     assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas)) | ||||
|     assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels)) | ||||
|     assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance)) | ||||
|     print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length)) | ||||
|  | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index) | ||||
|     if self.cache_images is not None and self.datas[index] in self.cache_images: | ||||
|       image = self.cache_images[ self.datas[index] ].clone() | ||||
|     else: | ||||
|       image = pil_loader(self.datas[index], self.use_gray) | ||||
|     target = self.labels[index].copy() | ||||
|     return self._process_(image, target, index) | ||||
|  | ||||
|  | ||||
|   def _process_(self, image, target, index): | ||||
|  | ||||
|     # transform the image and points | ||||
|     image, target, theta = self.transform(image, target) | ||||
|     (C, H, W), (height, width) = image.size(), self.shape | ||||
|  | ||||
|     # obtain the visiable indicator vector | ||||
|     if target.is_none(): nopoints = True | ||||
|     else               : nopoints = False | ||||
|     if index == -1: __path = None | ||||
|     else          : __path = self.datas[index] | ||||
|     if isinstance(theta, list) or isinstance(theta, tuple): | ||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], [] | ||||
|       for _theta in theta: | ||||
|         _affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \ | ||||
|           = self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path)) | ||||
|         affineImage.append(_affineImage) | ||||
|         heatmaps.append(_heatmaps) | ||||
|         mask.append(_mask) | ||||
|         norm_trans_points.append(_norm_trans_points) | ||||
|         THETA.append(_theta) | ||||
|         transpose_theta.append(_transpose_theta) | ||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \ | ||||
|           torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta) | ||||
|     else: | ||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path)) | ||||
|  | ||||
|     torch_index = torch.IntTensor([index]) | ||||
|     torch_nopoints = torch.ByteTensor( [ nopoints ] ) | ||||
|     torch_shape = torch.IntTensor([H,W]) | ||||
|  | ||||
|     return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape | ||||
|  | ||||
|    | ||||
|   def __process_affine(self, image, target, theta, nopoints, aux_info=None): | ||||
|     image, target, theta = image.clone(), target.copy(), theta.clone() | ||||
|     (C, H, W), (height, width) = image.size(), self.shape | ||||
|     if nopoints: # do not have label | ||||
|       norm_trans_points = torch.zeros((3, self.NUM_PTS)) | ||||
|       heatmaps          = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample)) | ||||
|       mask              = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8) | ||||
|       transpose_theta   = identity2affine(False) | ||||
|     else: | ||||
|       norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W)) | ||||
|       norm_trans_points = apply_boundary(norm_trans_points) | ||||
|       real_trans_points = norm_trans_points.clone() | ||||
|       real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:]) | ||||
|       heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C | ||||
|       heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor) | ||||
|       mask     = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) | ||||
|       if self.mean_face is None: | ||||
|         #warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') | ||||
|         transpose_theta = identity2affine(False) | ||||
|       else: | ||||
|         if torch.sum(norm_trans_points[2,:] == 1) < 3: | ||||
|           warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info)) | ||||
|           transpose_theta = identity2affine(False) | ||||
|         self.transform = transform | ||||
|         self.sigma = sigma | ||||
|         self.downsample = downsample | ||||
|         self.heatmap_type = heatmap_type | ||||
|         self.dataset_name = data_indicator | ||||
|         self.shape = shape  # [H,W] | ||||
|         self.use_gray = use_gray | ||||
|         assert transform is not None, "transform : {:}".format(transform) | ||||
|         self.mean_file = mean_file | ||||
|         if mean_file is None: | ||||
|             self.mean_data = None | ||||
|             warnings.warn("LandmarkDataset initialized with mean_data = None") | ||||
|         else: | ||||
|           transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone()) | ||||
|             assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file) | ||||
|             self.mean_data = torch.load(mean_file) | ||||
|         self.reset() | ||||
|         self.cutout = None | ||||
|         self.cache_images = cache_images | ||||
|         print("The general dataset initialization done : {:}".format(self)) | ||||
|         warnings.simplefilter("once") | ||||
|  | ||||
|     affineImage = affine2image(image, theta, self.shape) | ||||
|     if self.cutout is not None: affineImage = self.cutout( affineImage ) | ||||
|     def __repr__(self): | ||||
|         return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|     return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta | ||||
|     def set_cutout(self, length): | ||||
|         if length is not None and length >= 1: | ||||
|             self.cutout = CutOut(int(length)) | ||||
|         else: | ||||
|             self.cutout = None | ||||
|  | ||||
|     def reset(self, num_pts=-1, boxid="default", only_pts=False): | ||||
|         self.NUM_PTS = num_pts | ||||
|         if only_pts: | ||||
|             return | ||||
|         self.length = 0 | ||||
|         self.datas = [] | ||||
|         self.labels = [] | ||||
|         self.NormDistances = [] | ||||
|         self.BOXID = boxid | ||||
|         if self.mean_data is None: | ||||
|             self.mean_face = None | ||||
|         else: | ||||
|             self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) | ||||
|             assert (self.mean_face >= -1).all() and ( | ||||
|                 self.mean_face <= 1 | ||||
|             ).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face) | ||||
|         # assert self.dataset_name is not None, 'The dataset name is None' | ||||
|  | ||||
|     def __len__(self): | ||||
|         assert len(self.datas) == self.length, "The length is not correct : {}".format( | ||||
|             self.length | ||||
|         ) | ||||
|         return self.length | ||||
|  | ||||
|     def append(self, data, label, distance): | ||||
|         assert osp.isfile(data), "The image path is not a file : {:}".format(data) | ||||
|         self.datas.append(data) | ||||
|         self.labels.append(label) | ||||
|         self.NormDistances.append(distance) | ||||
|         self.length = self.length + 1 | ||||
|  | ||||
|     def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): | ||||
|         if reset: | ||||
|             self.reset(num_pts, boxindicator) | ||||
|         else: | ||||
|             assert ( | ||||
|                 self.NUM_PTS == num_pts and self.BOXID == boxindicator | ||||
|             ), "The number of point is inconsistance : {:} vs {:}".format( | ||||
|                 self.NUM_PTS, num_pts | ||||
|             ) | ||||
|         if isinstance(file_lists, str): | ||||
|             file_lists = [file_lists] | ||||
|         samples = [] | ||||
|         for idx, file_path in enumerate(file_lists): | ||||
|             print( | ||||
|                 ":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path) | ||||
|             ) | ||||
|             xdata = torch.load(file_path) | ||||
|             if isinstance(xdata, list): | ||||
|                 data = xdata  # image or video dataset list | ||||
|             elif isinstance(xdata, dict): | ||||
|                 data = xdata["datas"]  # multi-view dataset list | ||||
|             else: | ||||
|                 raise ValueError("Invalid Type Error : {:}".format(type(xdata))) | ||||
|             samples = samples + data | ||||
|         # samples is a dict, where the key is the image-path and the value is the annotation | ||||
|         # each annotation is a dict, contains 'points' (3,num_pts), and various box | ||||
|         print("GeneralDataset-V2 : {:} samples".format(len(samples))) | ||||
|  | ||||
|         # for index, annotation in enumerate(samples): | ||||
|         for index in tqdm(range(len(samples))): | ||||
|             annotation = samples[index] | ||||
|             image_path = annotation["current_frame"] | ||||
|             points, box = ( | ||||
|                 annotation["points"], | ||||
|                 annotation["box-{:}".format(boxindicator)], | ||||
|             ) | ||||
|             label = PointMeta2V( | ||||
|                 self.NUM_PTS, points, box, image_path, self.dataset_name | ||||
|             ) | ||||
|             if normalizeL is None: | ||||
|                 normDistance = None | ||||
|             else: | ||||
|                 normDistance = annotation["normalizeL-{:}".format(normalizeL)] | ||||
|             self.append(image_path, label, normDistance) | ||||
|  | ||||
|         assert ( | ||||
|             len(self.datas) == self.length | ||||
|         ), "The length and the data is not right {} vs {}".format( | ||||
|             self.length, len(self.datas) | ||||
|         ) | ||||
|         assert ( | ||||
|             len(self.labels) == self.length | ||||
|         ), "The length and the labels is not right {} vs {}".format( | ||||
|             self.length, len(self.labels) | ||||
|         ) | ||||
|         assert ( | ||||
|             len(self.NormDistances) == self.length | ||||
|         ), "The length and the NormDistances is not right {} vs {}".format( | ||||
|             self.length, len(self.NormDistance) | ||||
|         ) | ||||
|         print( | ||||
|             "Load data done for LandmarkDataset, which has {:} images.".format( | ||||
|                 self.length | ||||
|             ) | ||||
|         ) | ||||
|  | ||||
|     def __getitem__(self, index): | ||||
|         assert index >= 0 and index < self.length, "Invalid index : {:}".format(index) | ||||
|         if self.cache_images is not None and self.datas[index] in self.cache_images: | ||||
|             image = self.cache_images[self.datas[index]].clone() | ||||
|         else: | ||||
|             image = pil_loader(self.datas[index], self.use_gray) | ||||
|         target = self.labels[index].copy() | ||||
|         return self._process_(image, target, index) | ||||
|  | ||||
|     def _process_(self, image, target, index): | ||||
|  | ||||
|         # transform the image and points | ||||
|         image, target, theta = self.transform(image, target) | ||||
|         (C, H, W), (height, width) = image.size(), self.shape | ||||
|  | ||||
|         # obtain the visiable indicator vector | ||||
|         if target.is_none(): | ||||
|             nopoints = True | ||||
|         else: | ||||
|             nopoints = False | ||||
|         if index == -1: | ||||
|             __path = None | ||||
|         else: | ||||
|             __path = self.datas[index] | ||||
|         if isinstance(theta, list) or isinstance(theta, tuple): | ||||
|             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|                 [], | ||||
|             ) | ||||
|             for _theta in theta: | ||||
|                 ( | ||||
|                     _affineImage, | ||||
|                     _heatmaps, | ||||
|                     _mask, | ||||
|                     _norm_trans_points, | ||||
|                     _theta, | ||||
|                     _transpose_theta, | ||||
|                 ) = self.__process_affine( | ||||
|                     image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path) | ||||
|                 ) | ||||
|                 affineImage.append(_affineImage) | ||||
|                 heatmaps.append(_heatmaps) | ||||
|                 mask.append(_mask) | ||||
|                 norm_trans_points.append(_norm_trans_points) | ||||
|                 THETA.append(_theta) | ||||
|                 transpose_theta.append(_transpose_theta) | ||||
|             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( | ||||
|                 torch.stack(affineImage), | ||||
|                 torch.stack(heatmaps), | ||||
|                 torch.stack(mask), | ||||
|                 torch.stack(norm_trans_points), | ||||
|                 torch.stack(THETA), | ||||
|                 torch.stack(transpose_theta), | ||||
|             ) | ||||
|         else: | ||||
|             ( | ||||
|                 affineImage, | ||||
|                 heatmaps, | ||||
|                 mask, | ||||
|                 norm_trans_points, | ||||
|                 THETA, | ||||
|                 transpose_theta, | ||||
|             ) = self.__process_affine( | ||||
|                 image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path) | ||||
|             ) | ||||
|  | ||||
|         torch_index = torch.IntTensor([index]) | ||||
|         torch_nopoints = torch.ByteTensor([nopoints]) | ||||
|         torch_shape = torch.IntTensor([H, W]) | ||||
|  | ||||
|         return ( | ||||
|             affineImage, | ||||
|             heatmaps, | ||||
|             mask, | ||||
|             norm_trans_points, | ||||
|             THETA, | ||||
|             transpose_theta, | ||||
|             torch_index, | ||||
|             torch_nopoints, | ||||
|             torch_shape, | ||||
|         ) | ||||
|  | ||||
|     def __process_affine(self, image, target, theta, nopoints, aux_info=None): | ||||
|         image, target, theta = image.clone(), target.copy(), theta.clone() | ||||
|         (C, H, W), (height, width) = image.size(), self.shape | ||||
|         if nopoints:  # do not have label | ||||
|             norm_trans_points = torch.zeros((3, self.NUM_PTS)) | ||||
|             heatmaps = torch.zeros( | ||||
|                 (self.NUM_PTS + 1, height // self.downsample, width // self.downsample) | ||||
|             ) | ||||
|             mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8) | ||||
|             transpose_theta = identity2affine(False) | ||||
|         else: | ||||
|             norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W)) | ||||
|             norm_trans_points = apply_boundary(norm_trans_points) | ||||
|             real_trans_points = norm_trans_points.clone() | ||||
|             real_trans_points[:2, :] = denormalize_points( | ||||
|                 self.shape, real_trans_points[:2, :] | ||||
|             ) | ||||
|             heatmaps, mask = generate_label_map( | ||||
|                 real_trans_points.numpy(), | ||||
|                 height // self.downsample, | ||||
|                 width // self.downsample, | ||||
|                 self.sigma, | ||||
|                 self.downsample, | ||||
|                 nopoints, | ||||
|                 self.heatmap_type, | ||||
|             )  # H*W*C | ||||
|             heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type( | ||||
|                 torch.FloatTensor | ||||
|             ) | ||||
|             mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) | ||||
|             if self.mean_face is None: | ||||
|                 # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') | ||||
|                 transpose_theta = identity2affine(False) | ||||
|             else: | ||||
|                 if torch.sum(norm_trans_points[2, :] == 1) < 3: | ||||
|                     warnings.warn( | ||||
|                         "In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format( | ||||
|                             aux_info | ||||
|                         ) | ||||
|                     ) | ||||
|                     transpose_theta = identity2affine(False) | ||||
|                 else: | ||||
|                     transpose_theta = solve2theta( | ||||
|                         norm_trans_points, self.mean_face.clone() | ||||
|                     ) | ||||
|  | ||||
|         affineImage = affine2image(image, theta, self.shape) | ||||
|         if self.cutout is not None: | ||||
|             affineImage = self.cutout(affineImage) | ||||
|  | ||||
|         return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta | ||||
|   | ||||
| @@ -6,41 +6,49 @@ import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class SearchDataset(data.Dataset): | ||||
|     def __init__(self, name, data, train_split, valid_split, check=True): | ||||
|         self.datasetname = name | ||||
|         if isinstance(data, (list, tuple)):  # new type of SearchDataset | ||||
|             assert len(data) == 2, "invalid length: {:}".format(len(data)) | ||||
|             self.train_data = data[0] | ||||
|             self.valid_data = data[1] | ||||
|             self.train_split = train_split.copy() | ||||
|             self.valid_split = valid_split.copy() | ||||
|             self.mode_str = "V2"  # new mode | ||||
|         else: | ||||
|             self.mode_str = "V1"  # old mode | ||||
|             self.data = data | ||||
|             self.train_split = train_split.copy() | ||||
|             self.valid_split = valid_split.copy() | ||||
|             if check: | ||||
|                 intersection = set(train_split).intersection(set(valid_split)) | ||||
|                 assert ( | ||||
|                     len(intersection) == 0 | ||||
|                 ), "the splitted train and validation sets should have no intersection" | ||||
|         self.length = len(self.train_split) | ||||
|  | ||||
|   def __init__(self, name, data, train_split, valid_split, check=True): | ||||
|     self.datasetname = name | ||||
|     if isinstance(data, (list, tuple)): # new type of SearchDataset | ||||
|       assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) | ||||
|       self.train_data  = data[0] | ||||
|       self.valid_data  = data[1] | ||||
|       self.train_split = train_split.copy() | ||||
|       self.valid_split = valid_split.copy() | ||||
|       self.mode_str    = 'V2' # new mode  | ||||
|     else: | ||||
|       self.mode_str    = 'V1' # old mode  | ||||
|       self.data        = data | ||||
|       self.train_split = train_split.copy() | ||||
|       self.valid_split = valid_split.copy() | ||||
|       if check: | ||||
|         intersection = set(train_split).intersection(set(valid_split)) | ||||
|         assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' | ||||
|     self.length      = len(self.train_split) | ||||
|     def __repr__(self): | ||||
|         return "{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             datasetname=self.datasetname, | ||||
|             tr_L=len(self.train_split), | ||||
|             val_L=len(self.valid_split), | ||||
|             ver=self.mode_str, | ||||
|         ) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) | ||||
|     def __len__(self): | ||||
|         return self.length | ||||
|  | ||||
|   def __len__(self): | ||||
|     return self.length | ||||
|  | ||||
|   def __getitem__(self, index): | ||||
|     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) | ||||
|     train_index = self.train_split[index] | ||||
|     valid_index = random.choice( self.valid_split ) | ||||
|     if self.mode_str == 'V1': | ||||
|       train_image, train_label = self.data[train_index] | ||||
|       valid_image, valid_label = self.data[valid_index] | ||||
|     elif self.mode_str == 'V2': | ||||
|       train_image, train_label = self.train_data[train_index] | ||||
|       valid_image, valid_label = self.valid_data[valid_index] | ||||
|     else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) | ||||
|     return train_image, train_label, valid_image, valid_label | ||||
|     def __getitem__(self, index): | ||||
|         assert index >= 0 and index < self.length, "invalid index = {:}".format(index) | ||||
|         train_index = self.train_split[index] | ||||
|         valid_index = random.choice(self.valid_split) | ||||
|         if self.mode_str == "V1": | ||||
|             train_image, train_label = self.data[train_index] | ||||
|             valid_image, valid_label = self.data[valid_index] | ||||
|         elif self.mode_str == "V2": | ||||
|             train_image, train_label = self.train_data[train_index] | ||||
|             valid_image, valid_label = self.valid_data[valid_index] | ||||
|         else: | ||||
|             raise ValueError("invalid mode : {:}".format(self.mode_str)) | ||||
|         return train_image, train_label, valid_image, valid_label | ||||
|   | ||||
| @@ -4,4 +4,5 @@ | ||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||
| from .SearchDatasetWrap import SearchDataset | ||||
|  | ||||
| from .synthetic_adaptive_environment import QuadraticFunction | ||||
| from .synthetic_adaptive_environment import SynAdaptiveEnv | ||||
|   | ||||
| @@ -14,214 +14,349 @@ from .SearchDatasetWrap import SearchDataset | ||||
| from config_utils import load_config | ||||
|  | ||||
|  | ||||
| Dataset2Class = {'cifar10' : 10, | ||||
|                  'cifar100': 100, | ||||
|                  'imagenet-1k-s':1000, | ||||
|                  'imagenet-1k' : 1000, | ||||
|                  'ImageNet16'  : 1000, | ||||
|                  'ImageNet16-150': 150, | ||||
|                  'ImageNet16-120': 120, | ||||
|                  'ImageNet16-200': 200} | ||||
| Dataset2Class = { | ||||
|     "cifar10": 10, | ||||
|     "cifar100": 100, | ||||
|     "imagenet-1k-s": 1000, | ||||
|     "imagenet-1k": 1000, | ||||
|     "ImageNet16": 1000, | ||||
|     "ImageNet16-150": 150, | ||||
|     "ImageNet16-120": 120, | ||||
|     "ImageNet16-200": 200, | ||||
| } | ||||
|  | ||||
|  | ||||
| class CUTOUT(object): | ||||
|     def __init__(self, length): | ||||
|         self.length = length | ||||
|  | ||||
|   def __init__(self, length): | ||||
|     self.length = length | ||||
|     def __repr__(self): | ||||
|         return "{name}(length={length})".format( | ||||
|             name=self.__class__.__name__, **self.__dict__ | ||||
|         ) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) | ||||
|     def __call__(self, img): | ||||
|         h, w = img.size(1), img.size(2) | ||||
|         mask = np.ones((h, w), np.float32) | ||||
|         y = np.random.randint(h) | ||||
|         x = np.random.randint(w) | ||||
|  | ||||
|   def __call__(self, img): | ||||
|     h, w = img.size(1), img.size(2) | ||||
|     mask = np.ones((h, w), np.float32) | ||||
|     y = np.random.randint(h) | ||||
|     x = np.random.randint(w) | ||||
|         y1 = np.clip(y - self.length // 2, 0, h) | ||||
|         y2 = np.clip(y + self.length // 2, 0, h) | ||||
|         x1 = np.clip(x - self.length // 2, 0, w) | ||||
|         x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|     y1 = np.clip(y - self.length // 2, 0, h) | ||||
|     y2 = np.clip(y + self.length // 2, 0, h) | ||||
|     x1 = np.clip(x - self.length // 2, 0, w) | ||||
|     x2 = np.clip(x + self.length // 2, 0, w) | ||||
|  | ||||
|     mask[y1: y2, x1: x2] = 0. | ||||
|     mask = torch.from_numpy(mask) | ||||
|     mask = mask.expand_as(img) | ||||
|     img *= mask | ||||
|     return img | ||||
|         mask[y1:y2, x1:x2] = 0.0 | ||||
|         mask = torch.from_numpy(mask) | ||||
|         mask = mask.expand_as(img) | ||||
|         img *= mask | ||||
|         return img | ||||
|  | ||||
|  | ||||
| imagenet_pca = { | ||||
|     'eigval': np.asarray([0.2175, 0.0188, 0.0045]), | ||||
|     'eigvec': np.asarray([ | ||||
|         [-0.5675, 0.7192, 0.4009], | ||||
|         [-0.5808, -0.0045, -0.8140], | ||||
|         [-0.5836, -0.6948, 0.4203], | ||||
|     ]) | ||||
|     "eigval": np.asarray([0.2175, 0.0188, 0.0045]), | ||||
|     "eigvec": np.asarray( | ||||
|         [ | ||||
|             [-0.5675, 0.7192, 0.4009], | ||||
|             [-0.5808, -0.0045, -0.8140], | ||||
|             [-0.5836, -0.6948, 0.4203], | ||||
|         ] | ||||
|     ), | ||||
| } | ||||
|  | ||||
|  | ||||
| class Lighting(object): | ||||
|   def __init__(self, alphastd, | ||||
|          eigval=imagenet_pca['eigval'], | ||||
|          eigvec=imagenet_pca['eigvec']): | ||||
|     self.alphastd = alphastd | ||||
|     assert eigval.shape == (3,) | ||||
|     assert eigvec.shape == (3, 3) | ||||
|     self.eigval = eigval | ||||
|     self.eigvec = eigvec | ||||
|     def __init__( | ||||
|         self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"] | ||||
|     ): | ||||
|         self.alphastd = alphastd | ||||
|         assert eigval.shape == (3,) | ||||
|         assert eigvec.shape == (3, 3) | ||||
|         self.eigval = eigval | ||||
|         self.eigvec = eigvec | ||||
|  | ||||
|   def __call__(self, img): | ||||
|     if self.alphastd == 0.: | ||||
|       return img | ||||
|     rnd = np.random.randn(3) * self.alphastd | ||||
|     rnd = rnd.astype('float32') | ||||
|     v = rnd | ||||
|     old_dtype = np.asarray(img).dtype | ||||
|     v = v * self.eigval | ||||
|     v = v.reshape((3, 1)) | ||||
|     inc = np.dot(self.eigvec, v).reshape((3,)) | ||||
|     img = np.add(img, inc) | ||||
|     if old_dtype == np.uint8: | ||||
|       img = np.clip(img, 0, 255) | ||||
|     img = Image.fromarray(img.astype(old_dtype), 'RGB') | ||||
|     return img | ||||
|     def __call__(self, img): | ||||
|         if self.alphastd == 0.0: | ||||
|             return img | ||||
|         rnd = np.random.randn(3) * self.alphastd | ||||
|         rnd = rnd.astype("float32") | ||||
|         v = rnd | ||||
|         old_dtype = np.asarray(img).dtype | ||||
|         v = v * self.eigval | ||||
|         v = v.reshape((3, 1)) | ||||
|         inc = np.dot(self.eigvec, v).reshape((3,)) | ||||
|         img = np.add(img, inc) | ||||
|         if old_dtype == np.uint8: | ||||
|             img = np.clip(img, 0, 255) | ||||
|         img = Image.fromarray(img.astype(old_dtype), "RGB") | ||||
|         return img | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return self.__class__.__name__ + '()' | ||||
|     def __repr__(self): | ||||
|         return self.__class__.__name__ + "()" | ||||
|  | ||||
|  | ||||
| def get_datasets(name, root, cutout): | ||||
|  | ||||
|   if name == 'cifar10': | ||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|     std  = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|   elif name == 'cifar100': | ||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|     std  = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|   elif name.startswith('ImageNet16'): | ||||
|     mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|     std  = [x / 255 for x in [63.22,  61.26 , 65.09]] | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|     if name == "cifar10": | ||||
|         mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||
|         std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||
|     elif name == "cifar100": | ||||
|         mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||
|         std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||
|     elif name.startswith("ImageNet16"): | ||||
|         mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||
|         std = [x / 255 for x in [63.22, 61.26, 65.09]] | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|   # Data Argumentation | ||||
|   if name == 'cifar10' or name == 'cifar100': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 32, 32) | ||||
|   elif name.startswith('ImageNet16'): | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 16, 16) | ||||
|   elif name == 'tiered': | ||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] | ||||
|     train_transform = transforms.Compose(lists) | ||||
|     test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) | ||||
|     xshape = (1, 3, 32, 32) | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | ||||
|     if name == 'imagenet-1k': | ||||
|       xlists    = [transforms.RandomResizedCrop(224)] | ||||
|       xlists.append( | ||||
|         transforms.ColorJitter( | ||||
|         brightness=0.4, | ||||
|         contrast=0.4, | ||||
|         saturation=0.4, | ||||
|         hue=0.2)) | ||||
|       xlists.append( Lighting(0.1)) | ||||
|     elif name == 'imagenet-1k-s': | ||||
|       xlists    = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] | ||||
|     else: raise ValueError('invalid name : {:}'.format(name)) | ||||
|     xlists.append( transforms.RandomHorizontalFlip(p=0.5) ) | ||||
|     xlists.append( transforms.ToTensor() ) | ||||
|     xlists.append( normalize ) | ||||
|     train_transform = transforms.Compose(xlists) | ||||
|     test_transform  = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) | ||||
|     xshape = (1, 3, 224, 224) | ||||
|   else: | ||||
|     raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|     # Data Argumentation | ||||
|     if name == "cifar10" or name == "cifar100": | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(32, padding=4), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         ) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith("ImageNet16"): | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(16, padding=2), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||
|         ) | ||||
|         xshape = (1, 3, 16, 16) | ||||
|     elif name == "tiered": | ||||
|         lists = [ | ||||
|             transforms.RandomHorizontalFlip(), | ||||
|             transforms.RandomCrop(80, padding=4), | ||||
|             transforms.ToTensor(), | ||||
|             transforms.Normalize(mean, std), | ||||
|         ] | ||||
|         if cutout > 0: | ||||
|             lists += [CUTOUT(cutout)] | ||||
|         train_transform = transforms.Compose(lists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [ | ||||
|                 transforms.CenterCrop(80), | ||||
|                 transforms.ToTensor(), | ||||
|                 transforms.Normalize(mean, std), | ||||
|             ] | ||||
|         ) | ||||
|         xshape = (1, 3, 32, 32) | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         normalize = transforms.Normalize( | ||||
|             mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||||
|         ) | ||||
|         if name == "imagenet-1k": | ||||
|             xlists = [transforms.RandomResizedCrop(224)] | ||||
|             xlists.append( | ||||
|                 transforms.ColorJitter( | ||||
|                     brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 | ||||
|                 ) | ||||
|             ) | ||||
|             xlists.append(Lighting(0.1)) | ||||
|         elif name == "imagenet-1k-s": | ||||
|             xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] | ||||
|         else: | ||||
|             raise ValueError("invalid name : {:}".format(name)) | ||||
|         xlists.append(transforms.RandomHorizontalFlip(p=0.5)) | ||||
|         xlists.append(transforms.ToTensor()) | ||||
|         xlists.append(normalize) | ||||
|         train_transform = transforms.Compose(xlists) | ||||
|         test_transform = transforms.Compose( | ||||
|             [ | ||||
|                 transforms.Resize(256), | ||||
|                 transforms.CenterCrop(224), | ||||
|                 transforms.ToTensor(), | ||||
|                 normalize, | ||||
|             ] | ||||
|         ) | ||||
|         xshape = (1, 3, 224, 224) | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|   if name == 'cifar10': | ||||
|     train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True) | ||||
|     assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|   elif name == 'cifar100': | ||||
|     train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) | ||||
|     test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) | ||||
|     assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|   elif name.startswith('imagenet-1k'): | ||||
|     train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) | ||||
|     test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) | ||||
|     assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000) | ||||
|   elif name == 'ImageNet16': | ||||
|     train_data = ImageNet16(root, True , train_transform) | ||||
|     test_data  = ImageNet16(root, False, test_transform) | ||||
|     assert len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|   elif name == 'ImageNet16-120': | ||||
|     train_data = ImageNet16(root, True , train_transform, 120) | ||||
|     test_data  = ImageNet16(root, False, test_transform , 120) | ||||
|     assert len(train_data) == 151700 and len(test_data) == 6000 | ||||
|   elif name == 'ImageNet16-150': | ||||
|     train_data = ImageNet16(root, True , train_transform, 150) | ||||
|     test_data  = ImageNet16(root, False, test_transform , 150) | ||||
|     assert len(train_data) == 190272 and len(test_data) == 7500 | ||||
|   elif name == 'ImageNet16-200': | ||||
|     train_data = ImageNet16(root, True , train_transform, 200) | ||||
|     test_data  = ImageNet16(root, False, test_transform , 200) | ||||
|     assert len(train_data) == 254775 and len(test_data) == 10000 | ||||
|   else: raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|     if name == "cifar10": | ||||
|         train_data = dset.CIFAR10( | ||||
|             root, train=True, transform=train_transform, download=True | ||||
|         ) | ||||
|         test_data = dset.CIFAR10( | ||||
|             root, train=False, transform=test_transform, download=True | ||||
|         ) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name == "cifar100": | ||||
|         train_data = dset.CIFAR100( | ||||
|             root, train=True, transform=train_transform, download=True | ||||
|         ) | ||||
|         test_data = dset.CIFAR100( | ||||
|             root, train=False, transform=test_transform, download=True | ||||
|         ) | ||||
|         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||
|     elif name.startswith("imagenet-1k"): | ||||
|         train_data = dset.ImageFolder(osp.join(root, "train"), train_transform) | ||||
|         test_data = dset.ImageFolder(osp.join(root, "val"), test_transform) | ||||
|         assert ( | ||||
|             len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|         ), "invalid number of images : {:} & {:} vs {:} & {:}".format( | ||||
|             len(train_data), len(test_data), 1281167, 50000 | ||||
|         ) | ||||
|     elif name == "ImageNet16": | ||||
|         train_data = ImageNet16(root, True, train_transform) | ||||
|         test_data = ImageNet16(root, False, test_transform) | ||||
|         assert len(train_data) == 1281167 and len(test_data) == 50000 | ||||
|     elif name == "ImageNet16-120": | ||||
|         train_data = ImageNet16(root, True, train_transform, 120) | ||||
|         test_data = ImageNet16(root, False, test_transform, 120) | ||||
|         assert len(train_data) == 151700 and len(test_data) == 6000 | ||||
|     elif name == "ImageNet16-150": | ||||
|         train_data = ImageNet16(root, True, train_transform, 150) | ||||
|         test_data = ImageNet16(root, False, test_transform, 150) | ||||
|         assert len(train_data) == 190272 and len(test_data) == 7500 | ||||
|     elif name == "ImageNet16-200": | ||||
|         train_data = ImageNet16(root, True, train_transform, 200) | ||||
|         test_data = ImageNet16(root, False, test_transform, 200) | ||||
|         assert len(train_data) == 254775 and len(test_data) == 10000 | ||||
|     else: | ||||
|         raise TypeError("Unknow dataset : {:}".format(name)) | ||||
|  | ||||
|   class_num = Dataset2Class[name] | ||||
|   return train_data, test_data, xshape, class_num | ||||
|     class_num = Dataset2Class[name] | ||||
|     return train_data, test_data, xshape, class_num | ||||
|  | ||||
|  | ||||
| def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers): | ||||
|   if isinstance(batch_size, (list,tuple)): | ||||
|     batch, test_batch = batch_size | ||||
|   else: | ||||
|     batch, test_batch = batch_size, batch_size | ||||
|   if dataset == 'cifar10': | ||||
|     #split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|     cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None) | ||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set | ||||
|     #logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||
|     # To split data | ||||
|     xvalid_data  = deepcopy(train_data) | ||||
|     if hasattr(xvalid_data, 'transforms'): # to avoid a print issue | ||||
|       xvalid_data.transforms = valid_data.transform | ||||
|     xvalid_data.transform  = deepcopy( valid_data.transform ) | ||||
|     search_data   = SearchDataset(dataset, train_data, train_split, valid_split) | ||||
|     # data loader | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True) | ||||
|   elif dataset == 'cifar100': | ||||
|     cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|   elif dataset == 'ImageNet16-120': | ||||
|     imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None) | ||||
|     search_train_data = train_data | ||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform | ||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) | ||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) | ||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True) | ||||
|   else: | ||||
|     raise ValueError('invalid dataset : {:}'.format(dataset)) | ||||
|   return search_loader, train_loader, valid_loader | ||||
| def get_nas_search_loaders( | ||||
|     train_data, valid_data, dataset, config_root, batch_size, workers | ||||
| ): | ||||
|     if isinstance(batch_size, (list, tuple)): | ||||
|         batch, test_batch = batch_size | ||||
|     else: | ||||
|         batch, test_batch = batch_size, batch_size | ||||
|     if dataset == "cifar10": | ||||
|         # split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||
|         cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None) | ||||
|         train_split, valid_split = ( | ||||
|             cifar_split.train, | ||||
|             cifar_split.valid, | ||||
|         )  # search over the proposed training and validation set | ||||
|         # logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||
|         # To split data | ||||
|         xvalid_data = deepcopy(train_data) | ||||
|         if hasattr(xvalid_data, "transforms"):  # to avoid a print issue | ||||
|             xvalid_data.transforms = valid_data.transform | ||||
|         xvalid_data.transform = deepcopy(valid_data.transform) | ||||
|         search_data = SearchDataset(dataset, train_data, train_split, valid_split) | ||||
|         # data loader | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             xvalid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     elif dataset == "cifar100": | ||||
|         cifar100_test_split = load_config( | ||||
|             "{:}/cifar100-test-split.txt".format(config_root), None, None | ||||
|         ) | ||||
|         search_train_data = train_data | ||||
|         search_valid_data = deepcopy(valid_data) | ||||
|         search_valid_data.transform = train_data.transform | ||||
|         search_data = SearchDataset( | ||||
|             dataset, | ||||
|             [search_train_data, search_valid_data], | ||||
|             list(range(len(search_train_data))), | ||||
|             cifar100_test_split.xvalid, | ||||
|         ) | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                 cifar100_test_split.xvalid | ||||
|             ), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     elif dataset == "ImageNet16-120": | ||||
|         imagenet_test_split = load_config( | ||||
|             "{:}/imagenet-16-120-test-split.txt".format(config_root), None, None | ||||
|         ) | ||||
|         search_train_data = train_data | ||||
|         search_valid_data = deepcopy(valid_data) | ||||
|         search_valid_data.transform = train_data.transform | ||||
|         search_data = SearchDataset( | ||||
|             dataset, | ||||
|             [search_train_data, search_valid_data], | ||||
|             list(range(len(search_train_data))), | ||||
|             imagenet_test_split.xvalid, | ||||
|         ) | ||||
|         search_loader = torch.utils.data.DataLoader( | ||||
|             search_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         train_loader = torch.utils.data.DataLoader( | ||||
|             train_data, | ||||
|             batch_size=batch, | ||||
|             shuffle=True, | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|         valid_loader = torch.utils.data.DataLoader( | ||||
|             valid_data, | ||||
|             batch_size=test_batch, | ||||
|             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||
|                 imagenet_test_split.xvalid | ||||
|             ), | ||||
|             num_workers=workers, | ||||
|             pin_memory=True, | ||||
|         ) | ||||
|     else: | ||||
|         raise ValueError("invalid dataset : {:}".format(dataset)) | ||||
|     return search_loader, train_loader, valid_loader | ||||
|  | ||||
| #if __name__ == '__main__': | ||||
|  | ||||
| # if __name__ == '__main__': | ||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | ||||
| #  import pdb; pdb.set_trace() | ||||
|   | ||||
| @@ -9,108 +9,211 @@ from xvision import normalize_points | ||||
| from xvision import denormalize_points | ||||
|  | ||||
|  | ||||
| class PointMeta(): | ||||
|   # points    : 3 x num_pts (x, y, oculusion) | ||||
|   # image_size: original [width, height] | ||||
|   def __init__(self, num_point, points, box, image_path, dataset_name): | ||||
| class PointMeta: | ||||
|     # points    : 3 x num_pts (x, y, oculusion) | ||||
|     # image_size: original [width, height] | ||||
|     def __init__(self, num_point, points, box, image_path, dataset_name): | ||||
|  | ||||
|     self.num_point = num_point | ||||
|     if box is not None: | ||||
|       assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 | ||||
|       self.box = torch.Tensor(box) | ||||
|     else: self.box = None | ||||
|     if points is None: | ||||
|       self.points = points | ||||
|     else: | ||||
|       assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points ) | ||||
|       self.points = torch.Tensor(points.copy()) | ||||
|     self.image_path = image_path | ||||
|     self.datasets = dataset_name | ||||
|         self.num_point = num_point | ||||
|         if box is not None: | ||||
|             assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 | ||||
|             self.box = torch.Tensor(box) | ||||
|         else: | ||||
|             self.box = None | ||||
|         if points is None: | ||||
|             self.points = points | ||||
|         else: | ||||
|             assert ( | ||||
|                 len(points.shape) == 2 | ||||
|                 and points.shape[0] == 3 | ||||
|                 and points.shape[1] == self.num_point | ||||
|             ), "The shape of point is not right : {}".format(points) | ||||
|             self.points = torch.Tensor(points.copy()) | ||||
|         self.image_path = image_path | ||||
|         self.datasets = dataset_name | ||||
|  | ||||
|   def __repr__(self): | ||||
|     if self.box is None: boxstr = 'None' | ||||
|     else               : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist()) | ||||
|     return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')') | ||||
|     def __repr__(self): | ||||
|         if self.box is None: | ||||
|             boxstr = "None" | ||||
|         else: | ||||
|             boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist()) | ||||
|         return ( | ||||
|             "{name}(points={num_point}, ".format( | ||||
|                 name=self.__class__.__name__, **self.__dict__ | ||||
|             ) | ||||
|             + boxstr | ||||
|             + ")" | ||||
|         ) | ||||
|  | ||||
|   def get_box(self, return_diagonal=False): | ||||
|     if self.box is None: return None | ||||
|     if not return_diagonal: | ||||
|       return self.box.clone() | ||||
|     else: | ||||
|       W = (self.box[2]-self.box[0]).item() | ||||
|       H = (self.box[3]-self.box[1]).item() | ||||
|       return math.sqrt(H*H+W*W) | ||||
|     def get_box(self, return_diagonal=False): | ||||
|         if self.box is None: | ||||
|             return None | ||||
|         if not return_diagonal: | ||||
|             return self.box.clone() | ||||
|         else: | ||||
|             W = (self.box[2] - self.box[0]).item() | ||||
|             H = (self.box[3] - self.box[1]).item() | ||||
|             return math.sqrt(H * H + W * W) | ||||
|  | ||||
|   def get_points(self, ignore_indicator=False): | ||||
|     if ignore_indicator: last = 2 | ||||
|     else               : last = 3 | ||||
|     if self.points is not None: return self.points.clone()[:last, :] | ||||
|     else                      : return torch.zeros((last, self.num_point)) | ||||
|     def get_points(self, ignore_indicator=False): | ||||
|         if ignore_indicator: | ||||
|             last = 2 | ||||
|         else: | ||||
|             last = 3 | ||||
|         if self.points is not None: | ||||
|             return self.points.clone()[:last, :] | ||||
|         else: | ||||
|             return torch.zeros((last, self.num_point)) | ||||
|  | ||||
|   def is_none(self): | ||||
|     #assert self.box is not None, 'The box should not be None' | ||||
|     return self.points is None | ||||
|     #if self.box is None: return True | ||||
|     #else               : return self.points is None | ||||
|     def is_none(self): | ||||
|         # assert self.box is not None, 'The box should not be None' | ||||
|         return self.points is None | ||||
|         # if self.box is None: return True | ||||
|         # else               : return self.points is None | ||||
|  | ||||
|   def copy(self): | ||||
|     return copy.deepcopy(self) | ||||
|     def copy(self): | ||||
|         return copy.deepcopy(self) | ||||
|  | ||||
|   def visiable_pts_num(self): | ||||
|     with torch.no_grad(): | ||||
|       ans = self.points[2,:] > 0 | ||||
|       ans = torch.sum(ans) | ||||
|       ans = ans.item() | ||||
|     return ans | ||||
|     def visiable_pts_num(self): | ||||
|         with torch.no_grad(): | ||||
|             ans = self.points[2, :] > 0 | ||||
|             ans = torch.sum(ans) | ||||
|             ans = ans.item() | ||||
|         return ans | ||||
|  | ||||
|   def special_fun(self, indicator): | ||||
|     if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points. | ||||
|       assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point) | ||||
|       self.num_point = 49 | ||||
|       out = torch.ones((68), dtype=torch.uint8) | ||||
|       out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0 | ||||
|       if self.points is not None: self.points = self.points.clone()[:, out] | ||||
|     else: | ||||
|       raise ValueError('Invalid indicator : {:}'.format( indicator )) | ||||
|  | ||||
|   def apply_horizontal_flip(self): | ||||
|     #self.points[0, :] = width - self.points[0, :] - 1 | ||||
|     # Mugsy spefic or Synthetic | ||||
|     if self.datasets.startswith('HandsyROT'): | ||||
|       ori = np.array(list(range(0, 42))) | ||||
|       pos = np.array(list(range(21,42)) + list(range(0,21))) | ||||
|       self.points[:, pos] = self.points[:, ori] | ||||
|     elif self.datasets.startswith('face68'): | ||||
|       ori = np.array(list(range(0, 68))) | ||||
|       pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1 | ||||
|       self.points[:, ori] = self.points[:, pos] | ||||
|     else: | ||||
|       raise ValueError('Does not support {:}'.format(self.datasets)) | ||||
|     def special_fun(self, indicator): | ||||
|         if ( | ||||
|             indicator == "68to49" | ||||
|         ):  # For 300W or 300VW, convert the default 68 points to 49 points. | ||||
|             assert self.num_point == 68, "num-point must be 68 vs. {:}".format( | ||||
|                 self.num_point | ||||
|             ) | ||||
|             self.num_point = 49 | ||||
|             out = torch.ones((68), dtype=torch.uint8) | ||||
|             out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0 | ||||
|             if self.points is not None: | ||||
|                 self.points = self.points.clone()[:, out] | ||||
|         else: | ||||
|             raise ValueError("Invalid indicator : {:}".format(indicator)) | ||||
|  | ||||
|     def apply_horizontal_flip(self): | ||||
|         # self.points[0, :] = width - self.points[0, :] - 1 | ||||
|         # Mugsy spefic or Synthetic | ||||
|         if self.datasets.startswith("HandsyROT"): | ||||
|             ori = np.array(list(range(0, 42))) | ||||
|             pos = np.array(list(range(21, 42)) + list(range(0, 21))) | ||||
|             self.points[:, pos] = self.points[:, ori] | ||||
|         elif self.datasets.startswith("face68"): | ||||
|             ori = np.array(list(range(0, 68))) | ||||
|             pos = ( | ||||
|                 np.array( | ||||
|                     [ | ||||
|                         17, | ||||
|                         16, | ||||
|                         15, | ||||
|                         14, | ||||
|                         13, | ||||
|                         12, | ||||
|                         11, | ||||
|                         10, | ||||
|                         9, | ||||
|                         8, | ||||
|                         7, | ||||
|                         6, | ||||
|                         5, | ||||
|                         4, | ||||
|                         3, | ||||
|                         2, | ||||
|                         1, | ||||
|                         27, | ||||
|                         26, | ||||
|                         25, | ||||
|                         24, | ||||
|                         23, | ||||
|                         22, | ||||
|                         21, | ||||
|                         20, | ||||
|                         19, | ||||
|                         18, | ||||
|                         28, | ||||
|                         29, | ||||
|                         30, | ||||
|                         31, | ||||
|                         36, | ||||
|                         35, | ||||
|                         34, | ||||
|                         33, | ||||
|                         32, | ||||
|                         46, | ||||
|                         45, | ||||
|                         44, | ||||
|                         43, | ||||
|                         48, | ||||
|                         47, | ||||
|                         40, | ||||
|                         39, | ||||
|                         38, | ||||
|                         37, | ||||
|                         42, | ||||
|                         41, | ||||
|                         55, | ||||
|                         54, | ||||
|                         53, | ||||
|                         52, | ||||
|                         51, | ||||
|                         50, | ||||
|                         49, | ||||
|                         60, | ||||
|                         59, | ||||
|                         58, | ||||
|                         57, | ||||
|                         56, | ||||
|                         65, | ||||
|                         64, | ||||
|                         63, | ||||
|                         62, | ||||
|                         61, | ||||
|                         68, | ||||
|                         67, | ||||
|                         66, | ||||
|                     ] | ||||
|                 ) | ||||
|                 - 1 | ||||
|             ) | ||||
|             self.points[:, ori] = self.points[:, pos] | ||||
|         else: | ||||
|             raise ValueError("Does not support {:}".format(self.datasets)) | ||||
|  | ||||
|  | ||||
| # shape = (H,W) | ||||
| def apply_affine2point(points, theta, shape): | ||||
|   assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size()) | ||||
|   with torch.no_grad(): | ||||
|     ok_points = points[2,:] == 1 | ||||
|     assert torch.sum(ok_points).item() > 0, 'there is no visiable point' | ||||
|     points[:2,:] = normalize_points(shape, points[:2,:]) | ||||
|     assert points.size(0) == 3, "invalid points shape : {:}".format(points.size()) | ||||
|     with torch.no_grad(): | ||||
|         ok_points = points[2, :] == 1 | ||||
|         assert torch.sum(ok_points).item() > 0, "there is no visiable point" | ||||
|         points[:2, :] = normalize_points(shape, points[:2, :]) | ||||
|  | ||||
|     norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() | ||||
|         norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() | ||||
|  | ||||
|     trans_points, ___ = torch.gesv(points[:, ok_points], theta) | ||||
|         trans_points, ___ = torch.gesv(points[:, ok_points], theta) | ||||
|  | ||||
|     norm_trans_points[:, ok_points] = trans_points | ||||
|      | ||||
|   return norm_trans_points | ||||
|         norm_trans_points[:, ok_points] = trans_points | ||||
|  | ||||
|     return norm_trans_points | ||||
|  | ||||
|  | ||||
| def apply_boundary(norm_trans_points): | ||||
|   with torch.no_grad(): | ||||
|     norm_trans_points = norm_trans_points.clone() | ||||
|     oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0)) | ||||
|     oks = torch.sum(oks, dim=0) == 5 | ||||
|     norm_trans_points[2, :] = oks | ||||
|   return norm_trans_points | ||||
|     with torch.no_grad(): | ||||
|         norm_trans_points = norm_trans_points.clone() | ||||
|         oks = torch.stack( | ||||
|             ( | ||||
|                 norm_trans_points[0] > -1, | ||||
|                 norm_trans_points[0] < 1, | ||||
|                 norm_trans_points[1] > -1, | ||||
|                 norm_trans_points[1] < 1, | ||||
|                 norm_trans_points[2] > 0, | ||||
|             ) | ||||
|         ) | ||||
|         oks = torch.sum(oks, dim=0) == 5 | ||||
|         norm_trans_points[2, :] = oks | ||||
|     return norm_trans_points | ||||
|   | ||||
| @@ -1,39 +1,123 @@ | ||||
| ##################################################### | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||
| ##################################################### | ||||
| import math | ||||
| import numpy as np | ||||
| from typing import Optional | ||||
| import torch | ||||
| import torch.utils.data as data | ||||
|  | ||||
|  | ||||
| class QuadraticFunction: | ||||
|     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||
|  | ||||
|     def __init__(self, list_of_points=None): | ||||
|         self._params = dict(a=None, b=None, c=None) | ||||
|         if list_of_points is not None: | ||||
|             self.fit(list_of_points) | ||||
|  | ||||
|     def set(self, a, b, c): | ||||
|         self._params["a"] = a | ||||
|         self._params["b"] = b | ||||
|         self._params["c"] = c | ||||
|  | ||||
|     def check_valid(self): | ||||
|         for key, value in self._params.items(): | ||||
|             if value is None: | ||||
|                 raise ValueError("The {:} is None".format(key)) | ||||
|  | ||||
|     def __getitem__(self, x): | ||||
|         self.check_valid() | ||||
|         return self._params["a"] * x * x + self._params["b"] * x + self._params["c"] | ||||
|  | ||||
|     def fit( | ||||
|         self, | ||||
|         list_of_points, | ||||
|         transf=lambda x: x, | ||||
|         max_iter=900, | ||||
|         lr_max=1.0, | ||||
|         verbose=False, | ||||
|     ): | ||||
|         with torch.no_grad(): | ||||
|             data = torch.Tensor(list_of_points).type(torch.float32) | ||||
|             assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( | ||||
|                 data.shape | ||||
|             ) | ||||
|             x, y = data[:, 0], data[:, 1] | ||||
|         weights = torch.nn.Parameter(torch.Tensor(3)) | ||||
|         torch.nn.init.normal_(weights, mean=0.0, std=1.0) | ||||
|         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) | ||||
|         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_iter*0.25), int(max_iter*0.5), int(max_iter*0.75)], gamma=0.1) | ||||
|         if verbose: | ||||
|             print("The optimizer: {:}".format(optimizer)) | ||||
|  | ||||
|         best_loss = None | ||||
|         for _iter in range(max_iter): | ||||
|             y_hat = transf(weights[0] * x * x + weights[1] * x + weights[2]) | ||||
|             loss = torch.mean(torch.abs(y - y_hat)) | ||||
|             optimizer.zero_grad() | ||||
|             loss.backward() | ||||
|             optimizer.step() | ||||
|             lr_scheduler.step() | ||||
|             if verbose: | ||||
|                 print( | ||||
|                     "In QuadraticFunction's fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | ||||
|                         _iter, max_iter, loss.item() | ||||
|                     ) | ||||
|                 ) | ||||
|             # Update the params | ||||
|             if best_loss is None or best_loss > loss.item(): | ||||
|                 best_loss = loss.item() | ||||
|                 self._params["a"] = weights[0].item() | ||||
|                 self._params["b"] = weights[1].item() | ||||
|                 self._params["c"] = weights[2].item() | ||||
|  | ||||
|     def __repr__(self): | ||||
|         return "{name}(y = {a} * x^2 + {b} * x + {c})".format( | ||||
|             name=self.__class__.__name__, | ||||
|             a=self._params["a"], | ||||
|             b=self._params["b"], | ||||
|             c=self._params["c"], | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class SynAdaptiveEnv(data.Dataset): | ||||
|     """The synethtic dataset for adaptive environment.""" | ||||
|     """The synethtic dataset for adaptive environment. | ||||
|  | ||||
|     - x in [0, 1] | ||||
|     - y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) | ||||
|     - where | ||||
|     - the amplitude scale is a quadratic function of x | ||||
|     - the period-phase-shift is another quadratic function of x | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__( | ||||
|         self, | ||||
|         max_num_phase: int = 100, | ||||
|         interval: float = 0.1, | ||||
|         max_scale: float = 4, | ||||
|         offset_scale: float = 1.5, | ||||
|         num: int = 100, | ||||
|         num_sin_phase: int = 4, | ||||
|         min_amplitude: float = 1, | ||||
|         max_amplitude: float = 4, | ||||
|         phase_shift: float = 0, | ||||
|         mode: Optional[str] = None, | ||||
|     ): | ||||
|         self._amplitude_scale = QuadraticFunction( | ||||
|             [(0, min_amplitude), (0.5, max_amplitude), (0, min_amplitude)] | ||||
|         ) | ||||
|  | ||||
|         self._max_num_phase = max_num_phase | ||||
|         self._interval = interval | ||||
|         self._num_sin_phase = num_sin_phase | ||||
|         self._interval = 1.0 / (float(num) - 1) | ||||
|         self._total_num = num | ||||
|  | ||||
|         self._period_phase_shift = QuadraticFunction() | ||||
|  | ||||
|         fitting_data = [] | ||||
|         temp_max_scalar = 2 ** num_sin_phase | ||||
|         for i in range(num_sin_phase): | ||||
|             value = (2 ** i) / temp_max_scalar | ||||
|             fitting_data.append((value, math.sin(value))) | ||||
|         self._period_phase_shift.fit(fitting_data, transf=lambda x: torch.sin(x)) | ||||
|  | ||||
|         self._times = np.arange(0, np.pi * self._max_num_phase, self._interval) | ||||
|         xmin, xmax = self._times.min(), self._times.max() | ||||
|         self._inputs = [] | ||||
|         self._total_num = len(self._times) | ||||
|         for i in range(self._total_num): | ||||
|             scale = (i + 1.0) / self._total_num * max_scale | ||||
|             sin_scale = (i + 1.0) / self._total_num * 0.7 | ||||
|             sin_scale = -4 * (sin_scale - 0.5) ** 2 + 1 | ||||
|             # scale = -(self._times[i] - (xmin - xmax) / 2) + max_scale | ||||
|             self._inputs.append( | ||||
|                 np.sin(self._times[i] * sin_scale) * (offset_scale - scale) | ||||
|             ) | ||||
|         self._inputs = np.array(self._inputs) | ||||
|         # Training Set 60% | ||||
|         num_of_train = int(self._total_num * 0.6) | ||||
|         # Validation Set 20% | ||||
| @@ -70,10 +154,11 @@ class SynAdaptiveEnv(data.Dataset): | ||||
|     def __getitem__(self, index): | ||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||
|         index = self._indexes[index] | ||||
|         value = float(self._inputs[index]) | ||||
|         if self._transform is not None: | ||||
|             value = self._transform(value) | ||||
|         return index, float(self._times[index]), value | ||||
|         position = self._interval * index | ||||
|         value = self._amplitude_scale[position] * math.sin( | ||||
|             self._period_phase_shift[position] | ||||
|         ) | ||||
|         return index, position, value | ||||
|  | ||||
|     def __len__(self): | ||||
|         return len(self._indexes) | ||||
|   | ||||
| @@ -5,16 +5,20 @@ import os | ||||
|  | ||||
|  | ||||
| def test_imagenet_data(imagenet): | ||||
|   total_length = len(imagenet) | ||||
|   assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) | ||||
|   map_id = {} | ||||
|   for index in range(total_length): | ||||
|     path, target = imagenet.imgs[index] | ||||
|     folder, image_name = os.path.split(path) | ||||
|     _, folder = os.path.split(folder) | ||||
|     if folder not in map_id: | ||||
|       map_id[folder] = target | ||||
|     else: | ||||
|       assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) | ||||
|     assert image_name.find(folder) == 0, '{} is wrong.'.format(path) | ||||
|   print ('Check ImageNet Dataset OK') | ||||
|     total_length = len(imagenet) | ||||
|     assert ( | ||||
|         total_length == 1281166 or total_length == 50000 | ||||
|     ), "The length of ImageNet is wrong : {}".format(total_length) | ||||
|     map_id = {} | ||||
|     for index in range(total_length): | ||||
|         path, target = imagenet.imgs[index] | ||||
|         folder, image_name = os.path.split(path) | ||||
|         _, folder = os.path.split(folder) | ||||
|         if folder not in map_id: | ||||
|             map_id[folder] = target | ||||
|         else: | ||||
|             assert map_id[folder] == target, "Class : {} is not {}".format( | ||||
|                 folder, target | ||||
|             ) | ||||
|         assert image_name.find(folder) == 0, "{} is wrong.".format(path) | ||||
|     print("Check ImageNet Dataset OK") | ||||
|   | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -13,9 +13,33 @@ print("library path: {:}".format(lib_dir)) | ||||
| if str(lib_dir) not in sys.path: | ||||
|     sys.path.insert(0, str(lib_dir)) | ||||
|  | ||||
| from datasets import QuadraticFunction | ||||
| from datasets import SynAdaptiveEnv | ||||
|  | ||||
|  | ||||
| class TestQuadraticFunction(unittest.TestCase): | ||||
|     """Test the quadratic function.""" | ||||
|  | ||||
|     def test_simple(self): | ||||
|         function = QuadraticFunction([[0, 1], [0.5, 4], [1, 1]]) | ||||
|         print(function) | ||||
|         for x in (0, 0.5, 1): | ||||
|             print("f({:})={:}".format(x, function[x])) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function[0] - 1) < thresh) | ||||
|         self.assertTrue(abs(function[0.5] - 4) < thresh) | ||||
|         self.assertTrue(abs(function[1] - 1) < thresh) | ||||
|  | ||||
|     def test_none(self): | ||||
|         function = QuadraticFunction() | ||||
|         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) | ||||
|         print(function) | ||||
|         thresh = 0.2 | ||||
|         self.assertTrue(abs(function[0] - 1) < thresh) | ||||
|         self.assertTrue(abs(function[0.5] - 4) < thresh) | ||||
|         self.assertTrue(abs(function[1] - 1) < thresh) | ||||
|  | ||||
|  | ||||
| class TestSynAdaptiveEnv(unittest.TestCase): | ||||
|     """Test the synethtic adaptive environment.""" | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user