32 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			32 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
#####################################################
 | 
						|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.06 #
 | 
						|
#####################################################
 | 
						|
# python exps/experimental/test-resnest.py
 | 
						|
#####################################################
 | 
						|
import sys, time, torch, random, argparse
 | 
						|
from PIL import ImageFile
 | 
						|
 | 
						|
ImageFile.LOAD_TRUNCATED_IMAGES = True
 | 
						|
from copy import deepcopy
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
 | 
						|
if str(lib_dir) not in sys.path:
 | 
						|
    sys.path.insert(0, str(lib_dir))
 | 
						|
from utils import get_model_infos
 | 
						|
 | 
						|
torch.hub.list("zhanghang1989/ResNeSt", force_reload=True)
 | 
						|
 | 
						|
for model_name, xshape in [
 | 
						|
    ("resnest50", (1, 3, 224, 224)),
 | 
						|
    ("resnest101", (1, 3, 256, 256)),
 | 
						|
    ("resnest200", (1, 3, 320, 320)),
 | 
						|
    ("resnest269", (1, 3, 416, 416)),
 | 
						|
]:
 | 
						|
    # net = torch.hub.load('zhanghang1989/ResNeSt', model_name, pretrained=True)
 | 
						|
    net = torch.hub.load("zhanghang1989/ResNeSt", model_name, pretrained=False)
 | 
						|
    print("Model : {:}, input shape : {:}".format(model_name, xshape))
 | 
						|
    flops, param = get_model_infos(net, xshape)
 | 
						|
    print("flops  : {:.3f}M".format(flops))
 | 
						|
    print("params : {:.3f}M".format(param))
 |