Refine lib -> xautodl

This commit is contained in:
D-X-Y 2021-05-19 07:19:20 +00:00
parent bda202ce87
commit 5b9a028e60
19 changed files with 46 additions and 46 deletions

View File

@ -7,8 +7,8 @@ In this paper, we proposed a differentiable searching strategy for transformable
You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html). You could see the highlight of our Transformable Architecture Search (TAS) at our [project page](https://xuanyidong.com/assets/projects/NeurIPS-2019-TAS.html).
<p float="left"> <p float="left">
<img src="https://d-x-y.github.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/> <img src="http://xuanyidong.com/resources/paper-icon/NIPS-2019-TAS.png" width="680px"/>
<img src="https://d-x-y.github.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/> <img src="http://xuanyidong.com/resources/videos/NeurIPS-2019-TAS/TAS-arch.gif?raw=true" width="180px"/>
</p> </p>
@ -24,7 +24,7 @@ We provide some logs at [Google Drive](https://drive.google.com/open?id=1_qUY4DT
## Usage ## Usage
Use `bash ./scripts/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`. Use `bash ./scripts/TAS/prepare.sh` to prepare data splits for `CIFAR-10`, `CIFARR-100`, and `ILSVRC2012`.
If you do not have `ILSVRC2012` data, please comment L12 in `./scripts/prepare.sh`. If you do not have `ILSVRC2012` data, please comment L12 in `./scripts/prepare.sh`.
args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed. args: `cifar10` indicates the dataset name, `ResNet56` indicates the basemodel name, `CIFARX` indicates the searching hyper-parameters, `0.47/0.57` indicates the expected FLOP ratio, `-1` indicates the random seed.

View File

@ -27,8 +27,8 @@ from xautodl.datasets.synthetic_core import get_synthetic_env, EnvSampler
from xautodl.models.xcore import get_model from xautodl.models.xcore import get_model
from xautodl.xlayers import super_core, trunc_normal_ from xautodl.xlayers import super_core, trunc_normal_
from xautodl.lfna_utils import lfna_setup, train_model, TimeData from lfna_utils import lfna_setup, train_model, TimeData
from xautodl.lfna_meta_model import LFNA_Meta from lfna_meta_model import LFNA_Meta
def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger): def epoch_train(loader, meta_model, base_model, optimizer, criterion, device, logger):

View File

@ -4,8 +4,8 @@
import copy import copy
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from procedures import prepare_seed, prepare_logger from xautodl.procedures import prepare_seed, prepare_logger
from datasets.synthetic_core import get_synthetic_env from xautodl.datasets.synthetic_core import get_synthetic_env
def lfna_setup(args): def lfna_setup(args):

View File

@ -665,7 +665,7 @@ if __name__ == "__main__":
len(args.datasets), len(args.xpaths), len(args.splits) len(args.datasets), len(args.xpaths), len(args.splits)
) )
) )
if args.workers <= 0: if args.workers < 0:
raise ValueError("invalid number of workers : {:}".format(args.workers)) raise ValueError("invalid number of workers : {:}".format(args.workers))
target_indexes = filter_indexes( target_indexes = filter_indexes(
@ -675,7 +675,7 @@ if __name__ == "__main__":
assert torch.cuda.is_available(), "CUDA is not available." assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers) torch.set_num_threads(args.workers if args.workers > 0 else 1)
main( main(
save_dir, save_dir,

View File

@ -1,6 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.01 #
#####################################################
# python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth # python exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ./data/cifar10.split.pth
# python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth # python exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ./data/cifar100.split.pth
# python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth # python exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ./data/imagenet-1k.split.pth
#####################################################
import sys, time, torch, random, argparse import sys, time, torch, random, argparse
from collections import defaultdict from collections import defaultdict
import os.path as osp import os.path as osp
@ -12,9 +16,6 @@ from pathlib import Path
import torchvision import torchvision
import torchvision.datasets as dset import torchvision.datasets as dset
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Prepare splits for searching", description="Prepare splits for searching",
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
@ -35,9 +36,9 @@ def main():
print("torchvision version : {:}".format(torchvision.__version__)) print("torchvision version : {:}".format(torchvision.__version__))
if name == "cifar10": if name == "cifar10":
dataset = dset.CIFAR10(args.root, train=True) dataset = dset.CIFAR10(args.root, train=True, download=True)
elif name == "cifar100": elif name == "cifar100":
dataset = dset.CIFAR100(args.root, train=True) dataset = dset.CIFAR100(args.root, train=True, download=True)
elif name == "imagenet-1k": elif name == "imagenet-1k":
dataset = dset.ImageFolder(osp.join(args.root, "train")) dataset = dset.ImageFolder(osp.join(args.root, "train"))
else: else:

13
scripts/TAS/prepare.sh Normal file
View File

@ -0,0 +1,13 @@
#!/bin/bash
# bash ./scripts/TAS/prepare.sh
#datasets='cifar10 cifar100 imagenet-1k'
#ratios='0.5 0.8 0.9'
ratios='0.5'
save_dir=./.latent-data/splits
for ratio in ${ratios}
do
python ./exps/TAS/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio}
python ./exps/TAS/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio}
python ./exps/TAS/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio}
done

View File

