Refine lib -> xautodl
This commit is contained in:
parent
bda202ce87
commit
5b9a028e60
@ -7,8 +7,8 @@ In this paper, we proposed a differentiable searching strategy for transformable
|
|||||||
You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html).
|
You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html).
|
||||||
|
|
||||||
<p float="left">
|
<p float="left">
|
||||||
<img src="https://d-x-y.github.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/>
|
<img src="http://xuanyidong.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/>
|
||||||
<img src="https://d-x-y.github.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/>
|
<img src="http://xuanyidong.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ We provide some logs at [Google Drive](https://drive.google.com/open?id=1_qUY4DT
|
|||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Use `bash ./scripts/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
|
Use `bash ./scripts/TAS/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
|
||||||
If you do not have `ILSVRC2012` data, please comment L12 in `./scripts/prepare.sh`.
|
If you do not have `ILSVRC2012` data, please comment L12 in `./scripts/prepare.sh`.
|
||||||
|
|
||||||
args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed.
|
args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed.
|
||||||
|
@ -27,8 +27,8 @@ from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler
|
|||||||
from xautodl.models.xcore import get_model
|
from xautodl.models.xcore import get_model
|
||||||
from xautodl.xlayers import super_core, trunc_normal_
|
from xautodl.xlayers import super_core, trunc_normal_
|
||||||
|
|
||||||
from xautodl.lfna_utils import lfna_setup, train_model, TimeData
|
from lfna_utils import lfna_setup, train_model, TimeData
|
||||||
from xautodl.lfna_meta_model import LFNA_Meta
|
from lfna_meta_model import LFNA_Meta
|
||||||
|
|
||||||
|
|
||||||
def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger):
|
def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger):
|
||||||
|
@ -4,8 +4,8 @@
|
|||||||
import copy
|
import copy
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from procedures import prepare_seed, prepare_logger
|
from xautodl.procedures import prepare_seed, prepare_logger
|
||||||
from datasets.synthetic_core import get_synthetic_env
|
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||||
|
|
||||||
|
|
||||||
def lfna_setup(args):
|
def lfna_setup(args):
|
||||||
|
@ -665,7 +665,7 @@ if __name__ == "__main__":
|
|||||||
len(args.datasets), len(args.xpaths), len(args.splits)
|
len(args.datasets), len(args.xpaths), len(args.splits)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if args.workers <= 0:
|
if args.workers < 0:
|
||||||
raise ValueError("invalid number of workers : {:}".format(args.workers))
|
raise ValueError("invalid number of workers : {:}".format(args.workers))
|
||||||
|
|
||||||
target_indexes = filter_indexes(
|
target_indexes = filter_indexes(
|
||||||
@ -675,7 +675,7 @@ if __name__ == "__main__":
|
|||||||
assert torch.cuda.is_available(), "CUDA is not available."
|
assert torch.cuda.is_available(), "CUDA is not available."
|
||||||
torch.backends.cudnn.enabled = True
|
torch.backends.cudnn.enabled = True
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
torch.set_num_threads(args.workers)
|
torch.set_num_threads(args.workers if args.workers > 0 else 1)
|
||||||
|
|
||||||
main(
|
main(
|
||||||
save_dir,
|
save_dir,
|
||||||
|
@ -1,6 +1,10 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
|
||||||
|
#####################################################
|
||||||
# python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
|
# python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
|
||||||
# python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
|
# python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
|
||||||
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth
|
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth
|
||||||
|
#####################################################
|
||||||
import sys, time, torch, random, argparse
|
import sys, time, torch, random, argparse
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
@ -12,9 +16,6 @@ from pathlib import Path
|
|||||||
import torchvision
|
import torchvision
|
||||||
import torchvision.datasets as dset
|
import torchvision.datasets as dset
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
|
||||||
if str(lib_dir) not in sys.path:
|
|
||||||
sys.path.insert(0, str(lib_dir))
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Prepare splits for searching",
|
description="Prepare splits for searching",
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
@ -35,9 +36,9 @@ def main():
|
|||||||
print("torchvision version : {:}".format(torchvision.__version__))
|
print("torchvision version : {:}".format(torchvision.__version__))
|
||||||
|
|
||||||
if name == "cifar10":
|
if name == "cifar10":
|
||||||
dataset = dset.CIFAR10(args.root, train=True)
|
dataset = dset.CIFAR10(args.root, train=True, download=True)
|
||||||
elif name == "cifar100":
|
elif name == "cifar100":
|
||||||
dataset = dset.CIFAR100(args.root, train=True)
|
dataset = dset.CIFAR100(args.root, train=True, download=True)
|
||||||
elif name == "imagenet-1k":
|
elif name == "imagenet-1k":
|
||||||
dataset = dset.ImageFolder(osp.join(args.root, "train"))
|
dataset = dset.ImageFolder(osp.join(args.root, "train"))
|
||||||
else:
|
else:
|
13
scripts/TAS/prepare.sh
Normal file
13
scripts/TAS/prepare.sh
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# bash ./scripts/TAS/prepare.sh
|
||||||
|
#datasets='cifar10 cifar100 imagenet-1k'
|
||||||
|
#ratios='0.5 0.8 0.9'
|
||||||
|
ratios='0.5'
|
||||||
|
save_dir=./.latent-data/splits
|
||||||
|
|
||||||
|
for ratio in ${ratios}
|
||||||
|
do
|
||||||
|
python ./exps/TAS/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio}
|
||||||
|
python ./exps/TAS/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio}
|
||||||
|
python ./exps/TAS/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio}
|
||||||
|
done
|
@ -1,13 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# bash ./scripts/prepare.sh
|
|
||||||
#datasets='cifar10 cifar100 imagenet-1k'
|
|
||||||
#ratios='0.5 0.8 0.9'
|
|
||||||
ratios='0.5'
|
|
||||||
save_dir=./.latent-data/splits
|
|
||||||
|
|
||||||
for ratio in ${ratios}
|
|
||||||
do
|
|
||||||
python ./exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio}
|
|
||||||
python ./exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio}
|
|
||||||
python ./exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio}
|
|
||||||
done
|
|
@ -6,7 +6,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from models.cell_operations import OPS
|
from xautodl.models.cell_operations import OPS
|
||||||
|
|
||||||
|
|
||||||
# Cell for NAS-Bench-201
|
# Cell for NAS-Bench-201
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,11 +9,11 @@ import torch
|
|||||||
__all__ = ["get_model"]
|
__all__ = ["get_model"]
|
||||||
|
|
||||||
|
|
||||||
from xlayers.super_core import SuperSequential
|
from xautodl.xlayers.super_core import SuperSequential
|
||||||
from xlayers.super_core import SuperLinear
|
from xautodl.xlayers.super_core import SuperLinear
|
||||||
from xlayers.super_core import SuperDropout
|
from xautodl.xlayers.super_core import SuperDropout
|
||||||
from xlayers.super_core import super_name2norm
|
from xautodl.xlayers.super_core import super_name2norm
|
||||||
from xlayers.super_core import super_name2activation
|
from xautodl.xlayers.super_core import super_name2activation
|
||||||
|
|
||||||
|
|
||||||
def get_model(config: Dict[Text, Any], **kwargs):
|
def get_model(config: Dict[Text, Any], **kwargs):
|
||||||
|
@ -7,8 +7,7 @@ import os, sys, time, torch
|
|||||||
from typing import Optional, Text, Callable
|
from typing import Optional, Text, Callable
|
||||||
|
|
||||||
# modules in AutoDL
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter
|
from xautodl.log_utils import AverageMeter, time_string
|
||||||
from log_utils import time_string
|
|
||||||
from .eval_funcs import obtain_accuracy
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,8 +4,7 @@
|
|||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
|
|
||||||
# modules in AutoDL
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter
|
from xautodl.log_utils import AverageMeter, time_string
|
||||||
from log_utils import time_string
|
|
||||||
from .eval_funcs import obtain_accuracy
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,6 +15,6 @@ def obtain_accuracy(output, target, topk=(1,)):
|
|||||||
|
|
||||||
res = []
|
res = []
|
||||||
for k in topk:
|
for k in topk:
|
||||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
return res
|
return res
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
import os, time, copy, torch, pathlib
|
import os, time, copy, torch, pathlib
|
||||||
|
|
||||||
# modules in AutoDL
|
# modules in AutoDL
|
||||||
import xautodl.datasets
|
from xautodl import datasets
|
||||||
from xautodl.config_utils import load_config
|
from xautodl.config_utils import load_config
|
||||||
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
from xautodl.procedures import prepare_seed, get_optim_scheduler
|
||||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
|
@ -8,7 +8,6 @@ import pprint
|
|||||||
import logging
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
from log_utils import pickle_load
|
|
||||||
import qlib
|
import qlib
|
||||||
from qlib.utils import init_instance_by_config
|
from qlib.utils import init_instance_by_config
|
||||||
from qlib.workflow import R
|
from qlib.workflow import R
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
from log_utils import AverageMeter, time_string
|
|
||||||
from models import change_key
|
from xautodl.log_utils import AverageMeter, time_string
|
||||||
|
from xautodl.models import change_key
|
||||||
|
|
||||||
from .eval_funcs import obtain_accuracy
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@
|
|||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
|
|
||||||
# modules in AutoDL
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter, time_string
|
from xautodl.log_utils import AverageMeter, time_string
|
||||||
from models import change_key
|
from xautodl.models import change_key
|
||||||
from .eval_funcs import obtain_accuracy
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,7 +5,7 @@ import os, sys, time, torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# modules in AutoDL
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter, time_string
|
from xautodl.log_utils import AverageMeter, time_string
|
||||||
from .eval_funcs import obtain_accuracy
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ def prepare_seed(rand_seed):
|
|||||||
|
|
||||||
def prepare_logger(xargs):
|
def prepare_logger(xargs):
|
||||||
args = copy.deepcopy(xargs)
|
args = copy.deepcopy(xargs)
|
||||||
from log_utils import Logger
|
from xautodl.log_utils import Logger
|
||||||
|
|
||||||
logger = Logger(args.save_dir, args.rand_seed)
|
logger = Logger(args.save_dir, args.rand_seed)
|
||||||
logger.log("Main Function with logger : {:}".format(logger))
|
logger.log("Main Function with logger : {:}".format(logger))
|
||||||
|
Loading…
Reference in New Issue
Block a user