130 lines
4.6 KiB
Python
130 lines
4.6 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
#
|
||
|
# This source code is licensed under the MIT license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""Functions for manipulating networks."""
|
||
|
|
||
|
import itertools
|
||
|
import math
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from pycls.core.config import cfg
|
||
|
|
||
|
|
||
|
def init_weights(m):
|
||
|
"""Performs ResNet-style weight initialization."""
|
||
|
if isinstance(m, nn.Conv2d):
|
||
|
# Note that there is no bias due to BN
|
||
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||
|
m.weight.data.normal_(mean=0.0, std=math.sqrt(2.0 / fan_out))
|
||
|
elif isinstance(m, nn.BatchNorm2d):
|
||
|
zero_init_gamma = cfg.BN.ZERO_INIT_FINAL_GAMMA
|
||
|
zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma
|
||
|
m.weight.data.fill_(0.0 if zero_init_gamma else 1.0)
|
||
|
m.bias.data.zero_()
|
||
|
elif isinstance(m, nn.Linear):
|
||
|
m.weight.data.normal_(mean=0.0, std=0.01)
|
||
|
m.bias.data.zero_()
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def compute_precise_bn_stats(model, loader):
|
||
|
"""Computes precise BN stats on training data."""
|
||
|
# Compute the number of minibatches to use
|
||
|
num_iter = min(cfg.BN.NUM_SAMPLES_PRECISE // loader.batch_size, len(loader))
|
||
|
# Retrieve the BN layers
|
||
|
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
|
||
|
# Initialize stats storage
|
||
|
mus = [torch.zeros_like(bn.running_mean) for bn in bns]
|
||
|
sqs = [torch.zeros_like(bn.running_var) for bn in bns]
|
||
|
# Remember momentum values
|
||
|
moms = [bn.momentum for bn in bns]
|
||
|
# Disable momentum
|
||
|
for bn in bns:
|
||
|
bn.momentum = 1.0
|
||
|
# Accumulate the stats across the data samples
|
||
|
for inputs, _labels in itertools.islice(loader, num_iter):
|
||
|
model(inputs.cuda())
|
||
|
# Accumulate the stats for each BN layer
|
||
|
for i, bn in enumerate(bns):
|
||
|
m, v = bn.running_mean, bn.running_var
|
||
|
sqs[i] += (v + m * m) / num_iter
|
||
|
mus[i] += m / num_iter
|
||
|
# Set the stats and restore momentum values
|
||
|
for i, bn in enumerate(bns):
|
||
|
bn.running_var = sqs[i] - mus[i] * mus[i]
|
||
|
bn.running_mean = mus[i]
|
||
|
bn.momentum = moms[i]
|
||
|
|
||
|
|
||
|
def reset_bn_stats(model):
|
||
|
"""Resets running BN stats."""
|
||
|
for m in model.modules():
|
||
|
if isinstance(m, torch.nn.BatchNorm2d):
|
||
|
m.reset_running_stats()
|
||
|
|
||
|
|
||
|
def complexity_conv2d(cx, w_in, w_out, k, stride, padding, groups=1, bias=False):
|
||
|
"""Accumulates complexity of Conv2D into cx = (h, w, flops, params, acts)."""
|
||
|
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||
|
h = (h + 2 * padding - k) // stride + 1
|
||
|
w = (w + 2 * padding - k) // stride + 1
|
||
|
flops += k * k * w_in * w_out * h * w // groups
|
||
|
params += k * k * w_in * w_out // groups
|
||
|
flops += w_out if bias else 0
|
||
|
params += w_out if bias else 0
|
||
|
acts += w_out * h * w
|
||
|
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||
|
|
||
|
|
||
|
def complexity_batchnorm2d(cx, w_in):
|
||
|
"""Accumulates complexity of BatchNorm2D into cx = (h, w, flops, params, acts)."""
|
||
|
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||
|
params += 2 * w_in
|
||
|
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||
|
|
||
|
|
||
|
def complexity_maxpool2d(cx, k, stride, padding):
|
||
|
"""Accumulates complexity of MaxPool2d into cx = (h, w, flops, params, acts)."""
|
||
|
h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"]
|
||
|
h = (h + 2 * padding - k) // stride + 1
|
||
|
w = (w + 2 * padding - k) // stride + 1
|
||
|
return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts}
|
||
|
|
||
|
|
||
|
def complexity(model):
|
||
|
"""Compute model complexity (model can be model instance or model class)."""
|
||
|
size = cfg.TRAIN.IM_SIZE
|
||
|
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
|
||
|
cx = model.complexity(cx)
|
||
|
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}
|
||
|
|
||
|
|
||
|
def drop_connect(x, drop_ratio):
|
||
|
"""Drop connect (adapted from DARTS)."""
|
||
|
keep_ratio = 1.0 - drop_ratio
|
||
|
mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device)
|
||
|
mask.bernoulli_(keep_ratio)
|
||
|
x.div_(keep_ratio)
|
||
|
x.mul_(mask)
|
||
|
return x
|
||
|
|
||
|
|
||
|
def get_flat_weights(model):
|
||
|
"""Gets all model weights as a single flat vector."""
|
||
|
return torch.cat([p.data.view(-1, 1) for p in model.parameters()], 0)
|
||
|
|
||
|
|
||
|
def set_flat_weights(model, flat_weights):
|
||
|
"""Sets all model weights from a single flat vector."""
|
||
|
k = 0
|
||
|
for p in model.parameters():
|
||
|
n = p.data.numel()
|
||
|
p.data.copy_(flat_weights[k : (k + n)].view_as(p.data))
|
||
|
k += n
|
||
|
assert k == flat_weights.numel()
|