init
This commit is contained in:
parent
13e908f4df
commit
4eb1a5ccf9
10
README.md
10
README.md
@ -15,3 +15,13 @@ conda install pytorch torchvision cuda100 -c pytorch
|
||||
Searching CNNs
|
||||
```
|
||||
```
|
||||
|
||||
Train the Searched RNN
|
||||
```
|
||||
bash ./scripts-rnn/train-PTB.sh 0 DARTS_V1
|
||||
bash ./scripts-rnn/train-PTB.sh 0 DARTS_V2
|
||||
bash ./scripts-rnn/train-PTB.sh 0 GDAS
|
||||
bash ./scripts-rnn/train-WT2.sh 0 DARTS_V1
|
||||
bash ./scripts-rnn/train-WT2.sh 0 DARTS_V2
|
||||
bash ./scripts-rnn/train-WT2.sh 0 GDAS
|
||||
```
|
||||
|
100
data/ImageNet-100.txt
Normal file
100
data/ImageNet-100.txt
Normal file
@ -0,0 +1,100 @@
|
||||
n01532829
|
||||
n01560419
|
||||
n01580077
|
||||
n01614925
|
||||
n01664065
|
||||
n01751748
|
||||
n01871265
|
||||
n01924916
|
||||
n02087394
|
||||
n02091134
|
||||
n02091244
|
||||
n02094433
|
||||
n02097209
|
||||
n02102040
|
||||
n02102480
|
||||
n02105251
|
||||
n02106662
|
||||
n02108422
|
||||
n02108551
|
||||
n02123597
|
||||
n02165105
|
||||
n02190166
|
||||
n02268853
|
||||
n02279972
|
||||
n02408429
|
||||
n02412080
|
||||
n02443114
|
||||
n02488702
|
||||
n02509815
|
||||
n02606052
|
||||
n02701002
|
||||
n02782093
|
||||
n02794156
|
||||
n02802426
|
||||
n02804414
|
||||
n02808440
|
||||
n02906734
|
||||
n02917067
|
||||
n02950826
|
||||
n02963159
|
||||
n03017168
|
||||
n03042490
|
||||
n03045698
|
||||
n03063689
|
||||
n03065424
|
||||
n03100240
|
||||
n03109150
|
||||
n03124170
|
||||
n03131574
|
||||
n03272562
|
||||
n03345487
|
||||
n03443371
|
||||
n03461385
|
||||
n03527444
|
||||
n03690938
|
||||
n03692522
|
||||
n03721384
|
||||
n03729826
|
||||
n03792782
|
||||
n03838899
|
||||
n03843555
|
||||
n03874293
|
||||
n03877472
|
||||
n03877845
|
||||
n03908618
|
||||
n03929660
|
||||
n03930630
|
||||
n03933933
|
||||
n03970156
|
||||
n03976657
|
||||
n03982430
|
||||
n04004767
|
||||
n04065272
|
||||
n04141975
|
||||
n04146614
|
||||
n04152593
|
||||
n04192698
|
||||
n04200800
|
||||
n04204347
|
||||
n04317175
|
||||
n04326547
|
||||
n04344873
|
||||
n04370456
|
||||
n04389033
|
||||
n04501370
|
||||
n04515003
|
||||
n04542943
|
||||
n04554684
|
||||
n04562935
|
||||
n04596742
|
||||
n04597913
|
||||
n04606251
|
||||
n07583066
|
||||
n07718472
|
||||
n07734744
|
||||
n07873807
|
||||
n07880968
|
||||
n09229709
|
||||
n12768682
|
||||
n12998815
|
@ -1,90 +0,0 @@
|
||||
# EraseReLU: A Simple Way to Ease the Training of Deep Convolution Neural Networks
|
||||
|
||||
This project implements [this paper](https://arxiv.org/abs/1709.07634) in [PyTorch](pytorch.org). The implementation refers to [ResNeXt-DenseNet](https://github.com/D-X-Y/ResNeXt-DenseNet)
|
||||
|
||||
## Usage
|
||||
All the model definations are located in the directory `models`.
|
||||
|
||||
All the training scripts are located in the directory `scripts` and `Xscripts`.
|
||||
|
||||
To train the ResNet-110 with EraseReLU on CIFAR-10:
|
||||
```bash
|
||||
sh scripts/warmup_train_2gpu.sh resnet110_erase cifar10
|
||||
```
|
||||
|
||||
To train the original ResNet-110 on CIFAR-10:
|
||||
```bash
|
||||
sh scripts/warmup_train_2gpu.sh resnet110 cifar10
|
||||
```
|
||||
|
||||
### MiniImageNet for PatchShuffle
|
||||
```
|
||||
sh scripts-shuffle/train_resnet_00000.sh ResNet18
|
||||
sh scripts-shuffle/train_resnet_10000.sh ResNet18
|
||||
sh scripts-shuffle/train_resnet_11000.sh ResNet18
|
||||
```
|
||||
|
||||
```
|
||||
sh scripts-shuffle/train_pmd_00000.sh PMDNet18_300
|
||||
sh scripts-shuffle/train_pmd_00000.sh PMDNet34_300
|
||||
sh scripts-shuffle/train_pmd_00000.sh PMDNet50_300
|
||||
|
||||
sh scripts-shuffle/train_pmd_11000.sh PMDNet18_300
|
||||
sh scripts-shuffle/train_pmd_11000.sh PMDNet34_300
|
||||
sh scripts-shuffle/train_pmd_11000.sh PMDNet50_300
|
||||
```
|
||||
|
||||
### ImageNet
|
||||
- Use the scripts `train_imagenet.sh` to train models in PyTorch.
|
||||
- Or you can use the codes in `extra_torch` to train models in Torch.
|
||||
|
||||
#### Group Noramlization
|
||||
```
|
||||
sh Xscripts/train_vgg_gn.sh 0,1,2,3,4,5,6,7 vgg16_gn 256
|
||||
sh Xscripts/train_vgg_gn.sh 0,1,2,3,4,5,6,7 vgg16_gn 64
|
||||
sh Xscripts/train_vgg_gn.sh 0,1,2,3,4,5,6,7 vgg16_gn 16
|
||||
sh Xscripts/train_res_gn.sh 0,1,2,3,4,5,6,7 resnext50_32_4_gn 16
|
||||
```
|
||||
|
||||
| Model | Batch Size | Top-1 Error | Top-5 Errpr |
|
||||
|:--------------:|:----------:|:-----------:|:-----------:|
|
||||
| VGG16-GN | 256 | 28.82 | 9.64 |
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
| Model | Error on CIFAR-10 | Error on CIFAR-100|
|
||||
|:--------------:|:-----------------:|:-----------------:|
|
||||
| ResNet-56 | 6.97 | 30.60 |
|
||||
| ResNet-56 (ER) | 6.23 | 28.56 |
|
||||
|
||||
|
||||
## Citation
|
||||
If you find this project helos your research, please consider cite the paper:
|
||||
```
|
||||
@article{dong2017eraserelu,
|
||||
title={EraseReLU: A Simple Way to Ease the Training of Deep Convolution Neural Networks},
|
||||
author={Dong, Xuanyi and Kang, Guoliang and Zhan, Kun and Yang, Yi},
|
||||
journal={arXiv preprint arXiv:1709.07634},
|
||||
year={2017}
|
||||
}
|
||||
```
|
||||
|
||||
## Download the ImageNet dataset
|
||||
The ImageNet Large Scale Visual Recognition Challenge (ILSVRC) dataset has 1000 categories and 1.2 million images. The images do not need to be preprocessed or packaged in any database, but the validation images need to be moved into appropriate subfolders.
|
||||
|
||||
1. Download the images from http://image-net.org/download-images
|
||||
|
||||
2. Extract the training data:
|
||||
```bash
|
||||
mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
|
||||
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
|
||||
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
|
||||
cd ..
|
||||
```
|
||||
|
||||
3. Extract the validation data and move images to subfolders:
|
||||
```bash
|
||||
mkdir val && mv ILSVRC2012_img_val.tar val/ && cd val && tar -xvf ILSVRC2012_img_val.tar
|
||||
wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh | bash
|
||||
```
|
@ -1,5 +1,15 @@
|
||||
# ImageNet
|
||||
|
||||
The class names of ImageNet-1K are in `classes.txt`.
|
||||
|
||||
# A 100-class subset of ImageNet-1K : ImageNet-100
|
||||
|
||||
The class names of ImageNet-100 are in `ImageNet-100.txt`.
|
||||
|
||||
Run `python split-imagenet.py` will automatically create ImageNet-100 based on the data of ImageNet-1K. By default, we assume the data of ImageNet-1K locates at `~/.torch/ILSVRC2012`. If your data is in a different location, you need to modify line-19 and line-20 in `split-imagenet.py`.
|
||||
|
||||
# Tiny-ImageNet
|
||||
The official website is [here](https://tiny-imagenet.herokuapp.com/). Please run `python tiny-imagenet.py` to generate the correct format of Tiny ImageNet for training.
|
||||
|
||||
# PTB and WT2
|
||||
`bash Get-PTB-WT2.sh`
|
||||
Run `bash Get-PTB-WT2.sh` to download the data.
|
||||
|
37
data/split-imagenet.py
Normal file
37
data/split-imagenet.py
Normal file
@ -0,0 +1,37 @@
|
||||
import os, sys, random
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def sample_100_cls():
|
||||
with open('classes.txt') as f:
|
||||
content = f.readlines()
|
||||
content = [x.strip() for x in content]
|
||||
random.seed(111)
|
||||
classes = random.sample(content, 100)
|
||||
classes.sort()
|
||||
with open('ImageNet-100.txt', 'w') as f:
|
||||
for cls in classes: f.write('{:}\n'.format(cls))
|
||||
print('-'*100)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#sample_100_cls()
|
||||
IN1K_root = Path.home() / '.torch' / 'ILSVRC2012'
|
||||
IN100_root = Path.home() / '.torch' / 'ILSVRC2012-100'
|
||||
assert IN1K_root.exists(), 'ImageNet directory does not exist : {:}'.format(IN1K_root)
|
||||
print ('Create soft link from ImageNet directory into : {:}'.format(IN100_root))
|
||||
with open('ImageNet-100.txt', 'r') as f:
|
||||
classes = f.readlines()
|
||||
classes = [x.strip() for x in classes]
|
||||
for sub in ['train', 'val']:
|
||||
xdir = IN100_root / sub
|
||||
if not xdir.exists(): xdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for idx, cls in enumerate(classes):
|
||||
xdir = IN1K_root / 'train' / cls
|
||||
assert xdir.exists(), '{:} does not exist'.format(xdir)
|
||||
os.system('ln -s {:} {:}'.format(xdir, IN100_root / 'train' / cls))
|
||||
|
||||
xdir = IN1K_root / 'val' / cls
|
||||
assert xdir.exists(), '{:} does not exist'.format(xdir)
|
||||
os.system('ln -s {:} {:}'.format(xdir, IN100_root / 'val' / cls))
|
@ -1,397 +0,0 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from nas import Network, NetworkACC2, NetworkV3, NetworkV4, NetworkV5, NetworkFACC1
|
||||
from nas import return_alphas_str
|
||||
from train_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
Networks = {'base': Network, 'acc2': NetworkACC2, 'facc1': NetworkFACC1, 'NetworkV3': NetworkV3, 'NetworkV4': NetworkV4, 'NetworkV5': NetworkV5}
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--dataset', type=str, choices=['cifar10', 'cifar100'], help='Choose between Cifar10/100 and ImageNet.')
|
||||
parser.add_argument('--arch', type=str, choices=Networks.keys(), help='Choose networks.')
|
||||
parser.add_argument('--batch_size', type=int, help='the batch size')
|
||||
parser.add_argument('--learning_rate_max', type=float, help='initial learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate')
|
||||
parser.add_argument('--tau_max', type=float, help='initial tau')
|
||||
parser.add_argument('--tau_min', type=float, help='minimum tau')
|
||||
parser.add_argument('--momentum', type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, help='weight decay')
|
||||
parser.add_argument('--epochs', type=int, help='num of training epochs')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
#
|
||||
parser.add_argument('--init_channels', type=int, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, help='total number of layers')
|
||||
#
|
||||
parser.add_argument('--cutout', type=int, help='cutout length, negative means no cutout')
|
||||
parser.add_argument('--grad_clip', type=float, help='gradient clipping')
|
||||
parser.add_argument('--model_config', type=str , help='the model configuration')
|
||||
|
||||
# resume
|
||||
parser.add_argument('--resume', type=str , help='the resume path')
|
||||
parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model')
|
||||
# split data
|
||||
parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
args.dataset = args.dataset.lower()
|
||||
|
||||
# Mean + Std
|
||||
if args.dataset == 'cifar10':
|
||||
mean = [x / 255 for x in [125.3, 123.0, 113.9]]
|
||||
std = [x / 255 for x in [63.0, 62.1, 66.7]]
|
||||
elif args.dataset == 'cifar100':
|
||||
mean = [x / 255 for x in [129.3, 124.1, 112.4]]
|
||||
std = [x / 255 for x in [68.2, 65.4, 70.4]]
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Argumentation
|
||||
if args.dataset == 'cifar10' or args.dataset == 'cifar100':
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(mean, std)]
|
||||
if args.cutout > 0 : lists += [Cutout(args.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Datasets
|
||||
if args.dataset == 'cifar10':
|
||||
train_data = dset.CIFAR10(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes = 10
|
||||
elif args.dataset == 'cifar100':
|
||||
train_data = dset.CIFAR100(args.data_path, train= True, transform=train_transform, download=True)
|
||||
test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform , download=True)
|
||||
num_classes = 100
|
||||
else:
|
||||
raise TypeError("Unknow dataset : {:}".format(args.dataset))
|
||||
# Data Loader
|
||||
if args.validate:
|
||||
indices = list(range(len(train_data)))
|
||||
split = int(args.train_portion * len(indices))
|
||||
random.shuffle(indices)
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
test_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size,
|
||||
sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:]),
|
||||
pin_memory=True, num_workers=args.workers)
|
||||
else:
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)
|
||||
|
||||
# network and criterion
|
||||
criterion = torch.nn.CrossEntropyLoss().cuda()
|
||||
basemodel = Networks[args.arch](args.init_channels, num_classes, args.layers)
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log)
|
||||
print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log)
|
||||
|
||||
# optimizer and LR-scheduler
|
||||
base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
#base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay)
|
||||
base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
|
||||
arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth')
|
||||
if args.resume is not None and os.path.isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes = 0, {}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
config = load_config(args.model_config)
|
||||
|
||||
if args.only_base:
|
||||
print_log('---- Only Train the Searched Model ----', log)
|
||||
main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log)
|
||||
return
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
base_scheduler.step()
|
||||
|
||||
basemodel.set_tau( args.tau_max - epoch*1.0/args.epochs*(args.tau_max-args.tau_min) )
|
||||
#if epoch + 1 == args.epochs:
|
||||
# torch.cuda.empty_cache()
|
||||
# basemodel.set_gumbel(False)
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}] [Batch={:d}], tau={:}'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr()), args.batch_size, basemodel.get_tau()), log)
|
||||
|
||||
genotype = basemodel.genotype()
|
||||
print_log('genotype = {:}'.format(genotype), log)
|
||||
|
||||
print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log)
|
||||
|
||||
# training
|
||||
if epoch + 1 == args.epochs:
|
||||
train_acc1, train_acc5, train_obj, train_time \
|
||||
= train_joint(train_loader, test_loader, model, criterion, base_optimizer, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
else:
|
||||
train_acc1, train_acc5, train_obj, train_time \
|
||||
= train_base(train_loader, None, model, criterion, base_optimizer, None, epoch, log)
|
||||
total_train_time += train_time
|
||||
Arch__acc1, Arch__acc5, Arch__obj, train_time \
|
||||
= train_arch(None , test_loader, model, criterion, None, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
# validation
|
||||
valid_acc1, valid_acc5, valid_obj = infer(test_loader, model, criterion, epoch, log)
|
||||
print_log('{:03d}/{:03d}, Train-Accuracy = {:.2f}, Arch-Accuracy = {:.2f}, Test-Accuracy = {:.2f}'.format(epoch, args.epochs, train_acc1, Arch__acc1, valid_acc1), log)
|
||||
|
||||
# save genotype
|
||||
genotypes[epoch] = basemodel.genotype()
|
||||
# save checkpoint
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict(),
|
||||
'base_scheduler' : base_scheduler.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, args.dataset, args.data_path, args, basemodel.genotype(), 36, 20, log)
|
||||
log.close()
|
||||
|
||||
|
||||
def train_base(train_queue, _, model, criterion, base_optimizer, __, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-BASE ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def train_arch(_, valid_queue, model, criterion, __, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
outputs = model(inputs)
|
||||
arch_loss = criterion(outputs, targets)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(outputs.data, targets.data, topk=(1, 5))
|
||||
objs.update(arch_loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' TRAIN-ARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def train_joint(train_queue, valid_queue, model, criterion, base_optimizer, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
valid_iter = iter(valid_queue)
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# get a random minibatch from the search queue with replacement
|
||||
try:
|
||||
input_search, target_search = next(valid_iter)
|
||||
except:
|
||||
valid_iter = iter(valid_queue)
|
||||
input_search, target_search = next(valid_iter)
|
||||
|
||||
target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
output_search = model(input_search)
|
||||
arch_loss = criterion(output_search, target_search)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
def infer(valid_queue, model, criterion, epoch, log):
|
||||
objs, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
targets = targets.cuda(non_blocking=True)
|
||||
|
||||
logits = model(inputs)
|
||||
loss = criterion(logits, targets)
|
||||
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
objs.update(loss.item() , batch)
|
||||
top1.update(prec1.item(), batch)
|
||||
top5.update(prec5.item(), batch)
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(loss=objs, top1=top1, top5=top5)
|
||||
print_log(Sstr + ' ' + Lstr, log)
|
||||
|
||||
return top1.avg, top5.avg, objs.avg
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,312 +0,0 @@
|
||||
import os, sys, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from datasets import TieredImageNet, MetaBatchSampler
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import Cutout, count_parameters_in_MB
|
||||
from meta_nas import return_alphas_str, MetaNetwork
|
||||
from train_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
Networks = {'meta': MetaNetwork}
|
||||
|
||||
parser = argparse.ArgumentParser("cifar")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--arch', type=str, choices=Networks.keys(), help='Choose networks.')
|
||||
parser.add_argument('--n_way', type=int, help='N-WAY.')
|
||||
parser.add_argument('--k_shot', type=int, help='K-SHOT.')
|
||||
# Learning Parameters
|
||||
parser.add_argument('--learning_rate_max', type=float, help='initial learning rate')
|
||||
parser.add_argument('--learning_rate_min', type=float, help='minimum learning rate')
|
||||
parser.add_argument('--momentum', type=float, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, help='weight decay')
|
||||
parser.add_argument('--epochs', type=int, help='num of training epochs')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
#
|
||||
parser.add_argument('--init_channels', type=int, help='num of init channels')
|
||||
parser.add_argument('--layers', type=int, help='total number of layers')
|
||||
#
|
||||
parser.add_argument('--cutout', type=int, help='cutout length, negative means no cutout')
|
||||
parser.add_argument('--grad_clip', type=float, help='gradient clipping')
|
||||
parser.add_argument('--model_config', type=str , help='the model configuration')
|
||||
|
||||
# resume
|
||||
parser.add_argument('--resume', type=str , help='the resume path')
|
||||
parser.add_argument('--only_base',action='store_true', default=False, help='only train the searched model')
|
||||
# split data
|
||||
parser.add_argument('--validate', action='store_true', default=False, help='split train-data int train/val or not')
|
||||
parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
|
||||
# log
|
||||
parser.add_argument('--workers', type=int, default=2, help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
|
||||
# Mean + Std
|
||||
means, stds = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
|
||||
# Data Argumentation
|
||||
lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(),
|
||||
transforms.Normalize(means, stds)]
|
||||
if args.cutout > 0 : lists += [Cutout(args.cutout)]
|
||||
train_transform = transforms.Compose(lists)
|
||||
test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(means, stds)])
|
||||
|
||||
train_data = TieredImageNet(args.data_path, 'train', train_transform)
|
||||
test_data = TieredImageNet(args.data_path, 'val' , test_transform )
|
||||
|
||||
train_sampler = MetaBatchSampler(train_data.labels, args.n_way, args.k_shot * 2, len(train_data) // (args.n_way*args.k_shot))
|
||||
test_sampler = MetaBatchSampler( test_data.labels, args.n_way, args.k_shot * 2, len( test_data) // (args.n_way*args.k_shot))
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(train_data, batch_sampler=train_sampler)
|
||||
test_loader = torch.utils.data.DataLoader( test_data, batch_sampler= test_sampler)
|
||||
|
||||
# network
|
||||
basemodel = Networks[args.arch](args.init_channels, args.layers, head='imagenet')
|
||||
model = torch.nn.DataParallel(basemodel).cuda()
|
||||
print_log("Parameter size = {:.3f} MB".format(count_parameters_in_MB(basemodel.base_parameters())), log)
|
||||
print_log("Train-transformation : {:}\nTest--transformation : {:}".format(train_transform, test_transform), log)
|
||||
|
||||
# optimizer and LR-scheduler
|
||||
#base_optimizer = torch.optim.SGD (basemodel.base_parameters(), args.learning_rate_max, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
base_optimizer = torch.optim.Adam(basemodel.base_parameters(), lr=args.learning_rate_max, betas=(0.5, 0.999), weight_decay=args.weight_decay)
|
||||
base_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(base_optimizer, float(args.epochs), eta_min=args.learning_rate_min)
|
||||
arch_optimizer = torch.optim.Adam(basemodel.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-meta-search.pth')
|
||||
if args.resume is not None and os.path.isfile(args.resume):
|
||||
checkpoint = torch.load(args.resume)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load resume from {:} with start-epoch = {:}'.format(args.resume, start_epoch), log)
|
||||
elif os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
basemodel.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
base_scheduler.load_state_dict( checkpoint['base_scheduler'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes = 0, {}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
config = load_config(args.model_config)
|
||||
|
||||
if args.only_base:
|
||||
print_log('---- Only Train the Searched Model ----', log)
|
||||
CIFAR_DATA_DIR = os.environ['TORCH_HOME'] + '/cifar.python'
|
||||
main_procedure(config, 'cifar10', CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log)
|
||||
return
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
base_scheduler.step()
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} [LR={:6.4f} ~ {:6.4f}]'.format(time_string(), epoch, args.epochs, need_time, min(base_scheduler.get_lr()), max(base_scheduler.get_lr())), log)
|
||||
|
||||
genotype = basemodel.genotype()
|
||||
print_log('genotype = {:}'.format(genotype), log)
|
||||
print_log('{:03d}/{:03d} alphas :\n{:}'.format(epoch, args.epochs, return_alphas_str(basemodel)), log)
|
||||
|
||||
# training
|
||||
train_acc1, train_obj, train_time \
|
||||
= train(train_loader, test_loader, model, args.n_way, base_optimizer, arch_optimizer, epoch, log)
|
||||
total_train_time += train_time
|
||||
# validation
|
||||
valid_acc1, valid_obj = infer(test_loader, model, epoch, args.n_way, log)
|
||||
|
||||
print_log('META -> {:}-way {:}-shot : {:03d}/{:03d} : Train Acc : {:.2f}, Test Acc : {:.2f}'.format(args.n_way, args.k_shot, epoch, args.epochs, train_acc1, valid_acc1), log)
|
||||
# save genotype
|
||||
genotypes[epoch] = basemodel.genotype()
|
||||
|
||||
# save checkpoint
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': basemodel.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict(),
|
||||
'base_scheduler' : base_scheduler.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
CIFAR_DATA_DIR = os.environ['TORCH_HOME'] + '/cifar.python'
|
||||
print_log('test for CIFAR-10', log)
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, 'cifar10' , CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log)
|
||||
print_log('test for CIFAR-100', log)
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, 'cifar100', CIFAR_DATA_DIR, args, basemodel.genotype(), 36, 20, log)
|
||||
log.close()
|
||||
|
||||
|
||||
|
||||
def euclidean_dist(A, B):
|
||||
na, da = A.size()
|
||||
nb, db = B.size()
|
||||
assert da == db, 'invalid feature dim : {:} vs. {:}'.format(da, db)
|
||||
X, Y = A.view(na, 1, da), B.view(1, nb, db)
|
||||
return torch.pow(X-Y, 2).sum(2)
|
||||
|
||||
|
||||
|
||||
def get_loss(features, targets, n_way):
|
||||
classes = torch.unique(targets)
|
||||
shot = features.size(0) // n_way // 2
|
||||
|
||||
support_index, query_index, labels = [], [], []
|
||||
for idx, cls in enumerate( classes.tolist() ):
|
||||
indexs = (targets == cls).nonzero().view(-1).tolist()
|
||||
support_index.append(indexs[:shot])
|
||||
query_index += indexs[shot:]
|
||||
labels += [idx] * shot
|
||||
query_features = features[query_index, :]
|
||||
support_features = features[support_index, :]
|
||||
support_features = torch.mean(support_features, dim=1)
|
||||
|
||||
labels = torch.LongTensor(labels).cuda(non_blocking=True)
|
||||
logits = -euclidean_dist(query_features, support_features)
|
||||
loss = F.cross_entropy(logits, labels)
|
||||
accuracy = obtain_accuracy(logits.data, labels.data, topk=(1,))[0]
|
||||
return loss, accuracy
|
||||
|
||||
|
||||
|
||||
def train(train_queue, valid_queue, model, n_way, base_optimizer, arch_optimizer, epoch, log):
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
objs, accuracies = AverageMeter(), AverageMeter()
|
||||
model.train()
|
||||
|
||||
valid_iter = iter(valid_queue)
|
||||
end = time.time()
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
|
||||
#inputs, targets = inputs.cuda(), targets.cuda(non_blocking=True)
|
||||
#targets = targets.cuda(non_blocking=True)
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
# get a random minibatch from the search queue with replacement
|
||||
try:
|
||||
input_search, target_search = next(valid_iter)
|
||||
except:
|
||||
valid_iter = iter(valid_queue)
|
||||
input_search, target_search = next(valid_iter)
|
||||
|
||||
#target_search = target_search.cuda(non_blocking=True)
|
||||
|
||||
# update the architecture
|
||||
arch_optimizer.zero_grad()
|
||||
feature_search = model(input_search)
|
||||
arch_loss, arch_accuracy = get_loss(feature_search, target_search, n_way)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# update the parameters
|
||||
base_optimizer.zero_grad()
|
||||
feature_model = model(inputs)
|
||||
model_loss, model_accuracy = get_loss(feature_model, targets, n_way)
|
||||
|
||||
model_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.module.base_parameters(), args.grad_clip)
|
||||
base_optimizer.step()
|
||||
|
||||
objs.update(model_loss.item() , batch)
|
||||
accuracies.update(model_accuracy.item(), batch)
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if step % args.print_freq == 0 or (step+1) == len(train_queue):
|
||||
Sstr = ' TRAIN-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(train_queue))
|
||||
Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f})'.format(loss=objs, top1=accuracies)
|
||||
Istr = 'I : {:}'.format( list(inputs.size()) )
|
||||
print_log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr, log)
|
||||
|
||||
return accuracies.avg, objs.avg, batch_time.sum
|
||||
|
||||
|
||||
|
||||
def infer(valid_queue, model, epoch, n_way, log):
|
||||
objs, accuracies = AverageMeter(), AverageMeter()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
batch, C, H, W = inputs.size()
|
||||
#targets = targets.cuda(non_blocking=True)
|
||||
|
||||
features = model(inputs)
|
||||
loss, accuracy = get_loss(features, targets, n_way)
|
||||
|
||||
objs.update(loss.item() , batch)
|
||||
accuracies.update(accuracy.item(), batch)
|
||||
|
||||
if step % (args.print_freq*4) == 0 or (step+1) == len(valid_queue):
|
||||
Sstr = ' VALID-SEARCH ' + time_string() + ' Epoch: [{:03d}][{:03d}/{:03d}]'.format(epoch, step, len(valid_queue))
|
||||
Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f}) Prec@1 {top1.val:.2f} ({top1.avg:.2f})'.format(loss=objs, top1=accuracies)
|
||||
print_log(Sstr + ' ' + Lstr, log)
|
||||
|
||||
return accuracies.avg, objs.avg
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,276 +0,0 @@
|
||||
import os, gc, sys, math, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import count_parameters_in_MB
|
||||
from datasets import Corpus
|
||||
from nas_rnn import batchify, get_batch, repackage_hidden
|
||||
from nas_rnn import DARTSCellSearch, RNNModelSearch
|
||||
from train_rnn_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
parser = argparse.ArgumentParser("RNN")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--emsize', type=int, default=300, help='size of word embeddings')
|
||||
parser.add_argument('--nhid', type=int, default=300, help='number of hidden units per layer')
|
||||
parser.add_argument('--nhidlast', type=int, default=300, help='number of hidden units for the last rnn layer')
|
||||
parser.add_argument('--clip', type=float, default=0.25, help='gradient clipping')
|
||||
parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
|
||||
parser.add_argument('--batch_size', type=int, default=256, help='the batch size')
|
||||
parser.add_argument('--eval_batch_size', type=int, default=10, help='the evaluation batch size')
|
||||
parser.add_argument('--bptt', type=int, default=35, help='the sequence length')
|
||||
# DropOut
|
||||
parser.add_argument('--dropout', type=float, default=0.75, help='dropout applied to layers (0 = no dropout)')
|
||||
parser.add_argument('--dropouth', type=float, default=0.25, help='dropout for hidden nodes in rnn layers (0 = no dropout)')
|
||||
parser.add_argument('--dropoutx', type=float, default=0.75, help='dropout for input nodes in rnn layers (0 = no dropout)')
|
||||
parser.add_argument('--dropouti', type=float, default=0.2, help='dropout for input embedding layers (0 = no dropout)')
|
||||
parser.add_argument('--dropoute', type=float, default=0, help='dropout to remove words from embedding layer (0 = no dropout)')
|
||||
# Regularization
|
||||
parser.add_argument('--lr', type=float, default=20, help='initial learning rate')
|
||||
parser.add_argument('--alpha', type=float, default=0, help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)')
|
||||
parser.add_argument('--beta', type=float, default=1e-3, help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)')
|
||||
parser.add_argument('--wdecay', type=float, default=5e-7, help='weight decay applied to all weights')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_lr', type=float, default=3e-3, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_wdecay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
parser.add_argument('--config_path', type=str, help='the training configure for the discovered model')
|
||||
# acceleration
|
||||
parser.add_argument('--tau_max', type=float, help='initial tau')
|
||||
parser.add_argument('--tau_min', type=float, help='minimum tau')
|
||||
# log
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
if args.nhidlast < 0:
|
||||
args.nhidlast = args.emsize
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
|
||||
# Dataset
|
||||
corpus = Corpus(args.data_path)
|
||||
train_data = batchify(corpus.train, args.batch_size, True)
|
||||
search_data = batchify(corpus.valid, args.batch_size, True)
|
||||
valid_data = batchify(corpus.valid, args.eval_batch_size, True)
|
||||
print_log("Train--Data Size : {:}".format(train_data.size()), log)
|
||||
print_log("Search-Data Size : {:}".format(search_data.size()), log)
|
||||
print_log("Valid--Data Size : {:}".format(valid_data.size()), log)
|
||||
|
||||
ntokens = len(corpus.dictionary)
|
||||
model = RNNModelSearch(ntokens, args.emsize, args.nhid, args.nhidlast,
|
||||
args.dropout, args.dropouth, args.dropoutx, args.dropouti, args.dropoute,
|
||||
DARTSCellSearch, None)
|
||||
model = model.cuda()
|
||||
print_log('model ==>> : {:}'.format(model), log)
|
||||
print_log('Parameter size : {:} MB'.format(count_parameters_in_MB(model)), log)
|
||||
|
||||
base_optimizer = torch.optim.SGD(model.base_parameters(), lr=args.lr, weight_decay=args.wdecay)
|
||||
arch_optimizer = torch.optim.Adam(model.arch_parameters(), lr=args.arch_lr, weight_decay=args.arch_wdecay)
|
||||
|
||||
config = load_config(args.config_path)
|
||||
print_log('Load config from {:} ==>>\n {:}'.format(args.config_path, config), log)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth')
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
model.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
valid_losses = checkpoint['valid_losses']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes, valid_losses = 0, {}, {-1:1e8}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
model.set_gumbel(True, False)
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
model.set_tau( args.tau_max - epoch*1.0/args.epochs*(args.tau_max-args.tau_min) )
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} || tau={:}'.format(time_string(), epoch, args.epochs, need_time, model.get_tau()), log)
|
||||
|
||||
# training
|
||||
data_time, train_time = train(model, base_optimizer, arch_optimizer, corpus, train_data, search_data, epoch, log)
|
||||
total_train_time += train_time
|
||||
# evaluation
|
||||
|
||||
# validation
|
||||
valid_loss = infer(model, corpus, valid_data, args.eval_batch_size)
|
||||
# save genotype
|
||||
if valid_loss < min( valid_losses.values() ): is_best = True
|
||||
else : is_best = False
|
||||
print_log('-'*10 + ' [Epoch={:03d}/{:03d}] : is-best={:}, validation-loss={:}, validation-PPL={:}'.format(epoch, args.epochs, is_best, valid_loss, math.exp(valid_loss)), log)
|
||||
print_log('{:}'.format(F.softmax(model.arch_weights, dim=-1)), log)
|
||||
print_log('genotype : {:}'.format(model.genotype()), log)
|
||||
|
||||
valid_losses[epoch] = valid_loss
|
||||
genotypes[epoch] = model.genotype()
|
||||
print_log(' the {:}-th genotype = {:}'.format(epoch, genotypes[epoch]), log)
|
||||
# save checkpoint
|
||||
if is_best:
|
||||
genotypes['best'] = model.genotype()
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': model.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'valid_losses' : valid_losses,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, genotypes['best'], args.save_path, args.print_freq, log)
|
||||
log.close()
|
||||
|
||||
|
||||
def train(model, base_optimizer, arch_optimizer, corpus, train_data, search_data, epoch, log):
|
||||
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
# Turn on training mode which enables dropout.
|
||||
total_loss = 0
|
||||
start_time = time.time()
|
||||
ntokens = len(corpus.dictionary)
|
||||
hidden_train, hidden_valid = model.init_hidden(args.batch_size), model.init_hidden(args.batch_size)
|
||||
|
||||
batch, i = 0, 0
|
||||
|
||||
while i < train_data.size(0) - 1 - 1:
|
||||
seq_len = int( args.bptt if np.random.random() < 0.95 else args.bptt / 2. )
|
||||
# Prevent excessively small or negative sequence lengths
|
||||
# seq_len = max(5, int(np.random.normal(bptt, 5)))
|
||||
# # There's a very small chance that it could select a very long sequence length resulting in OOM
|
||||
# seq_len = min(seq_len, args.bptt + args.max_seq_len_delta)
|
||||
for param_group in base_optimizer.param_groups:
|
||||
param_group['lr'] *= float( seq_len / args.bptt )
|
||||
|
||||
model.train()
|
||||
|
||||
data_valid, targets_valid = get_batch(search_data, i % (search_data.size(0) - 1), args.bptt)
|
||||
data_train, targets_train = get_batch(train_data , i, seq_len)
|
||||
|
||||
hidden_train = repackage_hidden(hidden_train)
|
||||
hidden_valid = repackage_hidden(hidden_valid)
|
||||
|
||||
data_time.update(time.time() - start_time)
|
||||
|
||||
# validation loss
|
||||
targets_valid = targets_valid.contiguous().view(-1)
|
||||
|
||||
arch_optimizer.step()
|
||||
log_prob, hidden_valid = model(data_valid, hidden_valid, return_h=False)
|
||||
arch_loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets_valid)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# model update
|
||||
base_optimizer.zero_grad()
|
||||
targets_train = targets_train.contiguous().view(-1)
|
||||
|
||||
log_prob, hidden_train, rnn_hs, dropped_rnn_hs = model(data_train, hidden_train, return_h=True)
|
||||
raw_loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets_train)
|
||||
|
||||
loss = raw_loss
|
||||
# Activiation Regularization
|
||||
if args.alpha > 0:
|
||||
loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
|
||||
# Temporal Activation Regularization (slowness)
|
||||
loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
|
||||
loss.backward()
|
||||
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs.
|
||||
nn.utils.clip_grad_norm_(model.base_parameters(), args.clip)
|
||||
base_optimizer.step()
|
||||
|
||||
for param_group in base_optimizer.param_groups:
|
||||
param_group['lr'] /= float( seq_len / args.bptt )
|
||||
|
||||
total_loss += raw_loss.item()
|
||||
gc.collect()
|
||||
|
||||
batch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
batch, i = batch + 1, i + seq_len
|
||||
|
||||
if batch % args.print_freq == 0 or i >= train_data.size(0) - 1 - 1:
|
||||
print_log(' || Epoch: {:03d} :: {:03d}/{:03d} '.format(epoch, batch, len(train_data) // args.bptt), log)
|
||||
#print_log(' || Epoch: {:03d} :: {:03d}/{:03d} = {:}'.format(epoch, batch, len(train_data) // args.bptt, model.genotype()), log)
|
||||
cur_loss = total_loss / args.print_freq
|
||||
print_log(' [TRAIN] Time : data {:.3f} ({:.3f}) batch {:.3f} ({:.3f}) Loss : {:}, PPL : {:}'.format(data_time.val, data_time.avg, batch_time.val, batch_time.avg, cur_loss, math.exp(cur_loss)), log)
|
||||
#print(F.softmax(model.arch_weights, dim=-1))
|
||||
total_loss = 0
|
||||
|
||||
return data_time.sum, batch_time.sum
|
||||
|
||||
|
||||
def infer(model, corpus, data_source, batch_size):
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
total_loss = 0
|
||||
ntokens = len(corpus.dictionary)
|
||||
hidden = model.init_hidden(batch_size)
|
||||
for i in range(0, data_source.size(0) - 1, args.bptt):
|
||||
data, targets = get_batch(data_source, i, args.bptt)
|
||||
targets = targets.view(-1)
|
||||
|
||||
log_prob, hidden = model(data, hidden)
|
||||
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets)
|
||||
|
||||
total_loss += loss.item() * len(data)
|
||||
|
||||
hidden = repackage_hidden(hidden)
|
||||
return total_loss / len(data_source)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,75 +0,0 @@
|
||||
import os, gc, sys, time, math
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import print_log, obtain_accuracy, AverageMeter
|
||||
from utils import time_string, convert_secs2time
|
||||
from utils import count_parameters_in_MB
|
||||
from datasets import Corpus
|
||||
from nas_rnn import batchify, get_batch, repackage_hidden
|
||||
from nas_rnn import DARTS
|
||||
from nas_rnn import DARTSCell, RNNModel
|
||||
from nas_rnn import basemodel as model
|
||||
from scheduler import load_config
|
||||
|
||||
|
||||
def main_procedure(config, genotype, print_freq, log):
|
||||
|
||||
print_log('-'*90, log)
|
||||
print_log('genotype : {:}'.format(genotype), log)
|
||||
print_log('config : {:}'.format(config.bptt), log)
|
||||
|
||||
corpus = Corpus(config.data_path)
|
||||
train_data = batchify(corpus.train, config.train_batch, True)
|
||||
valid_data = batchify(corpus.valid, config.eval_batch , True)
|
||||
test_data = batchify(corpus.test, config.test_batch , True)
|
||||
ntokens = len(corpus.dictionary)
|
||||
print_log("Train--Data Size : {:}".format(train_data.size()), log)
|
||||
print_log("Valid--Data Size : {:}".format(valid_data.size()), log)
|
||||
print_log("Test---Data Size : {:}".format( test_data.size()), log)
|
||||
print_log("ntokens = {:}".format(ntokens), log)
|
||||
|
||||
model = RNNModel(ntokens, config.emsize, config.nhid, config.nhidlast,
|
||||
config.dropout, config.dropouth, config.dropoutx, config.dropouti, config.dropoute,
|
||||
cell_cls=DARTSCell, genotype=genotype)
|
||||
model = model.cuda()
|
||||
print_log('Network =>\n{:}'.format(model), log)
|
||||
print_log('Genotype : {:}'.format(genotype), log)
|
||||
print_log('Parameters : {:.3f} MB'.format(count_parameters_in_MB(model)), log)
|
||||
|
||||
|
||||
print_log('--------------------- Finish Training ----------------', log)
|
||||
test_loss = evaluate(model, corpus, test_data , config.test_batch, config.bptt)
|
||||
print_log('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)), log)
|
||||
vali_loss = evaluate(model, corpus, valid_data, config.eval_batch, config.bptt)
|
||||
print_log('| End of training | valid loss {:5.2f} | valid ppl {:8.2f}'.format(vali_loss, math.exp(vali_loss)), log)
|
||||
|
||||
|
||||
|
||||
def evaluate(model, corpus, data_source, batch_size, bptt):
|
||||
# Turn on evaluation mode which disables dropout.
|
||||
model.eval()
|
||||
total_loss, total_length = 0.0, 0.0
|
||||
with torch.no_grad():
|
||||
ntokens = len(corpus.dictionary)
|
||||
hidden = model.init_hidden(batch_size)
|
||||
for i in range(0, data_source.size(0) - 1, bptt):
|
||||
data, targets = get_batch(data_source, i, bptt)
|
||||
targets = targets.view(-1)
|
||||
|
||||
log_prob, hidden = model(data, hidden)
|
||||
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets)
|
||||
|
||||
total_loss += loss.item() * len(data)
|
||||
total_length += len(data)
|
||||
hidden = repackage_hidden(hidden)
|
||||
return total_loss / total_length
|
||||
|
||||
if __name__ == '__main__':
|
||||
path = './configs/NAS-PTB-BASE.config'
|
||||
config = load_config(path)
|
||||
main_procedure(config, DARTS, 10, None)
|
@ -1,267 +0,0 @@
|
||||
import os, gc, sys, math, time, glob, random, argparse
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.datasets as dset
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torchvision.transforms as transforms
|
||||
from pathlib import Path
|
||||
lib_dir = (Path(__file__).parent / '..' / '..' / 'lib').resolve()
|
||||
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
||||
from utils import AverageMeter, time_string, convert_secs2time
|
||||
from utils import print_log, obtain_accuracy
|
||||
from utils import count_parameters_in_MB
|
||||
from datasets import Corpus
|
||||
from nas_rnn import batchify, get_batch, repackage_hidden
|
||||
from nas_rnn import DARTSCellSearch, RNNModelSearch
|
||||
from train_rnn_utils import main_procedure
|
||||
from scheduler import load_config
|
||||
|
||||
parser = argparse.ArgumentParser("RNN")
|
||||
parser.add_argument('--data_path', type=str, help='Path to dataset')
|
||||
parser.add_argument('--emsize', type=int, default=300, help='size of word embeddings')
|
||||
parser.add_argument('--nhid', type=int, default=300, help='number of hidden units per layer')
|
||||
parser.add_argument('--nhidlast', type=int, default=300, help='number of hidden units for the last rnn layer')
|
||||
parser.add_argument('--clip', type=float, default=0.25, help='gradient clipping')
|
||||
parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
|
||||
parser.add_argument('--batch_size', type=int, default=256, help='the batch size')
|
||||
parser.add_argument('--eval_batch_size', type=int, default=10, help='the evaluation batch size')
|
||||
parser.add_argument('--bptt', type=int, default=35, help='the sequence length')
|
||||
# DropOut
|
||||
parser.add_argument('--dropout', type=float, default=0.75, help='dropout applied to layers (0 = no dropout)')
|
||||
parser.add_argument('--dropouth', type=float, default=0.25, help='dropout for hidden nodes in rnn layers (0 = no dropout)')
|
||||
parser.add_argument('--dropoutx', type=float, default=0.75, help='dropout for input nodes in rnn layers (0 = no dropout)')
|
||||
parser.add_argument('--dropouti', type=float, default=0.2, help='dropout for input embedding layers (0 = no dropout)')
|
||||
parser.add_argument('--dropoute', type=float, default=0, help='dropout to remove words from embedding layer (0 = no dropout)')
|
||||
# Regularization
|
||||
parser.add_argument('--lr', type=float, default=20, help='initial learning rate')
|
||||
parser.add_argument('--alpha', type=float, default=0, help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)')
|
||||
parser.add_argument('--beta', type=float, default=1e-3, help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)')
|
||||
parser.add_argument('--wdecay', type=float, default=5e-7, help='weight decay applied to all weights')
|
||||
# architecture leraning rate
|
||||
parser.add_argument('--arch_lr', type=float, default=3e-3, help='learning rate for arch encoding')
|
||||
parser.add_argument('--arch_wdecay', type=float, default=1e-3, help='weight decay for arch encoding')
|
||||
parser.add_argument('--config_path', type=str, help='the training configure for the discovered model')
|
||||
# log
|
||||
parser.add_argument('--save_path', type=str, help='Folder to save checkpoints and log.')
|
||||
parser.add_argument('--print_freq', type=int, help='print frequency (default: 200)')
|
||||
parser.add_argument('--manualSeed', type=int, help='manual seed')
|
||||
args = parser.parse_args()
|
||||
|
||||
assert torch.cuda.is_available(), 'torch.cuda is not available'
|
||||
|
||||
if args.manualSeed is None:
|
||||
args.manualSeed = random.randint(1, 10000)
|
||||
if args.nhidlast < 0:
|
||||
args.nhidlast = args.emsize
|
||||
random.seed(args.manualSeed)
|
||||
cudnn.benchmark = True
|
||||
cudnn.enabled = True
|
||||
torch.manual_seed(args.manualSeed)
|
||||
torch.cuda.manual_seed_all(args.manualSeed)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
# Init logger
|
||||
args.save_path = os.path.join(args.save_path, 'seed-{:}'.format(args.manualSeed))
|
||||
if not os.path.isdir(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
log = open(os.path.join(args.save_path, 'log-seed-{:}.txt'.format(args.manualSeed)), 'w')
|
||||
print_log('save path : {}'.format(args.save_path), log)
|
||||
state = {k: v for k, v in args._get_kwargs()}
|
||||
print_log(state, log)
|
||||
print_log("Random Seed: {}".format(args.manualSeed), log)
|
||||
print_log("Python version : {}".format(sys.version.replace('\n', ' ')), log)
|
||||
print_log("Torch version : {}".format(torch.__version__), log)
|
||||
print_log("CUDA version : {}".format(torch.version.cuda), log)
|
||||
print_log("cuDNN version : {}".format(cudnn.version()), log)
|
||||
print_log("Num of GPUs : {}".format(torch.cuda.device_count()), log)
|
||||
|
||||
# Dataset
|
||||
corpus = Corpus(args.data_path)
|
||||
train_data = batchify(corpus.train, args.batch_size, True)
|
||||
search_data = batchify(corpus.valid, args.batch_size, True)
|
||||
valid_data = batchify(corpus.valid, args.eval_batch_size, True)
|
||||
print_log("Train--Data Size : {:}".format(train_data.size()), log)
|
||||
print_log("Search-Data Size : {:}".format(search_data.size()), log)
|
||||
print_log("Valid--Data Size : {:}".format(valid_data.size()), log)
|
||||
|
||||
ntokens = len(corpus.dictionary)
|
||||
model = RNNModelSearch(ntokens, args.emsize, args.nhid, args.nhidlast,
|
||||
args.dropout, args.dropouth, args.dropoutx, args.dropouti, args.dropoute,
|
||||
DARTSCellSearch, None)
|
||||
model = model.cuda()
|
||||
print_log('model ==>> : {:}'.format(model), log)
|
||||
print_log('Parameter size : {:} MB'.format(count_parameters_in_MB(model)), log)
|
||||
|
||||
base_optimizer = torch.optim.SGD(model.base_parameters(), lr=args.lr, weight_decay=args.wdecay)
|
||||
arch_optimizer = torch.optim.Adam(model.arch_parameters(), lr=args.arch_lr, weight_decay=args.arch_wdecay)
|
||||
|
||||
config = load_config(args.config_path)
|
||||
print_log('Load config from {:} ==>>\n {:}'.format(args.config_path, config), log)
|
||||
|
||||
# snapshot
|
||||
checkpoint_path = os.path.join(args.save_path, 'checkpoint-search.pth')
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
start_epoch = checkpoint['epoch']
|
||||
model.load_state_dict( checkpoint['state_dict'] )
|
||||
base_optimizer.load_state_dict( checkpoint['base_optimizer'] )
|
||||
arch_optimizer.load_state_dict( checkpoint['arch_optimizer'] )
|
||||
genotypes = checkpoint['genotypes']
|
||||
valid_losses = checkpoint['valid_losses']
|
||||
print_log('Load checkpoint from {:} with start-epoch = {:}'.format(checkpoint_path, start_epoch), log)
|
||||
else:
|
||||
start_epoch, genotypes, valid_losses = 0, {}, {-1:1e8}
|
||||
print_log('Train model-search from scratch.', log)
|
||||
|
||||
# Main loop
|
||||
start_time, epoch_time, total_train_time = time.time(), AverageMeter(), 0
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
|
||||
need_time = convert_secs2time(epoch_time.val * (args.epochs-epoch), True)
|
||||
print_log('\n==>>{:s} [Epoch={:03d}/{:03d}] {:s}'.format(time_string(), epoch, args.epochs, need_time), log)
|
||||
# training
|
||||
data_time, train_time = train(model, base_optimizer, arch_optimizer, corpus, train_data, search_data, epoch, log)
|
||||
total_train_time += train_time
|
||||
# evaluation
|
||||
|
||||
# validation
|
||||
valid_loss = infer(model, corpus, valid_data, args.eval_batch_size)
|
||||
# save genotype
|
||||
if valid_loss < min( valid_losses.values() ): is_best = True
|
||||
else : is_best = False
|
||||
print_log('-'*10 + ' [Epoch={:03d}/{:03d}] : is-best={:}, validation-loss={:}, validation-PPL={:}'.format(epoch, args.epochs, is_best, valid_loss, math.exp(valid_loss)), log)
|
||||
|
||||
valid_losses[epoch] = valid_loss
|
||||
genotypes[epoch] = model.genotype()
|
||||
print_log(' the {:}-th genotype = {:}'.format(epoch, genotypes[epoch]), log)
|
||||
# save checkpoint
|
||||
if is_best:
|
||||
genotypes['best'] = model.genotype()
|
||||
torch.save({'epoch' : epoch + 1,
|
||||
'args' : deepcopy(args),
|
||||
'state_dict': model.state_dict(),
|
||||
'genotypes' : genotypes,
|
||||
'valid_losses' : valid_losses,
|
||||
'base_optimizer' : base_optimizer.state_dict(),
|
||||
'arch_optimizer' : arch_optimizer.state_dict()},
|
||||
checkpoint_path)
|
||||
print_log('----> Save into {:}'.format(checkpoint_path), log)
|
||||
|
||||
|
||||
# measure elapsed time
|
||||
epoch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
|
||||
print_log('Finish with training time = {:}'.format( convert_secs2time(total_train_time, True) ), log)
|
||||
|
||||
# clear GPU cache
|
||||
torch.cuda.empty_cache()
|
||||
main_procedure(config, genotypes['best'], args.save_path, args.print_freq, log)
|
||||
log.close()
|
||||
|
||||
|
||||
def train(model, base_optimizer, arch_optimizer, corpus, train_data, search_data, epoch, log):
|
||||
|
||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||
# Turn on training mode which enables dropout.
|
||||
total_loss = 0
|
||||
start_time = time.time()
|
||||
ntokens = len(corpus.dictionary)
|
||||
hidden_train, hidden_valid = model.init_hidden(args.batch_size), model.init_hidden(args.batch_size)
|
||||
|
||||
batch, i = 0, 0
|
||||
|
||||
while i < train_data.size(0) - 1 - 1:
|
||||
seq_len = int( args.bptt if np.random.random() < 0.95 else args.bptt / 2. )
|
||||
# Prevent excessively small or negative sequence lengths
|
||||
# seq_len = max(5, int(np.random.normal(bptt, 5)))
|
||||
# # There's a very small chance that it could select a very long sequence length resulting in OOM
|
||||
# seq_len = min(seq_len, args.bptt + args.max_seq_len_delta)
|
||||
for param_group in base_optimizer.param_groups:
|
||||
param_group['lr'] *= float( seq_len / args.bptt )
|
||||
|
||||
model.train()
|
||||
|
||||
data_valid, targets_valid = get_batch(search_data, i % (search_data.size(0) - 1), args.bptt)
|
||||
data_train, targets_train = get_batch(train_data , i, seq_len)
|
||||
|
||||
hidden_train = repackage_hidden(hidden_train)
|
||||
hidden_valid = repackage_hidden(hidden_valid)
|
||||
|
||||
data_time.update(time.time() - start_time)
|
||||
|
||||
# validation loss
|
||||
targets_valid = targets_valid.contiguous().view(-1)
|
||||
|
||||
arch_optimizer.step()
|
||||
log_prob, hidden_valid = model(data_valid, hidden_valid, return_h=False)
|
||||
arch_loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets_valid)
|
||||
arch_loss.backward()
|
||||
arch_optimizer.step()
|
||||
|
||||
# model update
|
||||
base_optimizer.zero_grad()
|
||||
targets_train = targets_train.contiguous().view(-1)
|
||||
|
||||
log_prob, hidden_train, rnn_hs, dropped_rnn_hs = model(data_train, hidden_train, return_h=True)
|
||||
raw_loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets_train)
|
||||
|
||||
loss = raw_loss
|
||||
# Activiation Regularization
|
||||
if args.alpha > 0:
|
||||
loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
|
||||
# Temporal Activation Regularization (slowness)
|
||||
loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
|
||||
loss.backward()
|
||||
# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs.
|
||||
nn.utils.clip_grad_norm_(model.base_parameters(), args.clip)
|
||||
base_optimizer.step()
|
||||
|
||||
for param_group in base_optimizer.param_groups:
|
||||
param_group['lr'] /= float( seq_len / args.bptt )
|
||||
|
||||
total_loss += raw_loss.item()
|
||||
gc.collect()
|
||||
|
||||
batch_time.update(time.time() - start_time)
|
||||
start_time = time.time()
|
||||
batch, i = batch + 1, i + seq_len
|
||||
|
||||
if batch % args.print_freq == 0 or i >= train_data.size(0) - 1 - 1:
|
||||
print_log(' || Epoch: {:03d} :: {:03d}/{:03d}'.format(epoch, batch, len(train_data) // args.bptt), log)
|
||||
#print_log(' || Epoch: {:03d} :: {:03d}/{:03d} = {:}'.format(epoch, batch, len(train_data) // args.bptt, model.genotype()), log)
|
||||
cur_loss = total_loss / args.print_freq
|
||||
print_log(' ---> Time : data {:.3f} ({:.3f}) batch {:.3f} ({:.3f}) Loss : {:}, PPL : {:}'.format(data_time.val, data_time.avg, batch_time.val, batch_time.avg, cur_loss, math.exp(cur_loss)), log)
|
||||
print(F.softmax(model.arch_weights, dim=-1))
|
||||
total_loss = 0
|
||||
|
||||
return data_time.sum, batch_time.sum
|
||||
|
||||
|
||||
def infer(model, corpus, data_source, batch_size):
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
total_loss, total_length = 0, 0
|
||||
ntokens = len(corpus.dictionary)
|
||||
hidden = model.init_hidden(batch_size)
|
||||
for i in range(0, data_source.size(0) - 1, args.bptt):
|
||||
data, targets = get_batch(data_source, i, args.bptt)
|
||||
targets = targets.view(-1)
|
||||
|
||||
log_prob, hidden = model(data, hidden)
|
||||
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets)
|
||||
|
||||
total_loss += loss.item() * len(data)
|
||||
total_length += len(data)
|
||||
hidden = repackage_hidden(hidden)
|
||||
return total_loss / total_length
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -1,12 +0,0 @@
|
||||
# Search RNN cell
|
||||
```
|
||||
bash scripts-nas-rnn/search-baseline.sh 3
|
||||
bash scripts-nas-rnn/search-accelerate.sh 0 200 10 1
|
||||
```
|
||||
|
||||
# Train the Searched Model
|
||||
```
|
||||
bash scripts-nas-rnn/train-PTB.sh 3 DARTS_V1
|
||||
bash scripts-nas-rnn/train-WT2.sh 3 DARTS_V1
|
||||
bash scripts-nas-rnn/train-PTB.sh 3 DARTS_V2
|
||||
```
|
@ -1,26 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
if [ "$#" -ne 4 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 4 parameters for the GPU and the epochs and tau-max and tau-min"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||
exit 1
|
||||
else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
gpus=$1
|
||||
epoch=$2
|
||||
tau_max=$3
|
||||
tau_min=$4
|
||||
SAVED=./snapshots/NAS-RNN/Search-Accelerate-tau_${tau_max}_${tau_min}-${epoch}
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${gpus} python ./exps-nas/rnn/acc_rnn_search.py \
|
||||
--data_path ./data/data/penn \
|
||||
--save_path ${SAVED} \
|
||||
--epochs ${epoch} \
|
||||
--tau_max ${tau_max} --tau_min ${tau_min} \
|
||||
--config_path ./configs/NAS-PTB-BASE.config \
|
||||
--print_freq 200
|
@ -1,23 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
if [ "$#" -ne 1 ] ;then
|
||||
echo "Input illegal number of parameters " $#
|
||||
echo "Need 1 parameters for the GPU"
|
||||
exit 1
|
||||
fi
|
||||
if [ "$TORCH_HOME" = "" ]; then
|
||||
echo "Must set TORCH_HOME envoriment variable for data dir saving"
|
||||
exit 1
|
||||
else
|
||||
echo "TORCH_HOME : $TORCH_HOME"
|
||||
fi
|
||||
|
||||
gpus=$1
|
||||
epoch=50
|
||||
SAVED=./snapshots/NAS-RNN/Search-Baseline-${epoch}
|
||||
|
||||
CUDA_VISIBLE_DEVICES=${gpus} python ./exps-nas/rnn/train_rnn_search.py \
|
||||
--data_path ./data/data/penn \
|
||||
--save_path ${SAVED} \
|
||||
--epochs ${epoch} \
|
||||
--config_path ./configs/NAS-PTB-BASE.config \
|
||||
--print_freq 200
|
Loading…
Reference in New Issue
Block a user