Fix bugs in TAS: missing ReLU in the end of each searching block
This commit is contained in:
parent
569b9b406a
commit
076f9c2d41
59
exps/NAS-Bench-201/xshape-file.py
Normal file
59
exps/NAS-Bench-201/xshape-file.py
Normal file
@ -0,0 +1,59 @@
|
||||
###############################################################
|
||||
# NAS-Bench-201, ICLR 2020 (https://arxiv.org/abs/2001.00326) #
|
||||
###############################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
|
||||
###############################################################
|
||||
# Usage: python exps/NAS-Bench-201/xshape-file.py --mode check
|
||||
###############################################################
|
||||
import os, sys, time, torch, argparse
|
||||
from typing import List, Text, Dict, Any
|
||||
from tqdm import tqdm
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import dict2config, load_config
|
||||
from procedures import bench_evaluate_for_seed
|
||||
from procedures import get_machine_info
|
||||
from datasets import get_datasets
|
||||
from log_utils import Logger, AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def obtain_valid_ckp(save_dir: Text, total: int):
|
||||
possible_seeds = [777, 888]
|
||||
seed2ckps = defaultdict(list)
|
||||
miss2ckps = defaultdict(list)
|
||||
for i in range(total):
|
||||
for seed in possible_seeds:
|
||||
path = os.path.join(save_dir, 'arch-{:06d}-seed-{:04d}.pth'.format(i, seed))
|
||||
if os.path.exists(path):
|
||||
seed2ckps[seed].append(i)
|
||||
else:
|
||||
miss2ckps[seed].append(i)
|
||||
"""
|
||||
ckps = [x for x in save_dir.glob('arch-{:06d}-seed-*.pth'.format(i))]
|
||||
for ckp in ckps:
|
||||
seed = ckp.name.split('-seed-')[-1].split('.pth')[0]
|
||||
seed2ckps[int(seed)].append(i)
|
||||
"""
|
||||
for seed, xlist in seed2ckps.items():
|
||||
print('[{:}] [seed={:}] has {:}/{:}'.format(save_dir, seed, len(xlist), total))
|
||||
return dict(seed2ckps), dict(miss2ckps)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='NAS-Bench-X', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--mode', type=str, required=True, choices=['check', 'copy'], help='The script mode.')
|
||||
parser.add_argument('--save_dir', type=str, default='output/NAS-BENCH-202', help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--check_N', type=int, default=32768, help='For safety.')
|
||||
# use for train the model
|
||||
args = parser.parse_args()
|
||||
possible_configs = ['01', '12', '90']
|
||||
if args.mode == 'check':
|
||||
for config in possible_configs:
|
||||
cur_save_dir = '{:}/raw-data-{:}'.format(args.save_dir, config)
|
||||
seed2ckps, miss2ckps = obtain_valid_ckp(cur_save_dir, args.check_N)
|
||||
torch.save(dict(seed2ckps=seed2ckps, miss2ckps=miss2ckps), '{:}/meta-{:}.pth'.format(args.save_dir, config))
|
||||
|
@ -165,7 +165,7 @@ def filter_indexes(xlist, mode, save_dir, seeds):
|
||||
if not temp_path.exists():
|
||||
all_indexes.append(index)
|
||||
break
|
||||
print('{:} [FILTER-INDEXES] : there are {:} architectures in total'.format(time_string(), len(all_indexes)))
|
||||
print('{:} [FILTER-INDEXES] : there are {:}/{:} architectures in total'.format(time_string(), len(all_indexes), len(xlist)))
|
||||
|
||||
SLURM_PROCID, SLURM_NTASKS = 'SLURM_PROCID', 'SLURM_NTASKS'
|
||||
if SLURM_PROCID in os.environ and SLURM_NTASKS in os.environ: # run on the slurm
|
||||
|
@ -172,7 +172,7 @@ class ResNetBasicblock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
@ -244,8 +244,7 @@ class ResNetBottleneck(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
|
||||
class SearchShapeCifarResNet(nn.Module):
|
||||
|
@ -156,7 +156,7 @@ class ResNetBasicblock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
@ -228,8 +228,7 @@ class ResNetBottleneck(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
|
||||
class SearchWidthCifarResNet(nn.Module):
|
||||
|
@ -171,7 +171,7 @@ class ResNetBasicblock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_b)
|
||||
return out, expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_b, sum([expected_flop_a, expected_flop_b, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv_a(inputs)
|
||||
@ -243,8 +243,7 @@ class ResNetBottleneck(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out_1x4)
|
||||
return out, expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
return nn.functional.relu(out, inplace=True), expected_inC_1x4, sum([expected_flop_1x1, expected_flop_3x3, expected_flop_1x4, expected_flop_c])
|
||||
|
||||
|
||||
class SearchShapeImagenetResNet(nn.Module):
|
||||
|
@ -153,7 +153,7 @@ class SimBlock(nn.Module):
|
||||
else:
|
||||
residual, expected_flop_c = inputs, 0
|
||||
out = additive_func(residual, out)
|
||||
return out, expected_next_inC, sum([expected_flop, expected_flop_c])
|
||||
return nn.functional.relu(out, inplace=True), expected_next_inC, sum([expected_flop, expected_flop_c])
|
||||
|
||||
def basic_forward(self, inputs):
|
||||
basicblock = self.conv(inputs)
|
||||
|
36
scripts-search/X-X/train-shapes-01.sh
Normal file
36
scripts-search/X-X/train-shapes-01.sh
Normal file
@ -0,0 +1,36 @@
|
||||
#!/bin/bash
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.01 #
|
||||
#####################################################
|
||||
# bash ./scripts-search/X-X/train-shapes-01.sh 0 4
|
||||
echo script name: $0
|
||||
echo $# arguments
|
||||
if [ "$#" -ne 2 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 2 parameters for hyper-parameters-opt-file, and seeds"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||
exit 1
|
||||
else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
srange=00000-32767
|
||||
opt=01
|
||||
all_seeds=777
|
||||
cpus=4
|
||||
|
||||
save_dir=./output/NAS-BENCH-202/
|
||||
|
||||
SLURM_PROCID=$1 SLURM_NTASKS=$2 OMP_NUM_THREADS=${cpus} python exps/NAS-Bench-201/xshapes.py \
|
||||
--mode new --srange ${srange} --hyper ${opt} --save_dir ${save_dir} \
|
||||
--datasets cifar10 cifar10 cifar100 ImageNet16-120 \
|
||||
--splits 1 0 0 0 \
|
||||
--xpaths $TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python \
|
||||
$TORCH_HOME/cifar.python/ImageNet16 \
|
||||
--workers ${cpus} \
|
||||
--seeds ${all_seeds}
|
3
scripts-search/algos/README.md
Normal file
3
scripts-search/algos/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# 10 NAS algorithms in NAS-Bench-201
|
||||
|
||||
Each script in this folder corresponds to one NAS algorithm, you can simple run it by one command.
|
Loading…
Reference in New Issue
Block a user