@ -1,13 +0,0 @@
#!/bin/bash
# bash ./scripts/prepare.sh
#datasets='cifar10 cifar100 imagenet-1k'
#ratios='0.5 0.8 0.9'
ratios='0.5'
save_dir=./.latent-data/splits
for ratio in ${ratios}
do
python ./exps/prepare.py --name cifar10 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar10-${ratio}.pth --ratio ${ratio}
python ./exps/prepare.py --name cifar100 --root $TORCH_HOME/cifar.python --save ${save_dir}/cifar100-${ratio}.pth --ratio ${ratio}
python ./exps/prepare.py --name imagenet-1k --root $TORCH_HOME/ILSVRC2012 --save ${save_dir}/imagenet-1k-${ratio}.pth --ratio ${ratio}
done

View File

@ -6,7 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from models.cell_operations import OPS from xautodl.models.cell_operations import OPS
# Cell for NAS-Bench-201 # Cell for NAS-Bench-201

View File

@ -4,6 +4,7 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from copy import deepcopy from copy import deepcopy
from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR from .cells import NASNetInferCell as InferCell, AuxiliaryHeadCIFAR

View File

@ -9,11 +9,11 @@ import torch
__all__ = ["get_model"] __all__ = ["get_model"]
from xlayers.super_core import SuperSequential from xautodl.xlayers.super_core import SuperSequential
from xlayers.super_core import SuperLinear from xautodl.xlayers.super_core import SuperLinear
from xlayers.super_core import SuperDropout from xautodl.xlayers.super_core import SuperDropout
from xlayers.super_core import super_name2norm from xautodl.xlayers.super_core import super_name2norm
from xlayers.super_core import super_name2activation from xautodl.xlayers.super_core import super_name2activation
def get_model(config: Dict[Text, Any], **kwargs): def get_model(config: Dict[Text, Any], **kwargs):

View File

@ -7,8 +7,7 @@ import os, sys, time, torch
from typing import Optional, Text, Callable from typing import Optional, Text, Callable
# modules in AutoDL # modules in AutoDL
from log_utils import AverageMeter from xautodl.log_utils import AverageMeter, time_string
from log_utils import time_string
from .eval_funcs import obtain_accuracy from .eval_funcs import obtain_accuracy

View File

@ -4,8 +4,7 @@
import os, sys, time, torch import os, sys, time, torch
# modules in AutoDL # modules in AutoDL
from log_utils import AverageMeter from xautodl.log_utils import AverageMeter, time_string
from log_utils import time_string
from .eval_funcs import obtain_accuracy from .eval_funcs import obtain_accuracy

View File

@ -15,6 +15,6 @@ def obtain_accuracy(output, target, topk=(1,)):
res = [] res = []
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size)) res.append(correct_k.mul_(100.0 / batch_size))
return res return res

View File

@ -4,7 +4,7 @@
import os, time, copy, torch, pathlib import os, time, copy, torch, pathlib
# modules in AutoDL # modules in AutoDL
import xautodl.datasets from xautodl import datasets
from xautodl.config_utils import load_config from xautodl.config_utils import load_config
from xautodl.procedures import prepare_seed, get_optim_scheduler from xautodl.procedures import prepare_seed, get_optim_scheduler
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time from xautodl.log_utils import AverageMeter, time_string, convert_secs2time

View File

@ -8,7 +8,6 @@ import pprint
import logging import logging
from copy import deepcopy from copy import deepcopy
from log_utils import pickle_load
import qlib import qlib
from qlib.utils import init_instance_by_config from qlib.utils import init_instance_by_config
from qlib.workflow import R from qlib.workflow import R

View File

@ -2,8 +2,9 @@
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
################################################## ##################################################
import os, sys, time, torch import os, sys, time, torch
from log_utils import AverageMeter, time_string
from models import change_key from xautodl.log_utils import AverageMeter, time_string
from xautodl.models import change_key
from .eval_funcs import obtain_accuracy from .eval_funcs import obtain_accuracy

View File

@ -4,8 +4,8 @@
import os, sys, time, torch import os, sys, time, torch
# modules in AutoDL # modules in AutoDL
from log_utils import AverageMeter, time_string from xautodl.log_utils import AverageMeter, time_string
from models import change_key from xautodl.models import change_key
from .eval_funcs import obtain_accuracy from .eval_funcs import obtain_accuracy

View File

@ -5,7 +5,7 @@ import os, sys, time, torch
import torch.nn.functional as F import torch.nn.functional as F
# modules in AutoDL # modules in AutoDL
from log_utils import AverageMeter, time_string from xautodl.log_utils import AverageMeter, time_string
from .eval_funcs import obtain_accuracy from .eval_funcs import obtain_accuracy

View File

@ -16,7 +16,7 @@ def prepare_seed(rand_seed):
def prepare_logger(xargs): def prepare_logger(xargs):
args = copy.deepcopy(xargs) args = copy.deepcopy(xargs)
from log_utils import Logger from xautodl.log_utils import Logger
logger = Logger(args.save_dir, args.rand_seed) logger = Logger(args.save_dir, args.rand_seed)
logger.log("Main Function with logger : {:}".format(logger)) logger.log("Main Function with logger : {:}".format(logger))