autodl-projects/exps-rnn/train_rnn_base.py
2019-03-30 02:10:20 +08:00

74 lines
2.9 KiB
Python

import os, gc, sys, math, time, glob, random, argparse
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import multiprocessing
from pathlib import Path
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
print ('lib-dir : {:}'.format(lib_dir))
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
from utils import AverageMeter, time_string, time_file_str, convert_secs2time
from utils import print_log, obtain_accuracy
from utils import count_parameters_in_MB
from nas_rnn import DARTS_V1, DARTS_V2, GDAS
from train_rnn_utils import main_procedure
from scheduler import load_config
Networks = {'DARTS_V1': DARTS_V1,
'DARTS_V2': DARTS_V2,
'GDAS' : GDAS}
parser = argparse.ArgumentParser("RNN")
parser.add_argument('--arch', type=str, choices=Networks.keys(), help='the network architecture')
parser.add_argument('--config_path', type=str, help='the training configure for the discovered model')
# log
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
parser.add_argument('--manualSeed', type=int, help='manual seed')
parser.add_argument('--threads', type=int, default=4, help='the number of threads')
args = parser.parse_args()
assert torch.cuda.is_available(), 'torch.cuda is not available'
if args.manualSeed is None:
args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
cudnn.benchmark = True
cudnn.enabled = True
torch.manual_seed(args.manualSeed)
torch.cuda.manual_seed_all(args.manualSeed)
torch.set_num_threads(args.threads)
def main():
# Init logger
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
if not os.path.isdir(args.save_path):
os.makedirs(args.save_path)
log = open(os.path.join(args.save_path, 'log-seed-{:}-{:}.txt'.format(args.manualSeed, time_file_str())), 'w')
print_log('save path : {:}'.format(args.save_path), log)
state = {k: v for k, v in args._get_kwargs()}
print_log(state, log)
print_log("Random Seed: {}".format(args.manualSeed), log)
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
print_log("Torch version : {}".format(torch.__version__), log)
print_log("CUDA version : {}".format(torch.version.cuda), log)
print_log("cuDNN version : {}".format(cudnn.version()), log)
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
print_log("Num of CPUs : {}".format(multiprocessing.cpu_count()), log)
config = load_config( args.config_path )
genotype = Networks[ args.arch ]
main_procedure(config, genotype, args.save_path, args.print_freq, log)
log.close()
if __name__ == '__main__':
main()