import torch
import torch.nn as nn


def additive_func(A, B):
  assert A.dim() == B.dim() and A.size(0) == B.size(0), '{:} vs {:}'.format(A.size(), B.size())
  C = min(A.size(1), B.size(1))
  if A.size(1) == B.size(1):
    return A + B
  elif A.size(1) < B.size(1):
    out = B.clone()
    out[:,:C] += A
    return out
  else:
    out = A.clone()
    out[:,:C] += B
    return out


def change_key(key, value):
  def func(m):
    if hasattr(m, key):
      setattr(m, key, value)
  return func


def parse_channel_info(xstring):
  blocks = xstring.split(' ')
  blocks = [x.split('-') for x in blocks]
  blocks = [[int(_) for _ in x] for x in blocks]
  return blocks