From 6814816d5fe26fc62476ed255d139b967ee0130b Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Wed, 16 Oct 2019 00:09:10 +1100 Subject: [PATCH] update scripts --- lib/models/searchs/SoftSelect.py | 1 + scripts-search/search-cifar.sh | 22 +++++----------------- scripts-search/search-depth-cifar.sh | 6 +++--- scripts-search/search-width-cifar.sh | 6 +++--- scripts/base-train.sh | 2 +- 5 files changed, 13 insertions(+), 24 deletions(-) diff --git a/lib/models/searchs/SoftSelect.py b/lib/models/searchs/SoftSelect.py index a2132d4..84f2ad8 100644 --- a/lib/models/searchs/SoftSelect.py +++ b/lib/models/searchs/SoftSelect.py @@ -10,6 +10,7 @@ def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): while True: # a trick to avoid the gumbels bug gumbels = -torch.empty_like(logits).exponential_().log() new_logits = (logits + gumbels) / tau + #new_logits = (logits.log_softmax(dim=1) + gumbels) / tau probs = nn.functional.softmax(new_logits, dim=1) if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break diff --git a/scripts-search/search-cifar.sh b/scripts-search/search-cifar.sh index 4c99dfa..c4d83a0 100644 --- a/scripts-search/search-cifar.sh +++ b/scripts-search/search-cifar.sh @@ -24,23 +24,11 @@ gumbel_max=5 expected_FLOP_ratio=$4 rseed=$5 -PY_C="./env/bin/python" -if [ ! -f ${PY_C} ]; then - echo "Local Run with Python: "`which python` - PY_C="python" - SAVE_ROOT="./output" -else - echo "Cluster Run with Python: "${PY_C} - SAVE_ROOT="./hadoop-data/TAS-checkpoints" - mkdir -p $TORCH_HOME/TAS-checkpoints/ - cp -r ./hadoop-data/TAS-checkpoints/basemodels $TORCH_HOME/TAS-checkpoints/ -fi +save_dir=./output/search-shape/${dataset}-${model}-${optim}-Gumbel_${gumbel_min}_${gumbel_max}-${expected_FLOP_ratio} -save_dir=${SAVE_ROOT}/search-shape/${dataset}-${model}-${optim}-Gumbel_${gumbel_min}_${gumbel_max}-${expected_FLOP_ratio} +python --version -${PY_C} --version - -${PY_C} ./exps/search-transformable.py --dataset ${dataset} \ +OMP_NUM_THREADS=4 python ./exps/search-transformable.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ./configs/archs/CIFAR-${model}.config \ --split_path ./.latent-data/splits/${dataset}-0.5.pth \ @@ -60,7 +48,7 @@ if [ "$rseed" = "-1" ]; then else # normal training xsave_dir=${save_dir}/seed-${rseed}-NMT - ${PY_C} ./exps/basic-main.py --dataset ${dataset} \ + OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ${save_dir}/seed-${rseed}-last.config \ --optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \ @@ -71,7 +59,7 @@ else --eval_frequency 1 --print_freq 100 --print_freq_eval 200 # KD training xsave_dir=${save_dir}/seed-${rseed}-KDT - ${PY_C} ./exps/KD-main.py --dataset ${dataset} \ + OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ${save_dir}/seed-${rseed}-last.config \ --optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \ diff --git a/scripts-search/search-depth-cifar.sh b/scripts-search/search-depth-cifar.sh index bb982b9..74e9dff 100644 --- a/scripts-search/search-depth-cifar.sh +++ b/scripts-search/search-depth-cifar.sh @@ -32,7 +32,7 @@ save_dir=${SAVE_ROOT}/search-depth/${dataset}-${model}-${optim}-Gumbel_${gumbel_ python --version -python ./exps/search-shape.py --dataset ${dataset} \ +OMP_NUM_THREADS=4 python ./exps/search-shape.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ./configs/archs/CIFAR-${model}.config \ --split_path ./.latent-data/splits/${dataset}-0.5.pth \ @@ -53,7 +53,7 @@ if [ "$rseed" = "-1" ]; then else # normal training xsave_dir=${save_dir}/seed-${rseed}-NMT - python ./exps/basic-main.py --dataset ${dataset} \ + OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ${save_dir}/seed-${rseed}-last.config \ --optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \ @@ -64,7 +64,7 @@ else --eval_frequency 1 --print_freq 100 --print_freq_eval 200 # KD training xsave_dir=${save_dir}/seed-${rseed}-KDT - python ./exps/KD-main.py --dataset ${dataset} \ + OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ${save_dir}/seed-${rseed}-last.config \ --optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \ diff --git a/scripts-search/search-width-cifar.sh b/scripts-search/search-width-cifar.sh index c156893..480d384 100644 --- a/scripts-search/search-width-cifar.sh +++ b/scripts-search/search-width-cifar.sh @@ -32,7 +32,7 @@ save_dir=${SAVE_ROOT}/search-width/${dataset}-${model}-${optim}-Gumbel_${gumbel_ python --version -python ./exps/search-shape.py --dataset ${dataset} \ +OMP_NUM_THREADS=4 python ./exps/search-shape.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ./configs/archs/CIFAR-${model}.config \ --split_path ./.latent-data/splits/${dataset}-0.5.pth \ @@ -53,7 +53,7 @@ if [ "$rseed" = "-1" ]; then else # normal training xsave_dir=${save_dir}/seed-${rseed}-NMT - python ./exps/basic-main.py --dataset ${dataset} \ + OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ${save_dir}/seed-${rseed}-last.config \ --optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \ @@ -64,7 +64,7 @@ else --eval_frequency 1 --print_freq 100 --print_freq_eval 200 # KD training xsave_dir=${save_dir}/seed-${rseed}-KDT - python ./exps/KD-main.py --dataset ${dataset} \ + OMP_NUM_THREADS=4 python ./exps/KD-main.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ${save_dir}/seed-${rseed}-last.config \ --optim_config ./configs/opts/CIFAR-E300-W5-L1-COS.config \ diff --git a/scripts/base-train.sh b/scripts/base-train.sh index 521dd02..abde051 100644 --- a/scripts/base-train.sh +++ b/scripts/base-train.sh @@ -28,7 +28,7 @@ save_dir=${SAVE_ROOT}/basic/${dataset}/${model}-${epoch}-${LR}-${batch} python --version -python ./exps/basic-main.py --dataset ${dataset} \ +OMP_NUM_THREADS=4 python ./exps/basic-main.py --dataset ${dataset} \ --data_path $TORCH_HOME/cifar.python \ --model_config ./configs/archs/CIFAR-${model}.config \ --optim_config ./configs/opts/CIFAR-${epoch}-W5-${LR}-COS.config \