update
This commit is contained in:
parent
fd43e67da1
commit
5a1dc89756
@ -96,12 +96,11 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_
|
|||||||
|
|
||||||
model.candidate_flags[cell_type][selected_eid] = False
|
model.candidate_flags[cell_type][selected_eid] = False
|
||||||
# print(model.get_projected_weights())
|
# print(model.get_projected_weights())
|
||||||
else:
|
measures = predictive.find_measures(model,
|
||||||
measures = predictive.find_measures(model,
|
proj_queue,
|
||||||
proj_queue,
|
('random', 1, n_classes),
|
||||||
('random', 1, n_classes),
|
torch.device("cuda"),
|
||||||
torch.device("cuda"),
|
measure_names=[proj_crit])
|
||||||
measure_names=[proj_crit])
|
|
||||||
|
|
||||||
# print(measures)
|
# print(measures)
|
||||||
for idx in range(num_ops):
|
for idx in range(num_ops):
|
||||||
|
@ -223,12 +223,11 @@ def main():
|
|||||||
else:
|
else:
|
||||||
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
|
#score = score_loop(network, None, train_queue, args.gpu, None, args.proj_crit)
|
||||||
network.requires_feature = False
|
network.requires_feature = False
|
||||||
else:
|
measures = predictive.find_measures(network,
|
||||||
measures = predictive.find_measures(network,
|
train_queue,
|
||||||
train_queue,
|
('random', 1, n_classes),
|
||||||
('random', 1, n_classes),
|
torch.device("cuda"),
|
||||||
torch.device("cuda"),
|
measure_names=[args.proj_crit])
|
||||||
measure_names=[args.proj_crit])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user