To answer issue #119
This commit is contained in:
		| @@ -24,6 +24,9 @@ | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||
| #### | ||||
| # The following scripts are added in 20 Mar 2022 | ||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777 | ||||
| ###################################################################################### | ||||
| import os, sys, time, random, argparse | ||||
| import numpy as np | ||||
| @@ -166,6 +169,8 @@ def search_func( | ||||
|             network.set_cal_mode("dynamic", sampled_arch) | ||||
|         elif algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif algo == "random": | ||||
| @@ -196,6 +201,8 @@ def search_func( | ||||
|             network.set_cal_mode("joint") | ||||
|         elif algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif algo == "random": | ||||
| @@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo): | ||||
|             archs, valid_accs = network.return_topK(n_samples, True), [] | ||||
|         elif algo == "setn": | ||||
|             archs, valid_accs = network.return_topK(n_samples, False), [] | ||||
|         elif algo.startswith("darts") or algo == "gdas": | ||||
|         elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1": | ||||
|             arch = network.genotype | ||||
|             archs, valid_accs = [arch], [] | ||||
|         elif algo == "enas": | ||||
| @@ -568,7 +575,7 @@ def main(xargs): | ||||
|         ) | ||||
|  | ||||
|         network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) | ||||
|         if xargs.algo == "gdas": | ||||
|         if xargs.algo == "gdas" or xargs.algo == "gdas_v1": | ||||
|             network.set_tau( | ||||
|                 xargs.tau_max | ||||
|                 - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) | ||||
| @@ -632,6 +639,8 @@ def main(xargs): | ||||
|             network.set_cal_mode("dynamic", genotype) | ||||
|         elif xargs.algo == "gdas": | ||||
|             network.set_cal_mode("gdas", None) | ||||
|         elif xargs.algo == "gdas_v1": | ||||
|             network.set_cal_mode("gdas_v1", None) | ||||
|         elif xargs.algo.startswith("darts"): | ||||
|             network.set_cal_mode("joint", None) | ||||
|         elif xargs.algo == "random": | ||||
| @@ -699,6 +708,8 @@ def main(xargs): | ||||
|         network.set_cal_mode("dynamic", genotype) | ||||
|     elif xargs.algo == "gdas": | ||||
|         network.set_cal_mode("gdas", None) | ||||
|     elif xargs.algo == "gdas_v1": | ||||
|         network.set_cal_mode("gdas_v1", None) | ||||
|     elif xargs.algo.startswith("darts"): | ||||
|         network.set_cal_mode("joint", None) | ||||
|     elif xargs.algo == "random": | ||||
| @@ -747,7 +758,7 @@ if __name__ == "__main__": | ||||
|     parser.add_argument( | ||||
|         "--algo", | ||||
|         type=str, | ||||
|         choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"], | ||||
|         choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"], | ||||
|         help="The search space name.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user