This commit is contained in:
HamsterMimi 2023-05-04 13:41:59 +08:00
parent fd43e67da1
commit 5a1dc89756
2 changed files with 10 additions and 12 deletions

View File

@ -96,12 +96,11 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_
model.candidate_flags[cell_type][selected_eid] = False
# print(model.get_projected_weights())
else:
measures = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[proj_crit])
measures = predictive.find_measures(model,
proj_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[proj_crit])
# print(measures)
for idx in range(num_ops):

View File

@ -223,12 +223,11 @@ def main():
else:
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
network.requires_feature = False
else:
measures = predictive.find_measures(network,
train_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])
measures = predictive.find_measures(network,
train_queue,
('random', 1, n_classes),
torch.device("cuda"),
measure_names=[args.proj_crit])