80 lines
3.0 KiB
Python
80 lines
3.0 KiB
Python
|
##################################################
|
||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||
|
##################################################
|
||
|
import os, sys, torch, random, PIL, copy, numpy as np
|
||
|
from os import path as osp
|
||
|
from shutil import copyfile
|
||
|
|
||
|
|
||
|
def prepare_seed(rand_seed):
|
||
|
random.seed(rand_seed)
|
||
|
np.random.seed(rand_seed)
|
||
|
torch.manual_seed(rand_seed)
|
||
|
torch.cuda.manual_seed(rand_seed)
|
||
|
torch.cuda.manual_seed_all(rand_seed)
|
||
|
|
||
|
|
||
|
def prepare_logger(xargs):
|
||
|
args = copy.deepcopy(xargs)
|
||
|
from xautodl.log_utils import Logger
|
||
|
|
||
|
logger = Logger(args.save_dir, args.rand_seed)
|
||
|
logger.log("Main Function with logger : {:}".format(logger))
|
||
|
logger.log("Arguments : -------------------------------")
|
||
|
for name, value in args._get_kwargs():
|
||
|
logger.log("{:16} : {:}".format(name, value))
|
||
|
logger.log("Python Version : {:}".format(sys.version.replace("\n", " ")))
|
||
|
logger.log("Pillow Version : {:}".format(PIL.__version__))
|
||
|
logger.log("PyTorch Version : {:}".format(torch.__version__))
|
||
|
logger.log("cuDNN Version : {:}".format(torch.backends.cudnn.version()))
|
||
|
logger.log("CUDA available : {:}".format(torch.cuda.is_available()))
|
||
|
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||
|
logger.log(
|
||
|
"CUDA_VISIBLE_DEVICES : {:}".format(
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"]
|
||
|
if "CUDA_VISIBLE_DEVICES" in os.environ
|
||
|
else "None"
|
||
|
)
|
||
|
)
|
||
|
return logger
|
||
|
|
||
|
|
||
|
def get_machine_info():
|
||
|
info = "Python Version : {:}".format(sys.version.replace("\n", " "))
|
||
|
info += "\nPillow Version : {:}".format(PIL.__version__)
|
||
|
info += "\nPyTorch Version : {:}".format(torch.__version__)
|
||
|
info += "\ncuDNN Version : {:}".format(torch.backends.cudnn.version())
|
||
|
info += "\nCUDA available : {:}".format(torch.cuda.is_available())
|
||
|
info += "\nCUDA GPU numbers : {:}".format(torch.cuda.device_count())
|
||
|
if "CUDA_VISIBLE_DEVICES" in os.environ:
|
||
|
info += "\nCUDA_VISIBLE_DEVICES={:}".format(os.environ["CUDA_VISIBLE_DEVICES"])
|
||
|
else:
|
||
|
info += "\nDoes not set CUDA_VISIBLE_DEVICES"
|
||
|
return info
|
||
|
|
||
|
|
||
|
def save_checkpoint(state, filename, logger):
|
||
|
if osp.isfile(filename):
|
||
|
if hasattr(logger, "log"):
|
||
|
logger.log(
|
||
|
"Find {:} exist, delete is at first before saving".format(filename)
|
||
|
)
|
||
|
os.remove(filename)
|
||
|
torch.save(state, filename)
|
||
|
assert osp.isfile(
|
||
|
filename
|
||
|
), "save filename : {:} failed, which is not found.".format(filename)
|
||
|
if hasattr(logger, "log"):
|
||
|
logger.log("save checkpoint into {:}".format(filename))
|
||
|
return filename
|
||
|
|
||
|
|
||
|
def copy_checkpoint(src, dst, logger):
|
||
|
if osp.isfile(dst):
|
||
|
if hasattr(logger, "log"):
|
||
|
logger.log("Find {:} exist, delete is at first before saving".format(dst))
|
||
|
os.remove(dst)
|
||
|
copyfile(src, dst)
|
||
|
if hasattr(logger, "log"):
|
||
|
logger.log("copy the file from {:} into {:}".format(src, dst))
|