To answer issue #119
This commit is contained in:
		| @@ -24,6 +24,9 @@ | |||||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||||
| # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 | ||||||
|  | #### | ||||||
|  | # The following scripts are added in 20 Mar 2022 | ||||||
|  | # python ./exps/NATS-algos/search-cell.py --dataset cifar10  --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777 | ||||||
| ###################################################################################### | ###################################################################################### | ||||||
| import os, sys, time, random, argparse | import os, sys, time, random, argparse | ||||||
| import numpy as np | import numpy as np | ||||||
| @@ -166,6 +169,8 @@ def search_func( | |||||||
|             network.set_cal_mode("dynamic", sampled_arch) |             network.set_cal_mode("dynamic", sampled_arch) | ||||||
|         elif algo == "gdas": |         elif algo == "gdas": | ||||||
|             network.set_cal_mode("gdas", None) |             network.set_cal_mode("gdas", None) | ||||||
|  |         elif algo == "gdas_v1": | ||||||
|  |             network.set_cal_mode("gdas_v1", None) | ||||||
|         elif algo.startswith("darts"): |         elif algo.startswith("darts"): | ||||||
|             network.set_cal_mode("joint", None) |             network.set_cal_mode("joint", None) | ||||||
|         elif algo == "random": |         elif algo == "random": | ||||||
| @@ -196,6 +201,8 @@ def search_func( | |||||||
|             network.set_cal_mode("joint") |             network.set_cal_mode("joint") | ||||||
|         elif algo == "gdas": |         elif algo == "gdas": | ||||||
|             network.set_cal_mode("gdas", None) |             network.set_cal_mode("gdas", None) | ||||||
|  |         elif algo == "gdas_v1": | ||||||
|  |             network.set_cal_mode("gdas_v1", None) | ||||||
|         elif algo.startswith("darts"): |         elif algo.startswith("darts"): | ||||||
|             network.set_cal_mode("joint", None) |             network.set_cal_mode("joint", None) | ||||||
|         elif algo == "random": |         elif algo == "random": | ||||||
| @@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo): | |||||||
|             archs, valid_accs = network.return_topK(n_samples, True), [] |             archs, valid_accs = network.return_topK(n_samples, True), [] | ||||||
|         elif algo == "setn": |         elif algo == "setn": | ||||||
|             archs, valid_accs = network.return_topK(n_samples, False), [] |             archs, valid_accs = network.return_topK(n_samples, False), [] | ||||||
|         elif algo.startswith("darts") or algo == "gdas": |         elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1": | ||||||
|             arch = network.genotype |             arch = network.genotype | ||||||
|             archs, valid_accs = [arch], [] |             archs, valid_accs = [arch], [] | ||||||
|         elif algo == "enas": |         elif algo == "enas": | ||||||
| @@ -568,7 +575,7 @@ def main(xargs): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|         network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) |         network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) | ||||||
|         if xargs.algo == "gdas": |         if xargs.algo == "gdas" or xargs.algo == "gdas_v1": | ||||||
|             network.set_tau( |             network.set_tau( | ||||||
|                 xargs.tau_max |                 xargs.tau_max | ||||||
|                 - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) |                 - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) | ||||||
| @@ -632,6 +639,8 @@ def main(xargs): | |||||||
|             network.set_cal_mode("dynamic", genotype) |             network.set_cal_mode("dynamic", genotype) | ||||||
|         elif xargs.algo == "gdas": |         elif xargs.algo == "gdas": | ||||||
|             network.set_cal_mode("gdas", None) |             network.set_cal_mode("gdas", None) | ||||||
|  |         elif xargs.algo == "gdas_v1": | ||||||
|  |             network.set_cal_mode("gdas_v1", None) | ||||||
|         elif xargs.algo.startswith("darts"): |         elif xargs.algo.startswith("darts"): | ||||||
|             network.set_cal_mode("joint", None) |             network.set_cal_mode("joint", None) | ||||||
|         elif xargs.algo == "random": |         elif xargs.algo == "random": | ||||||
| @@ -699,6 +708,8 @@ def main(xargs): | |||||||
|         network.set_cal_mode("dynamic", genotype) |         network.set_cal_mode("dynamic", genotype) | ||||||
|     elif xargs.algo == "gdas": |     elif xargs.algo == "gdas": | ||||||
|         network.set_cal_mode("gdas", None) |         network.set_cal_mode("gdas", None) | ||||||
|  |     elif xargs.algo == "gdas_v1": | ||||||
|  |         network.set_cal_mode("gdas_v1", None) | ||||||
|     elif xargs.algo.startswith("darts"): |     elif xargs.algo.startswith("darts"): | ||||||
|         network.set_cal_mode("joint", None) |         network.set_cal_mode("joint", None) | ||||||
|     elif xargs.algo == "random": |     elif xargs.algo == "random": | ||||||
| @@ -747,7 +758,7 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         "--algo", |         "--algo", | ||||||
|         type=str, |         type=str, | ||||||
|         choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"], |         choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"], | ||||||
|         help="The search space name.", |         help="The search space name.", | ||||||
|     ) |     ) | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|   | |||||||
| @@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module): | |||||||
|                     feature = cell.forward_gdas(feature, alphas, index) |                     feature = cell.forward_gdas(feature, alphas, index) | ||||||
|                     if self.verbose: |                     if self.verbose: | ||||||
|                         verbose_str += "-forward_gdas" |                         verbose_str += "-forward_gdas" | ||||||
|  |                 elif self.mode == "gdas_v1": | ||||||
|  |                     feature = cell.forward_gdas_v1(feature, alphas, index) | ||||||
|  |                     if self.verbose: | ||||||
|  |                         verbose_str += "-forward_gdas_v1" | ||||||
|                 else: |                 else: | ||||||
|                     raise ValueError("invalid mode={:}".format(self.mode)) |                     raise ValueError("invalid mode={:}".format(self.mode)) | ||||||
|             else: |             else: | ||||||
|   | |||||||
| @@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module): | |||||||
|             nodes.append(sum(inter_nodes)) |             nodes.append(sum(inter_nodes)) | ||||||
|         return nodes[-1] |         return nodes[-1] | ||||||
|  |  | ||||||
|  |     # GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119 | ||||||
|  |     def forward_gdas_v1(self, inputs, hardwts, index): | ||||||
|  |         nodes = [inputs] | ||||||
|  |         for i in range(1, self.max_nodes): | ||||||
|  |             inter_nodes = [] | ||||||
|  |             for j in range(i): | ||||||
|  |                 node_str = "{:}<-{:}".format(i, j) | ||||||
|  |                 weights = hardwts[self.edge2index[node_str]] | ||||||
|  |                 argmaxs = index[self.edge2index[node_str]].item() | ||||||
|  |                 weigsum = weights[argmaxs] * self.edges[node_str](nodes[j]) | ||||||
|  |                 inter_nodes.append(weigsum) | ||||||
|  |             nodes.append(sum(inter_nodes)) | ||||||
|  |         return nodes[-1] | ||||||
|  |  | ||||||
|     # joint |     # joint | ||||||
|     def forward_joint(self, inputs, weightss): |     def forward_joint(self, inputs, weightss): | ||||||
|         nodes = [inputs] |         nodes = [inputs] | ||||||
| @@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module): | |||||||
|         return nodes[-1] |         return nodes[-1] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 | ||||||
|  |  | ||||||
|  |  | ||||||
| class MixedOp(nn.Module): | class MixedOp(nn.Module): | ||||||
|     def __init__(self, space, C, stride, affine, track_running_stats): |     def __init__(self, space, C, stride, affine, track_running_stats): | ||||||
|         super(MixedOp, self).__init__() |         super(MixedOp, self).__init__() | ||||||
| @@ -167,7 +184,6 @@ class MixedOp(nn.Module): | |||||||
|         return sum(w * op(x) for w, op in zip(weights, self._ops)) |         return sum(w * op(x) for w, op in zip(weights, self._ops)) | ||||||
|  |  | ||||||
|  |  | ||||||
| # Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 |  | ||||||
| class NASNetSearchCell(nn.Module): | class NASNetSearchCell(nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, |         self, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user