xautodl/exps/experimental/test-flops.py

29 lines
787 B
Python
Raw Normal View History

2020-03-31 01:20:01 +02:00
import sys, time, random, argparse
from copy import deepcopy
import torchvision.models as models
from pathlib import Path
2021-03-17 10:25:58 +01:00
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
2020-03-31 01:20:01 +02:00
from utils import get_model_infos
2021-03-17 10:25:58 +01:00
# from models.ImageNet_MobileNetV2 import MobileNetV2
2020-03-31 01:20:01 +02:00
from torchvision.models.mobilenet import MobileNetV2
2021-03-17 10:25:58 +01:00
2020-03-31 01:20:01 +02:00
def main(width_mult):
2021-03-17 10:25:58 +01:00
# model = MobileNetV2(1001, width_mult, 32, 1280, 'InvertedResidual', 0.2)
model = MobileNetV2(width_mult=width_mult)
print(model)
flops, params = get_model_infos(model, (2, 3, 224, 224))
print("FLOPs : {:}".format(flops))
print("Params : {:}".format(params))
print("-" * 50)
2020-03-31 01:20:01 +02:00
2021-03-17 10:25:58 +01:00
if __name__ == "__main__":
main(1.0)
main(1.4)