Update VIS-CODES and SCRIPTS
This commit is contained in:
parent
4a2292a863
commit
0b0643c820
@ -363,9 +363,9 @@ def main(xargs):
|
|||||||
params = count_parameters_in_MB(search_model)
|
params = count_parameters_in_MB(search_model)
|
||||||
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
logger.log('The parameters of the search model = {:.2f} MB'.format(params))
|
||||||
logger.log('search-space : {:}'.format(search_space))
|
logger.log('search-space : {:}'.format(search_space))
|
||||||
try:
|
if bool(xargs.use_api):
|
||||||
api = API(verbose=False)
|
api = API(verbose=False)
|
||||||
except:
|
else:
|
||||||
api = None
|
api = None
|
||||||
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
logger.log('{:} create API = {:} done'.format(time_string(), api))
|
||||||
|
|
||||||
@ -486,6 +486,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
parser.add_argument('--dataset' , type=str, choices=['cifar10', 'cifar100', 'ImageNet16-120'], help='Choose between Cifar10/100 and ImageNet-16.')
|
||||||
parser.add_argument('--search_space', type=str, default='tss', choices=['tss'], help='The search space name.')
|
parser.add_argument('--search_space', type=str, default='tss', choices=['tss'], help='The search space name.')
|
||||||
parser.add_argument('--algo' , type=str, choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.')
|
parser.add_argument('--algo' , type=str, choices=['darts-v1', 'darts-v2', 'gdas', 'setn', 'random', 'enas'], help='The search space name.')
|
||||||
|
parser.add_argument('--use_api' , type=int, default=1, choices=[0,1], help='Whether use API or not (which will cost much memory).')
|
||||||
# FOR GDAS
|
# FOR GDAS
|
||||||
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
|
parser.add_argument('--tau_min', type=float, default=0.1, help='The minimum tau for Gumbel Softmax.')
|
||||||
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
|
parser.add_argument('--tau_max', type=float, default=10, help='The maximum tau for Gumbel Softmax.')
|
||||||
|
@ -30,14 +30,16 @@ def fetch_data(root_dir='./output/search', search_space='tss', dataset=None):
|
|||||||
ss_dir = '{:}-{:}'.format(root_dir, search_space)
|
ss_dir = '{:}-{:}'.format(root_dir, search_space)
|
||||||
alg2name, alg2path = OrderedDict(), OrderedDict()
|
alg2name, alg2path = OrderedDict(), OrderedDict()
|
||||||
seeds = [777]
|
seeds = [777]
|
||||||
|
if search_space == 'tss':
|
||||||
alg2name['GDAS'] = 'gdas-affine0_BN0-None'
|
alg2name['GDAS'] = 'gdas-affine0_BN0-None'
|
||||||
alg2name['RSPS'] = 'random-affine0_BN0-None'
|
alg2name['RSPS'] = 'random-affine0_BN0-None'
|
||||||
alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None'
|
alg2name['DARTS (1st)'] = 'darts-v1-affine0_BN0-None'
|
||||||
|
alg2name['DARTS (2nd)'] = 'darts-v2-affine0_BN0-None'
|
||||||
alg2name['ENAS'] = 'enas-affine0_BN0-None'
|
alg2name['ENAS'] = 'enas-affine0_BN0-None'
|
||||||
"""
|
alg2name['SETN'] = 'setn-affine0_BN0-None'
|
||||||
alg2name['DARTS (2nd)'] = 'darts-v2-affine1_BN0-None'
|
else:
|
||||||
alg2name['SETN'] = 'setn-affine1_BN0-None'
|
alg2name['TAS'] = 'tas-affine0_BN0'
|
||||||
"""
|
alg2name['FBNetV2'] = 'fbv2-affine0_BN0'
|
||||||
for alg, name in alg2name.items():
|
for alg, name in alg2name.items():
|
||||||
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
alg2path[alg] = os.path.join(ss_dir, dataset, name, 'seed-{:}-last-info.pth')
|
||||||
alg2data = OrderedDict()
|
alg2data = OrderedDict()
|
||||||
@ -66,6 +68,10 @@ y_max_s = {('cifar10', 'tss'): 94.5,
|
|||||||
('ImageNet16-120', 'tss'): 44,
|
('ImageNet16-120', 'tss'): 44,
|
||||||
('ImageNet16-120', 'sss'): 46}
|
('ImageNet16-120', 'sss'): 46}
|
||||||
|
|
||||||
|
name2label = {'cifar10': 'CIFAR-10',
|
||||||
|
'cifar100': 'CIFAR-100',
|
||||||
|
'ImageNet16-120': 'ImageNet-16-120'}
|
||||||
|
|
||||||
def visualize_curve(api, vis_save_dir, search_space):
|
def visualize_curve(api, vis_save_dir, search_space):
|
||||||
vis_save_dir = vis_save_dir.resolve()
|
vis_save_dir = vis_save_dir.resolve()
|
||||||
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
vis_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@ -94,8 +100,8 @@ def visualize_curve(api, vis_save_dir, search_space):
|
|||||||
alg2accuracies[alg] = accuracies
|
alg2accuracies[alg] = accuracies
|
||||||
ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg))
|
ax.plot(xs, accuracies, c=colors[idx], label='{:}'.format(alg))
|
||||||
ax.set_xlabel('The searching epoch', fontsize=LabelSize)
|
ax.set_xlabel('The searching epoch', fontsize=LabelSize)
|
||||||
ax.set_ylabel('Test accuracy on {:}'.format(dataset), fontsize=LabelSize)
|
ax.set_ylabel('Test accuracy on {:}'.format(name2label[dataset]), fontsize=LabelSize)
|
||||||
ax.set_title('Searching results on {:}'.format(dataset), fontsize=LabelSize+4)
|
ax.set_title('Searching results on {:}'.format(name2label[dataset]), fontsize=LabelSize+4)
|
||||||
ax.legend(loc=4, fontsize=LegendFontsize)
|
ax.legend(loc=4, fontsize=LegendFontsize)
|
||||||
|
|
||||||
fig, axs = plt.subplots(1, 3, figsize=figsize)
|
fig, axs = plt.subplots(1, 3, figsize=figsize)
|
||||||
|
Loading…
Reference in New Issue
Block a user