update code styles
This commit is contained in:
		| @@ -9,16 +9,25 @@ class SearchDataset(data.Dataset): | ||||
|  | ||||
|   def __init__(self, name, data, train_split, valid_split, check=True): | ||||
|     self.datasetname = name | ||||
|     self.data        = data | ||||
|     self.train_split = train_split.copy() | ||||
|     self.valid_split = valid_split.copy() | ||||
|     if check: | ||||
|       intersection = set(train_split).intersection(set(valid_split)) | ||||
|       assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' | ||||
|     if isinstance(data, (list, tuple)): # new type of SearchDataset | ||||
|       assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) | ||||
|       self.train_data  = data[0] | ||||
|       self.valid_data  = data[1] | ||||
|       self.train_split = train_split.copy() | ||||
|       self.valid_split = valid_split.copy() | ||||
|       self.mode_str    = 'V2' # new mode  | ||||
|     else: | ||||
|       self.mode_str    = 'V1' # old mode  | ||||
|       self.data        = data | ||||
|       self.train_split = train_split.copy() | ||||
|       self.valid_split = valid_split.copy() | ||||
|       if check: | ||||
|         intersection = set(train_split).intersection(set(valid_split)) | ||||
|         assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' | ||||
|     self.length      = len(self.train_split) | ||||
|  | ||||
|   def __repr__(self): | ||||
|     return ('{name}(name={datasetname}, train={tr_L}, valid={val_L})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split))) | ||||
|     return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) | ||||
|  | ||||
|   def __len__(self): | ||||
|     return self.length | ||||
| @@ -27,6 +36,11 @@ class SearchDataset(data.Dataset): | ||||
|     assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) | ||||
|     train_index = self.train_split[index] | ||||
|     valid_index = random.choice( self.valid_split ) | ||||
|     train_image, train_label = self.data[train_index] | ||||
|     valid_image, valid_label = self.data[valid_index] | ||||
|     if self.mode_str == 'V1': | ||||
|       train_image, train_label = self.data[train_index] | ||||
|       valid_image, valid_label = self.data[valid_index] | ||||
|     elif self.mode_str == 'V2': | ||||
|       train_image, train_label = self.train_data[train_index] | ||||
|       valid_image, valid_label = self.valid_data[valid_index] | ||||
|     else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) | ||||
|     return train_image, train_label, valid_image, valid_label | ||||
|   | ||||
| @@ -34,7 +34,7 @@ class PointMeta(): | ||||
|  | ||||
|   def get_box(self, return_diagonal=False): | ||||
|     if self.box is None: return None | ||||
|     if return_diagonal == False: | ||||
|     if not return_diagonal: | ||||
|       return self.box.clone() | ||||
|     else: | ||||
|       W = (self.box[2]-self.box[0]).item() | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| from copy import deepcopy | ||||
| from ..cell_operations import OPS | ||||
|   | ||||
| @@ -68,7 +68,7 @@ class Structure: | ||||
|     for i, node_info in enumerate(self.nodes): | ||||
|       sums = [] | ||||
|       for op, xin in node_info: | ||||
|         if op == 'none' or nodes[xin] == False: x = False | ||||
|         if op == 'none' or nodes[xin] is False: x = False | ||||
|         else: x = True | ||||
|         sums.append( x ) | ||||
|       nodes[i+1] = sum(sums) > 0 | ||||
|   | ||||
| @@ -85,7 +85,7 @@ class SearchCell(nn.Module): | ||||
|           candidates = self.edges[node_str] | ||||
|           select_op  = random.choice(candidates) | ||||
|           sops.append( select_op ) | ||||
|           if not hasattr(select_op, 'is_zero') or select_op.is_zero == False: has_non_zero=True | ||||
|           if not hasattr(select_op, 'is_zero') or select_op.is_zero is False: has_non_zero=True | ||||
|         if has_non_zero: break | ||||
|       inter_nodes = [] | ||||
|       for j, select_op in enumerate(sops): | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| import math, torch | ||||
| import math | ||||
| import torch.nn as nn | ||||
| import torch.nn.functional as F | ||||
| from ..initialization import initialize_resnet | ||||
|   | ||||
| @@ -70,6 +70,9 @@ class NASBench102API(object): | ||||
|   def __repr__(self): | ||||
|     return ('{name}({num}/{total} architectures)'.format(name=self.__class__.__name__, num=len(self.evaluated_indexes), total=len(self.meta_archs))) | ||||
|  | ||||
|   def random(self): | ||||
|     return random.randint(0, len(self.meta_archs)-1) | ||||
|  | ||||
|   def query_index_by_arch(self, arch): | ||||
|     if isinstance(arch, str): | ||||
|       if arch in self.archstr2index: arch_index = self.archstr2index[ arch ] | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| ################################################## | ||||
| # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # | ||||
| ################################################## | ||||
| import os, sys, time, torch, random, PIL, copy, numpy as np | ||||
| import os, sys, torch, random, PIL, copy, numpy as np | ||||
| from os import path as osp | ||||
| from shutil  import copyfile | ||||
|  | ||||
|   | ||||
| @@ -1,3 +1,5 @@ | ||||
| from .evaluation_utils import obtain_accuracy | ||||
| from .gpu_manager      import GPUManager | ||||
| from .flop_benchmark   import get_model_infos | ||||
| from .affine_utils     import normalize_points, denormalize_points | ||||
| from .affine_utils     import identity2affine, solve2theta, affine2image | ||||
|   | ||||
| @@ -1,10 +1,3 @@ | ||||
| # Copyright (c) Facebook, Inc. and its affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the license found in the | ||||
| # LICENSE file in the root directory of this source tree. | ||||
| # | ||||
| # | ||||
| # functions for affine transformation | ||||
| import math, torch | ||||
| import numpy as np | ||||
| @@ -1,4 +1,4 @@ | ||||
| import copy, torch | ||||
| import torch | ||||
| import torch.nn as nn | ||||
| import numpy as np | ||||
|  | ||||
|   | ||||
| @@ -27,7 +27,7 @@ class GPUManager(): | ||||
|         find = False | ||||
|         for gpu in all_gpus: | ||||
|           if gpu['index'] == CUDA_VISIBLE_DEVICE: | ||||
|             assert find==False, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE) | ||||
|             assert not find, 'Duplicate cuda device index : {}'.format(CUDA_VISIBLE_DEVICE) | ||||
|             find = True | ||||
|             selected_gpus.append( gpu.copy() ) | ||||
|             selected_gpus[-1]['index'] = '{}'.format(idx) | ||||
|   | ||||
							
								
								
									
										52
									
								
								lib/utils/nas_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								lib/utils/nas_utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,52 @@ | ||||
