Refine lib -> xautodl
This commit is contained in:
parent
bd407ac4dc
commit
1c6c3e7166
5
.github/workflows/super_model_test.yml
vendored
5
.github/workflows/super_model_test.yml
vendored
@ -35,3 +35,8 @@ jobs:
|
||||
python -m pip install torch torchvision
|
||||
python -m pytest ./tests/test_super_*.py
|
||||
shell: bash
|
||||
|
||||
- name: Test TAS (NeurIPS 2019)
|
||||
run: |
|
||||
python -m pytest ./tests/test_tas.py
|
||||
shell: bash
|
||||
|
@ -25,7 +25,7 @@ def main(args):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
@ -470,7 +470,7 @@ if __name__ == "__main__":
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
main(
|
||||
save_dir,
|
||||
|
@ -340,7 +340,7 @@ def train_single_model(
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = True
|
||||
torch.set_num_threads(workers)
|
||||
# torch.set_num_threads(workers)
|
||||
|
||||
save_dir = (
|
||||
Path(save_dir)
|
||||
@ -675,7 +675,7 @@ if __name__ == "__main__":
|
||||
assert torch.cuda.is_available(), "CUDA is not available."
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers if args.workers > 0 else 1)
|
||||
# torch.set_num_threads(args.workers if args.workers > 0 else 1)
|
||||
|
||||
main(
|
||||
save_dir,
|
||||
|
@ -132,7 +132,7 @@ def select_action(policy):
|
||||
|
||||
|
||||
def main(xargs, api):
|
||||
torch.set_num_threads(4)
|
||||
# torch.set_num_threads(4)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
|
@ -204,7 +204,7 @@ def main(xargs):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(xargs.workers)
|
||||
# torch.set_num_threads(xargs.workers)
|
||||
prepare_seed(xargs.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
||||
|
@ -8,17 +8,14 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import load_config, obtain_basic_args as obtain_args
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from procedures import get_optim_scheduler, get_procedures
|
||||
from datasets import get_datasets
|
||||
from models import obtain_model
|
||||
from nas_infer_model import obtain_nas_infer_model
|
||||
from utils import get_model_infos
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.datasets import get_datasets
|
||||
from xautodl.config_utils import load_config, obtain_basic_args as obtain_args
|
||||
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from xautodl.procedures import get_optim_scheduler, get_procedures
|
||||
from xautodl.models import obtain_model
|
||||
from xautodl.nas_infer_model import obtain_nas_infer_model
|
||||
from xautodl.utils import get_model_infos
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -26,7 +23,7 @@ def main(args):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
@ -10,21 +10,17 @@ import numpy as np
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
|
||||
print("lib_dir : {:}".format(lib_dir))
|
||||
if str(lib_dir) not in sys.path:
|
||||
sys.path.insert(0, str(lib_dir))
|
||||
from config_utils import (
|
||||
from xautodl.config_utils import (
|
||||
load_config,
|
||||
configure2str,
|
||||
obtain_search_single_args as obtain_args,
|
||||
)
|
||||
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from procedures import get_optim_scheduler, get_procedures
|
||||
from datasets import get_datasets, SearchDataset
|
||||
from models import obtain_search_model, obtain_model, change_key
|
||||
from utils import get_model_infos
|
||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
|
||||
from xautodl.procedures import get_optim_scheduler, get_procedures
|
||||
from xautodl.datasets import get_datasets, SearchDataset
|
||||
from xautodl.models import obtain_search_model, obtain_model, change_key
|
||||
from xautodl.utils import get_model_infos
|
||||
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
|
||||
|
||||
|
||||
def main(args):
|
||||
@ -32,7 +28,7 @@ def main(args):
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
torch.set_num_threads(args.workers)
|
||||
# torch.set_num_threads(args.workers)
|
||||
|
||||
prepare_seed(args.rand_seed)
|
||||
logger = prepare_logger(args)
|
||||
|
29
tests/test_loader.py
Normal file
29
tests/test_loader.py
Normal file
@ -0,0 +1,29 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
# pytest tests/test_loader.py -s #
|
||||
#####################################################
|
||||
import unittest
|
||||
import tempfile
|
||||
import torch
|
||||
|
||||
from xautodl.datasets import get_datasets
|
||||
|
||||
|
||||
def test_simple():
|
||||
xdir = tempfile.mkdtemp()
|
||||
train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1)
|
||||
print(train_data)
|
||||
print(valid_data)
|
||||
|
||||
xloader = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True
|
||||
)
|
||||
print(xloader)
|
||||
print(next(iter(xloader)))
|
||||
|
||||
for i, data in enumerate(xloader):
|
||||
print(i)
|
||||
|
||||
|
||||
test_simple()
|
23
tests/test_tas.py
Normal file
23
tests/test_tas.py
Normal file
@ -0,0 +1,23 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from xautodl.models.shape_searchs.SoftSelect import ChannelWiseInter
|
||||
|
||||
|
||||
class TestTASFunc(unittest.TestCase):
|
||||
"""Test the TAS function."""
|
||||
|
||||
def test_channel_interplation(self):
|
||||
tensors = torch.rand((16, 128, 7, 7))
|
||||
|
||||
for oc in range(200, 210):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
||||
for oc in range(48, 160):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
@ -9,9 +9,9 @@ from typing import List, Text, Any
|
||||
import random, torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.cell_operations import ResNetBasicblock
|
||||
from models.cell_infers.cells import InferCell
|
||||
from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter
|
||||
from ..cell_operations import ResNetBasicblock
|
||||
from ..cell_infers.cells import InferCell
|
||||
from .shape_searchs.SoftSelect import select2withP, ChannelWiseInter
|
||||
|
||||
|
||||
class GenericNAS301Model(nn.Module):
|
||||
|
@ -1,20 +0,0 @@
|
||||
##################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||
##################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from SoftSelect import ChannelWiseInter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
tensors = torch.rand((16, 128, 7, 7))
|
||||
|
||||
for oc in range(200, 210):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
||||
for oc in range(48, 160):
|
||||
out_v1 = ChannelWiseInter(tensors, oc, "v1")
|
||||
out_v2 = ChannelWiseInter(tensors, oc, "v2")
|
||||
assert (out_v1 == out_v2).any().item() == 1
|
Loading…
Reference in New Issue
Block a user