To answer issue #119

This commit is contained in:
D-X-Y 2022-03-20 23:12:12 -07:00
parent d2cef525f3
commit 8d0799dfb1
3 changed files with 35 additions and 4 deletions

View File

@ -24,6 +24,9 @@
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset cifar100 --data_path $TORCH_HOME/cifar.python --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
# python ./exps/NATS-algos/search-cell.py --dataset ImageNet16-120 --data_path $TORCH_HOME/cifar.python/ImageNet16 --algo enas --arch_weight_decay 0 --arch_learning_rate 0.001 --arch_eps 0.001 --rand_seed 777
####
# The following scripts are added in 20 Mar 2022
# python ./exps/NATS-algos/search-cell.py --dataset cifar10 --data_path $TORCH_HOME/cifar.python --algo gdas_v1 --rand_seed 777
######################################################################################
import os, sys, time, random, argparse
import numpy as np
@ -166,6 +169,8 @@ def search_func(
network.set_cal_mode("dynamic", sampled_arch)
elif algo == "gdas":
network.set_cal_mode("gdas", None)
elif algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif algo == "random":
@ -196,6 +201,8 @@ def search_func(
network.set_cal_mode("joint")
elif algo == "gdas":
network.set_cal_mode("gdas", None)
elif algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif algo == "random":
@ -373,7 +380,7 @@ def get_best_arch(xloader, network, n_samples, algo):
archs, valid_accs = network.return_topK(n_samples, True), []
elif algo == "setn":
archs, valid_accs = network.return_topK(n_samples, False), []
elif algo.startswith("darts") or algo == "gdas":
elif algo.startswith("darts") or algo == "gdas" or algo == "gdas_v1":
arch = network.genotype
archs, valid_accs = [arch], []
elif algo == "enas":
@ -568,7 +575,7 @@ def main(xargs):
)
network.set_drop_path(float(epoch + 1) / total_epoch, xargs.drop_path_rate)
if xargs.algo == "gdas":
if xargs.algo == "gdas" or xargs.algo == "gdas_v1":
network.set_tau(
xargs.tau_max
- (xargs.tau_max - xargs.tau_min) * epoch / (total_epoch - 1)
@ -632,6 +639,8 @@ def main(xargs):
network.set_cal_mode("dynamic", genotype)
elif xargs.algo == "gdas":
network.set_cal_mode("gdas", None)
elif xargs.algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif xargs.algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif xargs.algo == "random":
@ -699,6 +708,8 @@ def main(xargs):
network.set_cal_mode("dynamic", genotype)
elif xargs.algo == "gdas":
network.set_cal_mode("gdas", None)
elif xargs.algo == "gdas_v1":
network.set_cal_mode("gdas_v1", None)
elif xargs.algo.startswith("darts"):
network.set_cal_mode("joint", None)
elif xargs.algo == "random":
@ -747,7 +758,7 @@ if __name__ == "__main__":
parser.add_argument(
"--algo",
type=str,
choices=["darts-v1", "darts-v2", "gdas", "setn", "random", "enas"],
choices=["darts-v1", "darts-v2", "gdas", "gdas_v1", "setn", "random", "enas"],
help="The search space name.",
)
parser.add_argument(

View File

@ -347,6 +347,10 @@ class GenericNAS201Model(nn.Module):
feature = cell.forward_gdas(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas"
elif self.mode == "gdas_v1":
feature = cell.forward_gdas_v1(feature, alphas, index)
if self.verbose:
verbose_str += "-forward_gdas_v1"
else:
raise ValueError("invalid mode={:}".format(self.mode))
else:

View File

@ -85,6 +85,20 @@ class NAS201SearchCell(nn.Module):
nodes.append(sum(inter_nodes))
return nodes[-1]
# GDAS Variant: https://github.com/D-X-Y/AutoDL-Projects/issues/119
def forward_gdas_v1(self, inputs, hardwts, index):
nodes = [inputs]
for i in range(1, self.max_nodes):
inter_nodes = []
for j in range(i):
node_str = "{:}<-{:}".format(i, j)
weights = hardwts[self.edge2index[node_str]]
argmaxs = index[self.edge2index[node_str]].item()
weigsum = weights[argmaxs] * self.edges[node_str](nodes[j])
inter_nodes.append(weigsum)
nodes.append(sum(inter_nodes))
return nodes[-1]
# joint
def forward_joint(self, inputs, weightss):
nodes = [inputs]
@ -152,6 +166,9 @@ class NAS201SearchCell(nn.Module):
return nodes[-1]
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class MixedOp(nn.Module):
def __init__(self, space, C, stride, affine, track_running_stats):
super(MixedOp, self).__init__()
@ -167,7 +184,6 @@ class MixedOp(nn.Module):
return sum(w * op(x) for w, op in zip(weights, self._ops))
# Learning Transferable Architectures for Scalable Image Recognition, CVPR 2018
class NASNetSearchCell(nn.Module):
def __init__(
self,