Updates
This commit is contained in:
		
							
								
								
									
										8
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										8
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							@@ -48,5 +48,13 @@ jobs:
 | 
			
		||||
          ls
 | 
			
		||||
          python --version
 | 
			
		||||
          python -m pytest ./tests/test_basic_space.py -s
 | 
			
		||||
        shell: bash
 | 
			
		||||
 | 
			
		||||
      - name: Test Synthetic Data
 | 
			
		||||
        run: |
 | 
			
		||||
          python -m pip install pytest numpy
 | 
			
		||||
          python -m pip install parameterized
 | 
			
		||||
          python -m pip install torch
 | 
			
		||||
          python --version
 | 
			
		||||
          python -m pytest ./tests/test_synthetic.py -s
 | 
			
		||||
        shell: bash
 | 
			
		||||
 
 | 
			
		||||
 Submodule .latent-data/NATS-Bench updated: 3a8794322f...33bfb2eb13
									
								
							@@ -5,118 +5,133 @@ import os, sys, hashlib, torch
 | 
			
		||||
import numpy as np
 | 
			
		||||
from PIL import Image
 | 
			
		||||
import torch.utils.data as data
 | 
			
		||||
 | 
			
		||||
if sys.version_info[0] == 2:
 | 
			
		||||
  import cPickle as pickle
 | 
			
		||||
    import cPickle as pickle
 | 
			
		||||
