update
This commit is contained in:
parent
fd43e67da1
commit
5a1dc89756
@ -96,12 +96,11 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_
|
||||
|
||||
model.candidate_flags[cell_type][selected_eid] = False
|
||||
# print(model.get_projected_weights())
|
||||
else:
|
||||
measures = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[proj_crit])
|
||||
measures = predictive.find_measures(model,
|
||||
proj_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[proj_crit])
|
||||
|
||||
# print(measures)
|
||||
for idx in range(num_ops):
|
||||
|
@ -223,12 +223,11 @@ def main():
|
||||
else:
|
||||
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
|
||||
network.requires_feature = False
|
||||
else:
|
||||
measures = predictive.find_measures(network,
|
||||
train_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[args.proj_crit])
|
||||
measures = predictive.find_measures(network,
|
||||
train_queue,
|
||||
('random', 1, n_classes),
|
||||
torch.device("cuda"),
|
||||
measure_names=[args.proj_crit])
|
||||
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user