Updates
This commit is contained in:
		
							
								
								
									
										8
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -48,5 +48,13 @@ jobs: | |||||||
|           ls |           ls | ||||||
|           python --version |           python --version | ||||||
|           python -m pytest ./tests/test_basic_space.py -s |           python -m pytest ./tests/test_basic_space.py -s | ||||||
|  |         shell: bash | ||||||
|  |  | ||||||
|  |       - name: Test Synthetic Data | ||||||
|  |         run: | | ||||||
|  |           python -m pip install pytest numpy | ||||||
|  |           python -m pip install parameterized | ||||||
|  |           python -m pip install torch | ||||||
|  |           python --version | ||||||
|           python -m pytest ./tests/test_synthetic.py -s |           python -m pytest ./tests/test_synthetic.py -s | ||||||
|         shell: bash |         shell: bash | ||||||
|   | |||||||
 Submodule .latent-data/NATS-Bench updated: 3a8794322f...33bfb2eb13
									
								
							| @@ -5,118 +5,133 @@ import os, sys, hashlib, torch | |||||||
| import numpy as np | import numpy as np | ||||||
| from PIL import Image | from PIL import Image | ||||||
| import torch.utils.data as data | import torch.utils.data as data | ||||||
|  |  | ||||||
| if sys.version_info[0] == 2: | if sys.version_info[0] == 2: | ||||||
|   import cPickle as pickle |     import cPickle as pickle | ||||||
| else: | else: | ||||||
|   import pickle |     import pickle | ||||||
|  |  | ||||||
|  |  | ||||||
| def calculate_md5(fpath, chunk_size=1024 * 1024): | def calculate_md5(fpath, chunk_size=1024 * 1024): | ||||||
|   md5 = hashlib.md5() |     md5 = hashlib.md5() | ||||||
|   with open(fpath, 'rb') as f: |     with open(fpath, "rb") as f: | ||||||
|     for chunk in iter(lambda: f.read(chunk_size), b''): |         for chunk in iter(lambda: f.read(chunk_size), b""): | ||||||
|       md5.update(chunk) |             md5.update(chunk) | ||||||
|   return md5.hexdigest() |     return md5.hexdigest() | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_md5(fpath, md5, **kwargs): | def check_md5(fpath, md5, **kwargs): | ||||||
|   return md5 == calculate_md5(fpath, **kwargs) |     return md5 == calculate_md5(fpath, **kwargs) | ||||||
|  |  | ||||||
|  |  | ||||||
| def check_integrity(fpath, md5=None): | def check_integrity(fpath, md5=None): | ||||||
|   if not os.path.isfile(fpath): return False |     if not os.path.isfile(fpath): | ||||||
|   if md5 is None: return True |         return False | ||||||
|   else          : return check_md5(fpath, md5) |     if md5 is None: | ||||||
|  |         return True | ||||||
|  |     else: | ||||||
|  |         return check_md5(fpath, md5) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ImageNet16(data.Dataset): | class ImageNet16(data.Dataset): | ||||||
|   # http://image-net.org/download-images |     # http://image-net.org/download-images | ||||||
|   # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets |     # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets | ||||||
|   # https://arxiv.org/pdf/1707.08819.pdf |     # https://arxiv.org/pdf/1707.08819.pdf | ||||||
|    |  | ||||||
|   train_list = [ |     train_list = [ | ||||||
|         ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], |         ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"], | ||||||
|         ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], |         ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"], | ||||||
|         ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], |         ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"], | ||||||
|         ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], |         ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"], | ||||||
|         ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], |         ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"], | ||||||
|         ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], |         ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"], | ||||||
|         ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], |         ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"], | ||||||
|         ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], |         ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"], | ||||||
|         ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], |         ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"], | ||||||
|         ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], |         ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"], | ||||||
|     ] |     ] | ||||||
|   valid_list = [ |     valid_list = [ | ||||||
|         ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], |         ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"], | ||||||
|     ] |     ] | ||||||
|  |  | ||||||
|   def __init__(self, root, train, transform, use_num_of_class_only=None): |     def __init__(self, root, train, transform, use_num_of_class_only=None): | ||||||
|     self.root      = root |         self.root = root | ||||||
|     self.transform = transform |         self.transform = transform | ||||||
|     self.train     = train  # training set or valid set |         self.train = train  # training set or valid set | ||||||
|     if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') |         if not self._check_integrity(): | ||||||
|  |             raise RuntimeError("Dataset not found or corrupted.") | ||||||
|  |  | ||||||
|     if self.train: downloaded_list = self.train_list |         if self.train: | ||||||
|     else         : downloaded_list = self.valid_list |             downloaded_list = self.train_list | ||||||
|     self.data    = [] |  | ||||||
|     self.targets = [] |  | ||||||
|    |  | ||||||
|     # now load the picked numpy arrays |  | ||||||
|     for i, (file_name, checksum) in enumerate(downloaded_list): |  | ||||||
|       file_path = os.path.join(self.root, file_name) |  | ||||||
|       #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) |  | ||||||
|       with open(file_path, 'rb') as f: |  | ||||||
|         if sys.version_info[0] == 2: |  | ||||||
|           entry = pickle.load(f) |  | ||||||
|         else: |         else: | ||||||
|           entry = pickle.load(f, encoding='latin1') |             downloaded_list = self.valid_list | ||||||
|         self.data.append(entry['data']) |         self.data = [] | ||||||
|         self.targets.extend(entry['labels']) |         self.targets = [] | ||||||
|     self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) |  | ||||||
|     self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC |  | ||||||
|     if use_num_of_class_only is not None: |  | ||||||
|       assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) |  | ||||||
|       new_data, new_targets = [], [] |  | ||||||
|       for I, L in zip(self.data, self.targets): |  | ||||||
|         if 1 <= L <= use_num_of_class_only: |  | ||||||
|           new_data.append( I ) |  | ||||||
|           new_targets.append( L ) |  | ||||||
|       self.data    = new_data |  | ||||||
|       self.targets = new_targets |  | ||||||
|     #    self.mean.append(entry['mean']) |  | ||||||
|     #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) |  | ||||||
|     #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) |  | ||||||
|     #print ('Mean : {:}'.format(self.mean)) |  | ||||||
|     #temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) |  | ||||||
|     #std_data  = np.std(temp, axis=0) |  | ||||||
|     #std_data  = np.mean(np.mean(std_data, axis=0), axis=0) |  | ||||||
|     #print ('Std  : {:}'.format(std_data)) |  | ||||||
|  |  | ||||||
|   def __repr__(self): |         # now load the picked numpy arrays | ||||||
|     return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets)))) |         for i, (file_name, checksum) in enumerate(downloaded_list): | ||||||
|  |             file_path = os.path.join(self.root, file_name) | ||||||
|  |             # print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) | ||||||
|  |             with open(file_path, "rb") as f: | ||||||
|  |                 if sys.version_info[0] == 2: | ||||||
|  |                     entry = pickle.load(f) | ||||||
|  |                 else: | ||||||
|  |                     entry = pickle.load(f, encoding="latin1") | ||||||
|  |                 self.data.append(entry["data"]) | ||||||
|  |                 self.targets.extend(entry["labels"]) | ||||||
|  |         self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) | ||||||
|  |         self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC | ||||||
|  |         if use_num_of_class_only is not None: | ||||||
|  |             assert ( | ||||||
|  |                 isinstance(use_num_of_class_only, int) | ||||||
|  |                 and use_num_of_class_only > 0 | ||||||
|  |                 and use_num_of_class_only < 1000 | ||||||
|  |             ), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only) | ||||||
|  |             new_data, new_targets = [], [] | ||||||
|  |             for I, L in zip(self.data, self.targets): | ||||||
|  |                 if 1 <= L <= use_num_of_class_only: | ||||||
|  |                     new_data.append(I) | ||||||
|  |                     new_targets.append(L) | ||||||
|  |             self.data = new_data | ||||||
|  |             self.targets = new_targets | ||||||
|  |         #    self.mean.append(entry['mean']) | ||||||
|  |         # self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) | ||||||
|  |         # self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) | ||||||
|  |         # print ('Mean : {:}'.format(self.mean)) | ||||||
|  |         # temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3)) | ||||||
|  |         # std_data  = np.std(temp, axis=0) | ||||||
|  |         # std_data  = np.mean(np.mean(std_data, axis=0), axis=0) | ||||||
|  |         # print ('Std  : {:}'.format(std_data)) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({num} images, {classes} classes)".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             num=len(self.data), | ||||||
|  |             classes=len(set(self.targets)), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|   def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|     img, target = self.data[index], self.targets[index] - 1 |         img, target = self.data[index], self.targets[index] - 1 | ||||||
|  |  | ||||||
|     img = Image.fromarray(img) |         img = Image.fromarray(img) | ||||||
|  |  | ||||||
|     if self.transform is not None: |         if self.transform is not None: | ||||||
|       img = self.transform(img) |             img = self.transform(img) | ||||||
|  |  | ||||||
|     return img, target |         return img, target | ||||||
|  |  | ||||||
|   def __len__(self): |     def __len__(self): | ||||||
|     return len(self.data) |         return len(self.data) | ||||||
|  |  | ||||||
|  |     def _check_integrity(self): | ||||||
|  |         root = self.root | ||||||
|  |         for fentry in self.train_list + self.valid_list: | ||||||
|  |             filename, md5 = fentry[0], fentry[1] | ||||||
|  |             fpath = os.path.join(root, filename) | ||||||
|  |             if not check_integrity(fpath, md5): | ||||||
|  |                 return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|   def _check_integrity(self): |  | ||||||
|     root = self.root |  | ||||||
|     for fentry in (self.train_list + self.valid_list): |  | ||||||
|       filename, md5 = fentry[0], fentry[1] |  | ||||||
|       fpath = os.path.join(root, filename) |  | ||||||
|       if not check_integrity(fpath, md5): |  | ||||||
|         return False |  | ||||||
|     return True |  | ||||||
|  |  | ||||||
| """ | """ | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|   | |||||||
| @@ -20,172 +20,282 @@ import torch.utils.data as data | |||||||
|  |  | ||||||
|  |  | ||||||
| class LandmarkDataset(data.Dataset): | class LandmarkDataset(data.Dataset): | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|  |         transform, | ||||||
|  |         sigma, | ||||||
|  |         downsample, | ||||||
|  |         heatmap_type, | ||||||
|  |         shape, | ||||||
|  |         use_gray, | ||||||
|  |         mean_file, | ||||||
|  |         data_indicator, | ||||||
|  |         cache_images=None, | ||||||
|  |     ): | ||||||
|  |  | ||||||
|   def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None): |         self.transform = transform | ||||||
|  |         self.sigma = sigma | ||||||
|     self.transform    = transform |         self.downsample = downsample | ||||||
|     self.sigma        = sigma |         self.heatmap_type = heatmap_type | ||||||
|     self.downsample   = downsample |         self.dataset_name = data_indicator | ||||||
|     self.heatmap_type = heatmap_type |         self.shape = shape  # [H,W] | ||||||
|     self.dataset_name = data_indicator |         self.use_gray = use_gray | ||||||
|     self.shape        = shape # [H,W] |         assert transform is not None, "transform : {:}".format(transform) | ||||||
|     self.use_gray     = use_gray |         self.mean_file = mean_file | ||||||
|     assert transform is not None, 'transform : {:}'.format(transform) |         if mean_file is None: | ||||||
|     self.mean_file    = mean_file |             self.mean_data = None | ||||||
|     if mean_file is None: |             warnings.warn("LandmarkDataset initialized with mean_data = None") | ||||||
|       self.mean_data  = None |  | ||||||
|       warnings.warn('LandmarkDataset initialized with mean_data = None') |  | ||||||
|     else: |  | ||||||
|       assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file) |  | ||||||
|       self.mean_data  = torch.load(mean_file) |  | ||||||
|     self.reset() |  | ||||||
|     self.cutout       = None |  | ||||||
|     self.cache_images = cache_images |  | ||||||
|     print ('The general dataset initialization done : {:}'.format(self)) |  | ||||||
|     warnings.simplefilter( 'once' ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def __repr__(self): |  | ||||||
|     return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def set_cutout(self, length): |  | ||||||
|     if length is not None and length >= 1: |  | ||||||
|       self.cutout = CutOut( int(length) ) |  | ||||||
|     else: self.cutout = None |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def reset(self, num_pts=-1, boxid='default', only_pts=False): |  | ||||||
|     self.NUM_PTS = num_pts |  | ||||||
|     if only_pts: return |  | ||||||
|     self.length  = 0 |  | ||||||
|     self.datas   = [] |  | ||||||
|     self.labels  = [] |  | ||||||
|     self.NormDistances = [] |  | ||||||
|     self.BOXID = boxid |  | ||||||
|     if self.mean_data is None: |  | ||||||
|       self.mean_face = None |  | ||||||
|     else: |  | ||||||
|       self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) |  | ||||||
|       assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face) |  | ||||||
|     #assert self.dataset_name is not None, 'The dataset name is None' |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def __len__(self): |  | ||||||
|     assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length) |  | ||||||
|     return self.length |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def append(self, data, label, distance): |  | ||||||
|     assert osp.isfile(data), 'The image path is not a file : {:}'.format(data) |  | ||||||
|     self.datas.append( data )             ;  self.labels.append( label ) |  | ||||||
|     self.NormDistances.append( distance ) |  | ||||||
|     self.length = self.length + 1 |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): |  | ||||||
|     if reset: self.reset(num_pts, boxindicator) |  | ||||||
|     else    : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts) |  | ||||||
|     if isinstance(file_lists, str): file_lists = [file_lists] |  | ||||||
|     samples = [] |  | ||||||
|     for idx, file_path in enumerate(file_lists): |  | ||||||
|       print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path)) |  | ||||||
|       xdata = torch.load(file_path) |  | ||||||
|       if isinstance(xdata, list)  : data = xdata          # image or video dataset list |  | ||||||
|       elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list |  | ||||||
|       else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) )) |  | ||||||
|       samples = samples + data |  | ||||||
|     # samples is a dict, where the key is the image-path and the value is the annotation |  | ||||||
|     # each annotation is a dict, contains 'points' (3,num_pts), and various box |  | ||||||
|     print ('GeneralDataset-V2 : {:} samples'.format(len(samples))) |  | ||||||
|  |  | ||||||
|     #for index, annotation in enumerate(samples): |  | ||||||
|     for index in tqdm( range( len(samples) ) ): |  | ||||||
|       annotation = samples[index] |  | ||||||
|       image_path  = annotation['current_frame'] |  | ||||||
|       points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)] |  | ||||||
|       label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name) |  | ||||||
|       if normalizeL is None: normDistance = None |  | ||||||
|       else                 : normDistance = annotation['normalizeL-{:}'.format(normalizeL)] |  | ||||||
|       self.append(image_path, label, normDistance) |  | ||||||
|  |  | ||||||
|     assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas)) |  | ||||||
|     assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels)) |  | ||||||
|     assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance)) |  | ||||||
|     print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length)) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def __getitem__(self, index): |  | ||||||
|     assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index) |  | ||||||
|     if self.cache_images is not None and self.datas[index] in self.cache_images: |  | ||||||
|       image = self.cache_images[ self.datas[index] ].clone() |  | ||||||
|     else: |  | ||||||
|       image = pil_loader(self.datas[index], self.use_gray) |  | ||||||
|     target = self.labels[index].copy() |  | ||||||
|     return self._process_(image, target, index) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|   def _process_(self, image, target, index): |  | ||||||
|  |  | ||||||
|     # transform the image and points |  | ||||||
|     image, target, theta = self.transform(image, target) |  | ||||||
|     (C, H, W), (height, width) = image.size(), self.shape |  | ||||||
|  |  | ||||||
|     # obtain the visiable indicator vector |  | ||||||
|     if target.is_none(): nopoints = True |  | ||||||
|     else               : nopoints = False |  | ||||||
|     if index == -1: __path = None |  | ||||||
|     else          : __path = self.datas[index] |  | ||||||
|     if isinstance(theta, list) or isinstance(theta, tuple): |  | ||||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], [] |  | ||||||
|       for _theta in theta: |  | ||||||
|         _affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \ |  | ||||||
|           = self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path)) |  | ||||||
|         affineImage.append(_affineImage) |  | ||||||
|         heatmaps.append(_heatmaps) |  | ||||||
|         mask.append(_mask) |  | ||||||
|         norm_trans_points.append(_norm_trans_points) |  | ||||||
|         THETA.append(_theta) |  | ||||||
|         transpose_theta.append(_transpose_theta) |  | ||||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \ |  | ||||||
|           torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta) |  | ||||||
|     else: |  | ||||||
|       affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path)) |  | ||||||
|  |  | ||||||
|     torch_index = torch.IntTensor([index]) |  | ||||||
|     torch_nopoints = torch.ByteTensor( [ nopoints ] ) |  | ||||||
|     torch_shape = torch.IntTensor([H,W]) |  | ||||||
|  |  | ||||||
|     return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape |  | ||||||
|  |  | ||||||
|    |  | ||||||
|   def __process_affine(self, image, target, theta, nopoints, aux_info=None): |  | ||||||
|     image, target, theta = image.clone(), target.copy(), theta.clone() |  | ||||||
|     (C, H, W), (height, width) = image.size(), self.shape |  | ||||||
|     if nopoints: # do not have label |  | ||||||
|       norm_trans_points = torch.zeros((3, self.NUM_PTS)) |  | ||||||
|       heatmaps          = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample)) |  | ||||||
|       mask              = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8) |  | ||||||
|       transpose_theta   = identity2affine(False) |  | ||||||
|     else: |  | ||||||
|       norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W)) |  | ||||||
|       norm_trans_points = apply_boundary(norm_trans_points) |  | ||||||
|       real_trans_points = norm_trans_points.clone() |  | ||||||
|       real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:]) |  | ||||||
|       heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C |  | ||||||
|       heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor) |  | ||||||
|       mask     = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) |  | ||||||
|       if self.mean_face is None: |  | ||||||
|         #warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') |  | ||||||
|         transpose_theta = identity2affine(False) |  | ||||||
|       else: |  | ||||||
|         if torch.sum(norm_trans_points[2,:] == 1) < 3: |  | ||||||
|           warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info)) |  | ||||||
|           transpose_theta = identity2affine(False) |  | ||||||
|         else: |         else: | ||||||
|           transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone()) |             assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file) | ||||||
|  |             self.mean_data = torch.load(mean_file) | ||||||
|  |         self.reset() | ||||||
|  |         self.cutout = None | ||||||
|  |         self.cache_images = cache_images | ||||||
|  |         print("The general dataset initialization done : {:}".format(self)) | ||||||
|  |         warnings.simplefilter("once") | ||||||
|  |  | ||||||
|     affineImage = affine2image(image, theta, self.shape) |     def __repr__(self): | ||||||
|     if self.cutout is not None: affineImage = self.cutout( affineImage ) |         return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format( | ||||||
|  |             name=self.__class__.__name__, **self.__dict__ | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta |     def set_cutout(self, length): | ||||||
|  |         if length is not None and length >= 1: | ||||||
|  |             self.cutout = CutOut(int(length)) | ||||||
|  |         else: | ||||||
|  |             self.cutout = None | ||||||
|  |  | ||||||
|  |     def reset(self, num_pts=-1, boxid="default", only_pts=False): | ||||||
|  |         self.NUM_PTS = num_pts | ||||||
|  |         if only_pts: | ||||||
|  |             return | ||||||
|  |         self.length = 0 | ||||||
|  |         self.datas = [] | ||||||
|  |         self.labels = [] | ||||||
|  |         self.NormDistances = [] | ||||||
|  |         self.BOXID = boxid | ||||||
|  |         if self.mean_data is None: | ||||||
|  |             self.mean_face = None | ||||||
|  |         else: | ||||||
|  |             self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T) | ||||||
|  |             assert (self.mean_face >= -1).all() and ( | ||||||
|  |                 self.mean_face <= 1 | ||||||
|  |             ).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face) | ||||||
|  |         # assert self.dataset_name is not None, 'The dataset name is None' | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         assert len(self.datas) == self.length, "The length is not correct : {}".format( | ||||||
|  |             self.length | ||||||
|  |         ) | ||||||
|  |         return self.length | ||||||
|  |  | ||||||
|  |     def append(self, data, label, distance): | ||||||
|  |         assert osp.isfile(data), "The image path is not a file : {:}".format(data) | ||||||
|  |         self.datas.append(data) | ||||||
|  |         self.labels.append(label) | ||||||
|  |         self.NormDistances.append(distance) | ||||||
|  |         self.length = self.length + 1 | ||||||
|  |  | ||||||
|  |     def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset): | ||||||
|  |         if reset: | ||||||
|  |             self.reset(num_pts, boxindicator) | ||||||
|  |         else: | ||||||
|  |             assert ( | ||||||
|  |                 self.NUM_PTS == num_pts and self.BOXID == boxindicator | ||||||
|  |             ), "The number of point is inconsistance : {:} vs {:}".format( | ||||||
|  |                 self.NUM_PTS, num_pts | ||||||
|  |             ) | ||||||
|  |         if isinstance(file_lists, str): | ||||||
|  |             file_lists = [file_lists] | ||||||
|  |         samples = [] | ||||||
|  |         for idx, file_path in enumerate(file_lists): | ||||||
|  |             print( | ||||||
|  |                 ":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path) | ||||||
|  |             ) | ||||||
|  |             xdata = torch.load(file_path) | ||||||
|  |             if isinstance(xdata, list): | ||||||
|  |                 data = xdata  # image or video dataset list | ||||||
|  |             elif isinstance(xdata, dict): | ||||||
|  |                 data = xdata["datas"]  # multi-view dataset list | ||||||
|  |             else: | ||||||
|  |                 raise ValueError("Invalid Type Error : {:}".format(type(xdata))) | ||||||
|  |             samples = samples + data | ||||||
|  |         # samples is a dict, where the key is the image-path and the value is the annotation | ||||||
|  |         # each annotation is a dict, contains 'points' (3,num_pts), and various box | ||||||
|  |         print("GeneralDataset-V2 : {:} samples".format(len(samples))) | ||||||
|  |  | ||||||
|  |         # for index, annotation in enumerate(samples): | ||||||
|  |         for index in tqdm(range(len(samples))): | ||||||
|  |             annotation = samples[index] | ||||||
|  |             image_path = annotation["current_frame"] | ||||||
|  |             points, box = ( | ||||||
|  |                 annotation["points"], | ||||||
|  |                 annotation["box-{:}".format(boxindicator)], | ||||||
|  |             ) | ||||||
|  |             label = PointMeta2V( | ||||||
|  |                 self.NUM_PTS, points, box, image_path, self.dataset_name | ||||||
|  |             ) | ||||||
|  |             if normalizeL is None: | ||||||
|  |                 normDistance = None | ||||||
|  |             else: | ||||||
|  |                 normDistance = annotation["normalizeL-{:}".format(normalizeL)] | ||||||
|  |             self.append(image_path, label, normDistance) | ||||||
|  |  | ||||||
|  |         assert ( | ||||||
|  |             len(self.datas) == self.length | ||||||
|  |         ), "The length and the data is not right {} vs {}".format( | ||||||
|  |             self.length, len(self.datas) | ||||||
|  |         ) | ||||||
|  |         assert ( | ||||||
|  |             len(self.labels) == self.length | ||||||
|  |         ), "The length and the labels is not right {} vs {}".format( | ||||||
|  |             self.length, len(self.labels) | ||||||
|  |         ) | ||||||
|  |         assert ( | ||||||
|  |             len(self.NormDistances) == self.length | ||||||
|  |         ), "The length and the NormDistances is not right {} vs {}".format( | ||||||
|  |             self.length, len(self.NormDistance) | ||||||
|  |         ) | ||||||
|  |         print( | ||||||
|  |             "Load data done for LandmarkDataset, which has {:} images.".format( | ||||||
|  |                 self.length | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, index): | ||||||
|  |         assert index >= 0 and index < self.length, "Invalid index : {:}".format(index) | ||||||
|  |         if self.cache_images is not None and self.datas[index] in self.cache_images: | ||||||
|  |             image = self.cache_images[self.datas[index]].clone() | ||||||
|  |         else: | ||||||
|  |             image = pil_loader(self.datas[index], self.use_gray) | ||||||
|  |         target = self.labels[index].copy() | ||||||
|  |         return self._process_(image, target, index) | ||||||
|  |  | ||||||
|  |     def _process_(self, image, target, index): | ||||||
|  |  | ||||||
|  |         # transform the image and points | ||||||
|  |         image, target, theta = self.transform(image, target) | ||||||
|  |         (C, H, W), (height, width) = image.size(), self.shape | ||||||
|  |  | ||||||
|  |         # obtain the visiable indicator vector | ||||||
|  |         if target.is_none(): | ||||||
|  |             nopoints = True | ||||||
|  |         else: | ||||||
|  |             nopoints = False | ||||||
|  |         if index == -1: | ||||||
|  |             __path = None | ||||||
|  |         else: | ||||||
|  |             __path = self.datas[index] | ||||||
|  |         if isinstance(theta, list) or isinstance(theta, tuple): | ||||||
|  |             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( | ||||||
|  |                 [], | ||||||
|  |                 [], | ||||||
|  |                 [], | ||||||
|  |                 [], | ||||||
|  |                 [], | ||||||
|  |                 [], | ||||||
|  |             ) | ||||||
|  |             for _theta in theta: | ||||||
|  |                 ( | ||||||
|  |                     _affineImage, | ||||||
|  |                     _heatmaps, | ||||||
|  |                     _mask, | ||||||
|  |                     _norm_trans_points, | ||||||
|  |                     _theta, | ||||||
|  |                     _transpose_theta, | ||||||
|  |                 ) = self.__process_affine( | ||||||
|  |                     image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path) | ||||||
|  |                 ) | ||||||
|  |                 affineImage.append(_affineImage) | ||||||
|  |                 heatmaps.append(_heatmaps) | ||||||
|  |                 mask.append(_mask) | ||||||
|  |                 norm_trans_points.append(_norm_trans_points) | ||||||
|  |                 THETA.append(_theta) | ||||||
|  |                 transpose_theta.append(_transpose_theta) | ||||||
|  |             affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = ( | ||||||
|  |                 torch.stack(affineImage), | ||||||
|  |                 torch.stack(heatmaps), | ||||||
|  |                 torch.stack(mask), | ||||||
|  |                 torch.stack(norm_trans_points), | ||||||
|  |                 torch.stack(THETA), | ||||||
|  |                 torch.stack(transpose_theta), | ||||||
|  |             ) | ||||||
|  |         else: | ||||||
|  |             ( | ||||||
|  |                 affineImage, | ||||||
|  |                 heatmaps, | ||||||
|  |                 mask, | ||||||
|  |                 norm_trans_points, | ||||||
|  |                 THETA, | ||||||
|  |                 transpose_theta, | ||||||
|  |             ) = self.__process_affine( | ||||||
|  |                 image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         torch_index = torch.IntTensor([index]) | ||||||
|  |         torch_nopoints = torch.ByteTensor([nopoints]) | ||||||
|  |         torch_shape = torch.IntTensor([H, W]) | ||||||
|  |  | ||||||
|  |         return ( | ||||||
|  |             affineImage, | ||||||
|  |             heatmaps, | ||||||
|  |             mask, | ||||||
|  |             norm_trans_points, | ||||||
|  |             THETA, | ||||||
|  |             transpose_theta, | ||||||
|  |             torch_index, | ||||||
|  |             torch_nopoints, | ||||||
|  |             torch_shape, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |     def __process_affine(self, image, target, theta, nopoints, aux_info=None): | ||||||
|  |         image, target, theta = image.clone(), target.copy(), theta.clone() | ||||||
|  |         (C, H, W), (height, width) = image.size(), self.shape | ||||||
|  |         if nopoints:  # do not have label | ||||||
|  |             norm_trans_points = torch.zeros((3, self.NUM_PTS)) | ||||||
|  |             heatmaps = torch.zeros( | ||||||
|  |                 (self.NUM_PTS + 1, height // self.downsample, width // self.downsample) | ||||||
|  |             ) | ||||||
|  |             mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8) | ||||||
|  |             transpose_theta = identity2affine(False) | ||||||
|  |         else: | ||||||
|  |             norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W)) | ||||||
|  |             norm_trans_points = apply_boundary(norm_trans_points) | ||||||
|  |             real_trans_points = norm_trans_points.clone() | ||||||
|  |             real_trans_points[:2, :] = denormalize_points( | ||||||
|  |                 self.shape, real_trans_points[:2, :] | ||||||
|  |             ) | ||||||
|  |             heatmaps, mask = generate_label_map( | ||||||
|  |                 real_trans_points.numpy(), | ||||||
|  |                 height // self.downsample, | ||||||
|  |                 width // self.downsample, | ||||||
|  |                 self.sigma, | ||||||
|  |                 self.downsample, | ||||||
|  |                 nopoints, | ||||||
|  |                 self.heatmap_type, | ||||||
|  |             )  # H*W*C | ||||||
|  |             heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type( | ||||||
|  |                 torch.FloatTensor | ||||||
|  |             ) | ||||||
|  |             mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor) | ||||||
|  |             if self.mean_face is None: | ||||||
|  |                 # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.') | ||||||
|  |                 transpose_theta = identity2affine(False) | ||||||
|  |             else: | ||||||
|  |                 if torch.sum(norm_trans_points[2, :] == 1) < 3: | ||||||
|  |                     warnings.warn( | ||||||
|  |                         "In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format( | ||||||
|  |                             aux_info | ||||||
|  |                         ) | ||||||
|  |                     ) | ||||||
|  |                     transpose_theta = identity2affine(False) | ||||||
|  |                 else: | ||||||
|  |                     transpose_theta = solve2theta( | ||||||
|  |                         norm_trans_points, self.mean_face.clone() | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|  |         affineImage = affine2image(image, theta, self.shape) | ||||||
|  |         if self.cutout is not None: | ||||||
|  |             affineImage = self.cutout(affineImage) | ||||||
|  |  | ||||||
|  |         return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta | ||||||
|   | |||||||
| @@ -6,41 +6,49 @@ import torch.utils.data as data | |||||||
|  |  | ||||||
|  |  | ||||||
| class SearchDataset(data.Dataset): | class SearchDataset(data.Dataset): | ||||||
|  |     def __init__(self, name, data, train_split, valid_split, check=True): | ||||||
|  |         self.datasetname = name | ||||||
|  |         if isinstance(data, (list, tuple)):  # new type of SearchDataset | ||||||
|  |             assert len(data) == 2, "invalid length: {:}".format(len(data)) | ||||||
|  |             self.train_data = data[0] | ||||||
|  |             self.valid_data = data[1] | ||||||
|  |             self.train_split = train_split.copy() | ||||||
|  |             self.valid_split = valid_split.copy() | ||||||
|  |             self.mode_str = "V2"  # new mode | ||||||
|  |         else: | ||||||
|  |             self.mode_str = "V1"  # old mode | ||||||
|  |             self.data = data | ||||||
|  |             self.train_split = train_split.copy() | ||||||
|  |             self.valid_split = valid_split.copy() | ||||||
|  |             if check: | ||||||
|  |                 intersection = set(train_split).intersection(set(valid_split)) | ||||||
|  |                 assert ( | ||||||
|  |                     len(intersection) == 0 | ||||||
|  |                 ), "the splitted train and validation sets should have no intersection" | ||||||
|  |         self.length = len(self.train_split) | ||||||
|  |  | ||||||
|   def __init__(self, name, data, train_split, valid_split, check=True): |     def __repr__(self): | ||||||
|     self.datasetname = name |         return "{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})".format( | ||||||
|     if isinstance(data, (list, tuple)): # new type of SearchDataset |             name=self.__class__.__name__, | ||||||
|       assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) |             datasetname=self.datasetname, | ||||||
|       self.train_data  = data[0] |             tr_L=len(self.train_split), | ||||||
|       self.valid_data  = data[1] |             val_L=len(self.valid_split), | ||||||
|       self.train_split = train_split.copy() |             ver=self.mode_str, | ||||||
|       self.valid_split = valid_split.copy() |         ) | ||||||
|       self.mode_str    = 'V2' # new mode  |  | ||||||
|     else: |  | ||||||
|       self.mode_str    = 'V1' # old mode  |  | ||||||
|       self.data        = data |  | ||||||
|       self.train_split = train_split.copy() |  | ||||||
|       self.valid_split = valid_split.copy() |  | ||||||
|       if check: |  | ||||||
|         intersection = set(train_split).intersection(set(valid_split)) |  | ||||||
|         assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' |  | ||||||
|     self.length      = len(self.train_split) |  | ||||||
|  |  | ||||||
|   def __repr__(self): |     def __len__(self): | ||||||
|     return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) |         return self.length | ||||||
|  |  | ||||||
|   def __len__(self): |     def __getitem__(self, index): | ||||||
|     return self.length |         assert index >= 0 and index < self.length, "invalid index = {:}".format(index) | ||||||
|  |         train_index = self.train_split[index] | ||||||
|   def __getitem__(self, index): |         valid_index = random.choice(self.valid_split) | ||||||
|     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) |         if self.mode_str == "V1": | ||||||
|     train_index = self.train_split[index] |             train_image, train_label = self.data[train_index] | ||||||
|     valid_index = random.choice( self.valid_split ) |             valid_image, valid_label = self.data[valid_index] | ||||||
|     if self.mode_str == 'V1': |         elif self.mode_str == "V2": | ||||||
|       train_image, train_label = self.data[train_index] |             train_image, train_label = self.train_data[train_index] | ||||||
|       valid_image, valid_label = self.data[valid_index] |             valid_image, valid_label = self.valid_data[valid_index] | ||||||
|     elif self.mode_str == 'V2': |         else: | ||||||
|       train_image, train_label = self.train_data[train_index] |             raise ValueError("invalid mode : {:}".format(self.mode_str)) | ||||||
|       valid_image, valid_label = self.valid_data[valid_index] |         return train_image, train_label, valid_image, valid_label | ||||||
|     else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) |  | ||||||
|     return train_image, train_label, valid_image, valid_label |  | ||||||
|   | |||||||
| @@ -4,4 +4,5 @@ | |||||||
| from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders | ||||||
| from .SearchDatasetWrap import SearchDataset | from .SearchDatasetWrap import SearchDataset | ||||||
|  |  | ||||||
|  | from .synthetic_adaptive_environment import QuadraticFunction | ||||||
| from .synthetic_adaptive_environment import SynAdaptiveEnv | from .synthetic_adaptive_environment import SynAdaptiveEnv | ||||||
|   | |||||||
| @@ -14,214 +14,349 @@ from .SearchDatasetWrap import SearchDataset | |||||||
| from config_utils import load_config | from config_utils import load_config | ||||||
|  |  | ||||||
|  |  | ||||||
| Dataset2Class = {'cifar10' : 10, | Dataset2Class = { | ||||||
|                  'cifar100': 100, |     "cifar10": 10, | ||||||
|                  'imagenet-1k-s':1000, |     "cifar100": 100, | ||||||
|                  'imagenet-1k' : 1000, |     "imagenet-1k-s": 1000, | ||||||
|                  'ImageNet16'  : 1000, |     "imagenet-1k": 1000, | ||||||
|                  'ImageNet16-150': 150, |     "ImageNet16": 1000, | ||||||
|                  'ImageNet16-120': 120, |     "ImageNet16-150": 150, | ||||||
|                  'ImageNet16-200': 200} |     "ImageNet16-120": 120, | ||||||
|  |     "ImageNet16-200": 200, | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
| class CUTOUT(object): | class CUTOUT(object): | ||||||
|  |     def __init__(self, length): | ||||||
|  |         self.length = length | ||||||
|  |  | ||||||
|   def __init__(self, length): |     def __repr__(self): | ||||||
|     self.length = length |         return "{name}(length={length})".format( | ||||||
|  |             name=self.__class__.__name__, **self.__dict__ | ||||||
|  |         ) | ||||||
|  |  | ||||||
|   def __repr__(self): |     def __call__(self, img): | ||||||
|     return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) |         h, w = img.size(1), img.size(2) | ||||||
|  |         mask = np.ones((h, w), np.float32) | ||||||
|  |         y = np.random.randint(h) | ||||||
|  |         x = np.random.randint(w) | ||||||
|  |  | ||||||
|   def __call__(self, img): |         y1 = np.clip(y - self.length // 2, 0, h) | ||||||
|     h, w = img.size(1), img.size(2) |         y2 = np.clip(y + self.length // 2, 0, h) | ||||||
|     mask = np.ones((h, w), np.float32) |         x1 = np.clip(x - self.length // 2, 0, w) | ||||||
|     y = np.random.randint(h) |         x2 = np.clip(x + self.length // 2, 0, w) | ||||||
|     x = np.random.randint(w) |  | ||||||
|  |  | ||||||
|     y1 = np.clip(y - self.length // 2, 0, h) |         mask[y1:y2, x1:x2] = 0.0 | ||||||
|     y2 = np.clip(y + self.length // 2, 0, h) |         mask = torch.from_numpy(mask) | ||||||
|     x1 = np.clip(x - self.length // 2, 0, w) |         mask = mask.expand_as(img) | ||||||
|     x2 = np.clip(x + self.length // 2, 0, w) |         img *= mask | ||||||
|  |         return img | ||||||
|     mask[y1: y2, x1: x2] = 0. |  | ||||||
|     mask = torch.from_numpy(mask) |  | ||||||
|     mask = mask.expand_as(img) |  | ||||||
|     img *= mask |  | ||||||
|     return img |  | ||||||
|  |  | ||||||
|  |  | ||||||
| imagenet_pca = { | imagenet_pca = { | ||||||
|     'eigval': np.asarray([0.2175, 0.0188, 0.0045]), |     "eigval": np.asarray([0.2175, 0.0188, 0.0045]), | ||||||
|     'eigvec': np.asarray([ |     "eigvec": np.asarray( | ||||||
|         [-0.5675, 0.7192, 0.4009], |         [ | ||||||
|         [-0.5808, -0.0045, -0.8140], |             [-0.5675, 0.7192, 0.4009], | ||||||
|         [-0.5836, -0.6948, 0.4203], |             [-0.5808, -0.0045, -0.8140], | ||||||
|     ]) |             [-0.5836, -0.6948, 0.4203], | ||||||
|  |         ] | ||||||
|  |     ), | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| class Lighting(object): | class Lighting(object): | ||||||
|   def __init__(self, alphastd, |     def __init__( | ||||||
|          eigval=imagenet_pca['eigval'], |         self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"] | ||||||
|          eigvec=imagenet_pca['eigvec']): |     ): | ||||||
|     self.alphastd = alphastd |         self.alphastd = alphastd | ||||||
|     assert eigval.shape == (3,) |         assert eigval.shape == (3,) | ||||||
|     assert eigvec.shape == (3, 3) |         assert eigvec.shape == (3, 3) | ||||||
|     self.eigval = eigval |         self.eigval = eigval | ||||||
|     self.eigvec = eigvec |         self.eigvec = eigvec | ||||||
|  |  | ||||||
|   def __call__(self, img): |     def __call__(self, img): | ||||||
|     if self.alphastd == 0.: |         if self.alphastd == 0.0: | ||||||
|       return img |             return img | ||||||
|     rnd = np.random.randn(3) * self.alphastd |         rnd = np.random.randn(3) * self.alphastd | ||||||
|     rnd = rnd.astype('float32') |         rnd = rnd.astype("float32") | ||||||
|     v = rnd |         v = rnd | ||||||
|     old_dtype = np.asarray(img).dtype |         old_dtype = np.asarray(img).dtype | ||||||
|     v = v * self.eigval |         v = v * self.eigval | ||||||
|     v = v.reshape((3, 1)) |         v = v.reshape((3, 1)) | ||||||
|     inc = np.dot(self.eigvec, v).reshape((3,)) |         inc = np.dot(self.eigvec, v).reshape((3,)) | ||||||
|     img = np.add(img, inc) |         img = np.add(img, inc) | ||||||
|     if old_dtype == np.uint8: |         if old_dtype == np.uint8: | ||||||
|       img = np.clip(img, 0, 255) |             img = np.clip(img, 0, 255) | ||||||
|     img = Image.fromarray(img.astype(old_dtype), 'RGB') |         img = Image.fromarray(img.astype(old_dtype), "RGB") | ||||||
|     return img |         return img | ||||||
|  |  | ||||||
|   def __repr__(self): |     def __repr__(self): | ||||||
|     return self.__class__.__name__ + '()' |         return self.__class__.__name__ + "()" | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_datasets(name, root, cutout): | def get_datasets(name, root, cutout): | ||||||
|  |  | ||||||
|   if name == 'cifar10': |     if name == "cifar10": | ||||||
|     mean = [x / 255 for x in [125.3, 123.0, 113.9]] |         mean = [x / 255 for x in [125.3, 123.0, 113.9]] | ||||||
|     std  = [x / 255 for x in [63.0, 62.1, 66.7]] |         std = [x / 255 for x in [63.0, 62.1, 66.7]] | ||||||
|   elif name == 'cifar100': |     elif name == "cifar100": | ||||||
|     mean = [x / 255 for x in [129.3, 124.1, 112.4]] |         mean = [x / 255 for x in [129.3, 124.1, 112.4]] | ||||||
|     std  = [x / 255 for x in [68.2, 65.4, 70.4]] |         std = [x / 255 for x in [68.2, 65.4, 70.4]] | ||||||
|   elif name.startswith('imagenet-1k'): |     elif name.startswith("imagenet-1k"): | ||||||
|     mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |         mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] | ||||||
|   elif name.startswith('ImageNet16'): |     elif name.startswith("ImageNet16"): | ||||||
|     mean = [x / 255 for x in [122.68, 116.66, 104.01]] |         mean = [x / 255 for x in [122.68, 116.66, 104.01]] | ||||||
|     std  = [x / 255 for x in [63.22,  61.26 , 65.09]] |         std = [x / 255 for x in [63.22, 61.26, 65.09]] | ||||||
|   else: |     else: | ||||||
|     raise TypeError("Unknow dataset : {:}".format(name)) |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|   # Data Argumentation |     # Data Argumentation | ||||||
|   if name == 'cifar10' or name == 'cifar100': |     if name == "cifar10" or name == "cifar100": | ||||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] |         lists = [ | ||||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] |             transforms.RandomHorizontalFlip(), | ||||||
|     train_transform = transforms.Compose(lists) |             transforms.RandomCrop(32, padding=4), | ||||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) |             transforms.ToTensor(), | ||||||
|     xshape = (1, 3, 32, 32) |             transforms.Normalize(mean, std), | ||||||
|   elif name.startswith('ImageNet16'): |         ] | ||||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] |         if cutout > 0: | ||||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] |             lists += [CUTOUT(cutout)] | ||||||
|     train_transform = transforms.Compose(lists) |         train_transform = transforms.Compose(lists) | ||||||
|     test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) |         test_transform = transforms.Compose( | ||||||
|     xshape = (1, 3, 16, 16) |             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||||
|   elif name == 'tiered': |         ) | ||||||
|     lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] |         xshape = (1, 3, 32, 32) | ||||||
|     if cutout > 0 : lists += [CUTOUT(cutout)] |     elif name.startswith("ImageNet16"): | ||||||
|     train_transform = transforms.Compose(lists) |         lists = [ | ||||||
|     test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) |             transforms.RandomHorizontalFlip(), | ||||||
|     xshape = (1, 3, 32, 32) |             transforms.RandomCrop(16, padding=2), | ||||||
|   elif name.startswith('imagenet-1k'): |             transforms.ToTensor(), | ||||||
|     normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |             transforms.Normalize(mean, std), | ||||||
|     if name == 'imagenet-1k': |         ] | ||||||
|       xlists    = [transforms.RandomResizedCrop(224)] |         if cutout > 0: | ||||||
|       xlists.append( |             lists += [CUTOUT(cutout)] | ||||||
|         transforms.ColorJitter( |         train_transform = transforms.Compose(lists) | ||||||
|         brightness=0.4, |         test_transform = transforms.Compose( | ||||||
|         contrast=0.4, |             [transforms.ToTensor(), transforms.Normalize(mean, std)] | ||||||
|         saturation=0.4, |         ) | ||||||
|         hue=0.2)) |         xshape = (1, 3, 16, 16) | ||||||
|       xlists.append( Lighting(0.1)) |     elif name == "tiered": | ||||||
|     elif name == 'imagenet-1k-s': |         lists = [ | ||||||
|       xlists    = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] |             transforms.RandomHorizontalFlip(), | ||||||
|     else: raise ValueError('invalid name : {:}'.format(name)) |             transforms.RandomCrop(80, padding=4), | ||||||
|     xlists.append( transforms.RandomHorizontalFlip(p=0.5) ) |             transforms.ToTensor(), | ||||||
|     xlists.append( transforms.ToTensor() ) |             transforms.Normalize(mean, std), | ||||||
|     xlists.append( normalize ) |         ] | ||||||
|     train_transform = transforms.Compose(xlists) |         if cutout > 0: | ||||||
|     test_transform  = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) |             lists += [CUTOUT(cutout)] | ||||||
|     xshape = (1, 3, 224, 224) |         train_transform = transforms.Compose(lists) | ||||||
|   else: |         test_transform = transforms.Compose( | ||||||
|     raise TypeError("Unknow dataset : {:}".format(name)) |             [ | ||||||
|  |                 transforms.CenterCrop(80), | ||||||
|  |                 transforms.ToTensor(), | ||||||
|  |                 transforms.Normalize(mean, std), | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         xshape = (1, 3, 32, 32) | ||||||
|  |     elif name.startswith("imagenet-1k"): | ||||||
|  |         normalize = transforms.Normalize( | ||||||
|  |             mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] | ||||||
|  |         ) | ||||||
|  |         if name == "imagenet-1k": | ||||||
|  |             xlists = [transforms.RandomResizedCrop(224)] | ||||||
|  |             xlists.append( | ||||||
|  |                 transforms.ColorJitter( | ||||||
|  |                     brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2 | ||||||
|  |                 ) | ||||||
|  |             ) | ||||||
|  |             xlists.append(Lighting(0.1)) | ||||||
|  |         elif name == "imagenet-1k-s": | ||||||
|  |             xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] | ||||||
|  |         else: | ||||||
|  |             raise ValueError("invalid name : {:}".format(name)) | ||||||
|  |         xlists.append(transforms.RandomHorizontalFlip(p=0.5)) | ||||||
|  |         xlists.append(transforms.ToTensor()) | ||||||
|  |         xlists.append(normalize) | ||||||
|  |         train_transform = transforms.Compose(xlists) | ||||||
|  |         test_transform = transforms.Compose( | ||||||
|  |             [ | ||||||
|  |                 transforms.Resize(256), | ||||||
|  |                 transforms.CenterCrop(224), | ||||||
|  |                 transforms.ToTensor(), | ||||||
|  |                 normalize, | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|  |         xshape = (1, 3, 224, 224) | ||||||
|  |     else: | ||||||
|  |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|   if name == 'cifar10': |     if name == "cifar10": | ||||||
|     train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True) |         train_data = dset.CIFAR10( | ||||||
|     test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True) |             root, train=True, transform=train_transform, download=True | ||||||
|     assert len(train_data) == 50000 and len(test_data) == 10000 |         ) | ||||||
|   elif name == 'cifar100': |         test_data = dset.CIFAR10( | ||||||
|     train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True) |             root, train=False, transform=test_transform, download=True | ||||||
|     test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True) |         ) | ||||||
|     assert len(train_data) == 50000 and len(test_data) == 10000 |         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||||
|   elif name.startswith('imagenet-1k'): |     elif name == "cifar100": | ||||||
|     train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) |         train_data = dset.CIFAR100( | ||||||
|     test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform) |             root, train=True, transform=train_transform, download=True | ||||||
|     assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000) |         ) | ||||||
|   elif name == 'ImageNet16': |         test_data = dset.CIFAR100( | ||||||
|     train_data = ImageNet16(root, True , train_transform) |             root, train=False, transform=test_transform, download=True | ||||||
|     test_data  = ImageNet16(root, False, test_transform) |         ) | ||||||
|     assert len(train_data) == 1281167 and len(test_data) == 50000 |         assert len(train_data) == 50000 and len(test_data) == 10000 | ||||||
|   elif name == 'ImageNet16-120': |     elif name.startswith("imagenet-1k"): | ||||||
|     train_data = ImageNet16(root, True , train_transform, 120) |         train_data = dset.ImageFolder(osp.join(root, "train"), train_transform) | ||||||
|     test_data  = ImageNet16(root, False, test_transform , 120) |         test_data = dset.ImageFolder(osp.join(root, "val"), test_transform) | ||||||
|     assert len(train_data) == 151700 and len(test_data) == 6000 |         assert ( | ||||||
|   elif name == 'ImageNet16-150': |             len(train_data) == 1281167 and len(test_data) == 50000 | ||||||
|     train_data = ImageNet16(root, True , train_transform, 150) |         ), "invalid number of images : {:} & {:} vs {:} & {:}".format( | ||||||
|     test_data  = ImageNet16(root, False, test_transform , 150) |             len(train_data), len(test_data), 1281167, 50000 | ||||||
|     assert len(train_data) == 190272 and len(test_data) == 7500 |         ) | ||||||
|   elif name == 'ImageNet16-200': |     elif name == "ImageNet16": | ||||||
|     train_data = ImageNet16(root, True , train_transform, 200) |         train_data = ImageNet16(root, True, train_transform) | ||||||
|     test_data  = ImageNet16(root, False, test_transform , 200) |         test_data = ImageNet16(root, False, test_transform) | ||||||
|     assert len(train_data) == 254775 and len(test_data) == 10000 |         assert len(train_data) == 1281167 and len(test_data) == 50000 | ||||||
|   else: raise TypeError("Unknow dataset : {:}".format(name)) |     elif name == "ImageNet16-120": | ||||||
|    |         train_data = ImageNet16(root, True, train_transform, 120) | ||||||
|   class_num = Dataset2Class[name] |         test_data = ImageNet16(root, False, test_transform, 120) | ||||||
|   return train_data, test_data, xshape, class_num |         assert len(train_data) == 151700 and len(test_data) == 6000 | ||||||
|  |     elif name == "ImageNet16-150": | ||||||
|  |         train_data = ImageNet16(root, True, train_transform, 150) | ||||||
|  |         test_data = ImageNet16(root, False, test_transform, 150) | ||||||
|  |         assert len(train_data) == 190272 and len(test_data) == 7500 | ||||||
|  |     elif name == "ImageNet16-200": | ||||||
|  |         train_data = ImageNet16(root, True, train_transform, 200) | ||||||
|  |         test_data = ImageNet16(root, False, test_transform, 200) | ||||||
|  |         assert len(train_data) == 254775 and len(test_data) == 10000 | ||||||
|  |     else: | ||||||
|  |         raise TypeError("Unknow dataset : {:}".format(name)) | ||||||
|  |  | ||||||
|  |     class_num = Dataset2Class[name] | ||||||
|  |     return train_data, test_data, xshape, class_num | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers): | def get_nas_search_loaders( | ||||||
|   if isinstance(batch_size, (list,tuple)): |     train_data, valid_data, dataset, config_root, batch_size, workers | ||||||
|     batch, test_batch = batch_size | ): | ||||||
|   else: |     if isinstance(batch_size, (list, tuple)): | ||||||
|     batch, test_batch = batch_size, batch_size |         batch, test_batch = batch_size | ||||||
|   if dataset == 'cifar10': |     else: | ||||||
|     #split_Fpath = 'configs/nas-benchmark/cifar-split.txt' |         batch, test_batch = batch_size, batch_size | ||||||
|     cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None) |     if dataset == "cifar10": | ||||||
|     train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set |         # split_Fpath = 'configs/nas-benchmark/cifar-split.txt' | ||||||
|     #logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set |         cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None) | ||||||
|     # To split data |         train_split, valid_split = ( | ||||||
|     xvalid_data  = deepcopy(train_data) |             cifar_split.train, | ||||||
|     if hasattr(xvalid_data, 'transforms'): # to avoid a print issue |             cifar_split.valid, | ||||||
|       xvalid_data.transforms = valid_data.transform |         )  # search over the proposed training and validation set | ||||||
|     xvalid_data.transform  = deepcopy( valid_data.transform ) |         # logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set | ||||||
|     search_data   = SearchDataset(dataset, train_data, train_split, valid_split) |         # To split data | ||||||
|     # data loader |         xvalid_data = deepcopy(train_data) | ||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) |         if hasattr(xvalid_data, "transforms"):  # to avoid a print issue | ||||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True) |             xvalid_data.transforms = valid_data.transform | ||||||
|     valid_loader  = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True) |         xvalid_data.transform = deepcopy(valid_data.transform) | ||||||
|   elif dataset == 'cifar100': |         search_data = SearchDataset(dataset, train_data, train_split, valid_split) | ||||||
|     cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None) |         # data loader | ||||||
|     search_train_data = train_data |         search_loader = torch.utils.data.DataLoader( | ||||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform |             search_data, | ||||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid) |             batch_size=batch, | ||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) |             shuffle=True, | ||||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) |             num_workers=workers, | ||||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True) |             pin_memory=True, | ||||||
|   elif dataset == 'ImageNet16-120': |         ) | ||||||
|     imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None) |         train_loader = torch.utils.data.DataLoader( | ||||||
|     search_train_data = train_data |             train_data, | ||||||
|     search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform |             batch_size=batch, | ||||||
|     search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) |             sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), | ||||||
|     search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) |             num_workers=workers, | ||||||
|     train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) |             pin_memory=True, | ||||||
|     valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True) |         ) | ||||||
|   else: |         valid_loader = torch.utils.data.DataLoader( | ||||||
|     raise ValueError('invalid dataset : {:}'.format(dataset)) |             xvalid_data, | ||||||
|   return search_loader, train_loader, valid_loader |             batch_size=test_batch, | ||||||
|  |             sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), | ||||||
|  |             num_workers=workers, | ||||||
|  |             pin_memory=True, | ||||||
|  |         ) | ||||||
|  |     elif dataset == "cifar100": | ||||||
|  |         cifar100_test_split = load_config( | ||||||
|  |             "{:}/cifar100-test-split.txt".format(config_root), None, None | ||||||
|  |         ) | ||||||
|  |         search_train_data = train_data | ||||||
|  |         search_valid_data = deepcopy(valid_data) | ||||||
|  |         search_valid_data.transform = train_data.transform | ||||||
|  |         search_data = SearchDataset( | ||||||
|  |             dataset, | ||||||
|  |             [search_train_data, search_valid_data], | ||||||
|  |             list(range(len(search_train_data))), | ||||||
|  |             cifar100_test_split.xvalid, | ||||||
|  |         ) | ||||||
|  |         search_loader = torch.utils.data.DataLoader( | ||||||
|  |             search_data, | ||||||
|  |             batch_size=batch, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=workers, | ||||||
|  |             pin_memory=True, | ||||||
|  |         ) | ||||||
|  |         train_loader = torch.utils.data.DataLoader( | ||||||
|  |             train_data, | ||||||
|  |             batch_size=batch, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=workers, | ||||||
|  |             pin_memory=True, | ||||||
|  |         ) | ||||||
|  |         valid_loader = torch.utils.data.DataLoader( | ||||||
|  |             valid_data, | ||||||
|  |             batch_size=test_batch, | ||||||
|  |             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                 cifar100_test_split.xvalid | ||||||
|  |             ), | ||||||
|  |             num_workers=workers, | ||||||
|  |             pin_memory=True, | ||||||
|  |         ) | ||||||
|  |     elif dataset == "ImageNet16-120": | ||||||
|  |         imagenet_test_split = load_config( | ||||||
|  |             "{:}/imagenet-16-120-test-split.txt".format(config_root), None, None | ||||||
|  |         ) | ||||||
|  |         search_train_data = train_data | ||||||
|  |         search_valid_data = deepcopy(valid_data) | ||||||
|  |         search_valid_data.transform = train_data.transform | ||||||
|  |         search_data = SearchDataset( | ||||||
|  |             dataset, | ||||||
|  |             [search_train_data, search_valid_data], | ||||||
|  |             list(range(len(search_train_data))), | ||||||
|  |             imagenet_test_split.xvalid, | ||||||
|  |         ) | ||||||
|  |         search_loader = torch.utils.data.DataLoader( | ||||||
|  |             search_data, | ||||||
|  |             batch_size=batch, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=workers, | ||||||
|  |             pin_memory=True, | ||||||
|  |         ) | ||||||
|  |         train_loader = torch.utils.data.DataLoader( | ||||||
|  |             train_data, | ||||||
|  |             batch_size=batch, | ||||||
|  |             shuffle=True, | ||||||
|  |             num_workers=workers, | ||||||
|  |             pin_memory=True, | ||||||
|  |         ) | ||||||
|  |         valid_loader = torch.utils.data.DataLoader( | ||||||
|  |             valid_data, | ||||||
|  |             batch_size=test_batch, | ||||||
|  |             sampler=torch.utils.data.sampler.SubsetRandomSampler( | ||||||
|  |                 imagenet_test_split.xvalid | ||||||
|  |             ), | ||||||
|  |             num_workers=workers, | ||||||
|  |             pin_memory=True, | ||||||
|  |         ) | ||||||
|  |     else: | ||||||
|  |         raise ValueError("invalid dataset : {:}".format(dataset)) | ||||||
|  |     return search_loader, train_loader, valid_loader | ||||||
|  |  | ||||||
| #if __name__ == '__main__': |  | ||||||
|  | # if __name__ == '__main__': | ||||||
| #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | #  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1) | ||||||
| #  import pdb; pdb.set_trace() | #  import pdb; pdb.set_trace() | ||||||
|   | |||||||
| @@ -9,108 +9,211 @@ from xvision import normalize_points | |||||||
| from xvision import denormalize_points | from xvision import denormalize_points | ||||||
|  |  | ||||||
|  |  | ||||||
| class PointMeta(): | class PointMeta: | ||||||
|   # points    : 3 x num_pts (x, y, oculusion) |     # points    : 3 x num_pts (x, y, oculusion) | ||||||
|   # image_size: original [width, height] |     # image_size: original [width, height] | ||||||
|   def __init__(self, num_point, points, box, image_path, dataset_name): |     def __init__(self, num_point, points, box, image_path, dataset_name): | ||||||
|  |  | ||||||
|     self.num_point = num_point |         self.num_point = num_point | ||||||
|     if box is not None: |         if box is not None: | ||||||
|       assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 |             assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4 | ||||||
|       self.box = torch.Tensor(box) |             self.box = torch.Tensor(box) | ||||||
|     else: self.box = None |         else: | ||||||
|     if points is None: |             self.box = None | ||||||
|       self.points = points |         if points is None: | ||||||
|     else: |             self.points = points | ||||||
|       assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points ) |         else: | ||||||
|       self.points = torch.Tensor(points.copy()) |             assert ( | ||||||
|     self.image_path = image_path |                 len(points.shape) == 2 | ||||||
|     self.datasets = dataset_name |                 and points.shape[0] == 3 | ||||||
|  |                 and points.shape[1] == self.num_point | ||||||
|  |             ), "The shape of point is not right : {}".format(points) | ||||||
|  |             self.points = torch.Tensor(points.copy()) | ||||||
|  |         self.image_path = image_path | ||||||
|  |         self.datasets = dataset_name | ||||||
|  |  | ||||||
|   def __repr__(self): |     def __repr__(self): | ||||||
|     if self.box is None: boxstr = 'None' |         if self.box is None: | ||||||
|     else               : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist()) |             boxstr = "None" | ||||||
|     return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')') |         else: | ||||||
|  |             boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist()) | ||||||
|  |         return ( | ||||||
|  |             "{name}(points={num_point}, ".format( | ||||||
|  |                 name=self.__class__.__name__, **self.__dict__ | ||||||
|  |             ) | ||||||
|  |             + boxstr | ||||||
|  |             + ")" | ||||||
|  |         ) | ||||||
|  |  | ||||||
|   def get_box(self, return_diagonal=False): |     def get_box(self, return_diagonal=False): | ||||||
|     if self.box is None: return None |         if self.box is None: | ||||||
|     if not return_diagonal: |             return None | ||||||
|       return self.box.clone() |         if not return_diagonal: | ||||||
|     else: |             return self.box.clone() | ||||||
|       W = (self.box[2]-self.box[0]).item() |         else: | ||||||
|       H = (self.box[3]-self.box[1]).item() |             W = (self.box[2] - self.box[0]).item() | ||||||
|       return math.sqrt(H*H+W*W) |             H = (self.box[3] - self.box[1]).item() | ||||||
|  |             return math.sqrt(H * H + W * W) | ||||||
|  |  | ||||||
|   def get_points(self, ignore_indicator=False): |     def get_points(self, ignore_indicator=False): | ||||||
|     if ignore_indicator: last = 2 |         if ignore_indicator: | ||||||
|     else               : last = 3 |             last = 2 | ||||||
|     if self.points is not None: return self.points.clone()[:last, :] |         else: | ||||||
|     else                      : return torch.zeros((last, self.num_point)) |             last = 3 | ||||||
|  |         if self.points is not None: | ||||||
|  |             return self.points.clone()[:last, :] | ||||||
|  |         else: | ||||||
|  |             return torch.zeros((last, self.num_point)) | ||||||
|  |  | ||||||
|   def is_none(self): |     def is_none(self): | ||||||
|     #assert self.box is not None, 'The box should not be None' |         # assert self.box is not None, 'The box should not be None' | ||||||
|     return self.points is None |         return self.points is None | ||||||
|     #if self.box is None: return True |         # if self.box is None: return True | ||||||
|     #else               : return self.points is None |         # else               : return self.points is None | ||||||
|  |  | ||||||
|   def copy(self): |     def copy(self): | ||||||
|     return copy.deepcopy(self) |         return copy.deepcopy(self) | ||||||
|  |  | ||||||
|   def visiable_pts_num(self): |     def visiable_pts_num(self): | ||||||
|     with torch.no_grad(): |         with torch.no_grad(): | ||||||
|       ans = self.points[2,:] > 0 |             ans = self.points[2, :] > 0 | ||||||
|       ans = torch.sum(ans) |             ans = torch.sum(ans) | ||||||
|       ans = ans.item() |             ans = ans.item() | ||||||
|     return ans |         return ans | ||||||
|    |  | ||||||
|   def special_fun(self, indicator): |  | ||||||
|     if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points. |  | ||||||
|       assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point) |  | ||||||
|       self.num_point = 49 |  | ||||||
|       out = torch.ones((68), dtype=torch.uint8) |  | ||||||
|       out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0 |  | ||||||
|       if self.points is not None: self.points = self.points.clone()[:, out] |  | ||||||
|     else: |  | ||||||
|       raise ValueError('Invalid indicator : {:}'.format( indicator )) |  | ||||||
|  |  | ||||||
|   def apply_horizontal_flip(self): |     def special_fun(self, indicator): | ||||||
|     #self.points[0, :] = width - self.points[0, :] - 1 |         if ( | ||||||
|     # Mugsy spefic or Synthetic |             indicator == "68to49" | ||||||
|     if self.datasets.startswith('HandsyROT'): |         ):  # For 300W or 300VW, convert the default 68 points to 49 points. | ||||||
|       ori = np.array(list(range(0, 42))) |             assert self.num_point == 68, "num-point must be 68 vs. {:}".format( | ||||||
|       pos = np.array(list(range(21,42)) + list(range(0,21))) |                 self.num_point | ||||||
|       self.points[:, pos] = self.points[:, ori] |             ) | ||||||
|     elif self.datasets.startswith('face68'): |             self.num_point = 49 | ||||||
|       ori = np.array(list(range(0, 68))) |             out = torch.ones((68), dtype=torch.uint8) | ||||||
|       pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1 |             out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0 | ||||||
|       self.points[:, ori] = self.points[:, pos] |             if self.points is not None: | ||||||
|     else: |                 self.points = self.points.clone()[:, out] | ||||||
|       raise ValueError('Does not support {:}'.format(self.datasets)) |         else: | ||||||
|  |             raise ValueError("Invalid indicator : {:}".format(indicator)) | ||||||
|  |  | ||||||
|  |     def apply_horizontal_flip(self): | ||||||
|  |         # self.points[0, :] = width - self.points[0, :] - 1 | ||||||
|  |         # Mugsy spefic or Synthetic | ||||||
|  |         if self.datasets.startswith("HandsyROT"): | ||||||
|  |             ori = np.array(list(range(0, 42))) | ||||||
|  |             pos = np.array(list(range(21, 42)) + list(range(0, 21))) | ||||||
|  |             self.points[:, pos] = self.points[:, ori] | ||||||
|  |         elif self.datasets.startswith("face68"): | ||||||
|  |             ori = np.array(list(range(0, 68))) | ||||||
|  |             pos = ( | ||||||
|  |                 np.array( | ||||||
|  |                     [ | ||||||
|  |                         17, | ||||||
|  |                         16, | ||||||
|  |                         15, | ||||||
|  |                         14, | ||||||
|  |                         13, | ||||||
|  |                         12, | ||||||
|  |                         11, | ||||||
|  |                         10, | ||||||
|  |                         9, | ||||||
|  |                         8, | ||||||
|  |                         7, | ||||||
|  |                         6, | ||||||
|  |                         5, | ||||||
|  |                         4, | ||||||
|  |                         3, | ||||||
|  |                         2, | ||||||
|  |                         1, | ||||||
|  |                         27, | ||||||
|  |                         26, | ||||||
|  |                         25, | ||||||
|  |                         24, | ||||||
|  |                         23, | ||||||
|  |                         22, | ||||||
|  |                         21, | ||||||
|  |                         20, | ||||||
|  |                         19, | ||||||
|  |                         18, | ||||||
|  |                         28, | ||||||
|  |                         29, | ||||||
|  |                         30, | ||||||
|  |                         31, | ||||||
|  |                         36, | ||||||
|  |                         35, | ||||||
|  |                         34, | ||||||
|  |                         33, | ||||||
|  |                         32, | ||||||
|  |                         46, | ||||||
|  |                         45, | ||||||
|  |                         44, | ||||||
|  |                         43, | ||||||
|  |                         48, | ||||||
|  |                         47, | ||||||
|  |                         40, | ||||||
|  |                         39, | ||||||
|  |                         38, | ||||||
|  |                         37, | ||||||
|  |                         42, | ||||||
|  |                         41, | ||||||
|  |                         55, | ||||||
|  |                         54, | ||||||
|  |                         53, | ||||||
|  |                         52, | ||||||
|  |                         51, | ||||||
|  |                         50, | ||||||
|  |                         49, | ||||||
|  |                         60, | ||||||
|  |                         59, | ||||||
|  |                         58, | ||||||
|  |                         57, | ||||||
|  |                         56, | ||||||
|  |                         65, | ||||||
|  |                         64, | ||||||
|  |                         63, | ||||||
|  |                         62, | ||||||
|  |                         61, | ||||||
|  |                         68, | ||||||
|  |                         67, | ||||||
|  |                         66, | ||||||
|  |                     ] | ||||||
|  |                 ) | ||||||
|  |                 - 1 | ||||||
|  |             ) | ||||||
|  |             self.points[:, ori] = self.points[:, pos] | ||||||
|  |         else: | ||||||
|  |             raise ValueError("Does not support {:}".format(self.datasets)) | ||||||
|  |  | ||||||
|  |  | ||||||
| # shape = (H,W) | # shape = (H,W) | ||||||
| def apply_affine2point(points, theta, shape): | def apply_affine2point(points, theta, shape): | ||||||
|   assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size()) |     assert points.size(0) == 3, "invalid points shape : {:}".format(points.size()) | ||||||
|   with torch.no_grad(): |     with torch.no_grad(): | ||||||
|     ok_points = points[2,:] == 1 |         ok_points = points[2, :] == 1 | ||||||
|     assert torch.sum(ok_points).item() > 0, 'there is no visiable point' |         assert torch.sum(ok_points).item() > 0, "there is no visiable point" | ||||||
|     points[:2,:] = normalize_points(shape, points[:2,:]) |         points[:2, :] = normalize_points(shape, points[:2, :]) | ||||||
|  |  | ||||||
|     norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() |         norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float() | ||||||
|  |  | ||||||
|     trans_points, ___ = torch.gesv(points[:, ok_points], theta) |         trans_points, ___ = torch.gesv(points[:, ok_points], theta) | ||||||
|  |  | ||||||
|     norm_trans_points[:, ok_points] = trans_points |         norm_trans_points[:, ok_points] = trans_points | ||||||
|      |  | ||||||
|   return norm_trans_points |  | ||||||
|  |  | ||||||
|  |     return norm_trans_points | ||||||
|  |  | ||||||
|  |  | ||||||
| def apply_boundary(norm_trans_points): | def apply_boundary(norm_trans_points): | ||||||
|   with torch.no_grad(): |     with torch.no_grad(): | ||||||
|     norm_trans_points = norm_trans_points.clone() |         norm_trans_points = norm_trans_points.clone() | ||||||
|     oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0)) |         oks = torch.stack( | ||||||
|     oks = torch.sum(oks, dim=0) == 5 |             ( | ||||||
|     norm_trans_points[2, :] = oks |                 norm_trans_points[0] > -1, | ||||||
|   return norm_trans_points |                 norm_trans_points[0] < 1, | ||||||
|  |                 norm_trans_points[1] > -1, | ||||||
|  |                 norm_trans_points[1] < 1, | ||||||
|  |                 norm_trans_points[2] > 0, | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  |         oks = torch.sum(oks, dim=0) == 5 | ||||||
|  |         norm_trans_points[2, :] = oks | ||||||
|  |     return norm_trans_points | ||||||
|   | |||||||
| @@ -1,39 +1,123 @@ | |||||||
| ##################################################### | ##################################################### | ||||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
| ##################################################### | ##################################################### | ||||||
|  | import math | ||||||
| import numpy as np | import numpy as np | ||||||
| from typing import Optional | from typing import Optional | ||||||
|  | import torch | ||||||
| import torch.utils.data as data | import torch.utils.data as data | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class QuadraticFunction: | ||||||
|  |     """The quadratic function that outputs f(x) = a * x^2 + b * x + c.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, list_of_points=None): | ||||||
|  |         self._params = dict(a=None, b=None, c=None) | ||||||
|  |         if list_of_points is not None: | ||||||
|  |             self.fit(list_of_points) | ||||||
|  |  | ||||||
|  |     def set(self, a, b, c): | ||||||
|  |         self._params["a"] = a | ||||||
|  |         self._params["b"] = b | ||||||
|  |         self._params["c"] = c | ||||||
|  |  | ||||||
|  |     def check_valid(self): | ||||||
|  |         for key, value in self._params.items(): | ||||||
|  |             if value is None: | ||||||
|  |                 raise ValueError("The {:} is None".format(key)) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, x): | ||||||
|  |         self.check_valid() | ||||||
|  |         return self._params["a"] * x * x + self._params["b"] * x + self._params["c"] | ||||||
|  |  | ||||||
|  |     def fit( | ||||||
|  |         self, | ||||||
|  |         list_of_points, | ||||||
|  |         transf=lambda x: x, | ||||||
|  |         max_iter=900, | ||||||
|  |         lr_max=1.0, | ||||||
|  |         verbose=False, | ||||||
|  |     ): | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             data = torch.Tensor(list_of_points).type(torch.float32) | ||||||
|  |             assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format( | ||||||
|  |                 data.shape | ||||||
|  |             ) | ||||||
|  |             x, y = data[:, 0], data[:, 1] | ||||||
|  |         weights = torch.nn.Parameter(torch.Tensor(3)) | ||||||
|  |         torch.nn.init.normal_(weights, mean=0.0, std=1.0) | ||||||
|  |         optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True) | ||||||
|  |         lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_iter*0.25), int(max_iter*0.5), int(max_iter*0.75)], gamma=0.1) | ||||||
|  |         if verbose: | ||||||
|  |             print("The optimizer: {:}".format(optimizer)) | ||||||
|  |  | ||||||
|  |         best_loss = None | ||||||
|  |         for _iter in range(max_iter): | ||||||
|  |             y_hat = transf(weights[0] * x * x + weights[1] * x + weights[2]) | ||||||
|  |             loss = torch.mean(torch.abs(y - y_hat)) | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss.backward() | ||||||
|  |             optimizer.step() | ||||||
|  |             lr_scheduler.step() | ||||||
|  |             if verbose: | ||||||
|  |                 print( | ||||||
|  |                     "In QuadraticFunction's fit, loss at the {:02d}/{:02d}-th iter is {:}".format( | ||||||
|  |                         _iter, max_iter, loss.item() | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |             # Update the params | ||||||
|  |             if best_loss is None or best_loss > loss.item(): | ||||||
|  |                 best_loss = loss.item() | ||||||
|  |                 self._params["a"] = weights[0].item() | ||||||
|  |                 self._params["b"] = weights[1].item() | ||||||
|  |                 self._params["c"] = weights[2].item() | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}(y = {a} * x^2 + {b} * x + {c})".format( | ||||||
|  |             name=self.__class__.__name__, | ||||||
|  |             a=self._params["a"], | ||||||
|  |             b=self._params["b"], | ||||||
|  |             c=self._params["c"], | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class SynAdaptiveEnv(data.Dataset): | class SynAdaptiveEnv(data.Dataset): | ||||||
|     """The synethtic dataset for adaptive environment.""" |     """The synethtic dataset for adaptive environment. | ||||||
|  |  | ||||||
|  |     - x in [0, 1] | ||||||
|  |     - y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) ) | ||||||
|  |     - where | ||||||
|  |     - the amplitude scale is a quadratic function of x | ||||||
|  |     - the period-phase-shift is another quadratic function of x | ||||||
|  |  | ||||||
|  |     """ | ||||||
|  |  | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|         max_num_phase: int = 100, |         num: int = 100, | ||||||
|         interval: float = 0.1, |         num_sin_phase: int = 4, | ||||||
|         max_scale: float = 4, |         min_amplitude: float = 1, | ||||||
|         offset_scale: float = 1.5, |         max_amplitude: float = 4, | ||||||
|  |         phase_shift: float = 0, | ||||||
|         mode: Optional[str] = None, |         mode: Optional[str] = None, | ||||||
|     ): |     ): | ||||||
|  |         self._amplitude_scale = QuadraticFunction( | ||||||
|  |             [(0, min_amplitude), (0.5, max_amplitude), (0, min_amplitude)] | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         self._max_num_phase = max_num_phase |         self._num_sin_phase = num_sin_phase | ||||||
|         self._interval = interval |         self._interval = 1.0 / (float(num) - 1) | ||||||
|  |         self._total_num = num | ||||||
|  |  | ||||||
|  |         self._period_phase_shift = QuadraticFunction() | ||||||
|  |  | ||||||
|  |         fitting_data = [] | ||||||
|  |         temp_max_scalar = 2 ** num_sin_phase | ||||||
|  |         for i in range(num_sin_phase): | ||||||
|  |             value = (2 ** i) / temp_max_scalar | ||||||
|  |             fitting_data.append((value, math.sin(value))) | ||||||
|  |         self._period_phase_shift.fit(fitting_data, transf=lambda x: torch.sin(x)) | ||||||
|  |  | ||||||
|         self._times = np.arange(0, np.pi * self._max_num_phase, self._interval) |  | ||||||
|         xmin, xmax = self._times.min(), self._times.max() |  | ||||||
|         self._inputs = [] |  | ||||||
|         self._total_num = len(self._times) |  | ||||||
|         for i in range(self._total_num): |  | ||||||
|             scale = (i + 1.0) / self._total_num * max_scale |  | ||||||
|             sin_scale = (i + 1.0) / self._total_num * 0.7 |  | ||||||
|             sin_scale = -4 * (sin_scale - 0.5) ** 2 + 1 |  | ||||||
|             # scale = -(self._times[i] - (xmin - xmax) / 2) + max_scale |  | ||||||
|             self._inputs.append( |  | ||||||
|                 np.sin(self._times[i] * sin_scale) * (offset_scale - scale) |  | ||||||
|             ) |  | ||||||
|         self._inputs = np.array(self._inputs) |  | ||||||
|         # Training Set 60% |         # Training Set 60% | ||||||
|         num_of_train = int(self._total_num * 0.6) |         num_of_train = int(self._total_num * 0.6) | ||||||
|         # Validation Set 20% |         # Validation Set 20% | ||||||
| @@ -70,10 +154,11 @@ class SynAdaptiveEnv(data.Dataset): | |||||||
|     def __getitem__(self, index): |     def __getitem__(self, index): | ||||||
|         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) |         assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self)) | ||||||
|         index = self._indexes[index] |         index = self._indexes[index] | ||||||
|         value = float(self._inputs[index]) |         position = self._interval * index | ||||||
|         if self._transform is not None: |         value = self._amplitude_scale[position] * math.sin( | ||||||
|             value = self._transform(value) |             self._period_phase_shift[position] | ||||||
|         return index, float(self._times[index]), value |         ) | ||||||
|  |         return index, position, value | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._indexes) |         return len(self._indexes) | ||||||
|   | |||||||
| @@ -5,16 +5,20 @@ import os | |||||||
|  |  | ||||||
|  |  | ||||||
| def test_imagenet_data(imagenet): | def test_imagenet_data(imagenet): | ||||||
|   total_length = len(imagenet) |     total_length = len(imagenet) | ||||||
|   assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) |     assert ( | ||||||
|   map_id = {} |         total_length == 1281166 or total_length == 50000 | ||||||
|   for index in range(total_length): |     ), "The length of ImageNet is wrong : {}".format(total_length) | ||||||
|     path, target = imagenet.imgs[index] |     map_id = {} | ||||||
|     folder, image_name = os.path.split(path) |     for index in range(total_length): | ||||||
|     _, folder = os.path.split(folder) |         path, target = imagenet.imgs[index] | ||||||
|     if folder not in map_id: |         folder, image_name = os.path.split(path) | ||||||
|       map_id[folder] = target |         _, folder = os.path.split(folder) | ||||||
|     else: |         if folder not in map_id: | ||||||
|       assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) |             map_id[folder] = target | ||||||
|     assert image_name.find(folder) == 0, '{} is wrong.'.format(path) |         else: | ||||||
|   print ('Check ImageNet Dataset OK') |             assert map_id[folder] == target, "Class : {} is not {}".format( | ||||||
|  |                 folder, target | ||||||
|  |             ) | ||||||
|  |         assert image_name.find(folder) == 0, "{} is wrong.".format(path) | ||||||
|  |     print("Check ImageNet Dataset OK") | ||||||
|   | |||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -13,9 +13,33 @@ print("library path: {:}".format(lib_dir)) | |||||||
| if str(lib_dir) not in sys.path: | if str(lib_dir) not in sys.path: | ||||||
|     sys.path.insert(0, str(lib_dir)) |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from datasets import QuadraticFunction | ||||||
| from datasets import SynAdaptiveEnv | from datasets import SynAdaptiveEnv | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class TestQuadraticFunction(unittest.TestCase): | ||||||
|  |     """Test the quadratic function.""" | ||||||
|  |  | ||||||
|  |     def test_simple(self): | ||||||
|  |         function = QuadraticFunction([[0, 1], [0.5, 4], [1, 1]]) | ||||||
|  |         print(function) | ||||||
|  |         for x in (0, 0.5, 1): | ||||||
|  |             print("f({:})={:}".format(x, function[x])) | ||||||
|  |         thresh = 0.2 | ||||||
|  |         self.assertTrue(abs(function[0] - 1) < thresh) | ||||||
|  |         self.assertTrue(abs(function[0.5] - 4) < thresh) | ||||||
|  |         self.assertTrue(abs(function[1] - 1) < thresh) | ||||||
|  |  | ||||||
|  |     def test_none(self): | ||||||
|  |         function = QuadraticFunction() | ||||||
|  |         function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True) | ||||||
|  |         print(function) | ||||||
|  |         thresh = 0.2 | ||||||
|  |         self.assertTrue(abs(function[0] - 1) < thresh) | ||||||
|  |         self.assertTrue(abs(function[0.5] - 4) < thresh) | ||||||
|  |         self.assertTrue(abs(function[1] - 1) < thresh) | ||||||
|  |  | ||||||
|  |  | ||||||
| class TestSynAdaptiveEnv(unittest.TestCase): | class TestSynAdaptiveEnv(unittest.TestCase): | ||||||
|     """Test the synethtic adaptive environment.""" |     """Test the synethtic adaptive environment.""" | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user