else:
 | 
			
		||||
  import pickle
 | 
			
		||||
    import pickle
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_md5(fpath, chunk_size=1024 * 1024):
 | 
			
		||||
  md5 = hashlib.md5()
 | 
			
		||||
  with open(fpath, 'rb') as f:
 | 
			
		||||
    for chunk in iter(lambda: f.read(chunk_size), b''):
 | 
			
		||||
      md5.update(chunk)
 | 
			
		||||
  return md5.hexdigest()
 | 
			
		||||
    md5 = hashlib.md5()
 | 
			
		||||
    with open(fpath, "rb") as f:
 | 
			
		||||
        for chunk in iter(lambda: f.read(chunk_size), b""):
 | 
			
		||||
            md5.update(chunk)
 | 
			
		||||
    return md5.hexdigest()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_md5(fpath, md5, **kwargs):
 | 
			
		||||
  return md5 == calculate_md5(fpath, **kwargs)
 | 
			
		||||
    return md5 == calculate_md5(fpath, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def check_integrity(fpath, md5=None):
 | 
			
		||||
  if not os.path.isfile(fpath): return False
 | 
			
		||||
  if md5 is None: return True
 | 
			
		||||
  else          : return check_md5(fpath, md5)
 | 
			
		||||
    if not os.path.isfile(fpath):
 | 
			
		||||
        return False
 | 
			
		||||
    if md5 is None:
 | 
			
		||||
        return True
 | 
			
		||||
    else:
 | 
			
		||||
        return check_md5(fpath, md5)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ImageNet16(data.Dataset):
 | 
			
		||||
  # http://image-net.org/download-images
 | 
			
		||||
  # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
 | 
			
		||||
  # https://arxiv.org/pdf/1707.08819.pdf
 | 
			
		||||
  
 | 
			
		||||
  train_list = [
 | 
			
		||||
        ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
 | 
			
		||||
        ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
 | 
			
		||||
        ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
 | 
			
		||||
        ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
 | 
			
		||||
        ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
 | 
			
		||||
        ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
 | 
			
		||||
        ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
 | 
			
		||||
        ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
 | 
			
		||||
        ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
 | 
			
		||||
        ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
 | 
			
		||||
    # http://image-net.org/download-images
 | 
			
		||||
    # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
 | 
			
		||||
    # https://arxiv.org/pdf/1707.08819.pdf
 | 
			
		||||
 | 
			
		||||
    train_list = [
 | 
			
		||||
        ["train_data_batch_1", "27846dcaa50de8e21a7d1a35f30f0e91"],
 | 
			
		||||
        ["train_data_batch_2", "c7254a054e0e795c69120a5727050e3f"],
 | 
			
		||||
        ["train_data_batch_3", "4333d3df2e5ffb114b05d2ffc19b1e87"],
 | 
			
		||||
        ["train_data_batch_4", "1620cdf193304f4a92677b695d70d10f"],
 | 
			
		||||
        ["train_data_batch_5", "348b3c2fdbb3940c4e9e834affd3b18d"],
 | 
			
		||||
        ["train_data_batch_6", "6e765307c242a1b3d7d5ef9139b48945"],
 | 
			
		||||
        ["train_data_batch_7", "564926d8cbf8fc4818ba23d2faac7564"],
 | 
			
		||||
        ["train_data_batch_8", "f4755871f718ccb653440b9dd0ebac66"],
 | 
			
		||||
        ["train_data_batch_9", "bb6dd660c38c58552125b1a92f86b5d4"],
 | 
			
		||||
        ["train_data_batch_10", "8f03f34ac4b42271a294f91bf480f29b"],
 | 
			
		||||
    ]
 | 
			
		||||
  valid_list = [
 | 
			
		||||
        ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
 | 
			
		||||
    valid_list = [
 | 
			
		||||
        ["val_data", "3410e3017fdaefba8d5073aaa65e4bd6"],
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
  def __init__(self, root, train, transform, use_num_of_class_only=None):
 | 
			
		||||
    self.root      = root
 | 
			
		||||
    self.transform = transform
 | 
			
		||||
    self.train     = train  # training set or valid set
 | 
			
		||||
    if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
 | 
			
		||||
    def __init__(self, root, train, transform, use_num_of_class_only=None):
 | 
			
		||||
        self.root = root
 | 
			
		||||
        self.transform = transform
 | 
			
		||||
        self.train = train  # training set or valid set
 | 
			
		||||
        if not self._check_integrity():
 | 
			
		||||
            raise RuntimeError("Dataset not found or corrupted.")
 | 
			
		||||
 | 
			
		||||
    if self.train: downloaded_list = self.train_list
 | 
			
		||||
    else         : downloaded_list = self.valid_list
 | 
			
		||||
    self.data    = []
 | 
			
		||||
    self.targets = []
 | 
			
		||||
  
 | 
			
		||||
    # now load the picked numpy arrays
 | 
			
		||||
    for i, (file_name, checksum) in enumerate(downloaded_list):
 | 
			
		||||
      file_path = os.path.join(self.root, file_name)
 | 
			
		||||
      #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
 | 
			
		||||
      with open(file_path, 'rb') as f:
 | 
			
		||||
        if sys.version_info[0] == 2:
 | 
			
		||||
          entry = pickle.load(f)
 | 
			
		||||
        if self.train:
 | 
			
		||||
            downloaded_list = self.train_list
 | 
			
		||||
        else:
 | 
			
		||||
          entry = pickle.load(f, encoding='latin1')
 | 
			
		||||
        self.data.append(entry['data'])
 | 
			
		||||
        self.targets.extend(entry['labels'])
 | 
			
		||||
    self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
 | 
			
		||||
    self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
 | 
			
		||||
    if use_num_of_class_only is not None:
 | 
			
		||||
      assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
 | 
			
		||||
      new_data, new_targets = [], []
 | 
			
		||||
      for I, L in zip(self.data, self.targets):
 | 
			
		||||
        if 1 <= L <= use_num_of_class_only:
 | 
			
		||||
          new_data.append( I )
 | 
			
		||||
          new_targets.append( L )
 | 
			
		||||
      self.data    = new_data
 | 
			
		||||
      self.targets = new_targets
 | 
			
		||||
    #    self.mean.append(entry['mean'])
 | 
			
		||||
    #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
 | 
			
		||||
    #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
 | 
			
		||||
    #print ('Mean : {:}'.format(self.mean))
 | 
			
		||||
    #temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3))
 | 
			
		||||
    #std_data  = np.std(temp, axis=0)
 | 
			
		||||
    #std_data  = np.mean(np.mean(std_data, axis=0), axis=0)
 | 
			
		||||
    #print ('Std  : {:}'.format(std_data))
 | 
			
		||||
            downloaded_list = self.valid_list
 | 
			
		||||
        self.data = []
 | 
			
		||||
        self.targets = []
 | 
			
		||||
 | 
			
		||||
  def __repr__(self):
 | 
			
		||||
    return ('{name}({num} images, {classes} classes)'.format(name=self.__class__.__name__, num=len(self.data), classes=len(set(self.targets))))
 | 
			
		||||
        # now load the picked numpy arrays
 | 
			
		||||
        for i, (file_name, checksum) in enumerate(downloaded_list):
 | 
			
		||||
            file_path = os.path.join(self.root, file_name)
 | 
			
		||||
            # print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
 | 
			
		||||
            with open(file_path, "rb") as f:
 | 
			
		||||
                if sys.version_info[0] == 2:
 | 
			
		||||
                    entry = pickle.load(f)
 | 
			
		||||
                else:
 | 
			
		||||
                    entry = pickle.load(f, encoding="latin1")
 | 
			
		||||
                self.data.append(entry["data"])
 | 
			
		||||
                self.targets.extend(entry["labels"])
 | 
			
		||||
        self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
 | 
			
		||||
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
 | 
			
		||||
        if use_num_of_class_only is not None:
 | 
			
		||||
            assert (
 | 
			
		||||
                isinstance(use_num_of_class_only, int)
 | 
			
		||||
                and use_num_of_class_only > 0
 | 
			
		||||
                and use_num_of_class_only < 1000
 | 
			
		||||
            ), "invalid use_num_of_class_only : {:}".format(use_num_of_class_only)
 | 
			
		||||
            new_data, new_targets = [], []
 | 
			
		||||
            for I, L in zip(self.data, self.targets):
 | 
			
		||||
                if 1 <= L <= use_num_of_class_only:
 | 
			
		||||
                    new_data.append(I)
 | 
			
		||||
                    new_targets.append(L)
 | 
			
		||||
            self.data = new_data
 | 
			
		||||
            self.targets = new_targets
 | 
			
		||||
        #    self.mean.append(entry['mean'])
 | 
			
		||||
        # self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
 | 
			
		||||
        # self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
 | 
			
		||||
        # print ('Mean : {:}'.format(self.mean))
 | 
			
		||||
        # temp      = self.data - np.reshape(self.mean, (1, 1, 1, 3))
 | 
			
		||||
        # std_data  = np.std(temp, axis=0)
 | 
			
		||||
        # std_data  = np.mean(np.mean(std_data, axis=0), axis=0)
 | 
			
		||||
        # print ('Std  : {:}'.format(std_data))
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "{name}({num} images, {classes} classes)".format(
 | 
			
		||||
            name=self.__class__.__name__,
 | 
			
		||||
            num=len(self.data),
 | 
			
		||||
            classes=len(set(self.targets)),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
  def __getitem__(self, index):
 | 
			
		||||
    img, target = self.data[index], self.targets[index] - 1
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        img, target = self.data[index], self.targets[index] - 1
 | 
			
		||||
 | 
			
		||||
    img = Image.fromarray(img)
 | 
			
		||||
        img = Image.fromarray(img)
 | 
			
		||||
 | 
			
		||||
    if self.transform is not None:
 | 
			
		||||
      img = self.transform(img)
 | 
			
		||||
        if self.transform is not None:
 | 
			
		||||
            img = self.transform(img)
 | 
			
		||||
 | 
			
		||||
    return img, target
 | 
			
		||||
        return img, target
 | 
			
		||||
 | 
			
		||||
  def __len__(self):
 | 
			
		||||
    return len(self.data)
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.data)
 | 
			
		||||
 | 
			
		||||
    def _check_integrity(self):
 | 
			
		||||
        root = self.root
 | 
			
		||||
        for fentry in self.train_list + self.valid_list:
 | 
			
		||||
            filename, md5 = fentry[0], fentry[1]
 | 
			
		||||
            fpath = os.path.join(root, filename)
 | 
			
		||||
            if not check_integrity(fpath, md5):
 | 
			
		||||
                return False
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
  def _check_integrity(self):
 | 
			
		||||
    root = self.root
 | 
			
		||||
    for fentry in (self.train_list + self.valid_list):
 | 
			
		||||
      filename, md5 = fentry[0], fentry[1]
 | 
			
		||||
      fpath = os.path.join(root, filename)
 | 
			
		||||
      if not check_integrity(fpath, md5):
 | 
			
		||||
        return False
 | 
			
		||||
    return True
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
 
 | 
			
		||||
@@ -20,172 +20,282 @@ import torch.utils.data as data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LandmarkDataset(data.Dataset):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        transform,
 | 
			
		||||
        sigma,
 | 
			
		||||
        downsample,
 | 
			
		||||
        heatmap_type,
 | 
			
		||||
        shape,
 | 
			
		||||
        use_gray,
 | 
			
		||||
        mean_file,
 | 
			
		||||
        data_indicator,
 | 
			
		||||
        cache_images=None,
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
  def __init__(self, transform, sigma, downsample, heatmap_type, shape, use_gray, mean_file, data_indicator, cache_images=None):
 | 
			
		||||
 | 
			
		||||
    self.transform    = transform
 | 
			
		||||
    self.sigma        = sigma
 | 
			
		||||
    self.downsample   = downsample
 | 
			
		||||
    self.heatmap_type = heatmap_type
 | 
			
		||||
    self.dataset_name = data_indicator
 | 
			
		||||
    self.shape        = shape # [H,W]
 | 
			
		||||
    self.use_gray     = use_gray
 | 
			
		||||
    assert transform is not None, 'transform : {:}'.format(transform)
 | 
			
		||||
    self.mean_file    = mean_file
 | 
			
		||||
    if mean_file is None:
 | 
			
		||||
      self.mean_data  = None
 | 
			
		||||
      warnings.warn('LandmarkDataset initialized with mean_data = None')
 | 
			
		||||
    else:
 | 
			
		||||
      assert osp.isfile(mean_file), '{:} is not a file.'.format(mean_file)
 | 
			
		||||
      self.mean_data  = torch.load(mean_file)
 | 
			
		||||
    self.reset()
 | 
			
		||||
    self.cutout       = None
 | 
			
		||||
    self.cache_images = cache_images
 | 
			
		||||
    print ('The general dataset initialization done : {:}'.format(self))
 | 
			
		||||
    warnings.simplefilter( 'once' )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def __repr__(self):
 | 
			
		||||
    return ('{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def set_cutout(self, length):
 | 
			
		||||
    if length is not None and length >= 1:
 | 
			
		||||
      self.cutout = CutOut( int(length) )
 | 
			
		||||
    else: self.cutout = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def reset(self, num_pts=-1, boxid='default', only_pts=False):
 | 
			
		||||
    self.NUM_PTS = num_pts
 | 
			
		||||
    if only_pts: return
 | 
			
		||||
    self.length  = 0
 | 
			
		||||
    self.datas   = []
 | 
			
		||||
    self.labels  = []
 | 
			
		||||
    self.NormDistances = []
 | 
			
		||||
    self.BOXID = boxid
 | 
			
		||||
    if self.mean_data is None:
 | 
			
		||||
      self.mean_face = None
 | 
			
		||||
    else:
 | 
			
		||||
      self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
 | 
			
		||||
      assert (self.mean_face >= -1).all() and (self.mean_face <= 1).all(), 'mean-{:}-face : {:}'.format(boxid, self.mean_face)
 | 
			
		||||
    #assert self.dataset_name is not None, 'The dataset name is None'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def __len__(self):
 | 
			
		||||
    assert len(self.datas) == self.length, 'The length is not correct : {}'.format(self.length)
 | 
			
		||||
    return self.length
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def append(self, data, label, distance):
 | 
			
		||||
    assert osp.isfile(data), 'The image path is not a file : {:}'.format(data)
 | 
			
		||||
    self.datas.append( data )             ;  self.labels.append( label )
 | 
			
		||||
    self.NormDistances.append( distance )
 | 
			
		||||
    self.length = self.length + 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
 | 
			
		||||
    if reset: self.reset(num_pts, boxindicator)
 | 
			
		||||
    else    : assert self.NUM_PTS == num_pts and self.BOXID == boxindicator, 'The number of point is inconsistance : {:} vs {:}'.format(self.NUM_PTS, num_pts)
 | 
			
		||||
    if isinstance(file_lists, str): file_lists = [file_lists]
 | 
			
		||||
    samples = []
 | 
			
		||||
    for idx, file_path in enumerate(file_lists):
 | 
			
		||||
      print (':::: load list {:}/{:} : {:}'.format(idx, len(file_lists), file_path))
 | 
			
		||||
      xdata = torch.load(file_path)
 | 
			
		||||
      if isinstance(xdata, list)  : data = xdata          # image or video dataset list
 | 
			
		||||
      elif isinstance(xdata, dict): data = xdata['datas'] # multi-view dataset list
 | 
			
		||||
      else: raise ValueError('Invalid Type Error : {:}'.format( type(xdata) ))
 | 
			
		||||
      samples = samples + data
 | 
			
		||||
    # samples is a dict, where the key is the image-path and the value is the annotation
 | 
			
		||||
    # each annotation is a dict, contains 'points' (3,num_pts), and various box
 | 
			
		||||
    print ('GeneralDataset-V2 : {:} samples'.format(len(samples)))
 | 
			
		||||
 | 
			
		||||
    #for index, annotation in enumerate(samples):
 | 
			
		||||
    for index in tqdm( range( len(samples) ) ):
 | 
			
		||||
      annotation = samples[index]
 | 
			
		||||
      image_path  = annotation['current_frame']
 | 
			
		||||
      points, box = annotation['points'], annotation['box-{:}'.format(boxindicator)]
 | 
			
		||||
      label = PointMeta2V(self.NUM_PTS, points, box, image_path, self.dataset_name)
 | 
			
		||||
      if normalizeL is None: normDistance = None
 | 
			
		||||
      else                 : normDistance = annotation['normalizeL-{:}'.format(normalizeL)]
 | 
			
		||||
      self.append(image_path, label, normDistance)
 | 
			
		||||
 | 
			
		||||
    assert len(self.datas) == self.length, 'The length and the data is not right {} vs {}'.format(self.length, len(self.datas))
 | 
			
		||||
    assert len(self.labels) == self.length, 'The length and the labels is not right {} vs {}'.format(self.length, len(self.labels))
 | 
			
		||||
    assert len(self.NormDistances) == self.length, 'The length and the NormDistances is not right {} vs {}'.format(self.length, len(self.NormDistance))
 | 
			
		||||
    print ('Load data done for LandmarkDataset, which has {:} images.'.format(self.length))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def __getitem__(self, index):
 | 
			
		||||
    assert index >= 0 and index < self.length, 'Invalid index : {:}'.format(index)
 | 
			
		||||
    if self.cache_images is not None and self.datas[index] in self.cache_images:
 | 
			
		||||
      image = self.cache_images[ self.datas[index] ].clone()
 | 
			
		||||
    else:
 | 
			
		||||
      image = pil_loader(self.datas[index], self.use_gray)
 | 
			
		||||
    target = self.labels[index].copy()
 | 
			
		||||
    return self._process_(image, target, index)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
  def _process_(self, image, target, index):
 | 
			
		||||
 | 
			
		||||
    # transform the image and points
 | 
			
		||||
    image, target, theta = self.transform(image, target)
 | 
			
		||||
    (C, H, W), (height, width) = image.size(), self.shape
 | 
			
		||||
 | 
			
		||||
    # obtain the visiable indicator vector
 | 
			
		||||
    if target.is_none(): nopoints = True
 | 
			
		||||
    else               : nopoints = False
 | 
			
		||||
    if index == -1: __path = None
 | 
			
		||||
    else          : __path = self.datas[index]
 | 
			
		||||
    if isinstance(theta, list) or isinstance(theta, tuple):
 | 
			
		||||
      affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = [], [], [], [], [], []
 | 
			
		||||
      for _theta in theta:
 | 
			
		||||
        _affineImage, _heatmaps, _mask, _norm_trans_points, _theta, _transpose_theta \
 | 
			
		||||
          = self.__process_affine(image, target, _theta, nopoints, 'P[{:}]@{:}'.format(index, __path))
 | 
			
		||||
        affineImage.append(_affineImage)
 | 
			
		||||
        heatmaps.append(_heatmaps)
 | 
			
		||||
        mask.append(_mask)
 | 
			
		||||
        norm_trans_points.append(_norm_trans_points)
 | 
			
		||||
        THETA.append(_theta)
 | 
			
		||||
        transpose_theta.append(_transpose_theta)
 | 
			
		||||
      affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = \
 | 
			
		||||
          torch.stack(affineImage), torch.stack(heatmaps), torch.stack(mask), torch.stack(norm_trans_points), torch.stack(THETA), torch.stack(transpose_theta)
 | 
			
		||||
    else:
 | 
			
		||||
      affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = self.__process_affine(image, target, theta, nopoints, 'S[{:}]@{:}'.format(index, __path))
 | 
			
		||||
 | 
			
		||||
    torch_index = torch.IntTensor([index])
 | 
			
		||||
    torch_nopoints = torch.ByteTensor( [ nopoints ] )
 | 
			
		||||
    torch_shape = torch.IntTensor([H,W])
 | 
			
		||||
 | 
			
		||||
    return affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta, torch_index, torch_nopoints, torch_shape
 | 
			
		||||
 | 
			
		||||
  
 | 
			
		||||
  def __process_affine(self, image, target, theta, nopoints, aux_info=None):
 | 
			
		||||
    image, target, theta = image.clone(), target.copy(), theta.clone()
 | 
			
		||||
    (C, H, W), (height, width) = image.size(), self.shape
 | 
			
		||||
    if nopoints: # do not have label
 | 
			
		||||
      norm_trans_points = torch.zeros((3, self.NUM_PTS))
 | 
			
		||||
      heatmaps          = torch.zeros((self.NUM_PTS+1, height//self.downsample, width//self.downsample))
 | 
			
		||||
      mask              = torch.ones((self.NUM_PTS+1, 1, 1), dtype=torch.uint8)
 | 
			
		||||
      transpose_theta   = identity2affine(False)
 | 
			
		||||
    else:
 | 
			
		||||
      norm_trans_points = apply_affine2point(target.get_points(), theta, (H,W))
 | 
			
		||||
      norm_trans_points = apply_boundary(norm_trans_points)
 | 
			
		||||
      real_trans_points = norm_trans_points.clone()
 | 
			
		||||
      real_trans_points[:2, :] = denormalize_points(self.shape, real_trans_points[:2,:])
 | 
			
		||||
      heatmaps, mask = generate_label_map(real_trans_points.numpy(), height//self.downsample, width//self.downsample, self.sigma, self.downsample, nopoints, self.heatmap_type) # H*W*C
 | 
			
		||||
      heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(torch.FloatTensor)
 | 
			
		||||
      mask     = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
 | 
			
		||||
      if self.mean_face is None:
 | 
			
		||||
        #warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
 | 
			
		||||
        transpose_theta = identity2affine(False)
 | 
			
		||||
      else:
 | 
			
		||||
        if torch.sum(norm_trans_points[2,:] == 1) < 3:
 | 
			
		||||
          warnings.warn('In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}'.format(aux_info))
 | 
			
		||||
          transpose_theta = identity2affine(False)
 | 
			
		||||
        self.transform = transform
 | 
			
		||||
        self.sigma = sigma
 | 
			
		||||
        self.downsample = downsample
 | 
			
		||||
        self.heatmap_type = heatmap_type
 | 
			
		||||
        self.dataset_name = data_indicator
 | 
			
		||||
        self.shape = shape  # [H,W]
 | 
			
		||||
        self.use_gray = use_gray
 | 
			
		||||
        assert transform is not None, "transform : {:}".format(transform)
 | 
			
		||||
        self.mean_file = mean_file
 | 
			
		||||
        if mean_file is None:
 | 
			
		||||
            self.mean_data = None
 | 
			
		||||
            warnings.warn("LandmarkDataset initialized with mean_data = None")
 | 
			
		||||
        else:
 | 
			
		||||
          transpose_theta = solve2theta(norm_trans_points, self.mean_face.clone())
 | 
			
		||||
            assert osp.isfile(mean_file), "{:} is not a file.".format(mean_file)
 | 
			
		||||
            self.mean_data = torch.load(mean_file)
 | 
			
		||||
        self.reset()
 | 
			
		||||
        self.cutout = None
 | 
			
		||||
        self.cache_images = cache_images
 | 
			
		||||
        print("The general dataset initialization done : {:}".format(self))
 | 
			
		||||
        warnings.simplefilter("once")
 | 
			
		||||
 | 
			
		||||
    affineImage = affine2image(image, theta, self.shape)
 | 
			
		||||
    if self.cutout is not None: affineImage = self.cutout( affineImage )
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "{name}(point-num={NUM_PTS}, shape={shape}, sigma={sigma}, heatmap_type={heatmap_type}, length={length}, cutout={cutout}, dataset={dataset_name}, mean={mean_file})".format(
 | 
			
		||||
            name=self.__class__.__name__, **self.__dict__
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
 | 
			
		||||
    def set_cutout(self, length):
 | 
			
		||||
        if length is not None and length >= 1:
 | 
			
		||||
            self.cutout = CutOut(int(length))
 | 
			
		||||
        else:
 | 
			
		||||
            self.cutout = None
 | 
			
		||||
 | 
			
		||||
    def reset(self, num_pts=-1, boxid="default", only_pts=False):
 | 
			
		||||
        self.NUM_PTS = num_pts
 | 
			
		||||
        if only_pts:
 | 
			
		||||
            return
 | 
			
		||||
        self.length = 0
 | 
			
		||||
        self.datas = []
 | 
			
		||||
        self.labels = []
 | 
			
		||||
        self.NormDistances = []
 | 
			
		||||
        self.BOXID = boxid
 | 
			
		||||
        if self.mean_data is None:
 | 
			
		||||
            self.mean_face = None
 | 
			
		||||
        else:
 | 
			
		||||
            self.mean_face = torch.Tensor(self.mean_data[boxid].copy().T)
 | 
			
		||||
            assert (self.mean_face >= -1).all() and (
 | 
			
		||||
                self.mean_face <= 1
 | 
			
		||||
            ).all(), "mean-{:}-face : {:}".format(boxid, self.mean_face)
 | 
			
		||||
        # assert self.dataset_name is not None, 'The dataset name is None'
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        assert len(self.datas) == self.length, "The length is not correct : {}".format(
 | 
			
		||||
            self.length
 | 
			
		||||
        )
 | 
			
		||||
        return self.length
 | 
			
		||||
 | 
			
		||||
    def append(self, data, label, distance):
 | 
			
		||||
        assert osp.isfile(data), "The image path is not a file : {:}".format(data)
 | 
			
		||||
        self.datas.append(data)
 | 
			
		||||
        self.labels.append(label)
 | 
			
		||||
        self.NormDistances.append(distance)
 | 
			
		||||
        self.length = self.length + 1
 | 
			
		||||
 | 
			
		||||
    def load_list(self, file_lists, num_pts, boxindicator, normalizeL, reset):
 | 
			
		||||
        if reset:
 | 
			
		||||
            self.reset(num_pts, boxindicator)
 | 
			
		||||
        else:
 | 
			
		||||
            assert (
 | 
			
		||||
                self.NUM_PTS == num_pts and self.BOXID == boxindicator
 | 
			
		||||
            ), "The number of point is inconsistance : {:} vs {:}".format(
 | 
			
		||||
                self.NUM_PTS, num_pts
 | 
			
		||||
            )
 | 
			
		||||
        if isinstance(file_lists, str):
 | 
			
		||||
            file_lists = [file_lists]
 | 
			
		||||
        samples = []
 | 
			
		||||
        for idx, file_path in enumerate(file_lists):
 | 
			
		||||
            print(
 | 
			
		||||
                ":::: load list {:}/{:} : {:}".format(idx, len(file_lists), file_path)
 | 
			
		||||
            )
 | 
			
		||||
            xdata = torch.load(file_path)
 | 
			
		||||
            if isinstance(xdata, list):
 | 
			
		||||
                data = xdata  # image or video dataset list
 | 
			
		||||
            elif isinstance(xdata, dict):
 | 
			
		||||
                data = xdata["datas"]  # multi-view dataset list
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError("Invalid Type Error : {:}".format(type(xdata)))
 | 
			
		||||
            samples = samples + data
 | 
			
		||||
        # samples is a dict, where the key is the image-path and the value is the annotation
 | 
			
		||||
        # each annotation is a dict, contains 'points' (3,num_pts), and various box
 | 
			
		||||
        print("GeneralDataset-V2 : {:} samples".format(len(samples)))
 | 
			
		||||
 | 
			
		||||
        # for index, annotation in enumerate(samples):
 | 
			
		||||
        for index in tqdm(range(len(samples))):
 | 
			
		||||
            annotation = samples[index]
 | 
			
		||||
            image_path = annotation["current_frame"]
 | 
			
		||||
            points, box = (
 | 
			
		||||
                annotation["points"],
 | 
			
		||||
                annotation["box-{:}".format(boxindicator)],
 | 
			
		||||
            )
 | 
			
		||||
            label = PointMeta2V(
 | 
			
		||||
                self.NUM_PTS, points, box, image_path, self.dataset_name
 | 
			
		||||
            )
 | 
			
		||||
            if normalizeL is None:
 | 
			
		||||
                normDistance = None
 | 
			
		||||
            else:
 | 
			
		||||
                normDistance = annotation["normalizeL-{:}".format(normalizeL)]
 | 
			
		||||
            self.append(image_path, label, normDistance)
 | 
			
		||||
 | 
			
		||||
        assert (
 | 
			
		||||
            len(self.datas) == self.length
 | 
			
		||||
        ), "The length and the data is not right {} vs {}".format(
 | 
			
		||||
            self.length, len(self.datas)
 | 
			
		||||
        )
 | 
			
		||||
        assert (
 | 
			
		||||
            len(self.labels) == self.length
 | 
			
		||||
        ), "The length and the labels is not right {} vs {}".format(
 | 
			
		||||
            self.length, len(self.labels)
 | 
			
		||||
        )
 | 
			
		||||
        assert (
 | 
			
		||||
            len(self.NormDistances) == self.length
 | 
			
		||||
        ), "The length and the NormDistances is not right {} vs {}".format(
 | 
			
		||||
            self.length, len(self.NormDistance)
 | 
			
		||||
        )
 | 
			
		||||
        print(
 | 
			
		||||
            "Load data done for LandmarkDataset, which has {:} images.".format(
 | 
			
		||||
                self.length
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        assert index >= 0 and index < self.length, "Invalid index : {:}".format(index)
 | 
			
		||||
        if self.cache_images is not None and self.datas[index] in self.cache_images:
 | 
			
		||||
            image = self.cache_images[self.datas[index]].clone()
 | 
			
		||||
        else:
 | 
			
		||||
            image = pil_loader(self.datas[index], self.use_gray)
 | 
			
		||||
        target = self.labels[index].copy()
 | 
			
		||||
        return self._process_(image, target, index)
 | 
			
		||||
 | 
			
		||||
    def _process_(self, image, target, index):
 | 
			
		||||
 | 
			
		||||
        # transform the image and points
 | 
			
		||||
        image, target, theta = self.transform(image, target)
 | 
			
		||||
        (C, H, W), (height, width) = image.size(), self.shape
 | 
			
		||||
 | 
			
		||||
        # obtain the visiable indicator vector
 | 
			
		||||
        if target.is_none():
 | 
			
		||||
            nopoints = True
 | 
			
		||||
        else:
 | 
			
		||||
            nopoints = False
 | 
			
		||||
        if index == -1:
 | 
			
		||||
            __path = None
 | 
			
		||||
        else:
 | 
			
		||||
            __path = self.datas[index]
 | 
			
		||||
        if isinstance(theta, list) or isinstance(theta, tuple):
 | 
			
		||||
            affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
 | 
			
		||||
                [],
 | 
			
		||||
                [],
 | 
			
		||||
                [],
 | 
			
		||||
                [],
 | 
			
		||||
                [],
 | 
			
		||||
                [],
 | 
			
		||||
            )
 | 
			
		||||
            for _theta in theta:
 | 
			
		||||
                (
 | 
			
		||||
                    _affineImage,
 | 
			
		||||
                    _heatmaps,
 | 
			
		||||
                    _mask,
 | 
			
		||||
                    _norm_trans_points,
 | 
			
		||||
                    _theta,
 | 
			
		||||
                    _transpose_theta,
 | 
			
		||||
                ) = self.__process_affine(
 | 
			
		||||
                    image, target, _theta, nopoints, "P[{:}]@{:}".format(index, __path)
 | 
			
		||||
                )
 | 
			
		||||
                affineImage.append(_affineImage)
 | 
			
		||||
                heatmaps.append(_heatmaps)
 | 
			
		||||
                mask.append(_mask)
 | 
			
		||||
                norm_trans_points.append(_norm_trans_points)
 | 
			
		||||
                THETA.append(_theta)
 | 
			
		||||
                transpose_theta.append(_transpose_theta)
 | 
			
		||||
            affineImage, heatmaps, mask, norm_trans_points, THETA, transpose_theta = (
 | 
			
		||||
                torch.stack(affineImage),
 | 
			
		||||
                torch.stack(heatmaps),
 | 
			
		||||
                torch.stack(mask),
 | 
			
		||||
                torch.stack(norm_trans_points),
 | 
			
		||||
                torch.stack(THETA),
 | 
			
		||||
                torch.stack(transpose_theta),
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            (
 | 
			
		||||
                affineImage,
 | 
			
		||||
                heatmaps,
 | 
			
		||||
                mask,
 | 
			
		||||
                norm_trans_points,
 | 
			
		||||
                THETA,
 | 
			
		||||
                transpose_theta,
 | 
			
		||||
            ) = self.__process_affine(
 | 
			
		||||
                image, target, theta, nopoints, "S[{:}]@{:}".format(index, __path)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        torch_index = torch.IntTensor([index])
 | 
			
		||||
        torch_nopoints = torch.ByteTensor([nopoints])
 | 
			
		||||
        torch_shape = torch.IntTensor([H, W])
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            affineImage,
 | 
			
		||||
            heatmaps,
 | 
			
		||||
            mask,
 | 
			
		||||
            norm_trans_points,
 | 
			
		||||
            THETA,
 | 
			
		||||
            transpose_theta,
 | 
			
		||||
            torch_index,
 | 
			
		||||
            torch_nopoints,
 | 
			
		||||
            torch_shape,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def __process_affine(self, image, target, theta, nopoints, aux_info=None):
 | 
			
		||||
        image, target, theta = image.clone(), target.copy(), theta.clone()
 | 
			
		||||
        (C, H, W), (height, width) = image.size(), self.shape
 | 
			
		||||
        if nopoints:  # do not have label
 | 
			
		||||
            norm_trans_points = torch.zeros((3, self.NUM_PTS))
 | 
			
		||||
            heatmaps = torch.zeros(
 | 
			
		||||
                (self.NUM_PTS + 1, height // self.downsample, width // self.downsample)
 | 
			
		||||
            )
 | 
			
		||||
            mask = torch.ones((self.NUM_PTS + 1, 1, 1), dtype=torch.uint8)
 | 
			
		||||
            transpose_theta = identity2affine(False)
 | 
			
		||||
        else:
 | 
			
		||||
            norm_trans_points = apply_affine2point(target.get_points(), theta, (H, W))
 | 
			
		||||
            norm_trans_points = apply_boundary(norm_trans_points)
 | 
			
		||||
            real_trans_points = norm_trans_points.clone()
 | 
			
		||||
            real_trans_points[:2, :] = denormalize_points(
 | 
			
		||||
                self.shape, real_trans_points[:2, :]
 | 
			
		||||
            )
 | 
			
		||||
            heatmaps, mask = generate_label_map(
 | 
			
		||||
                real_trans_points.numpy(),
 | 
			
		||||
                height // self.downsample,
 | 
			
		||||
                width // self.downsample,
 | 
			
		||||
                self.sigma,
 | 
			
		||||
                self.downsample,
 | 
			
		||||
                nopoints,
 | 
			
		||||
                self.heatmap_type,
 | 
			
		||||
            )  # H*W*C
 | 
			
		||||
            heatmaps = torch.from_numpy(heatmaps.transpose((2, 0, 1))).type(
 | 
			
		||||
                torch.FloatTensor
 | 
			
		||||
            )
 | 
			
		||||
            mask = torch.from_numpy(mask.transpose((2, 0, 1))).type(torch.ByteTensor)
 | 
			
		||||
            if self.mean_face is None:
 | 
			
		||||
                # warnings.warn('In LandmarkDataset use identity2affine for transpose_theta because self.mean_face is None.')
 | 
			
		||||
                transpose_theta = identity2affine(False)
 | 
			
		||||
            else:
 | 
			
		||||
                if torch.sum(norm_trans_points[2, :] == 1) < 3:
 | 
			
		||||
                    warnings.warn(
 | 
			
		||||
                        "In LandmarkDataset after transformation, no visiable point, using identity instead. Aux: {:}".format(
 | 
			
		||||
                            aux_info
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                    transpose_theta = identity2affine(False)
 | 
			
		||||
                else:
 | 
			
		||||
                    transpose_theta = solve2theta(
 | 
			
		||||
                        norm_trans_points, self.mean_face.clone()
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
        affineImage = affine2image(image, theta, self.shape)
 | 
			
		||||
        if self.cutout is not None:
 | 
			
		||||
            affineImage = self.cutout(affineImage)
 | 
			
		||||
 | 
			
		||||
        return affineImage, heatmaps, mask, norm_trans_points, theta, transpose_theta
 | 
			
		||||
 
 | 
			
		||||
@@ -6,41 +6,49 @@ import torch.utils.data as data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SearchDataset(data.Dataset):
 | 
			
		||||
    def __init__(self, name, data, train_split, valid_split, check=True):
 | 
			
		||||
        self.datasetname = name
 | 
			
		||||
        if isinstance(data, (list, tuple)):  # new type of SearchDataset
 | 
			
		||||
            assert len(data) == 2, "invalid length: {:}".format(len(data))
 | 
			
		||||
            self.train_data = data[0]
 | 
			
		||||
            self.valid_data = data[1]
 | 
			
		||||
            self.train_split = train_split.copy()
 | 
			
		||||
            self.valid_split = valid_split.copy()
 | 
			
		||||
            self.mode_str = "V2"  # new mode
 | 
			
		||||
        else:
 | 
			
		||||
            self.mode_str = "V1"  # old mode
 | 
			
		||||
            self.data = data
 | 
			
		||||
            self.train_split = train_split.copy()
 | 
			
		||||
            self.valid_split = valid_split.copy()
 | 
			
		||||
            if check:
 | 
			
		||||
                intersection = set(train_split).intersection(set(valid_split))
 | 
			
		||||
                assert (
 | 
			
		||||
                    len(intersection) == 0
 | 
			
		||||
                ), "the splitted train and validation sets should have no intersection"
 | 
			
		||||
        self.length = len(self.train_split)
 | 
			
		||||
 | 
			
		||||
  def __init__(self, name, data, train_split, valid_split, check=True):
 | 
			
		||||
    self.datasetname = name
 | 
			
		||||
    if isinstance(data, (list, tuple)): # new type of SearchDataset
 | 
			
		||||
      assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
 | 
			
		||||
      self.train_data  = data[0]
 | 
			
		||||
      self.valid_data  = data[1]
 | 
			
		||||
      self.train_split = train_split.copy()
 | 
			
		||||
      self.valid_split = valid_split.copy()
 | 
			
		||||
      self.mode_str    = 'V2' # new mode 
 | 
			
		||||
    else:
 | 
			
		||||
      self.mode_str    = 'V1' # old mode 
 | 
			
		||||
      self.data        = data
 | 
			
		||||
      self.train_split = train_split.copy()
 | 
			
		||||
      self.valid_split = valid_split.copy()
 | 
			
		||||
      if check:
 | 
			
		||||
        intersection = set(train_split).intersection(set(valid_split))
 | 
			
		||||
        assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
 | 
			
		||||
    self.length      = len(self.train_split)
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})".format(
 | 
			
		||||
            name=self.__class__.__name__,
 | 
			
		||||
            datasetname=self.datasetname,
 | 
			
		||||
            tr_L=len(self.train_split),
 | 
			
		||||
            val_L=len(self.valid_split),
 | 
			
		||||
            ver=self.mode_str,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
  def __repr__(self):
 | 
			
		||||
    return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self.length
 | 
			
		||||
 | 
			
		||||
  def __len__(self):
 | 
			
		||||
    return self.length
 | 
			
		||||
 | 
			
		||||
  def __getitem__(self, index):
 | 
			
		||||
    assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
 | 
			
		||||
    train_index = self.train_split[index]
 | 
			
		||||
    valid_index = random.choice( self.valid_split )
 | 
			
		||||
    if self.mode_str == 'V1':
 | 
			
		||||
      train_image, train_label = self.data[train_index]
 | 
			
		||||
      valid_image, valid_label = self.data[valid_index]
 | 
			
		||||
    elif self.mode_str == 'V2':
 | 
			
		||||
      train_image, train_label = self.train_data[train_index]
 | 
			
		||||
      valid_image, valid_label = self.valid_data[valid_index]
 | 
			
		||||
    else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
 | 
			
		||||
    return train_image, train_label, valid_image, valid_label
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        assert index >= 0 and index < self.length, "invalid index = {:}".format(index)
 | 
			
		||||
        train_index = self.train_split[index]
 | 
			
		||||
        valid_index = random.choice(self.valid_split)
 | 
			
		||||
        if self.mode_str == "V1":
 | 
			
		||||
            train_image, train_label = self.data[train_index]
 | 
			
		||||
            valid_image, valid_label = self.data[valid_index]
 | 
			
		||||
        elif self.mode_str == "V2":
 | 
			
		||||
            train_image, train_label = self.train_data[train_index]
 | 
			
		||||
            valid_image, valid_label = self.valid_data[valid_index]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("invalid mode : {:}".format(self.mode_str))
 | 
			
		||||
        return train_image, train_label, valid_image, valid_label
 | 
			
		||||
 
 | 
			
		||||
@@ -4,4 +4,5 @@
 | 
			
		||||
from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
 | 
			
		||||
from .SearchDatasetWrap import SearchDataset
 | 
			
		||||
 | 
			
		||||
from .synthetic_adaptive_environment import QuadraticFunction
 | 
			
		||||
from .synthetic_adaptive_environment import SynAdaptiveEnv
 | 
			
		||||
 
 | 
			
		||||
@@ -14,214 +14,349 @@ from .SearchDatasetWrap import SearchDataset
 | 
			
		||||
from config_utils import load_config
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Dataset2Class = {'cifar10' : 10,
 | 
			
		||||
                 'cifar100': 100,
 | 
			
		||||
                 'imagenet-1k-s':1000,
 | 
			
		||||
                 'imagenet-1k' : 1000,
 | 
			
		||||
                 'ImageNet16'  : 1000,
 | 
			
		||||
                 'ImageNet16-150': 150,
 | 
			
		||||
                 'ImageNet16-120': 120,
 | 
			
		||||
                 'ImageNet16-200': 200}
 | 
			
		||||
Dataset2Class = {
 | 
			
		||||
    "cifar10": 10,
 | 
			
		||||
    "cifar100": 100,
 | 
			
		||||
    "imagenet-1k-s": 1000,
 | 
			
		||||
    "imagenet-1k": 1000,
 | 
			
		||||
    "ImageNet16": 1000,
 | 
			
		||||
    "ImageNet16-150": 150,
 | 
			
		||||
    "ImageNet16-120": 120,
 | 
			
		||||
    "ImageNet16-200": 200,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CUTOUT(object):
 | 
			
		||||
    def __init__(self, length):
 | 
			
		||||
        self.length = length
 | 
			
		||||
 | 
			
		||||
  def __init__(self, length):
 | 
			
		||||
    self.length = length
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "{name}(length={length})".format(
 | 
			
		||||
            name=self.__class__.__name__, **self.__dict__
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
  def __repr__(self):
 | 
			
		||||
    return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
 | 
			
		||||
    def __call__(self, img):
 | 
			
		||||
        h, w = img.size(1), img.size(2)
 | 
			
		||||
        mask = np.ones((h, w), np.float32)
 | 
			
		||||
        y = np.random.randint(h)
 | 
			
		||||
        x = np.random.randint(w)
 | 
			
		||||
 | 
			
		||||
  def __call__(self, img):
 | 
			
		||||
    h, w = img.size(1), img.size(2)
 | 
			
		||||
    mask = np.ones((h, w), np.float32)
 | 
			
		||||
    y = np.random.randint(h)
 | 
			
		||||
    x = np.random.randint(w)
 | 
			
		||||
        y1 = np.clip(y - self.length // 2, 0, h)
 | 
			
		||||
        y2 = np.clip(y + self.length // 2, 0, h)
 | 
			
		||||
        x1 = np.clip(x - self.length // 2, 0, w)
 | 
			
		||||
        x2 = np.clip(x + self.length // 2, 0, w)
 | 
			
		||||
 | 
			
		||||
    y1 = np.clip(y - self.length // 2, 0, h)
 | 
			
		||||
    y2 = np.clip(y + self.length // 2, 0, h)
 | 
			
		||||
    x1 = np.clip(x - self.length // 2, 0, w)
 | 
			
		||||
    x2 = np.clip(x + self.length // 2, 0, w)
 | 
			
		||||
 | 
			
		||||
    mask[y1: y2, x1: x2] = 0.
 | 
			
		||||
    mask = torch.from_numpy(mask)
 | 
			
		||||
    mask = mask.expand_as(img)
 | 
			
		||||
    img *= mask
 | 
			
		||||
    return img
 | 
			
		||||
        mask[y1:y2, x1:x2] = 0.0
 | 
			
		||||
        mask = torch.from_numpy(mask)
 | 
			
		||||
        mask = mask.expand_as(img)
 | 
			
		||||
        img *= mask
 | 
			
		||||
        return img
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
imagenet_pca = {
 | 
			
		||||
    'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
 | 
			
		||||
    'eigvec': np.asarray([
 | 
			
		||||
        [-0.5675, 0.7192, 0.4009],
 | 
			
		||||
        [-0.5808, -0.0045, -0.8140],
 | 
			
		||||
        [-0.5836, -0.6948, 0.4203],
 | 
			
		||||
    ])
 | 
			
		||||
    "eigval": np.asarray([0.2175, 0.0188, 0.0045]),
 | 
			
		||||
    "eigvec": np.asarray(
 | 
			
		||||
        [
 | 
			
		||||
            [-0.5675, 0.7192, 0.4009],
 | 
			
		||||
            [-0.5808, -0.0045, -0.8140],
 | 
			
		||||
            [-0.5836, -0.6948, 0.4203],
 | 
			
		||||
        ]
 | 
			
		||||
    ),
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Lighting(object):
 | 
			
		||||
  def __init__(self, alphastd,
 | 
			
		||||
         eigval=imagenet_pca['eigval'],
 | 
			
		||||
         eigvec=imagenet_pca['eigvec']):
 | 
			
		||||
    self.alphastd = alphastd
 | 
			
		||||
    assert eigval.shape == (3,)
 | 
			
		||||
    assert eigvec.shape == (3, 3)
 | 
			
		||||
    self.eigval = eigval
 | 
			
		||||
    self.eigvec = eigvec
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self, alphastd, eigval=imagenet_pca["eigval"], eigvec=imagenet_pca["eigvec"]
 | 
			
		||||
    ):
 | 
			
		||||
        self.alphastd = alphastd
 | 
			
		||||
        assert eigval.shape == (3,)
 | 
			
		||||
        assert eigvec.shape == (3, 3)
 | 
			
		||||
        self.eigval = eigval
 | 
			
		||||
        self.eigvec = eigvec
 | 
			
		||||
 | 
			
		||||
  def __call__(self, img):
 | 
			
		||||
    if self.alphastd == 0.:
 | 
			
		||||
      return img
 | 
			
		||||
    rnd = np.random.randn(3) * self.alphastd
 | 
			
		||||
    rnd = rnd.astype('float32')
 | 
			
		||||
    v = rnd
 | 
			
		||||
    old_dtype = np.asarray(img).dtype
 | 
			
		||||
    v = v * self.eigval
 | 
			
		||||
    v = v.reshape((3, 1))
 | 
			
		||||
    inc = np.dot(self.eigvec, v).reshape((3,))
 | 
			
		||||
    img = np.add(img, inc)
 | 
			
		||||
    if old_dtype == np.uint8:
 | 
			
		||||
      img = np.clip(img, 0, 255)
 | 
			
		||||
    img = Image.fromarray(img.astype(old_dtype), 'RGB')
 | 
			
		||||
    return img
 | 
			
		||||
    def __call__(self, img):
 | 
			
		||||
        if self.alphastd == 0.0:
 | 
			
		||||
            return img
 | 
			
		||||
        rnd = np.random.randn(3) * self.alphastd
 | 
			
		||||
        rnd = rnd.astype("float32")
 | 
			
		||||
        v = rnd
 | 
			
		||||
        old_dtype = np.asarray(img).dtype
 | 
			
		||||
        v = v * self.eigval
 | 
			
		||||
        v = v.reshape((3, 1))
 | 
			
		||||
        inc = np.dot(self.eigvec, v).reshape((3,))
 | 
			
		||||
        img = np.add(img, inc)
 | 
			
		||||
        if old_dtype == np.uint8:
 | 
			
		||||
            img = np.clip(img, 0, 255)
 | 
			
		||||
        img = Image.fromarray(img.astype(old_dtype), "RGB")
 | 
			
		||||
        return img
 | 
			
		||||
 | 
			
		||||
  def __repr__(self):
 | 
			
		||||
    return self.__class__.__name__ + '()'
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return self.__class__.__name__ + "()"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_datasets(name, root, cutout):
 | 
			
		||||
 | 
			
		||||
  if name == 'cifar10':
 | 
			
		||||
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
 | 
			
		||||
    std  = [x / 255 for x in [63.0, 62.1, 66.7]]
 | 
			
		||||
  elif name == 'cifar100':
 | 
			
		||||
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
 | 
			
		||||
    std  = [x / 255 for x in [68.2, 65.4, 70.4]]
 | 
			
		||||
  elif name.startswith('imagenet-1k'):
 | 
			
		||||
    mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
 | 
			
		||||
  elif name.startswith('ImageNet16'):
 | 
			
		||||
    mean = [x / 255 for x in [122.68, 116.66, 104.01]]
 | 
			
		||||
    std  = [x / 255 for x in [63.22,  61.26 , 65.09]]
 | 
			
		||||
  else:
 | 
			
		||||
    raise TypeError("Unknow dataset : {:}".format(name))
 | 
			
		||||
    if name == "cifar10":
 | 
			
		||||
        mean = [x / 255 for x in [125.3, 123.0, 113.9]]
 | 
			
		||||
        std = [x / 255 for x in [63.0, 62.1, 66.7]]
 | 
			
		||||
    elif name == "cifar100":
 | 
			
		||||
        mean = [x / 255 for x in [129.3, 124.1, 112.4]]
 | 
			
		||||
        std = [x / 255 for x in [68.2, 65.4, 70.4]]
 | 
			
		||||
    elif name.startswith("imagenet-1k"):
 | 
			
		||||
        mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
 | 
			
		||||
    elif name.startswith("ImageNet16"):
 | 
			
		||||
        mean = [x / 255 for x in [122.68, 116.66, 104.01]]
 | 
			
		||||
        std = [x / 255 for x in [63.22, 61.26, 65.09]]
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError("Unknow dataset : {:}".format(name))
 | 
			
		||||
 | 
			
		||||
  # Data Argumentation
 | 
			
		||||
  if name == 'cifar10' or name == 'cifar100':
 | 
			
		||||
    lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
 | 
			
		||||
    if cutout > 0 : lists += [CUTOUT(cutout)]
 | 
			
		||||
    train_transform = transforms.Compose(lists)
 | 
			
		||||
    test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
 | 
			
		||||
    xshape = (1, 3, 32, 32)
 | 
			
		||||
  elif name.startswith('ImageNet16'):
 | 
			
		||||
    lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
 | 
			
		||||
    if cutout > 0 : lists += [CUTOUT(cutout)]
 | 
			
		||||
    train_transform = transforms.Compose(lists)
 | 
			
		||||
    test_transform  = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
 | 
			
		||||
    xshape = (1, 3, 16, 16)
 | 
			
		||||
  elif name == 'tiered':
 | 
			
		||||
    lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
 | 
			
		||||
    if cutout > 0 : lists += [CUTOUT(cutout)]
 | 
			
		||||
    train_transform = transforms.Compose(lists)
 | 
			
		||||
    test_transform  = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
 | 
			
		||||
    xshape = (1, 3, 32, 32)
 | 
			
		||||
  elif name.startswith('imagenet-1k'):
 | 
			
		||||
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 | 
			
		||||
    if name == 'imagenet-1k':
 | 
			
		||||
      xlists    = [transforms.RandomResizedCrop(224)]
 | 
			
		||||
      xlists.append(
 | 
			
		||||
        transforms.ColorJitter(
 | 
			
		||||
        brightness=0.4,
 | 
			
		||||
        contrast=0.4,
 | 
			
		||||
        saturation=0.4,
 | 
			
		||||
        hue=0.2))
 | 
			
		||||
      xlists.append( Lighting(0.1))
 | 
			
		||||
    elif name == 'imagenet-1k-s':
 | 
			
		||||
      xlists    = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
 | 
			
		||||
    else: raise ValueError('invalid name : {:}'.format(name))
 | 
			
		||||
    xlists.append( transforms.RandomHorizontalFlip(p=0.5) )
 | 
			
		||||
    xlists.append( transforms.ToTensor() )
 | 
			
		||||
    xlists.append( normalize )
 | 
			
		||||
    train_transform = transforms.Compose(xlists)
 | 
			
		||||
    test_transform  = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
 | 
			
		||||
    xshape = (1, 3, 224, 224)
 | 
			
		||||
  else:
 | 
			
		||||
    raise TypeError("Unknow dataset : {:}".format(name))
 | 
			
		||||
    # Data Argumentation
 | 
			
		||||
    if name == "cifar10" or name == "cifar100":
 | 
			
		||||
        lists = [
 | 
			
		||||
            transforms.RandomHorizontalFlip(),
 | 
			
		||||
            transforms.RandomCrop(32, padding=4),
 | 
			
		||||
            transforms.ToTensor(),
 | 
			
		||||
            transforms.Normalize(mean, std),
 | 
			
		||||
        ]
 | 
			
		||||
        if cutout > 0:
 | 
			
		||||
            lists += [CUTOUT(cutout)]
 | 
			
		||||
        train_transform = transforms.Compose(lists)
 | 
			
		||||
        test_transform = transforms.Compose(
 | 
			
		||||
            [transforms.ToTensor(), transforms.Normalize(mean, std)]
 | 
			
		||||
        )
 | 
			
		||||
        xshape = (1, 3, 32, 32)
 | 
			
		||||
    elif name.startswith("ImageNet16"):
 | 
			
		||||
        lists = [
 | 
			
		||||
            transforms.RandomHorizontalFlip(),
 | 
			
		||||
            transforms.RandomCrop(16, padding=2),
 | 
			
		||||
            transforms.ToTensor(),
 | 
			
		||||
            transforms.Normalize(mean, std),
 | 
			
		||||
        ]
 | 
			
		||||
        if cutout > 0:
 | 
			
		||||
            lists += [CUTOUT(cutout)]
 | 
			
		||||
        train_transform = transforms.Compose(lists)
 | 
			
		||||
        test_transform = transforms.Compose(
 | 
			
		||||
            [transforms.ToTensor(), transforms.Normalize(mean, std)]
 | 
			
		||||
        )
 | 
			
		||||
        xshape = (1, 3, 16, 16)
 | 
			
		||||
    elif name == "tiered":
 | 
			
		||||
        lists = [
 | 
			
		||||
            transforms.RandomHorizontalFlip(),
 | 
			
		||||
            transforms.RandomCrop(80, padding=4),
 | 
			
		||||
            transforms.ToTensor(),
 | 
			
		||||
            transforms.Normalize(mean, std),
 | 
			
		||||
        ]
 | 
			
		||||
        if cutout > 0:
 | 
			
		||||
            lists += [CUTOUT(cutout)]
 | 
			
		||||
        train_transform = transforms.Compose(lists)
 | 
			
		||||
        test_transform = transforms.Compose(
 | 
			
		||||
            [
 | 
			
		||||
                transforms.CenterCrop(80),
 | 
			
		||||
                transforms.ToTensor(),
 | 
			
		||||
                transforms.Normalize(mean, std),
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        xshape = (1, 3, 32, 32)
 | 
			
		||||
    elif name.startswith("imagenet-1k"):
 | 
			
		||||
        normalize = transforms.Normalize(
 | 
			
		||||
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
 | 
			
		||||
        )
 | 
			
		||||
        if name == "imagenet-1k":
 | 
			
		||||
            xlists = [transforms.RandomResizedCrop(224)]
 | 
			
		||||
            xlists.append(
 | 
			
		||||
                transforms.ColorJitter(
 | 
			
		||||
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
            xlists.append(Lighting(0.1))
 | 
			
		||||
        elif name == "imagenet-1k-s":
 | 
			
		||||
            xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("invalid name : {:}".format(name))
 | 
			
		||||
        xlists.append(transforms.RandomHorizontalFlip(p=0.5))
 | 
			
		||||
        xlists.append(transforms.ToTensor())
 | 
			
		||||
        xlists.append(normalize)
 | 
			
		||||
        train_transform = transforms.Compose(xlists)
 | 
			
		||||
        test_transform = transforms.Compose(
 | 
			
		||||
            [
 | 
			
		||||
                transforms.Resize(256),
 | 
			
		||||
                transforms.CenterCrop(224),
 | 
			
		||||
                transforms.ToTensor(),
 | 
			
		||||
                normalize,
 | 
			
		||||
            ]
 | 
			
		||||
        )
 | 
			
		||||
        xshape = (1, 3, 224, 224)
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError("Unknow dataset : {:}".format(name))
 | 
			
		||||
 | 
			
		||||
  if name == 'cifar10':
 | 
			
		||||
    train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
 | 
			
		||||
    test_data  = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
 | 
			
		||||
    assert len(train_data) == 50000 and len(test_data) == 10000
 | 
			
		||||
  elif name == 'cifar100':
 | 
			
		||||
    train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
 | 
			
		||||
    test_data  = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
 | 
			
		||||
    assert len(train_data) == 50000 and len(test_data) == 10000
 | 
			
		||||
  elif name.startswith('imagenet-1k'):
 | 
			
		||||
    train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
 | 
			
		||||
    test_data  = dset.ImageFolder(osp.join(root, 'val'),   test_transform)
 | 
			
		||||
    assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000)
 | 
			
		||||
  elif name == 'ImageNet16':
 | 
			
		||||
    train_data = ImageNet16(root, True , train_transform)
 | 
			
		||||
    test_data  = ImageNet16(root, False, test_transform)
 | 
			
		||||
    assert len(train_data) == 1281167 and len(test_data) == 50000
 | 
			
		||||
  elif name == 'ImageNet16-120':
 | 
			
		||||
    train_data = ImageNet16(root, True , train_transform, 120)
 | 
			
		||||
    test_data  = ImageNet16(root, False, test_transform , 120)
 | 
			
		||||
    assert len(train_data) == 151700 and len(test_data) == 6000
 | 
			
		||||
  elif name == 'ImageNet16-150':
 | 
			
		||||
    train_data = ImageNet16(root, True , train_transform, 150)
 | 
			
		||||
    test_data  = ImageNet16(root, False, test_transform , 150)
 | 
			
		||||
    assert len(train_data) == 190272 and len(test_data) == 7500
 | 
			
		||||
  elif name == 'ImageNet16-200':
 | 
			
		||||
    train_data = ImageNet16(root, True , train_transform, 200)
 | 
			
		||||
    test_data  = ImageNet16(root, False, test_transform , 200)
 | 
			
		||||
    assert len(train_data) == 254775 and len(test_data) == 10000
 | 
			
		||||
  else: raise TypeError("Unknow dataset : {:}".format(name))
 | 
			
		||||
  
 | 
			
		||||
  class_num = Dataset2Class[name]
 | 
			
		||||
  return train_data, test_data, xshape, class_num
 | 
			
		||||
    if name == "cifar10":
 | 
			
		||||
        train_data = dset.CIFAR10(
 | 
			
		||||
            root, train=True, transform=train_transform, download=True
 | 
			
		||||
        )
 | 
			
		||||
        test_data = dset.CIFAR10(
 | 
			
		||||
            root, train=False, transform=test_transform, download=True
 | 
			
		||||
        )
 | 
			
		||||
        assert len(train_data) == 50000 and len(test_data) == 10000
 | 
			
		||||
    elif name == "cifar100":
 | 
			
		||||
        train_data = dset.CIFAR100(
 | 
			
		||||
            root, train=True, transform=train_transform, download=True
 | 
			
		||||
        )
 | 
			
		||||
        test_data = dset.CIFAR100(
 | 
			
		||||
            root, train=False, transform=test_transform, download=True
 | 
			
		||||
        )
 | 
			
		||||
        assert len(train_data) == 50000 and len(test_data) == 10000
 | 
			
		||||
    elif name.startswith("imagenet-1k"):
 | 
			
		||||
        train_data = dset.ImageFolder(osp.join(root, "train"), train_transform)
 | 
			
		||||
        test_data = dset.ImageFolder(osp.join(root, "val"), test_transform)
 | 
			
		||||
        assert (
 | 
			
		||||
            len(train_data) == 1281167 and len(test_data) == 50000
 | 
			
		||||
        ), "invalid number of images : {:} & {:} vs {:} & {:}".format(
 | 
			
		||||
            len(train_data), len(test_data), 1281167, 50000
 | 
			
		||||
        )
 | 
			
		||||
    elif name == "ImageNet16":
 | 
			
		||||
        train_data = ImageNet16(root, True, train_transform)
 | 
			
		||||
        test_data = ImageNet16(root, False, test_transform)
 | 
			
		||||
        assert len(train_data) == 1281167 and len(test_data) == 50000
 | 
			
		||||
    elif name == "ImageNet16-120":
 | 
			
		||||
        train_data = ImageNet16(root, True, train_transform, 120)
 | 
			
		||||
        test_data = ImageNet16(root, False, test_transform, 120)
 | 
			
		||||
        assert len(train_data) == 151700 and len(test_data) == 6000
 | 
			
		||||
    elif name == "ImageNet16-150":
 | 
			
		||||
        train_data = ImageNet16(root, True, train_transform, 150)
 | 
			
		||||
        test_data = ImageNet16(root, False, test_transform, 150)
 | 
			
		||||
        assert len(train_data) == 190272 and len(test_data) == 7500
 | 
			
		||||
    elif name == "ImageNet16-200":
 | 
			
		||||
        train_data = ImageNet16(root, True, train_transform, 200)
 | 
			
		||||
        test_data = ImageNet16(root, False, test_transform, 200)
 | 
			
		||||
        assert len(train_data) == 254775 and len(test_data) == 10000
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError("Unknow dataset : {:}".format(name))
 | 
			
		||||
 | 
			
		||||
    class_num = Dataset2Class[name]
 | 
			
		||||
    return train_data, test_data, xshape, class_num
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers):
 | 
			
		||||
  if isinstance(batch_size, (list,tuple)):
 | 
			
		||||
    batch, test_batch = batch_size
 | 
			
		||||
  else:
 | 
			
		||||
    batch, test_batch = batch_size, batch_size
 | 
			
		||||
  if dataset == 'cifar10':
 | 
			
		||||
    #split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
 | 
			
		||||
    cifar_split = load_config('{:}/cifar-split.txt'.format(config_root), None, None)
 | 
			
		||||
    train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
 | 
			
		||||
    #logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set
 | 
			
		||||
    # To split data
 | 
			
		||||
    xvalid_data  = deepcopy(train_data)
 | 
			
		||||
    if hasattr(xvalid_data, 'transforms'): # to avoid a print issue
 | 
			
		||||
      xvalid_data.transforms = valid_data.transform
 | 
			
		||||
    xvalid_data.transform  = deepcopy( valid_data.transform )
 | 
			
		||||
    search_data   = SearchDataset(dataset, train_data, train_split, valid_split)
 | 
			
		||||
    # data loader
 | 
			
		||||
    search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
 | 
			
		||||
    train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True)
 | 
			
		||||
    valid_loader  = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True)
 | 
			
		||||
  elif dataset == 'cifar100':
 | 
			
		||||
    cifar100_test_split = load_config('{:}/cifar100-test-split.txt'.format(config_root), None, None)
 | 
			
		||||
    search_train_data = train_data
 | 
			
		||||
    search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
 | 
			
		||||
    search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), cifar100_test_split.xvalid)
 | 
			
		||||
    search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
 | 
			
		||||
    train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
 | 
			
		||||
    valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(cifar100_test_split.xvalid), num_workers=workers, pin_memory=True)
 | 
			
		||||
  elif dataset == 'ImageNet16-120':
 | 
			
		||||
    imagenet_test_split = load_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None)
 | 
			
		||||
    search_train_data = train_data
 | 
			
		||||
    search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
 | 
			
		||||
    search_data   = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
 | 
			
		||||
    search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
 | 
			
		||||
    train_loader  = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
 | 
			
		||||
    valid_loader  = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True)
 | 
			
		||||
  else:
 | 
			
		||||
    raise ValueError('invalid dataset : {:}'.format(dataset))
 | 
			
		||||
  return search_loader, train_loader, valid_loader
 | 
			
		||||
def get_nas_search_loaders(
 | 
			
		||||
    train_data, valid_data, dataset, config_root, batch_size, workers
 | 
			
		||||
):
 | 
			
		||||
    if isinstance(batch_size, (list, tuple)):
 | 
			
		||||
        batch, test_batch = batch_size
 | 
			
		||||
    else:
 | 
			
		||||
        batch, test_batch = batch_size, batch_size
 | 
			
		||||
    if dataset == "cifar10":
 | 
			
		||||
        # split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
 | 
			
		||||
        cifar_split = load_config("{:}/cifar-split.txt".format(config_root), None, None)
 | 
			
		||||
        train_split, valid_split = (
 | 
			
		||||
            cifar_split.train,
 | 
			
		||||
            cifar_split.valid,
 | 
			
		||||
        )  # search over the proposed training and validation set
 | 
			
		||||
        # logger.log('Load split file from {:}'.format(split_Fpath))      # they are two disjoint groups in the original CIFAR-10 training set
 | 
			
		||||
        # To split data
 | 
			
		||||
        xvalid_data = deepcopy(train_data)
 | 
			
		||||
        if hasattr(xvalid_data, "transforms"):  # to avoid a print issue
 | 
			
		||||
            xvalid_data.transforms = valid_data.transform
 | 
			
		||||
        xvalid_data.transform = deepcopy(valid_data.transform)
 | 
			
		||||
        search_data = SearchDataset(dataset, train_data, train_split, valid_split)
 | 
			
		||||
        # data loader
 | 
			
		||||
        search_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            search_data,
 | 
			
		||||
            batch_size=batch,
 | 
			
		||||
            shuffle=True,
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
        train_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            train_data,
 | 
			
		||||
            batch_size=batch,
 | 
			
		||||
            sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split),
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
        valid_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            xvalid_data,
 | 
			
		||||
            batch_size=test_batch,
 | 
			
		||||
            sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split),
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
    elif dataset == "cifar100":
 | 
			
		||||
        cifar100_test_split = load_config(
 | 
			
		||||
            "{:}/cifar100-test-split.txt".format(config_root), None, None
 | 
			
		||||
        )
 | 
			
		||||
        search_train_data = train_data
 | 
			
		||||
        search_valid_data = deepcopy(valid_data)
 | 
			
		||||
        search_valid_data.transform = train_data.transform
 | 
			
		||||
        search_data = SearchDataset(
 | 
			
		||||
            dataset,
 | 
			
		||||
            [search_train_data, search_valid_data],
 | 
			
		||||
            list(range(len(search_train_data))),
 | 
			
		||||
            cifar100_test_split.xvalid,
 | 
			
		||||
        )
 | 
			
		||||
        search_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            search_data,
 | 
			
		||||
            batch_size=batch,
 | 
			
		||||
            shuffle=True,
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
        train_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            train_data,
 | 
			
		||||
            batch_size=batch,
 | 
			
		||||
            shuffle=True,
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
        valid_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            valid_data,
 | 
			
		||||
            batch_size=test_batch,
 | 
			
		||||
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
 | 
			
		||||
                cifar100_test_split.xvalid
 | 
			
		||||
            ),
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
    elif dataset == "ImageNet16-120":
 | 
			
		||||
        imagenet_test_split = load_config(
 | 
			
		||||
            "{:}/imagenet-16-120-test-split.txt".format(config_root), None, None
 | 
			
		||||
        )
 | 
			
		||||
        search_train_data = train_data
 | 
			
		||||
        search_valid_data = deepcopy(valid_data)
 | 
			
		||||
        search_valid_data.transform = train_data.transform
 | 
			
		||||
        search_data = SearchDataset(
 | 
			
		||||
            dataset,
 | 
			
		||||
            [search_train_data, search_valid_data],
 | 
			
		||||
            list(range(len(search_train_data))),
 | 
			
		||||
            imagenet_test_split.xvalid,
 | 
			
		||||
        )
 | 
			
		||||
        search_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            search_data,
 | 
			
		||||
            batch_size=batch,
 | 
			
		||||
            shuffle=True,
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
        train_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            train_data,
 | 
			
		||||
            batch_size=batch,
 | 
			
		||||
            shuffle=True,
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
        valid_loader = torch.utils.data.DataLoader(
 | 
			
		||||
            valid_data,
 | 
			
		||||
            batch_size=test_batch,
 | 
			
		||||
            sampler=torch.utils.data.sampler.SubsetRandomSampler(
 | 
			
		||||
                imagenet_test_split.xvalid
 | 
			
		||||
            ),
 | 
			
		||||
            num_workers=workers,
 | 
			
		||||
            pin_memory=True,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("invalid dataset : {:}".format(dataset))
 | 
			
		||||
    return search_loader, train_loader, valid_loader
 | 
			
		||||
 | 
			
		||||
#if __name__ == '__main__':
 | 
			
		||||
 | 
			
		||||
# if __name__ == '__main__':
 | 
			
		||||
#  train_data, test_data, xshape, class_num = dataset = get_datasets('cifar10', '/data02/dongxuanyi/.torch/cifar.python/', -1)
 | 
			
		||||
#  import pdb; pdb.set_trace()
 | 
			
		||||
 
 | 
			
		||||
@@ -9,108 +9,211 @@ from xvision import normalize_points
 | 
			
		||||
from xvision import denormalize_points
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PointMeta():
 | 
			
		||||
  # points    : 3 x num_pts (x, y, oculusion)
 | 
			
		||||
  # image_size: original [width, height]
 | 
			
		||||
  def __init__(self, num_point, points, box, image_path, dataset_name):
 | 
			
		||||
class PointMeta:
 | 
			
		||||
    # points    : 3 x num_pts (x, y, oculusion)
 | 
			
		||||
    # image_size: original [width, height]
 | 
			
		||||
    def __init__(self, num_point, points, box, image_path, dataset_name):
 | 
			
		||||
 | 
			
		||||
    self.num_point = num_point
 | 
			
		||||
    if box is not None:
 | 
			
		||||
      assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
 | 
			
		||||
      self.box = torch.Tensor(box)
 | 
			
		||||
    else: self.box = None
 | 
			
		||||
    if points is None:
 | 
			
		||||
      self.points = points
 | 
			
		||||
    else:
 | 
			
		||||
      assert len(points.shape) == 2 and points.shape[0] == 3 and points.shape[1] == self.num_point, 'The shape of point is not right : {}'.format( points )
 | 
			
		||||
      self.points = torch.Tensor(points.copy())
 | 
			
		||||
    self.image_path = image_path
 | 
			
		||||
    self.datasets = dataset_name
 | 
			
		||||
        self.num_point = num_point
 | 
			
		||||
        if box is not None:
 | 
			
		||||
            assert (isinstance(box, tuple) or isinstance(box, list)) and len(box) == 4
 | 
			
		||||
            self.box = torch.Tensor(box)
 | 
			
		||||
        else:
 | 
			
		||||
            self.box = None
 | 
			
		||||
        if points is None:
 | 
			
		||||
            self.points = points
 | 
			
		||||
        else:
 | 
			
		||||
            assert (
 | 
			
		||||
                len(points.shape) == 2
 | 
			
		||||
                and points.shape[0] == 3
 | 
			
		||||
                and points.shape[1] == self.num_point
 | 
			
		||||
            ), "The shape of point is not right : {}".format(points)
 | 
			
		||||
            self.points = torch.Tensor(points.copy())
 | 
			
		||||
        self.image_path = image_path
 | 
			
		||||
        self.datasets = dataset_name
 | 
			
		||||
 | 
			
		||||
  def __repr__(self):
 | 
			
		||||
    if self.box is None: boxstr = 'None'
 | 
			
		||||
    else               : boxstr = 'box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]'.format(*self.box.tolist())
 | 
			
		||||
    return ('{name}(points={num_point}, '.format(name=self.__class__.__name__, **self.__dict__) + boxstr + ')')
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        if self.box is None:
 | 
			
		||||
            boxstr = "None"
 | 
			
		||||
        else:
 | 
			
		||||
            boxstr = "box=[{:.1f}, {:.1f}, {:.1f}, {:.1f}]".format(*self.box.tolist())
 | 
			
		||||
        return (
 | 
			
		||||
            "{name}(points={num_point}, ".format(
 | 
			
		||||
                name=self.__class__.__name__, **self.__dict__
 | 
			
		||||
            )
 | 
			
		||||
            + boxstr
 | 
			
		||||
            + ")"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
  def get_box(self, return_diagonal=False):
 | 
			
		||||
    if self.box is None: return None
 | 
			
		||||
    if not return_diagonal:
 | 
			
		||||
      return self.box.clone()
 | 
			
		||||
    else:
 | 
			
		||||
      W = (self.box[2]-self.box[0]).item()
 | 
			
		||||
      H = (self.box[3]-self.box[1]).item()
 | 
			
		||||
      return math.sqrt(H*H+W*W)
 | 
			
		||||
    def get_box(self, return_diagonal=False):
 | 
			
		||||
        if self.box is None:
 | 
			
		||||
            return None
 | 
			
		||||
        if not return_diagonal:
 | 
			
		||||
            return self.box.clone()
 | 
			
		||||
        else:
 | 
			
		||||
            W = (self.box[2] - self.box[0]).item()
 | 
			
		||||
            H = (self.box[3] - self.box[1]).item()
 | 
			
		||||
            return math.sqrt(H * H + W * W)
 | 
			
		||||
 | 
			
		||||
  def get_points(self, ignore_indicator=False):
 | 
			
		||||
    if ignore_indicator: last = 2
 | 
			
		||||
    else               : last = 3
 | 
			
		||||
    if self.points is not None: return self.points.clone()[:last, :]
 | 
			
		||||
    else                      : return torch.zeros((last, self.num_point))
 | 
			
		||||
    def get_points(self, ignore_indicator=False):
 | 
			
		||||
        if ignore_indicator:
 | 
			
		||||
            last = 2
 | 
			
		||||
        else:
 | 
			
		||||
            last = 3
 | 
			
		||||
        if self.points is not None:
 | 
			
		||||
            return self.points.clone()[:last, :]
 | 
			
		||||
        else:
 | 
			
		||||
            return torch.zeros((last, self.num_point))
 | 
			
		||||
 | 
			
		||||
  def is_none(self):
 | 
			
		||||
    #assert self.box is not None, 'The box should not be None'
 | 
			
		||||
    return self.points is None
 | 
			
		||||
    #if self.box is None: return True
 | 
			
		||||
    #else               : return self.points is None
 | 
			
		||||
    def is_none(self):
 | 
			
		||||
        # assert self.box is not None, 'The box should not be None'
 | 
			
		||||
        return self.points is None
 | 
			
		||||
        # if self.box is None: return True
 | 
			
		||||
        # else               : return self.points is None
 | 
			
		||||
 | 
			
		||||
  def copy(self):
 | 
			
		||||
    return copy.deepcopy(self)
 | 
			
		||||
    def copy(self):
 | 
			
		||||
        return copy.deepcopy(self)
 | 
			
		||||
 | 
			
		||||
  def visiable_pts_num(self):
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
      ans = self.points[2,:] > 0
 | 
			
		||||
      ans = torch.sum(ans)
 | 
			
		||||
      ans = ans.item()
 | 
			
		||||
    return ans
 | 
			
		||||
  
 | 
			
		||||
  def special_fun(self, indicator):
 | 
			
		||||
    if indicator == '68to49': # For 300W or 300VW, convert the default 68 points to 49 points.
 | 
			
		||||
      assert self.num_point == 68, 'num-point must be 68 vs. {:}'.format(self.num_point)
 | 
			
		||||
      self.num_point = 49
 | 
			
		||||
      out = torch.ones((68), dtype=torch.uint8)
 | 
			
		||||
      out[[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,60,64]] = 0
 | 
			
		||||
      if self.points is not None: self.points = self.points.clone()[:, out]
 | 
			
		||||
    else:
 | 
			
		||||
      raise ValueError('Invalid indicator : {:}'.format( indicator ))
 | 
			
		||||
    def visiable_pts_num(self):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            ans = self.points[2, :] > 0
 | 
			
		||||
            ans = torch.sum(ans)
 | 
			
		||||
            ans = ans.item()
 | 
			
		||||
        return ans
 | 
			
		||||
 | 
			
		||||
  def apply_horizontal_flip(self):
 | 
			
		||||
    #self.points[0, :] = width - self.points[0, :] - 1
 | 
			
		||||
    # Mugsy spefic or Synthetic
 | 
			
		||||
    if self.datasets.startswith('HandsyROT'):
 | 
			
		||||
      ori = np.array(list(range(0, 42)))
 | 
			
		||||
      pos = np.array(list(range(21,42)) + list(range(0,21)))
 | 
			
		||||
      self.points[:, pos] = self.points[:, ori]
 | 
			
		||||
    elif self.datasets.startswith('face68'):
 | 
			
		||||
      ori = np.array(list(range(0, 68)))
 | 
			
		||||
      pos = np.array([17,16,15,14,13,12,11,10, 9, 8,7,6,5,4,3,2,1, 27,26,25,24,23,22,21,20,19,18, 28,29,30,31, 36,35,34,33,32, 46,45,44,43,48,47, 40,39,38,37,42,41, 55,54,53,52,51,50,49,60,59,58,57,56,65,64,63,62,61,68,67,66])-1
 | 
			
		||||
      self.points[:, ori] = self.points[:, pos]
 | 
			
		||||
    else:
 | 
			
		||||
      raise ValueError('Does not support {:}'.format(self.datasets))
 | 
			
		||||
    def special_fun(self, indicator):
 | 
			
		||||
        if (
 | 
			
		||||
            indicator == "68to49"
 | 
			
		||||
        ):  # For 300W or 300VW, convert the default 68 points to 49 points.
 | 
			
		||||
            assert self.num_point == 68, "num-point must be 68 vs. {:}".format(
 | 
			
		||||
                self.num_point
 | 
			
		||||
            )
 | 
			
		||||
            self.num_point = 49
 | 
			
		||||
            out = torch.ones((68), dtype=torch.uint8)
 | 
			
		||||
            out[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 60, 64]] = 0
 | 
			
		||||
            if self.points is not None:
 | 
			
		||||
                self.points = self.points.clone()[:, out]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Invalid indicator : {:}".format(indicator))
 | 
			
		||||
 | 
			
		||||
    def apply_horizontal_flip(self):
 | 
			
		||||
        # self.points[0, :] = width - self.points[0, :] - 1
 | 
			
		||||
        # Mugsy spefic or Synthetic
 | 
			
		||||
        if self.datasets.startswith("HandsyROT"):
 | 
			
		||||
            ori = np.array(list(range(0, 42)))
 | 
			
		||||
            pos = np.array(list(range(21, 42)) + list(range(0, 21)))
 | 
			
		||||
            self.points[:, pos] = self.points[:, ori]
 | 
			
		||||
        elif self.datasets.startswith("face68"):
 | 
			
		||||
            ori = np.array(list(range(0, 68)))
 | 
			
		||||
            pos = (
 | 
			
		||||
                np.array(
 | 
			
		||||
                    [
 | 
			
		||||
                        17,
 | 
			
		||||
                        16,
 | 
			
		||||
                        15,
 | 
			
		||||
                        14,
 | 
			
		||||
                        13,
 | 
			
		||||
                        12,
 | 
			
		||||
                        11,
 | 
			
		||||
                        10,
 | 
			
		||||
                        9,
 | 
			
		||||
                        8,
 | 
			
		||||
                        7,
 | 
			
		||||
                        6,
 | 
			
		||||
                        5,
 | 
			
		||||
                        4,
 | 
			
		||||
                        3,
 | 
			
		||||
                        2,
 | 
			
		||||
                        1,
 | 
			
		||||
                        27,
 | 
			
		||||
                        26,
 | 
			
		||||
                        25,
 | 
			
		||||
                        24,
 | 
			
		||||
                        23,
 | 
			
		||||
                        22,
 | 
			
		||||
                        21,
 | 
			
		||||
                        20,
 | 
			
		||||
                        19,
 | 
			
		||||
                        18,
 | 
			
		||||
                        28,
 | 
			
		||||
                        29,
 | 
			
		||||
                        30,
 | 
			
		||||
                        31,
 | 
			
		||||
                        36,
 | 
			
		||||
                        35,
 | 
			
		||||
                        34,
 | 
			
		||||
                        33,
 | 
			
		||||
                        32,
 | 
			
		||||
                        46,
 | 
			
		||||
                        45,
 | 
			
		||||
                        44,
 | 
			
		||||
                        43,
 | 
			
		||||
                        48,
 | 
			
		||||
                        47,
 | 
			
		||||
                        40,
 | 
			
		||||
                        39,
 | 
			
		||||
                        38,
 | 
			
		||||
                        37,
 | 
			
		||||
                        42,
 | 
			
		||||
                        41,
 | 
			
		||||
                        55,
 | 
			
		||||
                        54,
 | 
			
		||||
                        53,
 | 
			
		||||
                        52,
 | 
			
		||||
                        51,
 | 
			
		||||
                        50,
 | 
			
		||||
                        49,
 | 
			
		||||
                        60,
 | 
			
		||||
                        59,
 | 
			
		||||
                        58,
 | 
			
		||||
                        57,
 | 
			
		||||
                        56,
 | 
			
		||||
                        65,
 | 
			
		||||
                        64,
 | 
			
		||||
                        63,
 | 
			
		||||
                        62,
 | 
			
		||||
                        61,
 | 
			
		||||
                        68,
 | 
			
		||||
                        67,
 | 
			
		||||
                        66,
 | 
			
		||||
                    ]
 | 
			
		||||
                )
 | 
			
		||||
                - 1
 | 
			
		||||
            )
 | 
			
		||||
            self.points[:, ori] = self.points[:, pos]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Does not support {:}".format(self.datasets))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# shape = (H,W)
 | 
			
		||||
def apply_affine2point(points, theta, shape):
 | 
			
		||||
  assert points.size(0) == 3, 'invalid points shape : {:}'.format(points.size())
 | 
			
		||||
  with torch.no_grad():
 | 
			
		||||
    ok_points = points[2,:] == 1
 | 
			
		||||
    assert torch.sum(ok_points).item() > 0, 'there is no visiable point'
 | 
			
		||||
    points[:2,:] = normalize_points(shape, points[:2,:])
 | 
			
		||||
    assert points.size(0) == 3, "invalid points shape : {:}".format(points.size())
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        ok_points = points[2, :] == 1
 | 
			
		||||
        assert torch.sum(ok_points).item() > 0, "there is no visiable point"
 | 
			
		||||
        points[:2, :] = normalize_points(shape, points[:2, :])
 | 
			
		||||
 | 
			
		||||
    norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
 | 
			
		||||
        norm_trans_points = ok_points.unsqueeze(0).repeat(3, 1).float()
 | 
			
		||||
 | 
			
		||||
    trans_points, ___ = torch.gesv(points[:, ok_points], theta)
 | 
			
		||||
        trans_points, ___ = torch.gesv(points[:, ok_points], theta)
 | 
			
		||||
 | 
			
		||||
    norm_trans_points[:, ok_points] = trans_points
 | 
			
		||||
    
 | 
			
		||||
  return norm_trans_points
 | 
			
		||||
        norm_trans_points[:, ok_points] = trans_points
 | 
			
		||||
 | 
			
		||||
    return norm_trans_points
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_boundary(norm_trans_points):
 | 
			
		||||
  with torch.no_grad():
 | 
			
		||||
    norm_trans_points = norm_trans_points.clone()
 | 
			
		||||
    oks = torch.stack((norm_trans_points[0]>-1, norm_trans_points[0]<1, norm_trans_points[1]>-1, norm_trans_points[1]<1, norm_trans_points[2]>0))
 | 
			
		||||
    oks = torch.sum(oks, dim=0) == 5
 | 
			
		||||
    norm_trans_points[2, :] = oks
 | 
			
		||||
  return norm_trans_points
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        norm_trans_points = norm_trans_points.clone()
 | 
			
		||||
        oks = torch.stack(
 | 
			
		||||
            (
 | 
			
		||||
                norm_trans_points[0] > -1,
 | 
			
		||||
                norm_trans_points[0] < 1,
 | 
			
		||||
                norm_trans_points[1] > -1,
 | 
			
		||||
                norm_trans_points[1] < 1,
 | 
			
		||||
                norm_trans_points[2] > 0,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        oks = torch.sum(oks, dim=0) == 5
 | 
			
		||||
        norm_trans_points[2, :] = oks
 | 
			
		||||
    return norm_trans_points
 | 
			
		||||
 
 | 
			
		||||
@@ -1,39 +1,123 @@
 | 
			
		||||
#####################################################
 | 
			
		||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
 | 
			
		||||
#####################################################
 | 
			
		||||
import math
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Optional
 | 
			
		||||
import torch
 | 
			
		||||
import torch.utils.data as data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class QuadraticFunction:
 | 
			
		||||
    """The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, list_of_points=None):
 | 
			
		||||
        self._params = dict(a=None, b=None, c=None)
 | 
			
		||||
        if list_of_points is not None:
 | 
			
		||||
            self.fit(list_of_points)
 | 
			
		||||
 | 
			
		||||
    def set(self, a, b, c):
 | 
			
		||||
        self._params["a"] = a
 | 
			
		||||
        self._params["b"] = b
 | 
			
		||||
        self._params["c"] = c
 | 
			
		||||
 | 
			
		||||
    def check_valid(self):
 | 
			
		||||
        for key, value in self._params.items():
 | 
			
		||||
            if value is None:
 | 
			
		||||
                raise ValueError("The {:} is None".format(key))
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, x):
 | 
			
		||||
        self.check_valid()
 | 
			
		||||
        return self._params["a"] * x * x + self._params["b"] * x + self._params["c"]
 | 
			
		||||
 | 
			
		||||
    def fit(
 | 
			
		||||
        self,
 | 
			
		||||
        list_of_points,
 | 
			
		||||
        transf=lambda x: x,
 | 
			
		||||
        max_iter=900,
 | 
			
		||||
        lr_max=1.0,
 | 
			
		||||
        verbose=False,
 | 
			
		||||
    ):
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            data = torch.Tensor(list_of_points).type(torch.float32)
 | 
			
		||||
            assert data.ndim == 2 and data.size(1) == 2, "Invalid shape : {:}".format(
 | 
			
		||||
                data.shape
 | 
			
		||||
            )
 | 
			
		||||
            x, y = data[:, 0], data[:, 1]
 | 
			
		||||
        weights = torch.nn.Parameter(torch.Tensor(3))
 | 
			
		||||
        torch.nn.init.normal_(weights, mean=0.0, std=1.0)
 | 
			
		||||
        optimizer = torch.optim.Adam([weights], lr=lr_max, amsgrad=True)
 | 
			
		||||
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(max_iter*0.25), int(max_iter*0.5), int(max_iter*0.75)], gamma=0.1)
 | 
			
		||||
        if verbose:
 | 
			
		||||
            print("The optimizer: {:}".format(optimizer))
 | 
			
		||||
 | 
			
		||||
        best_loss = None
 | 
			
		||||
        for _iter in range(max_iter):
 | 
			
		||||
            y_hat = transf(weights[0] * x * x + weights[1] * x + weights[2])
 | 
			
		||||
            loss = torch.mean(torch.abs(y - y_hat))
 | 
			
		||||
            optimizer.zero_grad()
 | 
			
		||||
            loss.backward()
 | 
			
		||||
            optimizer.step()
 | 
			
		||||
            lr_scheduler.step()
 | 
			
		||||
            if verbose:
 | 
			
		||||
                print(
 | 
			
		||||
                    "In QuadraticFunction's fit, loss at the {:02d}/{:02d}-th iter is {:}".format(
 | 
			
		||||
                        _iter, max_iter, loss.item()
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            # Update the params
 | 
			
		||||
            if best_loss is None or best_loss > loss.item():
 | 
			
		||||
                best_loss = loss.item()
 | 
			
		||||
                self._params["a"] = weights[0].item()
 | 
			
		||||
                self._params["b"] = weights[1].item()
 | 
			
		||||
                self._params["c"] = weights[2].item()
 | 
			
		||||
 | 
			
		||||
    def __repr__(self):
 | 
			
		||||
        return "{name}(y = {a} * x^2 + {b} * x + {c})".format(
 | 
			
		||||
            name=self.__class__.__name__,
 | 
			
		||||
            a=self._params["a"],
 | 
			
		||||
            b=self._params["b"],
 | 
			
		||||
            c=self._params["c"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SynAdaptiveEnv(data.Dataset):
 | 
			
		||||
    """The synethtic dataset for adaptive environment."""
 | 
			
		||||
    """The synethtic dataset for adaptive environment.
 | 
			
		||||
 | 
			
		||||
    - x in [0, 1]
 | 
			
		||||
    - y = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
 | 
			
		||||
    - where
 | 
			
		||||
    - the amplitude scale is a quadratic function of x
 | 
			
		||||
    - the period-phase-shift is another quadratic function of x
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        max_num_phase: int = 100,
 | 
			
		||||
        interval: float = 0.1,
 | 
			
		||||
        max_scale: float = 4,
 | 
			
		||||
        offset_scale: float = 1.5,
 | 
			
		||||
        num: int = 100,
 | 
			
		||||
        num_sin_phase: int = 4,
 | 
			
		||||
        min_amplitude: float = 1,
 | 
			
		||||
        max_amplitude: float = 4,
 | 
			
		||||
        phase_shift: float = 0,
 | 
			
		||||
        mode: Optional[str] = None,
 | 
			
		||||
    ):
 | 
			
		||||
        self._amplitude_scale = QuadraticFunction(
 | 
			
		||||
            [(0, min_amplitude), (0.5, max_amplitude), (0, min_amplitude)]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self._max_num_phase = max_num_phase
 | 
			
		||||
        self._interval = interval
 | 
			
		||||
        self._num_sin_phase = num_sin_phase
 | 
			
		||||
        self._interval = 1.0 / (float(num) - 1)
 | 
			
		||||
        self._total_num = num
 | 
			
		||||
 | 
			
		||||
        self._period_phase_shift = QuadraticFunction()
 | 
			
		||||
 | 
			
		||||
        fitting_data = []
 | 
			
		||||
        temp_max_scalar = 2 ** num_sin_phase
 | 
			
		||||
        for i in range(num_sin_phase):
 | 
			
		||||
            value = (2 ** i) / temp_max_scalar
 | 
			
		||||
            fitting_data.append((value, math.sin(value)))
 | 
			
		||||
        self._period_phase_shift.fit(fitting_data, transf=lambda x: torch.sin(x))
 | 
			
		||||
 | 
			
		||||
        self._times = np.arange(0, np.pi * self._max_num_phase, self._interval)
 | 
			
		||||
        xmin, xmax = self._times.min(), self._times.max()
 | 
			
		||||
        self._inputs = []
 | 
			
		||||
        self._total_num = len(self._times)
 | 
			
		||||
        for i in range(self._total_num):
 | 
			
		||||
            scale = (i + 1.0) / self._total_num * max_scale
 | 
			
		||||
            sin_scale = (i + 1.0) / self._total_num * 0.7
 | 
			
		||||
            sin_scale = -4 * (sin_scale - 0.5) ** 2 + 1
 | 
			
		||||
            # scale = -(self._times[i] - (xmin - xmax) / 2) + max_scale
 | 
			
		||||
            self._inputs.append(
 | 
			
		||||
                np.sin(self._times[i] * sin_scale) * (offset_scale - scale)
 | 
			
		||||
            )
 | 
			
		||||
        self._inputs = np.array(self._inputs)
 | 
			
		||||
        # Training Set 60%
 | 
			
		||||
        num_of_train = int(self._total_num * 0.6)
 | 
			
		||||
        # Validation Set 20%
 | 
			
		||||
@@ -70,10 +154,11 @@ class SynAdaptiveEnv(data.Dataset):
 | 
			
		||||
    def __getitem__(self, index):
 | 
			
		||||
        assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
 | 
			
		||||
        index = self._indexes[index]
 | 
			
		||||
        value = float(self._inputs[index])
 | 
			
		||||
        if self._transform is not None:
 | 
			
		||||
            value = self._transform(value)
 | 
			
		||||
        return index, float(self._times[index]), value
 | 
			
		||||
        position = self._interval * index
 | 
			
		||||
        value = self._amplitude_scale[position] * math.sin(
 | 
			
		||||
            self._period_phase_shift[position]
 | 
			
		||||
        )
 | 
			
		||||
        return index, position, value
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self._indexes)
 | 
			
		||||
 
 | 
			
		||||
@@ -5,16 +5,20 @@ import os
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_imagenet_data(imagenet):
 | 
			
		||||
  total_length = len(imagenet)
 | 
			
		||||
  assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
 | 
			
		||||
  map_id = {}
 | 
			
		||||
  for index in range(total_length):
 | 
			
		||||
    path, target = imagenet.imgs[index]
 | 
			
		||||
    folder, image_name = os.path.split(path)
 | 
			
		||||
    _, folder = os.path.split(folder)
 | 
			
		||||
    if folder not in map_id:
 | 
			
		||||
      map_id[folder] = target
 | 
			
		||||
    else:
 | 
			
		||||
      assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
 | 
			
		||||
    assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
 | 
			
		||||
  print ('Check ImageNet Dataset OK')
 | 
			
		||||
    total_length = len(imagenet)
 | 
			
		||||
    assert (
 | 
			
		||||
        total_length == 1281166 or total_length == 50000
 | 
			
		||||
    ), "The length of ImageNet is wrong : {}".format(total_length)
 | 
			
		||||
    map_id = {}
 | 
			
		||||
    for index in range(total_length):
 | 
			
		||||
        path, target = imagenet.imgs[index]
 | 
			
		||||
        folder, image_name = os.path.split(path)
 | 
			
		||||
        _, folder = os.path.split(folder)
 | 
			
		||||
        if folder not in map_id:
 | 
			
		||||
            map_id[folder] = target
 | 
			
		||||
        else:
 | 
			
		||||
            assert map_id[folder] == target, "Class : {} is not {}".format(
 | 
			
		||||
                folder, target
 | 
			
		||||
            )
 | 
			
		||||
        assert image_name.find(folder) == 0, "{} is wrong.".format(path)
 | 
			
		||||
    print("Check ImageNet Dataset OK")
 | 
			
		||||
 
 | 
			
		||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@@ -13,9 +13,33 @@ print("library path: {:}".format(lib_dir))
 | 
			
		||||
if str(lib_dir) not in sys.path:
 | 
			
		||||
    sys.path.insert(0, str(lib_dir))
 | 
			
		||||
 | 
			
		||||
from datasets import QuadraticFunction
 | 
			
		||||
from datasets import SynAdaptiveEnv
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestQuadraticFunction(unittest.TestCase):
 | 
			
		||||
    """Test the quadratic function."""
 | 
			
		||||
 | 
			
		||||
    def test_simple(self):
 | 
			
		||||
        function = QuadraticFunction([[0, 1], [0.5, 4], [1, 1]])
 | 
			
		||||
        print(function)
 | 
			
		||||
        for x in (0, 0.5, 1):
 | 
			
		||||
            print("f({:})={:}".format(x, function[x]))
 | 
			
		||||
        thresh = 0.2
 | 
			
		||||
        self.assertTrue(abs(function[0] - 1) < thresh)
 | 
			
		||||
        self.assertTrue(abs(function[0.5] - 4) < thresh)
 | 
			
		||||
        self.assertTrue(abs(function[1] - 1) < thresh)
 | 
			
		||||
 | 
			
		||||
    def test_none(self):
 | 
			
		||||
        function = QuadraticFunction()
 | 
			
		||||
        function.fit([[0, 1], [0.5, 4], [1, 1]], max_iter=3000, verbose=True)
 | 
			
		||||
        print(function)
 | 
			
		||||
        thresh = 0.2
 | 
			
		||||
        self.assertTrue(abs(function[0] - 1) < thresh)
 | 
			
		||||
        self.assertTrue(abs(function[0.5] - 4) < thresh)
 | 
			
		||||
        self.assertTrue(abs(function[1] - 1) < thresh)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSynAdaptiveEnv(unittest.TestCase):
 | 
			
		||||
    """Test the synethtic adaptive environment."""
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user