##################################################### # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # ##################################################### import torch import torch.nn as nn def additive_func(A, B): assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size()) C = min(A.size(1), B.size(1)) if A.size(1) == B.size(1): return A + B elif A.size(1) < B.size(1): out = B.clone() out[:,:C] += A return out else: out = A.clone() out[:,:C] += B return out def change_key(key, value): def func(m): if hasattr(m, key): setattr(m, key, value) return func def parse_channel_info(xstring): blocks = xstring.split(' ') blocks = [x.split('-') for x in blocks] blocks = [[int(_) for _ in x] for x in blocks] return blocks