MeCo/nasbench201/init_projection.py

619 lines
22 KiB
Python
Raw Normal View History

2023-05-04 07:42:06 +02:00
import os
import sys
import numpy as np
import torch
import torch.nn.functional as f
sys.path.insert(0, '../')
import nasbench201.utils as ig_utils
import logging
import torch.utils
import copy
import scipy.stats as ss
from collections import OrderedDict
from foresight.pruners import *
from op_score import Jocab_Score, get_ntk_n
import gc
from nasbench201.linear_region import Linear_Region_Collector
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)
# global-edge-iter: similar toglobal-op-iterbut iteratively selects edge e from E based on the average score of all operations on each edge
def global_op_greedy_pt_project(proj_queue, model, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
best_eid = eid
logging.info('best opid %d', best_opid)
return best_eid, best_opid
tune_epochs = model.arch_parameters()[0].shape[0]
for epoch in range(tune_epochs):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
# global-edge-iter: similar toglobal-op-oncebut uses the average score of operations on edges to obtain the edge discretization order
def global_edge_greedy_pt_project(proj_queue, model, args):
def select_eid(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
eid_score = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
eid_score.append(crit)
eid_score = np.mean(eid_score)
if crit_extrema is None or compare(eid_score, crit_extrema):
crit_extrema = eid_score
best_eid = eid
return best_eid
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid = select_eid(model, args)
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# global-op-once: only evaluates S(A(e,o)) for all operations once to obtain a ranking order of the operations, and discretizes the edgesEaccording to this order
def global_op_once_pt_project(proj_queue, model, args):
def order(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
edge_score = OrderedDict()
input, target = next(iter(proj_queue))
for eid in remain_eids:
crit_list = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
crit_list.append(crit)
edge_score[eid] = np.nanargmin(crit_list)
return edge_score
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
eid_order = order(model, args)
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, _ = eid_order.popitem()
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# global-edge-once: similar toglobal-op-oncebut uses the average score of operations on dges to obtain the edge discretization order
def global_edge_once_pt_project(proj_queue, model, args):
def order(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
edge_score = OrderedDict()
crit_extrema = None
best_eid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
crit_list = []
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
edge_score[eid] = np.mean(crit_list)
return edge_score
def project(model, args, selected_eid):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
eid_order = order(model, args)
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, _ = eid_order.popitem()
selected_eid, best_opid = project(model, args, selected_eid)
model.project_op(selected_eid, best_opid)
return
# fixed [reverse, order]: discretizes the edges in a fixed order, where in our experiments we discretize from the222input towards the output of the cell struct
# random: discretizes the edges in a random order (DARTS-PT)
# NOTE: Only this methods allows use other zero-cost proxy metrics
def pt_project(proj_queue, model, args):
def project(model, args):
## macros,一共6条边每条边有5个操作
num_edge, num_op = model.num_edge, model.num_op
## select an edge
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
# print('candidate_flags:', model.candidate_flags)
# print(model.candidate_flags)
# 选边的方法
if args.edge_decision == "random":
# 选出来了一个数组,取其中的一个元素
selected_eid = np.random.choice(remain_eids, size=1)[0]
elif args.edge_decision == "reverse":
selected_eid = remain_eids[-1]
else:
selected_eid = remain_eids[0]
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
# print(selected_eid, weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
else:
cache_weight = model.proj_weights[selected_eid]
cache_flag = model.candidate_flags[selected_eid]
for idx in range(num_op):
if idx == opid:
model.proj_weights[selected_eid][opid] = 0
else:
model.proj_weights[selected_eid][idx] = 1.0/num_op
model.candidate_flags[selected_eid] = False
# print(model.get_projected_weights())
if args.proj_crit == 'comb':
synflow = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['synflow'])
var = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=['var'])
# print(synflow, var)
comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1)
measures = {'comb': comb}
else:
measures = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])
# print(measures)
for idx in range(num_op):
model.proj_weights[selected_eid][idx] = 0
model.candidate_flags[selected_eid] = cache_flag
crit = measures[args.proj_crit]
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
# best_opid = op_ids[np.nanargmax(crit_list)]
logging.info('best opid %d', best_opid)
logging.info('current edge id %d', selected_eid)
logging.info(crit_list)
return selected_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
def tenas_project(proj_queue, model, model_thin, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
ntks = []
lrs = []
edge_op_id = []
best_eid = None
if args.proj_crit == 'tenas':
lrc_model = Linear_Region_Collector(input_size=(1000, 1, 3, 3), sample_batch=3, dataset=args.dataset, data_path=args.data, seed=args.seed)
for eid in remain_eids:
for opid in range(num_op):
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'tenas':
lrc_model.reinit(ori_models=[model_thin], seed=args.seed, weights=weights)
lr = lrc_model.forward_batch_sample()
lrc_model.clear()
ntk = get_ntk_n(proj_queue, [model], recalbn=0, train_mode=True, num_batch=1, weights=weights)
ntks.append(ntk)
lrs.append(lr)
edge_op_id.append('{}:{}'.format(eid, opid))
print('ntls', ntks)
print('lrs', lrs)
ntks_ranks = ss.rankdata(ntks)
lrs_ranks = ss.rankdata(lrs)
ntks_ranks = len(ntks_ranks) - ntks_ranks.astype(int)
op_ranks = []
for i in range(len(edge_op_id)):
op_ranks.append(ntks_ranks[i]+lrs_ranks[i])
best_op_index = edge_op_id[np.nanargmin(op_ranks[0:num_op])]
best_eid, best_opid = [int(x) for x in best_op_index.split(':')]
logging.info(op_ranks)
logging.info('best eid %d', best_eid)
logging.info('best opid %d', best_opid)
return best_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
return
#new methods
#Randomly propose candidate of networks and transfer it to supernet, then perform global op selection in this subspace
def shrink_pt_project(proj_queue, model, args):
def project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
## select an edge
remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
selected_eid = np.random.choice(remain_eids, size=1)[0]
## select the best operation
if args.proj_crit == 'jacob':
crit_idx = 3
compare = lambda x, y: x < y
else:
crit_idx = 4
compare = lambda x, y: x < y
if args.dataset == 'cifar100':
n_classes = 100
elif args.dataset == 'imagenet16-120':
n_classes = 120
else:
n_classes = 10
best_opid = 0
crit_extrema = None
crit_list = []
op_ids = []
input, target = next(iter(proj_queue))
for opid in range(num_op):
## projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[selected_eid])
proj_mask[opid] = 0
weights[selected_eid] = weights[selected_eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
crit = Jocab_Score(model, input, target, weights=weights)
else:
cache_weight = model.proj_weights[selected_eid]
cache_flag = model.candidate_flags[selected_eid]
for idx in range(num_op):
if idx == opid:
model.proj_weights[selected_eid][opid] = 0
else:
model.proj_weights[selected_eid][idx] = 1.0/num_op
model.candidate_flags[selected_eid] = False
measures = predictive.find_measures(model,
train_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])
for idx in range(num_op):
model.proj_weights[selected_eid][idx] = 0
model.candidate_flags[selected_eid] = cache_flag
crit = measures[args.proj_crit]
crit_list.append(crit)
op_ids.append(opid)
best_opid = op_ids[np.nanargmin(crit_list)]
logging.info('best opid %d', best_opid)
logging.info('current edge id %d', selected_eid)
logging.info(crit_list)
return selected_eid, best_opid
def global_project(model, args):
## macros
num_edge, num_op = model.num_edge, model.num_op
##get remain eid numbers
remain_eids = torch.nonzero(model.subspace_candidate_flags).cpu().numpy().T[0]
compare = lambda x, y : x < y
crit_extrema = None
best_eid = None
best_opid = None
input, target = next(iter(proj_queue))
for eid in remain_eids:
remain_oids = torch.nonzero(model.proj_weights[eid]).cpu().numpy().T[0]
for opid in remain_oids:
# projection
weights = model.get_projected_weights()
proj_mask = torch.ones_like(weights[eid])
proj_mask[opid] = 0
weights[eid] = weights[eid] * proj_mask
## proj evaluation
if args.proj_crit == 'jacob':
valid_stats = Jocab_Score(model, input, target, weights=weights)
crit = valid_stats
if crit_extrema is None or compare(crit, crit_extrema):
crit_extrema = crit
best_opid = opid
best_eid = eid
logging.info('best eid %d', best_eid)
logging.info('best opid %d', best_opid)
model.subspace_candidate_flags[best_eid] = False
proj_mask = torch.zeros_like(model.proj_weights[best_eid])
model.proj_weights[best_eid] = model.proj_weights[best_eid] * proj_mask
model.proj_weights[best_eid][best_opid] = 1
return best_eid, best_opid
num_edges = model.arch_parameters()[0].shape[0]
#subspace
logging.info('Start subspace proposal')
subspace = copy.deepcopy(model.proj_weights)
for i in range(20):
model.reset_arch_parameters()
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = project(model, args)
model.project_op(selected_eid, best_opid)
subspace += model.proj_weights
model.reset_arch_parameters()
subspace = torch.gt(subspace, 0).int().float()
subspace = f.normalize(subspace, p=1, dim=1)
model.proj_weights += subspace
for i in range(num_edges):
model.candidate_flags[i] = False
logging.info('Start final search in subspace')
logging.info(subspace)
model.subspace_candidate_flags = torch.tensor(len(model._arch_parameters) * [True], requires_grad=False, dtype=torch.bool).cuda()
for epoch in range(num_edges):
logging.info('epoch %d', epoch)
logging.info('project')
selected_eid, best_opid = global_project(model, args)
model.printing(logging)
#model.project_op(selected_eid, best_opid)
return