Refine lib -> xautodl

This commit is contained in:
D-X-Y 2021-05-19 08:10:42 +00:00
parent bd407ac4dc
commit 1c6c3e7166
12 changed files with 83 additions and 53 deletions

View File

@ -35,3 +35,8 @@ jobs:
python -m pip install torch torchvision
python -m pytest ./tests/test_super_*.py
shell: bash
- name: Test TAS (NeurIPS 2019)
run: |
python -m pytest ./tests/test_tas.py
shell: bash

View File

@ -25,7 +25,7 @@ def main(args):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
prepare_seed(args.rand_seed)
logger = prepare_logger(args)

View File

@ -470,7 +470,7 @@ if __name__ == "__main__":
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
main(
save_dir,

View File

@ -340,7 +340,7 @@ def train_single_model(
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True
torch.set_num_threads(workers)
# torch.set_num_threads(workers)
save_dir = (
Path(save_dir)
@ -675,7 +675,7 @@ if __name__ == "__main__":
assert torch.cuda.is_available(), "CUDA is not available."
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers if args.workers > 0 else 1)
# torch.set_num_threads(args.workers if args.workers > 0 else 1)
main(
save_dir,

View File

@ -132,7 +132,7 @@ def select_action(policy):
def main(xargs, api):
torch.set_num_threads(4)
# torch.set_num_threads(4)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)

View File

@ -204,7 +204,7 @@ def main(xargs):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.set_num_threads(xargs.workers)
# torch.set_num_threads(xargs.workers)
prepare_seed(xargs.rand_seed)
logger = prepare_logger(args)

View File

@ -8,17 +8,14 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import load_config, obtain_basic_args as obtain_args
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from procedures import get_optim_scheduler, get_procedures
from datasets import get_datasets
from models import obtain_model
from nas_infer_model import obtain_nas_infer_model
from utils import get_model_infos
from log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.datasets import get_datasets
from xautodl.config_utils import load_config, obtain_basic_args as obtain_args
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from xautodl.procedures import get_optim_scheduler, get_procedures
from xautodl.models import obtain_model
from xautodl.nas_infer_model import obtain_nas_infer_model
from xautodl.utils import get_model_infos
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
def main(args):
@ -26,7 +23,7 @@ def main(args):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
prepare_seed(args.rand_seed)
logger = prepare_logger(args)

View File

@ -10,21 +10,17 @@ import numpy as np
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "lib").resolve()
print("lib_dir : {:}".format(lib_dir))
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from config_utils import (
from xautodl.config_utils import (
load_config,
configure2str,
obtain_search_single_args as obtain_args,
)
from procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from procedures import get_optim_scheduler, get_procedures
from datasets import get_datasets, SearchDataset
from models import obtain_search_model, obtain_model, change_key
from utils import get_model_infos
from log_utils import AverageMeter, time_string, convert_secs2time
from xautodl.procedures import prepare_seed, prepare_logger, save_checkpoint, copy_checkpoint
from xautodl.procedures import get_optim_scheduler, get_procedures
from xautodl.datasets import get_datasets, SearchDataset
from xautodl.models import obtain_search_model, obtain_model, change_key
from xautodl.utils import get_model_infos
from xautodl.log_utils import AverageMeter, time_string, convert_secs2time
def main(args):
@ -32,7 +28,7 @@ def main(args):
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = True
torch.set_num_threads(args.workers)
# torch.set_num_threads(args.workers)
prepare_seed(args.rand_seed)
logger = prepare_logger(args)

29
tests/test_loader.py Normal file
View File

@ -0,0 +1,29 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest tests/test_loader.py -s #
#####################################################
import unittest
import tempfile
import torch
from xautodl.datasets import get_datasets
def test_simple():
xdir = tempfile.mkdtemp()
train_data, valid_data, xshape, class_num = get_datasets("cifar10", xdir, -1)
print(train_data)
print(valid_data)
xloader = torch.utils.data.DataLoader(
train_data, batch_size=256, shuffle=True, num_workers=4, pin_memory=True
)
print(xloader)
print(next(iter(xloader)))
for i, data in enumerate(xloader):
print(i)
test_simple()

23
tests/test_tas.py Normal file
View File

@ -0,0 +1,23 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
from xautodl.models.shape_searchs.SoftSelect import ChannelWiseInter
class TestTASFunc(unittest.TestCase):
"""Test the TAS function."""
def test_channel_interplation(self):
tensors = torch.rand((16, 128, 7, 7))
for oc in range(200, 210):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1
for oc in range(48, 160):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1

View File

@ -9,9 +9,9 @@ from typing import List, Text, Any
import random, torch
import torch.nn as nn
from models.cell_operations import ResNetBasicblock
from models.cell_infers.cells import InferCell
from models.shape_searchs.SoftSelect import select2withP, ChannelWiseInter
from ..cell_operations import ResNetBasicblock
from ..cell_infers.cells import InferCell
from .shape_searchs.SoftSelect import select2withP, ChannelWiseInter
class GenericNAS301Model(nn.Module):

View File

@ -1,20 +0,0 @@
##################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
##################################################
import torch
import torch.nn as nn
from SoftSelect import ChannelWiseInter
if __name__ == "__main__":
tensors = torch.rand((16, 128, 7, 7))
for oc in range(200, 210):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1
for oc in range(48, 160):
out_v1 = ChannelWiseInter(tensors, oc, "v1")
out_v2 = ChannelWiseInter(tensors, oc, "v2")
assert (out_v1 == out_v2).any().item() == 1