update GDAS
This commit is contained in:
		| @@ -88,7 +88,9 @@ class TinyNetworkGDAS(nn.Module): | ||||
|       index   = probs.max(-1, keepdim=True)[1] | ||||
|       one_h   = torch.zeros_like(logits).scatter_(-1, index, 1.0) | ||||
|       hardwts = one_h - probs.detach() + probs | ||||
|       if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): continue | ||||
|       if (torch.isinf(gumbels).any()) or (torch.isinf(probs).any()) or (torch.isnan(probs).any()): | ||||
|         continue | ||||
|       else: break | ||||
|  | ||||
|     feature = self.stem(inputs) | ||||
|     for i, cell in enumerate(self.cells): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user