2023-05-04 07:09:03 +02:00
import sys
import os
import json
import tqdm
import torch
import torch . utils
import torchvision . datasets as dset
import torch . backends . cudnn as cudnn
import random
import glob
import logging
import shutil
import numpy as np
sys . path . insert ( 0 , ' ../ ' )
from nasbench201 . cell_infers . tiny_network import TinyNetwork
from nasbench201 . genotypes import Structure
from nas_201_api import NASBench201API as API
from pycls . models . nas . nas import NetworkImageNet , NetworkCIFAR
from pycls . models . nas . genotypes import Genotype
import nasbench201 . utils as ig_utils
from foresight . pruners import *
from Scorers . scorer import Jocab_Scorer
import torchvision . transforms as transforms
import argparse
from mobilenet_search_space . retrain_architecture . model import Network
from torch . utils . tensorboard import SummaryWriter
from sota . cnn . hdf5 import H5Dataset
parser = argparse . ArgumentParser ( " sota " )
parser . add_argument ( ' --data ' , type = str , default = ' ../data ' ,
help = ' location of the data corpus ' )
parser . add_argument ( ' --dataset ' , type = str , default = ' cifar10 ' , help = ' choose dataset ' )
parser . add_argument ( ' --gpu ' , type = str , default = ' auto ' , help = ' gpu device id ' )
parser . add_argument ( ' --save ' , type = str , default = ' exp ' , help = ' experiment name ' )
parser . add_argument ( ' --save_path ' , type = str , default = ' ../experiments/sota ' , help = ' experiment name ' )
parser . add_argument ( ' --seed ' , type = int , default = 2 , help = ' random seed ' )
parser . add_argument ( ' --ckpt_path ' , type = str , help = ' path that saved networks pool ' )
parser . add_argument ( ' --train_portion ' , type = float , default = 0.5 , help = ' portion of training data ' )
parser . add_argument ( ' --maxiter ' , default = 1 , type = int , help = ' score is the max of this many evaluations of the network ' )
parser . add_argument ( ' --batch_size ' , type = int , default = 256 , help = ' batch size for alpha ' )
parser . add_argument ( ' --cutout ' , action = ' store_true ' , default = False , help = ' use cutout ' )
parser . add_argument ( ' --cutout_length ' , type = int , default = 16 , help = ' cutout length ' )
parser . add_argument ( ' --cutout_prob ' , type = float , default = 1.0 , help = ' cutout probability ' )
parser . add_argument ( ' --init_channels ' , type = int , default = 16 , help = ' num of init channels ' )
parser . add_argument ( ' --layers ' , type = int , default = 8 , help = ' total number of layers ' )
parser . add_argument ( ' --validate_rounds ' , type = int , default = 10 , help = ' score round for networks ' )
parser . add_argument ( ' --proj_crit ' , type = str , default = ' jacob ' , choices = [ ' loss ' , ' acc ' , ' var ' , ' cor ' , ' norm ' , ' jacob ' , ' snip ' , ' fisher ' , ' synflow ' , ' grad_norm ' , ' grasp ' , ' jacob_cov ' , ' comb ' , ' meco ' , ' zico ' ] , help = ' criteria for projection ' )
parser . add_argument ( ' --edge_decision ' , type = str , default = ' random ' , choices = [ ' random ' , ' reverse ' , ' order ' , ' global_op_greedy ' , ' global_op_once ' , ' global_edge_greedy ' , ' global_edge_once ' ] , help = ' which edge to be projected next ' )
args = parser . parse_args ( )
torch . backends . cudnn . deterministic = True
torch . backends . cudnn . benchmark = False
random . seed ( args . seed )
np . random . seed ( args . seed )
torch . manual_seed ( args . seed )
torch . cuda . manual_seed ( args . seed )
def load_network_pool ( ckpt_path ) :
with open ( os . path . join ( ckpt_path , ' networks_pool.json ' ) , ' r ' ) as save_file :
for line in save_file :
networks_pool = json . loads ( line )
if ' pool_size ' in networks_pool :
return networks_pool [ ' search_space ' ] , networks_pool [ ' dataset ' ] , networks_pool [ ' networks ' ] , networks_pool [ ' pool_size ' ]
else :
return networks_pool [ ' search_space ' ] , networks_pool [ ' dataset ' ] , networks_pool [ ' networks ' ] , len ( networks_pool [ ' networks ' ] )
#### args augment
search_space , dataset , networks_pool , pool_size = load_network_pool ( args . ckpt_path )
# print(search_space, dataset, networks_pool, pool_size)
search_space = search_space . strip ( )
dataset = dataset . strip ( )
expid = args . save
args . save = ' {} / {} -valid- {} - {} - {} - {} ' . format ( args . save_path , search_space , args . save , args . seed , pool_size , args . validate_rounds )
if not dataset == ' cifar10 ' :
args . save + = ' - ' + dataset
if not args . edge_decision == ' random ' :
args . save + = ' - ' + args . edge_decision
if not args . proj_crit == ' jacob ' :
args . save + = ' - ' + args . proj_crit
scripts_to_save = glob . glob ( ' *.py ' ) + [ ' ../exp_scripts/ {} .sh ' . format ( expid ) ]
if os . path . exists ( args . save ) :
if input ( " WARNING: {} exists, override?[y/n] " . format ( args . save ) ) == ' y ' :
print ( ' proceed to override saving directory ' )
shutil . rmtree ( args . save )
else :
exit ( 0 )
ig_utils . create_exp_dir ( args . save , scripts_to_save = None )
log_format = ' %(asctime)s %(message)s '
logging . basicConfig ( stream = sys . stdout , level = logging . INFO ,
format = log_format , datefmt = ' % m/ %d % I: % M: % S % p ' )
log_file = ' log '
log_file + = ' .txt '
log_path = os . path . join ( args . save , log_file )
logging . info ( ' ======> log filename: %s ' , log_file )
logging . info ( ' load pool from space: %s and dataset: %s ' , search_space , dataset )
if os . path . exists ( log_path ) :
if input ( " WARNING: {} exists, override?[y/n] " . format ( log_file ) ) == ' y ' :
print ( ' proceed to override log file directory ' )
else :
exit ( 0 )
fh = logging . FileHandler ( log_path , mode = ' w ' )
fh . setFormatter ( logging . Formatter ( log_format ) )
logging . getLogger ( ) . addHandler ( fh )
writer = SummaryWriter ( args . save + ' /runs ' )
#### macros
if dataset == ' cifar100 ' :
n_classes = 100
elif dataset == ' imagenet16-120 ' :
n_classes = 120
elif dataset == ' imagenet ' :
n_classes = 1000
else :
n_classes = 10
if search_space == ' nas-bench-201 ' :
api = API ( ' ../data/NAS-Bench-201-v1_0-e61699.pth ' )
if search_space == ' nb_macro ' :
import pickle as pkl
f = open ( ' ../data/nbmacro-base-0.pickle ' , ' rb ' )
head = pkl . load ( f )
value = pkl . load ( f )
api = { }
for v in value :
h , val_t1 , test_t1 , t_time = v
api [ h ] = test_t1
def main ( ) :
#### data
if dataset == ' imagenet ' :
normalize = transforms . Normalize ( mean = [ 0.485 , 0.456 , 0.406 ] , std = [ 0.229 , 0.224 , 0.225 ] )
train_transform = transforms . Compose ( [
transforms . RandomResizedCrop ( 224 ) ,
transforms . RandomHorizontalFlip ( ) ,
transforms . ColorJitter (
brightness = 0.4 ,
contrast = 0.4 ,
saturation = 0.4 ,
hue = 0.2 ) ,
transforms . ToTensor ( ) ,
normalize ,
] )
test_transform = transforms . Compose ( [
transforms . Resize ( 256 ) ,
transforms . CenterCrop ( 224 ) ,
transforms . ToTensor ( ) ,
normalize ,
] )
train_data = H5Dataset ( os . path . join ( args . data , ' imagenet-train-256.h5 ' ) , transform = train_transform )
num_train = len ( train_data )
indices = list ( range ( num_train ) )
split = int ( np . floor ( args . validate_rounds * args . batch_size ) )
train_queue = torch . utils . data . DataLoader (
train_data , batch_size = args . batch_size , num_workers = 4 , pin_memory = True , sampler = torch . utils . data . sampler . SubsetRandomSampler ( indices [ : split ] ) )
else :
if dataset == ' cifar10 ' :
train_transform , valid_transform = ig_utils . _data_transforms_cifar10 ( args )
train_data = dset . CIFAR10 ( root = args . data , train = True , download = True , transform = train_transform )
valid_data = dset . CIFAR10 ( root = args . data , train = False , download = True , transform = valid_transform )
elif dataset == ' cifar100 ' :
train_transform , valid_transform = ig_utils . _data_transforms_cifar100 ( args )
train_data = dset . CIFAR100 ( root = args . data , train = True , download = True , transform = train_transform )
valid_data = dset . CIFAR100 ( root = args . data , train = False , download = True , transform = valid_transform )
elif dataset == ' svhn ' :
train_transform , valid_transform = ig_utils . _data_transforms_svhn ( args )
train_data = dset . SVHN ( root = args . data , split = ' train ' , download = True , transform = train_transform )
valid_data = dset . SVHN ( root = args . data , split = ' test ' , download = True , transform = valid_transform )
elif dataset == ' imagenet16-120 ' :
from nasbench201 . DownsampledImageNet import ImageNet16
mean = [ x / 255 for x in [ 122.68 , 116.66 , 104.01 ] ]
std = [ x / 255 for x in [ 63.22 , 61.26 , 65.09 ] ]
lists = [ transforms . RandomHorizontalFlip ( ) , transforms . RandomCrop ( 16 , padding = 2 ) , transforms . ToTensor ( ) , transforms . Normalize ( mean , std ) ]
train_transform = transforms . Compose ( lists )
train_data = ImageNet16 ( root = os . path . join ( data , ' imagenet16 ' ) , train = True , transform = train_transform , use_num_of_class_only = 120 )
valid_data = ImageNet16 ( root = os . path . join ( data , ' imagenet16 ' ) , train = False , transform = train_transform , use_num_of_class_only = 120 )
assert len ( train_data ) == 151700
num_train = len ( train_data )
indices = list ( range ( num_train ) )
split = int ( np . floor ( args . validate_rounds * args . batch_size ) )
train_queue = torch . utils . data . DataLoader (
train_data , batch_size = args . batch_size ,
sampler = torch . utils . data . sampler . SubsetRandomSampler ( indices [ : split ] ) ,
pin_memory = True , num_workers = 4 )
gpu = ig_utils . pick_gpu_lowest_memory ( ) if args . gpu == ' auto ' else int ( args . gpu )
torch . cuda . set_device ( gpu )
if args . proj_crit == ' jacob ' :
validate_scorer = Jocab_Scorer ( gpu )
best_id = None
best_score = 0
best_networks = None
crit_list = [ ]
print ( len ( train_queue ) )
net_history = [ ]
for net_config in tqdm . tqdm ( networks_pool , desc = " networks " , position = 0 ) :
net_id = net_config [ ' id ' ]
# print(net_id)
net_genotype = net_config [ ' genotype ' ]
# print(net_genotype)
if net_genotype not in net_history :
net_history . append ( net_genotype )
# print(net_genotype)
network = get_networks_from_genotype ( net_genotype , dataset , search_space )
# print(network)
if args . proj_crit == ' jacob ' :
validate_scorer . setup_hooks ( network , args . batch_size )
for step , ( input , target ) in tqdm . tqdm ( enumerate ( train_queue ) , desc = " validate_rounds " , position = 1 , leave = False ) :
input . cuda ( )
target . cuda ( )
if args . proj_crit == ' jacob ' :
score = validate_scorer . score ( network , input , target )
else :
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
network . requires_feature = False
2023-05-04 07:41:59 +02:00
measures = predictive . find_measures ( network ,
train_queue ,
( ' random ' , 1 , n_classes ) ,
torch . device ( " cuda " ) ,
measure_names = [ args . proj_crit ] )
2023-05-04 07:09:03 +02:00
# measures = predictive.find_measures(network,
# train_queue,
# ('random', 1, n_classes), #TODO don't hard-code num_classes to 10
# torch.device("cuda"),
# measure_names=[args.proj_crit])
score = measures [ args . proj_crit ]
if step == 0 :
crit_list . append ( score )
else :
crit_list [ - 1 ] + = score
if args . proj_crit != ' jacob ' :
break
#best_networks = networks_pool[np.nanargmax(crit_list)]['genotype']
best_networks = net_history [ np . nanargmax ( crit_list ) ]
if search_space == ' nas-bench-201 ' :
cifar10_train , cifar10_test , cifar100_train , cifar100_valid , \
cifar100_test , imagenet16_train , imagenet16_valid , imagenet16_test = query ( api , best_networks , logging )
networks_info = { }
networks_info [ ' search_space ' ] = search_space
networks_info [ ' dataset ' ] = dataset
networks_info [ ' networks ' ] = best_networks
with open ( os . path . join ( args . save , ' best_networks.json ' ) , ' w ' ) as save_file :
json . dump ( networks_info , save_file )
#### util functions
def distill ( result ) :
result = result . split ( ' \n ' )
cifar10 = result [ 5 ] . replace ( ' ' , ' ' ) . split ( ' : ' )
cifar100 = result [ 7 ] . replace ( ' ' , ' ' ) . split ( ' : ' )
imagenet16 = result [ 9 ] . replace ( ' ' , ' ' ) . split ( ' : ' )
cifar10_train = float ( cifar10 [ 1 ] . strip ( ' ,test ' ) [ - 7 : - 2 ] . strip ( ' = ' ) )
cifar10_test = float ( cifar10 [ 2 ] [ - 7 : - 2 ] . strip ( ' = ' ) )
cifar100_train = float ( cifar100 [ 1 ] . strip ( ' ,valid ' ) [ - 7 : - 2 ] . strip ( ' = ' ) )
cifar100_valid = float ( cifar100 [ 2 ] . strip ( ' ,test ' ) [ - 7 : - 2 ] . strip ( ' = ' ) )
cifar100_test = float ( cifar100 [ 3 ] [ - 7 : - 2 ] . strip ( ' = ' ) )
imagenet16_train = float ( imagenet16 [ 1 ] . strip ( ' ,valid ' ) [ - 7 : - 2 ] . strip ( ' = ' ) )
imagenet16_valid = float ( imagenet16 [ 2 ] . strip ( ' ,test ' ) [ - 7 : - 2 ] . strip ( ' = ' ) )
imagenet16_test = float ( imagenet16 [ 3 ] [ - 7 : - 2 ] . strip ( ' = ' ) )
return cifar10_train , cifar10_test , cifar100_train , cifar100_valid , \
cifar100_test , imagenet16_train , imagenet16_valid , imagenet16_test
def query ( api , genotype , logging ) :
result = api . query_by_arch ( genotype , hp = ' 200 ' )
logging . info ( ' {:} ' . format ( result ) )
cifar10_train , cifar10_test , cifar100_train , cifar100_valid , \
cifar100_test , imagenet16_train , imagenet16_valid , imagenet16_test = distill ( result )
logging . info ( ' cifar10 train %f test %f ' , cifar10_train , cifar10_test )
logging . info ( ' cifar100 train %f valid %f test %f ' , cifar100_train , cifar100_valid , cifar100_test )
logging . info ( ' imagenet16 train %f valid %f test %f ' , imagenet16_train , imagenet16_valid , imagenet16_test )
return cifar10_train , cifar10_test , cifar100_train , cifar100_valid , \
cifar100_test , imagenet16_train , imagenet16_valid , imagenet16_test
def get_networks_from_genotype ( genotype_str , dataset , search_space ) :
if search_space == ' nas-bench-201 ' :
net_index = api . query_index_by_arch ( genotype_str )
##print(dataset)
net_config = api . get_net_config ( net_index , ' cifar10-valid ' )
print ( net_config )
genotype = Structure . str2structure ( net_config [ ' arch_str ' ] )
network = TinyNetwork ( net_config [ ' C ' ] , net_config [ ' N ' ] , genotype , n_classes )
return network
elif search_space == ' mobilenet ' :
rngs = [ int ( id ) for id in genotype_str . split ( ' ' ) ]
network = Network ( rngs , n_class = n_classes )
return network
else :
# print(genotype_str)
genotype_config = json . loads ( genotype_str )
genotype = Genotype ( normal = genotype_config [ ' normal ' ] , normal_concat = genotype_config [ ' normal_concat ' ] , reduce = genotype_config [ ' reduce ' ] , reduce_concat = genotype_config [ ' reduce_concat ' ] )
if dataset == ' imagenet ' :
network = NetworkImageNet ( args . init_channels , n_classes , args . layers , False , genotype )
else :
network = NetworkCIFAR ( args . init_channels , n_classes , args . layers , False , genotype )
network . drop_path_prob = 0.
return network
if __name__ == ' __main__ ' :
main ( )