Refine lib -> xautodl
This commit is contained in:
parent
bda202ce87
commit
5b9a028e60
@ -7,8 +7,8 @@ In this paper, we proposed a differentiable searching strategy for transformable
|
||||
You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html).
|
||||
|
||||
<p float="left">
|
||||
<img src="https://d-x-y.github.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/>
|
||||
<img src="https://d-x-y.github.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/>
|
||||
<img src="http://xuanyidong.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/>
|
||||
<img src="http://xuanyidong.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/>
|
||||
</p>
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ We provide some logs at [Google Drive](https://drive.google.com/open?id=1_qUY4DT
|
||||
|
||||
## Usage
|
||||
|
||||
Use `bash ./scripts/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
|
||||
Use `bash ./scripts/TAS/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
|
||||
If you do not have `ILSVRC2012` data, please comment L12 in `./scripts/prepare.sh`.
|
||||
|
||||
args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed.
|
||||
|
@ -27,8 +27,8 @@ from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler
|
||||
from xautodl.models.xcore import get_model
|
||||
from xautodl.xlayers import super_core, trunc_normal_
|
||||
|
||||
from xautodl.lfna_utils import lfna_setup, train_model, TimeData
|
||||
from xautodl.lfna_meta_model import LFNA_Meta
|
||||
from lfna_utils import lfna_setup, train_model, TimeData
|
||||
from lfna_meta_model import LFNA_Meta
|
||||
|
||||
|
||||
def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger):
|
||||
|
@ -4,8 +4,8 @@
|
||||
import copy
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from procedures import prepare_seed, prepare_logger
|
||||
from datasets.synthetic_core import get_synthetic_env
|
||||
from xautodl.procedures import prepare_seed, prepare_logger
|
||||
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||
|
||||
|
||||
def lfna_setup(args):
|
||||
|
@ -665,7 +665,7 @@ if __name__ == "__main__":
|
||||
len(args.datasets), len(args.xpaths), len(args.splits)
|
||||
)
|
||||
)
|
||||
if args.workers <= 0:
|
||||
if args.workers < 0:
|
||||
raise ValueError("invalid number of workers : {:}".format(args.workers))
|
||||
|
||||
target_indexes = filter_indexes(
|
||||
@ -675,7 +675,7 @@ if __name__ == "__main__":
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
torch.set_num_threads(args.workers if args.workers > 0 else 1)
|
||||
|
||||
main(
|
||||
save_dir,
|
||||
|
@ -1,6 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||||
#####################################################
|
||||
# python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
|
||||
# python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
|
||||
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth
|
||||
#####################################################
|
||||
import sys, time, torch, random, argparse
|
||||
from collections import defaultdict
|
||||
import os.path as osp
|
||||
@ -12,9 +16,6 @@ from pathlib import Path
|
||||
import torchvision
|
||||
import torchvision.datasets as dset
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Prepare splits for searching",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
@ -35,9 +36,9 @@ def main():
|
||||
print("torchvision version : {:}".format(torchvision.__version__))
|
||||
|
||||
if name == "cifar10":
|
||||
dataset = dset.CIFAR10(args.root, train=True)
|
||||
dataset = dset.CIFAR10(args.root, train=True, download=True)
|
||||
elif name == "cifar100":
|
||||
dataset = dset.CIFAR100(args.root, train=True)
|
||||
dataset = dset.CIFAR100(args.root, train=True, download=True)
|
||||
elif name == "imagenet-1k":
|
||||
dataset = dset.ImageFolder(osp.join(args.root, "train"))
|
||||
else:
|
13
scripts/TAS/prepare.sh
Normal file
13
scripts/TAS/prepare.sh
Normal file
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
# bash ./scripts/TAS/prepare.sh
|
||||
#datasets='cifar10 cifar100 imagenet-1k'
|
||||
#ratios='0.5 0.8 0.9'
|
||||
ratios='0.5'
|
||||
save_dir=./.latent-data/splits
|
||||
|
||||
for ratio in ${ratios}
|
||||
do
|
||||
python ./exps/TAS/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio}
|
||||
python ./exps/TAS/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio}
|
||||
python ./exps/TAS/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio}
|
||||
done
|
@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
# bash ./scripts/prepare.sh
|
||||
#datasets='cifar10 cifar100 imagenet-1k'
|
||||
#ratios='0.5 0.8 0.9'
|
||||
ratios='0.5'
|
||||
save_dir=./.latent-data/splits
|
||||
|
||||
for ratio in ${ratios}
|
||||
do
|
||||
python ./exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio}
|
||||
python ./exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio}
|
||||
python ./exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio}
|
||||
done
|
@ -6,7 +6,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from models.cell_operations import OPS
|
||||
from xautodl.models.cell_operations import OPS
|
||||
|
||||
|
||||
# Cell for NAS-Bench-201
|
||||
|
@ -4,6 +4,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from copy import deepcopy
|
||||
|
||||
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
||||
|
||||
|
||||
|
@ -9,11 +9,11 @@ import torch
|
||||
__all__ = ["get_model"]
|
||||
|
||||
|
||||
from xlayers.super_core import SuperSequential
|
||||
from xlayers.super_core import SuperLinear
|
||||
from xlayers.super_core import SuperDropout
|
||||
from xlayers.super_core import super_name2norm
|
||||
from xlayers.super_core import super_name2activation
|
||||
from xautodl.xlayers.super_core import SuperSequential
|
||||
from xautodl.xlayers.super_core import SuperLinear
|
||||
from xautodl.xlayers.super_core import SuperDropout
|
||||
from xautodl.xlayers.super_core import super_name2norm
|
||||
from xautodl.xlayers.super_core import super_name2activation
|
||||
|
||||
|
||||
def get_model(config: Dict[Text, Any], **kwargs):
|
||||
|
@ -7,8 +7,7 @@ import os, sys, time, torch
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@ -4,8 +4,7 @@
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
from log_utils import time_string
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@ -15,6 +15,6 @@ def obtain_accuracy(output, target, topk=(1,)):
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
@ -4,7 +4,7 @@
|
||||
import os, time, copy, torch, pathlib
|
||||
|
||||
# modules in AutoDL
|
||||
import xautodl.datasets
|
||||
from xautodl import datasets
|
||||
from xautodl.config_utils import load_config
|
||||
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
@ -8,7 +8,6 @@ import pprint
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
from log_utils import pickle_load
|
||||
import qlib
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
|
@ -2,8 +2,9 @@
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import os, sys, time, torch
|
||||
from log_utils import AverageMeter, time_string
|
||||
from models import change_key
|
||||
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from xautodl.models import change_key
|
||||
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
@ -4,8 +4,8 @@
|
||||
import os, sys, time, torch
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from models import change_key
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from xautodl.models import change_key
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ import os, sys, time, torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter, time_string
|
||||
from xautodl.log_utils import AverageMeter, time_string
|
||||
from .eval_funcs import obtain_accuracy
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@ def prepare_seed(rand_seed):
|
||||
|
||||
def prepare_logger(xargs):
|
||||
args = copy.deepcopy(xargs)
|
||||
from log_utils import Logger
|
||||
from xautodl.log_utils import Logger
|
||||
|
||||
logger = Logger(args.save_dir, args.rand_seed)
|
||||
logger.log("Main Function with logger : {:}".format(logger))
|
||||
|
Loading…
Reference in New Issue
Block a user