import torch
import torch.nn as nn


def copy_conv(module, init):
    assert isinstance(module, nn.Conv2d), "invalid module : {:}".format(module)
    assert isinstance(init, nn.Conv2d), "invalid module : {:}".format(init)
    new_i, new_o = module.in_channels, module.out_channels
    module.weight.copy_(init.weight.detach()[:new_o, :new_i])
    if module.bias is not None:
        module.bias.copy_(init.bias.detach()[:new_o])


def copy_bn(module, init):
    assert isinstance(module, nn.BatchNorm2d), "invalid module : {:}".format(module)
    assert isinstance(init, nn.BatchNorm2d), "invalid module : {:}".format(init)
    num_features = module.num_features
    if module.weight is not None:
        module.weight.copy_(init.weight.detach()[:num_features])
    if module.bias is not None:
        module.bias.copy_(init.bias.detach()[:num_features])
    if module.running_mean is not None:
        module.running_mean.copy_(init.running_mean.detach()[:num_features])
    if module.running_var is not None:
        module.running_var.copy_(init.running_var.detach()[:num_features])


def copy_fc(module, init):
    assert isinstance(module, nn.Linear), "invalid module : {:}".format(module)
    assert isinstance(init, nn.Linear), "invalid module : {:}".format(init)
    new_i, new_o = module.in_features, module.out_features
    module.weight.copy_(init.weight.detach()[:new_o, :new_i])
    if module.bias is not None:
        module.bias.copy_(init.bias.detach()[:new_o])


def copy_base(module, init):
    assert type(module).__name__ in [
        "ConvBNReLU",
        "Downsample",
    ], "invalid module : {:}".format(module)
    assert type(init).__name__ in [
        "ConvBNReLU",
        "Downsample",
    ], "invalid module : {:}".format(init)
    if module.conv is not None:
        copy_conv(module.conv, init.conv)
    if module.bn is not None:
        copy_bn(module.bn, init.bn)


def copy_basic(module, init):
    copy_base(module.conv_a, init.conv_a)
    copy_base(module.conv_b, init.conv_b)
    if module.downsample is not None:
        if init.downsample is not None:
            copy_base(module.downsample, init.downsample)
        # else:
        # import pdb; pdb.set_trace()


def init_from_model(network, init_model):
    with torch.no_grad():
        copy_fc(network.classifier, init_model.classifier)
        for base, target in zip(init_model.layers, network.layers):
            assert (
                type(base).__name__ == type(target).__name__
            ), "invalid type : {:} vs {:}".format(base, target)
            if type(base).__name__ == "ConvBNReLU":
                copy_base(target, base)
            elif type(base).__name__ == "ResNetBasicblock":
                copy_basic(target, base)
            else:
                raise ValueError("unknown type name : {:}".format(type(base).__name__))