update scripts
This commit is contained in:
		| @@ -10,6 +10,7 @@ def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): | ||||
|     while True: # a trick to avoid the gumbels bug | ||||
|       gumbels = -torch.empty_like(logits).exponential_().log() | ||||
|       new_logits = (logits + gumbels) / tau | ||||
|       #new_logits = (logits.log_softmax(dim=1) + gumbels) / tau | ||||
|       probs = nn.functional.softmax(new_logits, dim=1) | ||||
|       if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user