update
This commit is contained in:
		| @@ -96,20 +96,6 @@ def project_op(model, input, target, args, cell_type, proj_queue=None, selected_ | ||||
|  | ||||
|             model.candidate_flags[cell_type][selected_eid] = False | ||||
|             # print(model.get_projected_weights()) | ||||
|             if proj_crit == 'comb': | ||||
|                 synflow = predictive.find_measures(model, | ||||
|                                                    proj_queue, | ||||
|                                                    ('random', 1, n_classes), | ||||
|                                                    torch.device("cuda"), | ||||
|                                                    measure_names=['synflow']) | ||||
|                 var = predictive.find_measures(model, | ||||
|                                                proj_queue, | ||||
|                                                ('random', 1, n_classes), | ||||
|                                                torch.device("cuda"), | ||||
|                                                measure_names=['var']) | ||||
|                 # print(synflow, var) | ||||
|                 comb = np.log(synflow['synflow'] + 1) / (var['var'] + 0.1) | ||||
|                 measures = {'comb': comb} | ||||
|             else: | ||||
|                 measures = predictive.find_measures(model, | ||||
|                                                     proj_queue, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user