Update yaml configs
This commit is contained in:
parent
1a7440d2af
commit
9bf0fa5f04
@ -1,7 +0,0 @@
|
|||||||
class_or_func: CIFAR10
|
|
||||||
module_path: torchvision.datasets
|
|
||||||
args: []
|
|
||||||
kwargs:
|
|
||||||
train: False
|
|
||||||
download: True
|
|
||||||
transform: null
|
|
@ -1,7 +0,0 @@
|
|||||||
class_or_func: CIFAR10
|
|
||||||
module_path: torchvision.datasets
|
|
||||||
args: []
|
|
||||||
kwargs:
|
|
||||||
train: True
|
|
||||||
download: True
|
|
||||||
transform: null
|
|
22
configs/yaml.data/cifar10.test
Normal file
22
configs/yaml.data/cifar10.test
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
class_or_func: CIFAR10
|
||||||
|
module_path: torchvision.datasets
|
||||||
|
args: []
|
||||||
|
kwargs:
|
||||||
|
train: False
|
||||||
|
download: True
|
||||||
|
transform:
|
||||||
|
class_or_func: Compose
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args:
|
||||||
|
-
|
||||||
|
- class_or_func: ToTensor
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args: []
|
||||||
|
kwargs: {}
|
||||||
|
- class_or_func: Normalize
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args: []
|
||||||
|
kwargs:
|
||||||
|
mean: (0.491, 0.482, 0.447)
|
||||||
|
std: (0.247, 0.244, 0.262)
|
||||||
|
kwargs: {}
|
30
configs/yaml.data/cifar10.train
Normal file
30
configs/yaml.data/cifar10.train
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
class_or_func: CIFAR10
|
||||||
|
module_path: torchvision.datasets
|
||||||
|
args: []
|
||||||
|
kwargs:
|
||||||
|
train: True
|
||||||
|
download: True
|
||||||
|
transform:
|
||||||
|
class_or_func: Compose
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args:
|
||||||
|
-
|
||||||
|
- class_or_func: RandomHorizontalFlip
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args: []
|
||||||
|
kwargs: {}
|
||||||
|
- class_or_func: RandomCrop
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args: [32]
|
||||||
|
kwargs: {padding: 4}
|
||||||
|
- class_or_func: ToTensor
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args: []
|
||||||
|
kwargs: {}
|
||||||
|
- class_or_func: Normalize
|
||||||
|
module_path: torchvision.transforms
|
||||||
|
args: []
|
||||||
|
kwargs:
|
||||||
|
mean: (0.491, 0.482, 0.447)
|
||||||
|
std: (0.247, 0.244, 0.262)
|
||||||
|
kwargs: {}
|
4
configs/yaml.loss/cross-entropy
Normal file
4
configs/yaml.loss/cross-entropy
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
class_or_func: CrossEntropyLoss
|
||||||
|
module_path: torch.nn
|
||||||
|
args: []
|
||||||
|
kwargs: {}
|
4
configs/yaml.model/vit-cifar10.s0
Normal file
4
configs/yaml.model/vit-cifar10.s0
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
class_or_func: get_transformer
|
||||||
|
module_path: xautodl.xmodels.transformers
|
||||||
|
args: [vit-cifar10-p4-d4-h4-c32]
|
||||||
|
kwargs: {}
|
7
configs/yaml.opt/vit.cifar
Normal file
7
configs/yaml.opt/vit.cifar
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
class_or_func: Adam
|
||||||
|
module_path: torch.optim
|
||||||
|
args: []
|
||||||
|
kwargs:
|
||||||
|
betas: [0.9, 0.999]
|
||||||
|
weight_decay: 0.1
|
||||||
|
amsgrad: False
|
@ -3,7 +3,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# python exps/basic/xmain.py --save_dir outputs/x #
|
# python exps/basic/xmain.py --save_dir outputs/x #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, time, torch, random, argparse
|
import os, sys, time, torch, random, argparse
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -12,24 +12,38 @@ print("LIB-DIR: {:}".format(lib_dir))
|
|||||||
if str(lib_dir) not in sys.path:
|
if str(lib_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
from xautodl.xmisc import nested_call_by_yaml
|
from xautodl import xmisc
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
|
||||||
train_data = nested_call_by_yaml(args.train_data_config, args.data_path)
|
train_data = xmisc.nested_call_by_yaml(args.train_data_config, args.data_path)
|
||||||
valid_data = nested_call_by_yaml(args.valid_data_config, args.data_path)
|
valid_data = xmisc.nested_call_by_yaml(args.valid_data_config, args.data_path)
|
||||||
|
logger = xmisc.Logger(args.save_dir, prefix="seed-{:}-".format(args.rand_seed))
|
||||||
|
|
||||||
import pdb
|
logger.log("Create the logger: {:}".format(logger))
|
||||||
|
logger.log("Arguments : -------------------------------")
|
||||||
pdb.set_trace()
|
for name, value in args._get_kwargs():
|
||||||
|
logger.log("{:16} : {:}".format(name, value))
|
||||||
prepare_seed(args.rand_seed)
|
logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
|
||||||
logger = prepare_logger(args)
|
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||||||
|
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||||||
train_data, valid_data, xshape, class_num = get_datasets(
|
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||||||
args.dataset, args.data_path, args.cutout_length
|
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||||
|
logger.log(
|
||||||
|
"CUDA_VISIBLE_DEVICES : {:}".format(
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
|
if "CUDA_VISIBLE_DEVICES" in os.environ
|
||||||
|
else "None"
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
logger.log("The training data is:\n{:}".format(train_data))
|
||||||
|
logger.log("The validation data is:\n{:}".format(valid_data))
|
||||||
|
|
||||||
|
model = xmisc.nested_call_by_yaml(args.model_config)
|
||||||
|
logger.log("The model is:\n{:}".format(model))
|
||||||
|
logger.log("The model size is {:.4f} M".format(xmisc.count_parameters(model)))
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_data,
|
train_data,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
@ -44,100 +58,25 @@ def main(args):
|
|||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
# get configures
|
|
||||||
model_config = load_config(args.model_config, {"class_num": class_num}, logger)
|
|
||||||
optim_config = load_config(args.optim_config, {"class_num": class_num}, logger)
|
|
||||||
|
|
||||||
if args.model_source == "normal":
|
logger.log("The training loader: {:}".format(train_loader))
|
||||||
base_model = obtain_model(model_config)
|
logger.log("The validation loader: {:}".format(valid_loader))
|
||||||
elif args.model_source == "nas":
|
optimizer = xmisc.nested_call_by_yaml(
|
||||||
base_model = obtain_nas_infer_model(model_config, args.extra_model_path)
|
args.optim_config,
|
||||||
elif args.model_source == "autodl-searched":
|
model.parameters(),
|
||||||
base_model = obtain_model(model_config, args.extra_model_path)
|
lr=args.lr,
|
||||||
elif args.model_source in ("x", "xmodel"):
|
weight_decay=args.weight_decay,
|
||||||
base_model = obtain_xmodel(model_config)
|
|
||||||
else:
|
|
||||||
raise ValueError("invalid model-source : {:}".format(args.model_source))
|
|
||||||
flop, param = get_model_infos(base_model, xshape)
|
|
||||||
logger.log("model ====>>>>:\n{:}".format(base_model))
|
|
||||||
logger.log("model information : {:}".format(base_model.get_message()))
|
|
||||||
logger.log("-" * 50)
|
|
||||||
logger.log(
|
|
||||||
"Params={:.2f} MB, FLOPs={:.2f} M ... = {:.2f} G".format(
|
|
||||||
param, flop, flop / 1e3
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
logger.log("-" * 50)
|
loss = xmisc.nested_call_by_yaml(args.loss_config)
|
||||||
logger.log("train_data : {:}".format(train_data))
|
|
||||||
logger.log("valid_data : {:}".format(valid_data))
|
|
||||||
optimizer, scheduler, criterion = get_optim_scheduler(
|
|
||||||
base_model.parameters(), optim_config
|
|
||||||
)
|
|
||||||
logger.log("optimizer : {:}".format(optimizer))
|
|
||||||
logger.log("scheduler : {:}".format(scheduler))
|
|
||||||
logger.log("criterion : {:}".format(criterion))
|
|
||||||
|
|
||||||
last_info, model_base_path, model_best_path = (
|
logger.log("The optimizer is:\n{:}".format(optimizer))
|
||||||
logger.path("info"),
|
logger.log("The loss is {:}".format(loss))
|
||||||
logger.path("model"),
|
|
||||||
logger.path("best"),
|
|
||||||
)
|
|
||||||
network, criterion = torch.nn.DataParallel(base_model).cuda(), criterion.cuda()
|
|
||||||
|
|
||||||
if last_info.exists(): # automatically resume from previous checkpoint
|
model, loss = torch.nn.DataParallel(model).cuda(), loss.cuda()
|
||||||
logger.log(
|
|
||||||
"=> loading checkpoint of the last-info '{:}' start".format(last_info)
|
import pdb
|
||||||
)
|
|
||||||
last_infox = torch.load(last_info)
|
pdb.set_trace()
|
||||||
start_epoch = last_infox["epoch"] + 1
|
|
||||||
last_checkpoint_path = last_infox["last_checkpoint"]
|
|
||||||
if not last_checkpoint_path.exists():
|
|
||||||
logger.log(
|
|
||||||
"Does not find {:}, try another path".format(last_checkpoint_path)
|
|
||||||
)
|
|
||||||
last_checkpoint_path = (
|
|
||||||
last_info.parent
|
|
||||||
/ last_checkpoint_path.parent.name
|
|
||||||
/ last_checkpoint_path.name
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(last_checkpoint_path)
|
|
||||||
base_model.load_state_dict(checkpoint["base-model"])
|
|
||||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
|
||||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
||||||
valid_accuracies = checkpoint["valid_accuracies"]
|
|
||||||
max_bytes = checkpoint["max_bytes"]
|
|
||||||
logger.log(
|
|
||||||
"=> loading checkpoint of the last-info '{:}' start with {:}-th epoch.".format(
|
|
||||||
last_info, start_epoch
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif args.resume is not None:
|
|
||||||
assert Path(args.resume).exists(), "Can not find the resume file : {:}".format(
|
|
||||||
args.resume
|
|
||||||
)
|
|
||||||
checkpoint = torch.load(args.resume)
|
|
||||||
start_epoch = checkpoint["epoch"] + 1
|
|
||||||
base_model.load_state_dict(checkpoint["base-model"])
|
|
||||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
|
||||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
|
||||||
valid_accuracies = checkpoint["valid_accuracies"]
|
|
||||||
max_bytes = checkpoint["max_bytes"]
|
|
||||||
logger.log(
|
|
||||||
"=> loading checkpoint from '{:}' start with {:}-th epoch.".format(
|
|
||||||
args.resume, start_epoch
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif args.init_model is not None:
|
|
||||||
assert Path(
|
|
||||||
args.init_model
|
|
||||||
).exists(), "Can not find the initialization file : {:}".format(args.init_model)
|
|
||||||
checkpoint = torch.load(args.init_model)
|
|
||||||
base_model.load_state_dict(checkpoint["base-model"])
|
|
||||||
start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {}
|
|
||||||
logger.log("=> initialize the model from {:}".format(args.init_model))
|
|
||||||
else:
|
|
||||||
logger.log("=> do not find the last-info file : {:}".format(last_info))
|
|
||||||
start_epoch, valid_accuracies, max_bytes = 0, {"best": -1}, {}
|
|
||||||
|
|
||||||
train_func, valid_func = get_procedures(args.procedure)
|
train_func, valid_func = get_procedures(args.procedure)
|
||||||
|
|
||||||
@ -284,7 +223,7 @@ def main(args):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Train a model with a loss function.",
|
description="Train a classification model with a loss function.",
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -293,27 +232,21 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--resume", type=str, help="Resume path.")
|
parser.add_argument("--resume", type=str, help="Resume path.")
|
||||||
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
parser.add_argument("--init_model", type=str, help="The initialization model path.")
|
||||||
parser.add_argument("--model_config", type=str, help="The path to the model config")
|
parser.add_argument("--model_config", type=str, help="The path to the model config")
|
||||||
|
parser.add_argument("--optim_config", type=str, help="The optimizer config file.")
|
||||||
|
parser.add_argument("--loss_config", type=str, help="The loss config file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--optim_config", type=str, help="The path to the optimizer config"
|
"--train_data_config", type=str, help="The training dataset config path."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--train_data_config", type=str, help="The dataset config path."
|
"--valid_data_config", type=str, help="The validation dataset config path."
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--valid_data_config", type=str, help="The dataset config path."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--data_path", type=str, help="The path to the dataset."
|
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--data_path", type=str, help="The path to the dataset.")
|
||||||
parser.add_argument("--algorithm", type=str, help="The algorithm.")
|
parser.add_argument("--algorithm", type=str, help="The algorithm.")
|
||||||
# Optimization options
|
# Optimization options
|
||||||
|
parser.add_argument("--lr", type=float, help="The learning rate")
|
||||||
|
parser.add_argument("--weight_decay", type=float, help="The weight decay")
|
||||||
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
|
parser.add_argument("--batch_size", type=int, default=2, help="The batch size.")
|
||||||
parser.add_argument(
|
parser.add_argument("--workers", type=int, default=4, help="The number of workers")
|
||||||
"--workers",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="number of data loading workers (default: 8)",
|
|
||||||
)
|
|
||||||
# Random Seed
|
# Random Seed
|
||||||
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
parser.add_argument("--rand_seed", type=int, default=-1, help="manual seed")
|
||||||
|
|
||||||
|
@ -22,6 +22,10 @@ save_dir=./outputs/${dataset}/vit-experimental
|
|||||||
python --version
|
python --version
|
||||||
|
|
||||||
python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \
|
python ./exps/basic/xmain.py --save_dir ${save_dir} --rand_seed ${rseed} \
|
||||||
--train_data_config ./configs/data.yaml/${dataset}.train \
|
--train_data_config ./configs/yaml.data/${dataset}.train \
|
||||||
--valid_data_config ./configs/data.yaml/${dataset}.test \
|
--valid_data_config ./configs/yaml.data/${dataset}.test \
|
||||||
--data_path $TORCH_HOME/cifar.python
|
--data_path $TORCH_HOME/cifar.python \
|
||||||
|
--model_config ./configs/yaml.model/vit-cifar10.s0 \
|
||||||
|
--optim_config ./configs/yaml.opt/vit.cifar \
|
||||||
|
--loss_config ./configs/yaml.loss/cross-entropy \
|
||||||
|
--lr 0.003 --weight_decay 0.3
|
||||||
|
@ -3,10 +3,8 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# pytest tests/test_basic_space.py -s #
|
# pytest tests/test_basic_space.py -s #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from xautodl.spaces import Categorical
|
from xautodl.spaces import Categorical
|
||||||
from xautodl.spaces import Continuous
|
from xautodl.spaces import Continuous
|
||||||
|
@ -3,12 +3,6 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# pytest ./tests/test_import.py #
|
# pytest ./tests/test_import.py #
|
||||||
#####################################################
|
#####################################################
|
||||||
import os, sys, time, torch
|
|
||||||
import pickle
|
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def test_import():
|
def test_import():
|
||||||
from xautodl import config_utils
|
from xautodl import config_utils
|
||||||
from xautodl import datasets
|
from xautodl import datasets
|
||||||
@ -19,6 +13,9 @@ def test_import():
|
|||||||
from xautodl import spaces
|
from xautodl import spaces
|
||||||
from xautodl import trade_models
|
from xautodl import trade_models
|
||||||
from xautodl import utils
|
from xautodl import utils
|
||||||
|
|
||||||
from xautodl import xlayers
|
from xautodl import xlayers
|
||||||
|
from xautodl import xmisc
|
||||||
|
from xautodl import xmmodels
|
||||||
|
|
||||||
print("Check all imports done")
|
print("Check all imports done")
|
||||||
|
@ -3,13 +3,11 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# pytest ./tests/test_super_att.py -s #
|
# pytest ./tests/test_super_att.py -s #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from xautodl import spaces
|
from xautodl import spaces
|
||||||
from xautodl.xlayers import super_core
|
from xautodl.xlayers import super_core
|
||||||
|
|
||||||
|
@ -3,10 +3,9 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# pytest ./tests/test_super_container.py -s #
|
# pytest ./tests/test_super_container.py -s #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys, random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
import pytest
|
import pytest
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xautodl import spaces
|
from xautodl import spaces
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# pytest ./tests/test_super_rearrange.py -s #
|
# pytest ./tests/test_super_rearrange.py -s #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# pytest ./tests/test_super_vit.py -s #
|
# pytest ./tests/test_super_vit.py -s #
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from parameterized import parameterized
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from xautodl.xmodels import transformers
|
from xautodl.xmodels import transformers
|
||||||
@ -16,25 +16,28 @@ class TestSuperViT(unittest.TestCase):
|
|||||||
|
|
||||||
def test_super_vit(self):
|
def test_super_vit(self):
|
||||||
model = transformers.get_transformer("vit-base-16")
|
model = transformers.get_transformer("vit-base-16")
|
||||||
tensor = torch.rand((16, 3, 224, 224))
|
tensor = torch.rand((2, 3, 224, 224))
|
||||||
print("The tensor shape: {:}".format(tensor.shape))
|
print("The tensor shape: {:}".format(tensor.shape))
|
||||||
# print(model)
|
# print(model)
|
||||||
outs = model(tensor)
|
outs = model(tensor)
|
||||||
print("The output tensor shape: {:}".format(outs.shape))
|
print("The output tensor shape: {:}".format(outs.shape))
|
||||||
|
|
||||||
def test_imagenet(self):
|
@parameterized.expand(
|
||||||
name2config = transformers.name2config
|
[
|
||||||
print("There are {:} models in total.".format(len(name2config)))
|
["vit-cifar10-p4-d4-h4-c32", 32],
|
||||||
for name, config in name2config.items():
|
["vit-base-16", 224],
|
||||||
if "cifar" in name:
|
["vit-large-16", 224],
|
||||||
tensor = torch.rand((16, 3, 32, 32))
|
["vit-huge-14", 224],
|
||||||
else:
|
]
|
||||||
tensor = torch.rand((16, 3, 224, 224))
|
)
|
||||||
model = transformers.get_transformer(config)
|
def test_imagenet(self, name, resolution):
|
||||||
outs = model(tensor)
|
tensor = torch.rand((2, 3, resolution, resolution))
|
||||||
size = count_parameters(model, "mb", True)
|
config = transformers.name2config[name]
|
||||||
print(
|
model = transformers.get_transformer(config)
|
||||||
"{:10s} : size={:.2f}MB, out-shape: {:}".format(
|
outs = model(tensor)
|
||||||
name, size, tuple(outs.shape)
|
size = count_parameters(model, "mb", True)
|
||||||
)
|
print(
|
||||||
|
"{:10s} : size={:.2f}MB, out-shape: {:}".format(
|
||||||
|
name, size, tuple(outs.shape)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
@ -6,3 +6,7 @@ from .module_utils import call_by_yaml
|
|||||||
from .module_utils import nested_call_by_dict
|
from .module_utils import nested_call_by_dict
|
||||||
from .module_utils import nested_call_by_yaml
|
from .module_utils import nested_call_by_yaml
|
||||||
from .yaml_utils import load_yaml
|
from .yaml_utils import load_yaml
|
||||||
|
|
||||||
|
from .torch_utils import count_parameters
|
||||||
|
|
||||||
|
from .logger_utils import Logger
|
||||||
|
49
xautodl/xmisc/logger_utils.py
Normal file
49
xautodl/xmisc/logger_utils.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||||
|
#####################################################
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from .time_utils import time_for_file, time_string
|
||||||
|
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
"""A logger used in xautodl."""
|
||||||
|
|
||||||
|
def __init__(self, root_dir, prefix="", log_time=True):
|
||||||
|
"""Create a summary writer logging to log_dir."""
|
||||||
|
self.root_dir = Path(root_dir)
|
||||||
|
self.log_dir = self.root_dir / "logs"
|
||||||
|
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self._prefix = prefix
|
||||||
|
self._log_time = log_time
|
||||||
|
self.logger_path = self.log_dir / "{:}{:}.log".format(
|
||||||
|
self._prefix, time_for_file()
|
||||||
|
)
|
||||||
|
self._logger_file = open(self.logger_path, "w")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self):
|
||||||
|
return self._logger_file
|
||||||
|
|
||||||
|
def log(self, string, save=True, stdout=False):
|
||||||
|
string = "{:} {:}".format(time_string(), string) if self._log_time else string
|
||||||
|
if stdout:
|
||||||
|
sys.stdout.write(string)
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
print(string)
|
||||||
|
if save:
|
||||||
|
self._logger_file.write("{:}\n".format(string))
|
||||||
|
self._logger_file.flush()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
self._logger_file.close()
|
||||||
|
if self.writer is not None:
|
||||||
|
self.writer.close()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}(dir={log_dir}, prefix={_prefix}, log_time={_log_time})".format(
|
||||||
|
name=self.__class__.__name__, **self.__dict__
|
||||||
|
)
|
@ -62,18 +62,25 @@ def call_by_yaml(path, *args, **kwargs) -> object:
|
|||||||
|
|
||||||
def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object:
|
def nested_call_by_dict(config: Union[Dict[Text, Any], Any], *args, **kwargs) -> object:
|
||||||
"""Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called."""
|
"""Similar to `call_by_dict`, but differently, the args may contain another dict needs to be called."""
|
||||||
if not has_key_words(config):
|
if isinstance(config, list):
|
||||||
|
return [nested_call_by_dict(x) for x in config]
|
||||||
|
elif isinstance(config, tuple):
|
||||||
|
return (nested_call_by_dict(x) for x in config)
|
||||||
|
elif not isinstance(config, dict):
|
||||||
return config
|
return config
|
||||||
module = get_module_by_module_path(config["module_path"])
|
elif not has_key_words(config):
|
||||||
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
return {key: nested_call_by_dict(x) for x, key in config.items()}
|
||||||
args = tuple(list(config["args"]) + list(args))
|
else:
|
||||||
kwargs = {**config["kwargs"], **kwargs}
|
module = get_module_by_module_path(config["module_path"])
|
||||||
# check whether there are nested special dict
|
cls_or_func = getattr(module, config[CLS_FUNC_KEY])
|
||||||
new_args = [nested_call_by_dict(x) for x in args]
|
args = tuple(list(config["args"]) + list(args))
|
||||||
new_kwargs = {}
|
kwargs = {**config["kwargs"], **kwargs}
|
||||||
for key, x in kwargs.items():
|
# check whether there are nested special dict
|
||||||
new_kwargs[key] = nested_call_by_dict(x)
|
new_args = [nested_call_by_dict(x) for x in args]
|
||||||
return cls_or_func(*new_args, **new_kwargs)
|
new_kwargs = {}
|
||||||
|
for key, x in kwargs.items():
|
||||||
|
new_kwargs[key] = nested_call_by_dict(x)
|
||||||
|
return cls_or_func(*new_args, **new_kwargs)
|
||||||
|
|
||||||
|
|
||||||
def nested_call_by_yaml(path, *args, **kwargs) -> object:
|
def nested_call_by_yaml(path, *args, **kwargs) -> object:
|
||||||
|
136
xautodl/xmisc/scheduler_utils.py
Normal file
136
xautodl/xmisc/scheduler_utils.py
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||||
|
#####################################################
|
||||||
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
|
|
||||||
|
|
||||||
|
class CosineDecayWithWarmup(_LRScheduler):
|
||||||
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||||
|
schedule, where :math:`\eta_{max}` is set to the initial lr, :math:`T_{cur}`
|
||||||
|
is the number of epochs since the last restart and :math:`T_{i}` is the number
|
||||||
|
of epochs between two warm restarts in SGDR:
|
||||||
|
.. math::
|
||||||
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||||
|
\cos\left(\frac{T_{cur}}{T_{i}}\pi\right)\right)
|
||||||
|
When :math:`T_{cur}=T_{i}`, set :math:`\eta_t = \eta_{min}`.
|
||||||
|
When :math:`T_{cur}=0` after restart, set :math:`\eta_t=\eta_{max}`.
|
||||||
|
It has been proposed in
|
||||||
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_.
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
T_0 (int): Number of iterations for the first restart.
|
||||||
|
T_mult (int, optional): A factor increases :math:`T_{i}` after a restart. Default: 1.
|
||||||
|
eta_min (float, optional): Minimum learning rate. Default: 0.
|
||||||
|
last_epoch (int, optional): The index of last epoch. Default: -1.
|
||||||
|
verbose (bool): If ``True``, prints a message to stdout for
|
||||||
|
each update. Default: ``False``.
|
||||||
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||||
|
https://arxiv.org/abs/1608.03983
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1, verbose=False
|
||||||
|
):
|
||||||
|
if T_0 <= 0 or not isinstance(T_0, int):
|
||||||
|
raise ValueError("Expected positive integer T_0, but got {}".format(T_0))
|
||||||
|
if T_mult < 1 or not isinstance(T_mult, int):
|
||||||
|
raise ValueError("Expected integer T_mult >= 1, but got {}".format(T_mult))
|
||||||
|
self.T_0 = T_0
|
||||||
|
self.T_i = T_0
|
||||||
|
self.T_mult = T_mult
|
||||||
|
self.eta_min = eta_min
|
||||||
|
|
||||||
|
super(CosineDecayWithWarmup, self).__init__(optimizer, last_epoch, verbose)
|
||||||
|
|
||||||
|
self.T_cur = self.last_epoch
|
||||||
|
|
||||||
|
def get_lr(self):
|
||||||
|
if not self._get_lr_called_within_step:
|
||||||
|
warnings.warn(
|
||||||
|
"To get the last learning rate computed by the scheduler, "
|
||||||
|
"please use `get_last_lr()`.",
|
||||||
|
UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
self.eta_min
|
||||||
|
+ (base_lr - self.eta_min)
|
||||||
|
* (1 + math.cos(math.pi * self.T_cur / self.T_i))
|
||||||
|
/ 2
|
||||||
|
for base_lr in self.base_lrs
|
||||||
|
]
|
||||||
|
|
||||||
|
def step(self, epoch=None):
|
||||||
|
"""Step could be called after every batch update
|
||||||
|
Example:
|
||||||
|
>>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult)
|
||||||
|
>>> iters = len(dataloader)
|
||||||
|
>>> for epoch in range(20):
|
||||||
|
>>> for i, sample in enumerate(dataloader):
|
||||||
|
>>> inputs, labels = sample['inputs'], sample['labels']
|
||||||
|
>>> optimizer.zero_grad()
|
||||||
|
>>> outputs = net(inputs)
|
||||||
|
>>> loss = criterion(outputs, labels)
|
||||||
|
>>> loss.backward()
|
||||||
|
>>> optimizer.step()
|
||||||
|
>>> scheduler.step(epoch + i / iters)
|
||||||
|
This function can be called in an interleaved way.
|
||||||
|
Example:
|
||||||
|
>>> scheduler = CosineDecayWithWarmup(optimizer, T_0, T_mult)
|
||||||
|
>>> for epoch in range(20):
|
||||||
|
>>> scheduler.step()
|
||||||
|
>>> scheduler.step(26)
|
||||||
|
>>> scheduler.step() # scheduler.step(27), instead of scheduler(20)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if epoch is None and self.last_epoch < 0:
|
||||||
|
epoch = 0
|
||||||
|
|
||||||
|
if epoch is None:
|
||||||
|
epoch = self.last_epoch + 1
|
||||||
|
self.T_cur = self.T_cur + 1
|
||||||
|
if self.T_cur >= self.T_i:
|
||||||
|
self.T_cur = self.T_cur - self.T_i
|
||||||
|
self.T_i = self.T_i * self.T_mult
|
||||||
|
else:
|
||||||
|
if epoch < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected non-negative epoch, but got {}".format(epoch)
|
||||||
|
)
|
||||||
|
if epoch >= self.T_0:
|
||||||
|
if self.T_mult == 1:
|
||||||
|
self.T_cur = epoch % self.T_0
|
||||||
|
else:
|
||||||
|
n = int(
|
||||||
|
math.log(
|
||||||
|
(epoch / self.T_0 * (self.T_mult - 1) + 1), self.T_mult
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.T_cur = epoch - self.T_0 * (self.T_mult ** n - 1) / (
|
||||||
|
self.T_mult - 1
|
||||||
|
)
|
||||||
|
self.T_i = self.T_0 * self.T_mult ** (n)
|
||||||
|
else:
|
||||||
|
self.T_i = self.T_0
|
||||||
|
self.T_cur = epoch
|
||||||
|
self.last_epoch = math.floor(epoch)
|
||||||
|
|
||||||
|
class _enable_get_lr_call:
|
||||||
|
def __init__(self, o):
|
||||||
|
self.o = o
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.o._get_lr_called_within_step = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, type, value, traceback):
|
||||||
|
self.o._get_lr_called_within_step = False
|
||||||
|
return self
|
||||||
|
|
||||||
|
with _enable_get_lr_call(self):
|
||||||
|
for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())):
|
||||||
|
param_group, lr = data
|
||||||
|
param_group["lr"] = lr
|
||||||
|
self.print_lr(self.verbose, i, lr, epoch)
|
||||||
|
|
||||||
|
self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
|
26
xautodl/xmisc/time_utils.py
Normal file
26
xautodl/xmisc/time_utils.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||||
|
#####################################################
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def time_for_file():
|
||||||
|
ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S"
|
||||||
|
return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||||
|
|
||||||
|
|
||||||
|
def time_string():
|
||||||
|
ISOTIMEFORMAT = "%Y-%m-%d %X"
|
||||||
|
string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time())))
|
||||||
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
def convert_secs2time(epoch_time, return_str=False):
|
||||||
|
need_hour = int(epoch_time / 3600)
|
||||||
|
need_mins = int((epoch_time - 3600 * need_hour) / 60)
|
||||||
|
need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins)
|
||||||
|
if return_str:
|
||||||
|
str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs)
|
||||||
|
return str
|
||||||
|
else:
|
||||||
|
return need_hour, need_mins, need_secs
|
26
xautodl/xmisc/torch_utils.py
Normal file
26
xautodl/xmisc/torch_utils.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.06 #
|
||||||
|
#####################################################
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model_or_parameters, unit="mb"):
|
||||||
|
if isinstance(model_or_parameters, nn.Module):
|
||||||
|
counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters())
|
||||||
|
elif isinstance(model_or_parameters, nn.Parameter):
|
||||||
|
counts = models_or_parameters.numel()
|
||||||
|
elif isinstance(model_or_parameters, (list, tuple)):
|
||||||
|
counts = sum(count_parameters(x, None) for x in models_or_parameters)
|
||||||
|
else:
|
||||||
|
counts = sum(np.prod(v.size()) for v in model_or_parameters)
|
||||||
|
if unit.lower() == "kb" or unit.lower() == "k":
|
||||||
|
counts /= 1e3
|
||||||
|
elif unit.lower() == "mb" or unit.lower() == "m":
|
||||||
|
counts /= 1e6
|
||||||
|
elif unit.lower() == "gb" or unit.lower() == "g":
|
||||||
|
counts /= 1e9
|
||||||
|
elif unit is not None:
|
||||||
|
raise ValueError("Unknow unit: {:}".format(unit))
|
||||||
|
return counts
|
Loading…
Reference in New Issue
Block a user