89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Model and loss construction functions."""
|
|
|
|
import torch
|
|
from pycls.core.config import cfg
|
|
from pycls.models.anynet import AnyNet
|
|
from pycls.models.effnet import EffNet
|
|
from pycls.models.regnet import RegNet
|
|
from pycls.models.resnet import ResNet
|
|
from pycls.models.nas.nas import NAS
|
|
from pycls.models.nas.nas_search import NAS_Search
|
|
from pycls.models.nas_bench.model_builder import NAS_Bench
|
|
|
|
|
|
class LabelSmoothedCrossEntropyLoss(torch.nn.Module):
|
|
"""CrossEntropyLoss with label smoothing."""
|
|
def __init__(self):
|
|
super(LabelSmoothedCrossEntropyLoss, self).__init__()
|
|
self.eps = cfg.MODEL.LABEL_SMOOTHING_EPS
|
|
self.num_classes = cfg.MODEL.NUM_CLASSES
|
|
|
|
def forward(self, logits, target):
|
|
pred = logits.log_softmax(dim=-1)
|
|
with torch.no_grad():
|
|
target_dist = torch.ones_like(pred) * self.eps / (self.num_classes - 1)
|
|
target_dist.scatter_(-1, target.unsqueeze(-1), 1 - self.eps)
|
|
return (-target_dist * pred).sum(dim=-1).mean()
|
|
|
|
|
|
# Supported models
|
|
_models = {
|
|
"anynet": AnyNet,
|
|
"effnet": EffNet,
|
|
"resnet": ResNet,
|
|
"regnet": RegNet,
|
|
"nas": NAS,
|
|
"nas_search": NAS_Search,
|
|
"nas_bench": NAS_Bench,
|
|
}
|
|
|
|
# Supported loss functions
|
|
_loss_funs = {
|
|
"cross_entropy": torch.nn.CrossEntropyLoss,
|
|
"label_smoothed_cross_entropy": LabelSmoothedCrossEntropyLoss,
|
|
}
|
|
|
|
|
|
def get_model():
|
|
"""Gets the model class specified in the config."""
|
|
err_str = "Model type '{}' not supported"
|
|
assert cfg.MODEL.TYPE in _models.keys(), err_str.format(cfg.MODEL.TYPE)
|
|
return _models[cfg.MODEL.TYPE]
|
|
|
|
|
|
def get_loss_fun():
|
|
"""Gets the loss function class specified in the config."""
|
|
err_str = "Loss function type '{}' not supported"
|
|
assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS)
|
|
return _loss_funs[cfg.MODEL.LOSS_FUN]
|
|
|
|
|
|
def build_model():
|
|
"""Builds the model."""
|
|
return get_model()()
|
|
|
|
|
|
def build_loss_fun():
|
|
"""Build the loss function."""
|
|
if cfg.TASK == "seg":
|
|
return get_loss_fun()(ignore_index=255)
|
|
else:
|
|
return get_loss_fun()()
|
|
|
|
|
|
def register_model(name, ctor):
|
|
"""Registers a model dynamically."""
|
|
_models[name] = ctor
|
|
|
|
|
|
def register_loss_fun(name, ctor):
|
|
"""Registers a loss function dynamically."""
|
|
_loss_funs[name] = ctor
|