90 lines
3.2 KiB
Python
90 lines
3.2 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""RegNet models."""
|
|
|
|
import numpy as np
|
|
from pycls.core.config import cfg
|
|
from pycls.models.anynet import AnyNet
|
|
|
|
|
|
def quantize_float(f, q):
|
|
"""Converts a float to closest non-zero int divisible by q."""
|
|
return int(round(f / q) * q)
|
|
|
|
|
|
def adjust_ws_gs_comp(ws, bms, gs):
|
|
"""Adjusts the compatibility of widths and groups."""
|
|
ws_bot = [int(w * b) for w, b in zip(ws, bms)]
|
|
gs = [min(g, w_bot) for g, w_bot in zip(gs, ws_bot)]
|
|
ws_bot = [quantize_float(w_bot, g) for w_bot, g in zip(ws_bot, gs)]
|
|
ws = [int(w_bot / b) for w_bot, b in zip(ws_bot, bms)]
|
|
return ws, gs
|
|
|
|
|
|
def get_stages_from_blocks(ws, rs):
|
|
"""Gets ws/ds of network at each stage from per block values."""
|
|
ts_temp = zip(ws + [0], [0] + ws, rs + [0], [0] + rs)
|
|
ts = [w != wp or r != rp for w, wp, r, rp in ts_temp]
|
|
s_ws = [w for w, t in zip(ws, ts[:-1]) if t]
|
|
s_ds = np.diff([d for d, t in zip(range(len(ts)), ts) if t]).tolist()
|
|
return s_ws, s_ds
|
|
|
|
|
|
def generate_regnet(w_a, w_0, w_m, d, q=8):
|
|
"""Generates per block ws from RegNet parameters."""
|
|
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
|
|
ws_cont = np.arange(d) * w_a + w_0
|
|
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
|
|
ws = w_0 * np.power(w_m, ks)
|
|
ws = np.round(np.divide(ws, q)) * q
|
|
num_stages, max_stage = len(np.unique(ws)), ks.max() + 1
|
|
ws, ws_cont = ws.astype(int).tolist(), ws_cont.tolist()
|
|
return ws, num_stages, max_stage, ws_cont
|
|
|
|
|
|
class RegNet(AnyNet):
|
|
"""RegNet model."""
|
|
|
|
@staticmethod
|
|
def get_args():
|
|
"""Convert RegNet to AnyNet parameter format."""
|
|
# Generate RegNet ws per block
|
|
w_a, w_0, w_m, d = cfg.REGNET.WA, cfg.REGNET.W0, cfg.REGNET.WM, cfg.REGNET.DEPTH
|
|
ws, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
|
|
# Convert to per stage format
|
|
s_ws, s_ds = get_stages_from_blocks(ws, ws)
|
|
# Use the same gw, bm and ss for each stage
|
|
s_gs = [cfg.REGNET.GROUP_W for _ in range(num_stages)]
|
|
s_bs = [cfg.REGNET.BOT_MUL for _ in range(num_stages)]
|
|
s_ss = [cfg.REGNET.STRIDE for _ in range(num_stages)]
|
|
# Adjust the compatibility of ws and gws
|
|
s_ws, s_gs = adjust_ws_gs_comp(s_ws, s_bs, s_gs)
|
|
# Get AnyNet arguments defining the RegNet
|
|
return {
|
|
"stem_type": cfg.REGNET.STEM_TYPE,
|
|
"stem_w": cfg.REGNET.STEM_W,
|
|
"block_type": cfg.REGNET.BLOCK_TYPE,
|
|
"ds": s_ds,
|
|
"ws": s_ws,
|
|
"ss": s_ss,
|
|
"bms": s_bs,
|
|
"gws": s_gs,
|
|
"se_r": cfg.REGNET.SE_R if cfg.REGNET.SE_ON else None,
|
|
"nc": cfg.MODEL.NUM_CLASSES,
|
|
}
|
|
|
|
def __init__(self):
|
|
kwargs = RegNet.get_args()
|
|
super(RegNet, self).__init__(**kwargs)
|
|
|
|
@staticmethod
|
|
def complexity(cx, **kwargs):
|
|
"""Computes model complexity. If you alter the model, make sure to update."""
|
|
kwargs = RegNet.get_args() if not kwargs else kwargs
|
|
return AnyNet.complexity(cx, **kwargs)
|