35 lines
828 B
Python
35 lines
828 B
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
|
#####################################################
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
def additive_func(A, B):
|
|
assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
|
|
C = min(A.size(1), B.size(1))
|
|
if A.size(1) == B.size(1):
|
|
return A + B
|
|
elif A.size(1) < B.size(1):
|
|
out = B.clone()
|
|
out[:,:C] += A
|
|
return out
|
|
else:
|
|
out = A.clone()
|
|
out[:,:C] += B
|
|
return out
|
|
|
|
|
|
def change_key(key, value):
|
|
def func(m):
|
|
if hasattr(m, key):
|
|
setattr(m, key, value)
|
|
return func
|
|
|
|
|
|
def parse_channel_info(xstring):
|
|
blocks = xstring.split(' ')
|
|
blocks = [x.split('-') for x in blocks]
|
|
blocks = [[int(_) for _ in x] for x in blocks]
|
|
return blocks
|