2020-01-09 12:26:23 +01:00
# This file is for experimental usage
2020-03-11 08:44:39 +01:00
import torch , random
2020-01-09 12:26:23 +01:00
import numpy as np
from copy import deepcopy
import torch . nn as nn
2021-04-11 15:45:20 +02:00
# modules in AutoDL
2020-01-09 12:26:23 +01:00
from models import CellStructure
from log_utils import time_string
2020-07-13 12:04:52 +02:00
2020-01-09 12:26:23 +01:00
def evaluate_one_shot ( model , xloader , api , cal_mode , seed = 111 ) :
2021-03-07 04:09:47 +01:00
print (
" This is an old version of codes to use NAS-Bench-API, and should be modified to align with the new version. Please contact me for more details if you use this function. "
)
weights = deepcopy ( model . state_dict ( ) )
model . train ( cal_mode )
with torch . no_grad ( ) :
logits = nn . functional . log_softmax ( model . arch_parameters , dim = - 1 )
archs = CellStructure . gen_all ( model . op_names , model . max_nodes , False )
probs , accuracies , gt_accs_10_valid , gt_accs_10_test = [ ] , [ ] , [ ] , [ ]
2020-01-09 12:26:23 +01:00
loader_iter = iter ( xloader )
2021-03-07 04:09:47 +01:00
random . seed ( seed )
random . shuffle ( archs )
for idx , arch in enumerate ( archs ) :
arch_index = api . query_index_by_arch ( arch )
metrics = api . get_more_info ( arch_index , " cifar10-valid " , None , False , False )
gt_accs_10_valid . append ( metrics [ " valid-accuracy " ] )
metrics = api . get_more_info ( arch_index , " cifar10 " , None , False , False )
gt_accs_10_test . append ( metrics [ " test-accuracy " ] )
select_logits = [ ]
for i , node_info in enumerate ( arch . nodes ) :
for op , xin in node_info :
node_str = " {:} <- {:} " . format ( i + 1 , xin )
op_index = model . op_names . index ( op )
select_logits . append ( logits [ model . edge2index [ node_str ] , op_index ] )
cur_prob = sum ( select_logits ) . item ( )
probs . append ( cur_prob )
cor_prob_valid = np . corrcoef ( probs , gt_accs_10_valid ) [ 0 , 1 ]
cor_prob_test = np . corrcoef ( probs , gt_accs_10_test ) [ 0 , 1 ]
print (
" {:} correlation for probabilities : {:.6f} on CIFAR-10 validation and {:.6f} on CIFAR-10 test " . format (
time_string ( ) , cor_prob_valid , cor_prob_test
)
)
for idx , arch in enumerate ( archs ) :
model . set_cal_mode ( " dynamic " , arch )
try :
inputs , targets = next ( loader_iter )
except :
loader_iter = iter ( xloader )
inputs , targets = next ( loader_iter )
_ , logits = model ( inputs . cuda ( ) )
_ , preds = torch . max ( logits , dim = - 1 )
correct = ( preds == targets . cuda ( ) ) . float ( )
accuracies . append ( correct . mean ( ) . item ( ) )
if idx != 0 and ( idx % 500 == 0 or idx + 1 == len ( archs ) ) :
2021-04-11 15:45:20 +02:00
cor_accs_valid = np . corrcoef ( accuracies , gt_accs_10_valid [ : idx + 1 ] ) [
0 , 1
]
cor_accs_test = np . corrcoef ( accuracies , gt_accs_10_test [ : idx + 1 ] ) [
0 , 1
]
2021-03-07 04:09:47 +01:00
print (
" {:} {:05d} / {:05d} mode= {:5s} , correlation : accs= {:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test. " . format (
2021-04-11 15:45:20 +02:00
time_string ( ) ,
idx ,
len ( archs ) ,
" Train " if cal_mode else " Eval " ,
cor_accs_valid ,
cor_accs_test ,
2021-03-07 04:09:47 +01:00
)
)
model . load_state_dict ( weights )
return archs , probs , accuracies