| # This file is for experimental usage | ||||
| import os, sys, torch, random | ||||
| import numpy as np | ||||
| from copy import deepcopy | ||||
| from tqdm import tqdm | ||||
| import torch.nn as nn | ||||
|  | ||||
| from utils  import obtain_accuracy | ||||
| from models import CellStructure | ||||
| from log_utils import time_string | ||||
|  | ||||
| def evaluate_one_shot(model, xloader, api, cal_mode, seed=111): | ||||
|   weights = deepcopy(model.state_dict()) | ||||
|   model.train(cal_mode) | ||||
|   with torch.no_grad(): | ||||
|     logits = nn.functional.log_softmax(model.arch_parameters, dim=-1) | ||||
|     archs = CellStructure.gen_all(model.op_names, model.max_nodes, False) | ||||
|     probs, accuracies, gt_accs = [], [], [] | ||||
|     loader_iter = iter(xloader) | ||||
|     random.seed(seed) | ||||
|     random.shuffle(archs) | ||||
|     for idx, arch in enumerate(archs): | ||||
|       arch_index = api.query_index_by_arch( arch ) | ||||
|       metrics = api.get_more_info(arch_index, 'cifar10-valid', None, False, False) | ||||
|       gt_accs.append( metrics['valid-accuracy'] ) | ||||
|       select_logits = [] | ||||
|       for i, node_info in enumerate(arch.nodes): | ||||
|         for op, xin in node_info: | ||||
|           node_str = '{:}<-{:}'.format(i+1, xin) | ||||
|           op_index = model.op_names.index(op) | ||||
|           select_logits.append( logits[model.edge2index[node_str], op_index] ) | ||||
|       cur_prob = sum(select_logits).item() | ||||
|       probs.append( cur_prob ) | ||||
|     cor_prob = np.corrcoef(probs, gt_accs)[0,1] | ||||
|     print ('correlation for probabilities : {:}'.format(cor_prob)) | ||||
|        | ||||
|     for idx, arch in enumerate(archs): | ||||
|       model.set_cal_mode('dynamic', arch) | ||||
|       try: | ||||
|         inputs, targets = next(loader_iter) | ||||
|       except: | ||||
|         loader_iter = iter(xloader) | ||||
|         inputs, targets = next(loader_iter) | ||||
|       _, logits = model(inputs.cuda()) | ||||
|       _, preds  = torch.max(logits, dim=-1) | ||||
|       correct = (preds == targets.cuda() ).float() | ||||
|       accuracies.append( correct.mean().item() ) | ||||
|       if idx != 0 and (idx % 300 == 0 or idx + 1 == len(archs) or idx == 10): | ||||
|         cor_accs = np.corrcoef(accuracies, gt_accs[:idx+1])[0,1] | ||||
|         print ('{:} {:03d}/{:03d} mode={:5s}, correlation : accs={:.4f}, arch={:}'.format(time_string(), idx, len(archs), 'Train' if cal_mode else 'Eval', cor_accs, arch)) | ||||
|   model.load_state_dict(weights) | ||||
|   return archs, probs, accuracies | ||||
| @@ -1 +0,0 @@ | ||||
| from .affine_utils import normalize_points, denormalize_points | ||||
		Reference in New Issue
	
	Block a user