################################################## # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # ################################################## import math, torch import torch.nn as nn def select2withP(logits, tau, just_prob=False, num=2, eps=1e-7): if tau <= 0: new_logits = logits probs = nn.functional.softmax(new_logits, dim=1) else : while True: # a trick to avoid the gumbels bug gumbels = -torch.empty_like(logits).exponential_().log() new_logits = (logits.log_softmax(dim=1) + gumbels) / tau probs = nn.functional.softmax(new_logits, dim=1) if (not torch.isinf(gumbels).any()) and (not torch.isinf(probs).any()) and (not torch.isnan(probs).any()): break if just_prob: return probs #with torch.no_grad(): # add eps for unexpected torch error # probs = nn.functional.softmax(new_logits, dim=1) # selected_index = torch.multinomial(probs + eps, 2, False) with torch.no_grad(): # add eps for unexpected torch error probs = probs.cpu() selected_index = torch.multinomial(probs + eps, num, False).to(logits.device) selected_logit = torch.gather(new_logits, 1, selected_index) selcted_probs = nn.functional.softmax(selected_logit, dim=1) return selected_index, selcted_probs def ChannelWiseInter(inputs, oC, mode='v2'): if mode == 'v1': return ChannelWiseInterV1(inputs, oC) elif mode == 'v2': return ChannelWiseInterV2(inputs, oC) else: raise ValueError('invalid mode : {:}'.format(mode)) def ChannelWiseInterV1(inputs, oC): assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) def start_index(a, b, c): return int( math.floor(float(a * c) / b) ) def end_index(a, b, c): return int( math.ceil(float((a + 1) * c) / b) ) batch, iC, H, W = inputs.size() outputs = torch.zeros((batch, oC, H, W), dtype=inputs.dtype, device=inputs.device) if iC == oC: return inputs for ot in range(oC): istartT, iendT = start_index(ot, oC, iC), end_index(ot, oC, iC) values = inputs[:, istartT:iendT].mean(dim=1) outputs[:, ot, :, :] = values return outputs def ChannelWiseInterV2(inputs, oC): assert inputs.dim() == 4, 'invalid dimension : {:}'.format(inputs.size()) batch, C, H, W = inputs.size() if C == oC: return inputs else : return nn.functional.adaptive_avg_pool3d(inputs, (oC,H,W)) #inputs_5D = inputs.view(batch, 1, C, H, W) #otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'area', None) #otputs = otputs_5D.view(batch, oC, H, W) #otputs_5D = nn.functional.interpolate(inputs_5D, (oC,H,W), None, 'trilinear', False) #return otputs def linear_forward(inputs, linear): if linear is None: return inputs iC = inputs.size(1) weight = linear.weight[:, :iC] if linear.bias is None: bias = None else : bias = linear.bias return nn.functional.linear(inputs, weight, bias) def get_width_choices(nOut): xsrange = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] if nOut is None: return len(xsrange) else: Xs = [int(nOut * i) for i in xsrange] #xs = [ int(nOut * i // 10) for i in range(2, 11)] #Xs = [x for i, x in enumerate(xs) if i+1 == len(xs) or xs[i+1] > x+1] Xs = sorted( list( set(Xs) ) ) return tuple(Xs) def get_depth_choices(nDepth): if nDepth is None: return 3 else: assert nDepth >= 3, 'nDepth should be greater than 2 vs {:}'.format(nDepth) if nDepth == 1 : return (1, 1, 1) elif nDepth == 2: return (1, 1, 2) elif nDepth >= 3: return (nDepth//3, nDepth*2//3, nDepth) else: raise ValueError('invalid Depth : {:}'.format(nDepth)) def drop_path(x, drop_prob): if drop_prob > 0.: keep_prob = 1. - drop_prob mask = x.new_zeros(x.size(0), 1, 1, 1) mask = mask.bernoulli_(keep_prob) x = x * (mask / keep_prob) #x.div_(keep_prob) #x.mul_(mask) return x