From 8d0799dfb168d4410d71c889207b95b17d2ea511 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Sun, 20 Mar 2022 23:12:12 -0700 Subject: [PATCH] To answer issue #119 --- exps/NATS-algos/search-cell.py | 17 ++++++++++++++--- xautodl/models/cell_searchs/generic_model.py | 4 ++++ xautodl/models/cell_searchs/search_cells.py | 18 +++++++++++++++++- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/exps/NATS-algos/search-cell.py b/exps/NATS-algos/search-cell.py index b1632fb..66842a8 100644 --- a/exps/NATS-algos/search-cell.py +++ b/exps/NATS-algos/search-cell.py @@ -24,6 +24,9 @@ # python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 # python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 # python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777 +#### +# The following scripts are added in 20 Mar 2022 +# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777 ###################################################################################### import os, sys, time, random, argparse import numpy as np @@ -166,6 +169,8 @@ def search_func( network.set_cal_mode("dynamic", sampled_arch) elif algo == "gdas": network.set_cal_mode("gdas", None) + elif algo == "gdas_v1": + network.set_cal_mode("gdas_v1", None) elif algo.startswith("darts"): network.set_cal_mode("joint", None) elif algo == "random": @@ -196,6 +201,8 @@ def search_func( network.set_cal_mode("joint") elif algo == "gdas": network.set_cal_mode("gdas", None) + elif algo == "gdas_v1": + network.set_cal_mode("gdas_v1", None) elif algo.startswith("darts"): network.set_cal_mode("joint", None) elif algo == "random": @@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo): archs, valid_accs = network.return_topK(n_samples, True), [] elif algo == "setn": archs, valid_accs = network.return_topK(n_samples, False), [] - elif algo.startswith("darts") or algo == "gdas": + elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1": arch = network.genotype archs, valid_accs = [arch], [] elif algo == "enas": @@ -568,7 +575,7 @@ def main(xargs): ) network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate) - if xargs.algo == "gdas": + if xargs.algo == "gdas" or xargs.algo == "gdas_v1": network.set_tau( xargs.tau_max - (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1) @@ -632,6 +639,8 @@ def main(xargs): network.set_cal_mode("dynamic", genotype) elif xargs.algo == "gdas": network.set_cal_mode("gdas", None) + elif xargs.algo == "gdas_v1": + network.set_cal_mode("gdas_v1", None) elif xargs.algo.startswith("darts"): network.set_cal_mode("joint", None) elif xargs.algo == "random": @@ -699,6 +708,8 @@ def main(xargs): network.set_cal_mode("dynamic", genotype) elif xargs.algo == "gdas": network.set_cal_mode("gdas", None) + elif xargs.algo == "gdas_v1": + network.set_cal_mode("gdas_v1", None) elif xargs.algo.startswith("darts"): network.set_cal_mode("joint", None) elif xargs.algo == "random": @@ -747,7 +758,7 @@ if __name__ == "__main__": parser.add_argument( "--algo", type=str, - choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"], + choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"], help="The search space name.", ) parser.add_argument( diff --git a/xautodl/models/cell_searchs/generic_model.py b/xautodl/models/cell_searchs/generic_model.py index ad0cd30..bbbbb1f 100644 --- a/xautodl/models/cell_searchs/generic_model.py +++ b/xautodl/models/cell_searchs/generic_model.py @@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module): feature = cell.forward_gdas(feature, alphas, index) if self.verbose: verbose_str += "-forward_gdas" + elif self.mode == "gdas_v1": + feature = cell.forward_gdas_v1(feature, alphas, index) + if self.verbose: + verbose_str += "-forward_gdas_v1" else: raise ValueError("invalid mode={:}".format(self.mode)) else: diff --git a/xautodl/models/cell_searchs/search_cells.py b/xautodl/models/cell_searchs/search_cells.py index 9235823..6be7c52 100644 --- a/xautodl/models/cell_searchs/search_cells.py +++ b/xautodl/models/cell_searchs/search_cells.py @@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module): nodes.append(sum(inter_nodes)) return nodes[-1] + # GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119 + def forward_gdas_v1(self, inputs, hardwts, index): + nodes = [inputs] + for i in range(1, self.max_nodes): + inter_nodes = [] + for j in range(i): + node_str = "{:}<-{:}".format(i, j) + weights = hardwts[self.edge2index[node_str]] + argmaxs = index[self.edge2index[node_str]].item() + weigsum = weights[argmaxs] * self.edges[node_str](nodes[j]) + inter_nodes.append(weigsum) + nodes.append(sum(inter_nodes)) + return nodes[-1] + # joint def forward_joint(self, inputs, weightss): nodes = [inputs] @@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module): return nodes[-1] +# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 + + class MixedOp(nn.Module): def __init__(self, space, C, stride, affine, track_running_stats): super(MixedOp, self).__init__() @@ -167,7 +184,6 @@ class MixedOp(nn.Module): return sum(w * op(x) for w, op in zip(weights, self._ops)) -# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018 class NASNetSearchCell(nn.Module): def __init__( self,