This commit is contained in:
HamsterMimi 2023-05-04 13:41:59 +08:00
parent fd43e67da1
commit 5a1dc89756
2 changed files with 10 additions and 12 deletions

View File

@ -96,12 +96,11 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_
model.candidate_flags[cell_type][selected_eid] = False model.candidate_flags[cell_type][selected_eid] = False
# print(model.get_projected_weights()) # print(model.get_projected_weights())
else: measures = predictive.find_measures(model,
measures = predictive.find_measures(model, proj_queue,
proj_queue, ('random', 1, n_classes),
('random', 1, n_classes), torch.device("cuda"),
torch.device("cuda"), measure_names=[proj_crit])
measure_names=[proj_crit])
# print(measures) # print(measures)
for idx in range(num_ops): for idx in range(num_ops):

View File

@ -223,12 +223,11 @@ def main():
else: else:
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit) #score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
network.requires_feature = False network.requires_feature = False
else: measures = predictive.find_measures(network,
measures = predictive.find_measures(network, train_queue,
train_queue, ('random', 1, n_classes),
('random', 1, n_classes), torch.device("cuda"),
torch.device("cuda"), measure_names=[args.proj_crit])
measure_names=[args.proj_crit])