Update SuperMLP
This commit is contained in:
parent
31b8122cc1
commit
0c56a729ad
1
.github/workflows/basic_test.yml
vendored
1
.github/workflows/basic_test.yml
vendored
@ -35,6 +35,7 @@ jobs:
|
|||||||
python -m black ./lib/xlayers -l 88 --check --diff --verbose
|
python -m black ./lib/xlayers -l 88 --check --diff --verbose
|
||||||
python -m black ./lib/spaces -l 88 --check --diff --verbose
|
python -m black ./lib/spaces -l 88 --check --diff --verbose
|
||||||
python -m black ./lib/trade_models -l 88 --check --diff --verbose
|
python -m black ./lib/trade_models -l 88 --check --diff --verbose
|
||||||
|
python -m black ./lib/procedures -l 88 --check --diff --verbose
|
||||||
|
|
||||||
- name: Test Search Space
|
- name: Test Search Space
|
||||||
run: |
|
run: |
|
||||||
|
@ -7,22 +7,63 @@ from log_utils import time_string
|
|||||||
from utils import obtain_accuracy
|
from utils import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
|
def basic_train(
|
||||||
|
xloader,
|
||||||
|
network,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
optimizer,
|
||||||
|
optim_config,
|
||||||
|
extra_info,
|
||||||
|
print_freq,
|
||||||
|
logger,
|
||||||
|
):
|
||||||
loss, acc1, acc5 = procedure(
|
loss, acc1, acc5 = procedure(
|
||||||
xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger
|
xloader,
|
||||||
|
network,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
optimizer,
|
||||||
|
"train",
|
||||||
|
optim_config,
|
||||||
|
extra_info,
|
||||||
|
print_freq,
|
||||||
|
logger,
|
||||||
)
|
)
|
||||||
return loss, acc1, acc5
|
return loss, acc1, acc5
|
||||||
|
|
||||||
|
|
||||||
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
|
def basic_valid(
|
||||||
|
xloader, network, criterion, optim_config, extra_info, print_freq, logger
|
||||||
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
loss, acc1, acc5 = procedure(
|
loss, acc1, acc5 = procedure(
|
||||||
xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger
|
xloader,
|
||||||
|
network,
|
||||||
|
criterion,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
"valid",
|
||||||
|
None,
|
||||||
|
extra_info,
|
||||||
|
print_freq,
|
||||||
|
logger,
|
||||||
)
|
)
|
||||||
return loss, acc1, acc5
|
return loss, acc1, acc5
|
||||||
|
|
||||||
|
|
||||||
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
def procedure(
|
||||||
|
xloader,
|
||||||
|
network,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
optimizer,
|
||||||
|
mode,
|
||||||
|
config,
|
||||||
|
extra_info,
|
||||||
|
print_freq,
|
||||||
|
logger,
|
||||||
|
):
|
||||||
data_time, batch_time, losses, top1, top5 = (
|
data_time, batch_time, losses, top1, top5 = (
|
||||||
AverageMeter(),
|
AverageMeter(),
|
||||||
AverageMeter(),
|
AverageMeter(),
|
||||||
@ -39,7 +80,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
|
|||||||
|
|
||||||
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
|
||||||
logger.log(
|
logger.log(
|
||||||
"[{:5s}] config :: auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1)
|
"[{:5s}] config :: auxiliary={:}".format(
|
||||||
|
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
|
||||||
|
)
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
for i, (inputs, targets) in enumerate(xloader):
|
for i, (inputs, targets) in enumerate(xloader):
|
||||||
@ -55,7 +98,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
|
|||||||
|
|
||||||
features, logits = network(inputs)
|
features, logits = network(inputs)
|
||||||
if isinstance(logits, list):
|
if isinstance(logits, list):
|
||||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
|
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||||
|
2, len(logits)
|
||||||
|
)
|
||||||
logits, logits_aux = logits
|
logits, logits_aux = logits
|
||||||
else:
|
else:
|
||||||
logits, logits_aux = logits, None
|
logits, logits_aux = logits, None
|
||||||
@ -97,7 +142,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
|
|||||||
|
|
||||||
logger.log(
|
logger.log(
|
||||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||||
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
mode=mode.upper(),
|
||||||
|
top1=top1,
|
||||||
|
top5=top5,
|
||||||
|
error1=100 - top1.avg,
|
||||||
|
error5=100 - top5.avg,
|
||||||
|
loss=losses.avg,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return losses.avg, top1.avg, top5.avg
|
return losses.avg, top1.avg, top5.avg
|
||||||
|
@ -78,23 +78,42 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
|
|||||||
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
return losses.avg, top1.avg, top5.avg, batch_time.sum
|
||||||
|
|
||||||
|
|
||||||
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
|
def evaluate_for_seed(
|
||||||
|
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
|
||||||
|
):
|
||||||
|
|
||||||
prepare_seed(seed) # random seed
|
prepare_seed(seed) # random seed
|
||||||
net = get_cell_based_tiny_net(arch_config)
|
net = get_cell_based_tiny_net(arch_config)
|
||||||
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
|
||||||
flop, param = get_model_infos(net, opt_config.xshape)
|
flop, param = get_model_infos(net, opt_config.xshape)
|
||||||
logger.log("Network : {:}".format(net.get_message()), False)
|
logger.log("Network : {:}".format(net.get_message()), False)
|
||||||
logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed))
|
logger.log(
|
||||||
|
"{:} Seed-------------------------- {:} --------------------------".format(
|
||||||
|
time_string(), seed
|
||||||
|
)
|
||||||
|
)
|
||||||
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
|
||||||
# train and valid
|
# train and valid
|
||||||
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
|
||||||
default_device = torch.cuda.current_device()
|
default_device = torch.cuda.current_device()
|
||||||
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device)
|
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(
|
||||||
|
device=default_device
|
||||||
|
)
|
||||||
criterion = criterion.cuda(device=default_device)
|
criterion = criterion.cuda(device=default_device)
|
||||||
# start training
|
# start training
|
||||||
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
|
start_time, epoch_time, total_epoch = (
|
||||||
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
|
time.time(),
|
||||||
|
AverageMeter(),
|
||||||
|
opt_config.epochs + opt_config.warmup,
|
||||||
|
)
|
||||||
|
(
|
||||||
|
train_losses,
|
||||||
|
train_acc1es,
|
||||||
|
train_acc5es,
|
||||||
|
valid_losses,
|
||||||
|
valid_acc1es,
|
||||||
|
valid_acc5es,
|
||||||
|
) = ({}, {}, {}, {}, {}, {})
|
||||||
train_times, valid_times, lrs = {}, {}, {}
|
train_times, valid_times, lrs = {}, {}, {}
|
||||||
for epoch in range(total_epoch):
|
for epoch in range(total_epoch):
|
||||||
scheduler.update(epoch, 0.0)
|
scheduler.update(epoch, 0.0)
|
||||||
@ -120,7 +139,9 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed
|
|||||||
# measure elapsed time
|
# measure elapsed time
|
||||||
epoch_time.update(time.time() - start_time)
|
epoch_time.update(time.time() - start_time)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True))
|
need_time = "Time Left: {:}".format(
|
||||||
|
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
|
||||||
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
|
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
|
||||||
time_string(),
|
time_string(),
|
||||||
@ -171,14 +192,29 @@ def get_nas_bench_loaders(workers):
|
|||||||
break_line = "-" * 150
|
break_line = "-" * 150
|
||||||
print("{:} Create data-loader for all datasets".format(time_string()))
|
print("{:} Create data-loader for all datasets".format(time_string()))
|
||||||
print(break_line)
|
print(break_line)
|
||||||
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1)
|
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets(
|
||||||
|
"cifar10", str(torch_dir / "cifar.python"), -1
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||||
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
|
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None)
|
cifar10_splits = load_config(
|
||||||
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [
|
root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None
|
||||||
|
)
|
||||||
|
assert cifar10_splits.train[:10] == [
|
||||||
|
0,
|
||||||
|
5,
|
||||||
|
7,
|
||||||
|
11,
|
||||||
|
13,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
17,
|
||||||
|
20,
|
||||||
|
24,
|
||||||
|
] and cifar10_splits.valid[:10] == [
|
||||||
1,
|
1,
|
||||||
2,
|
2,
|
||||||
3,
|
3,
|
||||||
@ -194,7 +230,11 @@ def get_nas_bench_loaders(workers):
|
|||||||
temp_dataset.transform = VALID_CIFAR10.transform
|
temp_dataset.transform = VALID_CIFAR10.transform
|
||||||
# data loader
|
# data loader
|
||||||
trainval_cifar10_loader = torch.utils.data.DataLoader(
|
trainval_cifar10_loader = torch.utils.data.DataLoader(
|
||||||
TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
|
TRAIN_CIFAR10,
|
||||||
|
batch_size=cifar_config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
train_cifar10_loader = torch.utils.data.DataLoader(
|
train_cifar10_loader = torch.utils.data.DataLoader(
|
||||||
TRAIN_CIFAR10,
|
TRAIN_CIFAR10,
|
||||||
@ -211,7 +251,11 @@ def get_nas_bench_loaders(workers):
|
|||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
test__cifar10_loader = torch.utils.data.DataLoader(
|
test__cifar10_loader = torch.utils.data.DataLoader(
|
||||||
VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
|
VALID_CIFAR10,
|
||||||
|
batch_size=cifar_config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
|
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
|
||||||
@ -235,14 +279,29 @@ def get_nas_bench_loaders(workers):
|
|||||||
)
|
)
|
||||||
print(break_line)
|
print(break_line)
|
||||||
# CIFAR-100
|
# CIFAR-100
|
||||||
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1)
|
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets(
|
||||||
|
"cifar100", str(torch_dir / "cifar.python"), -1
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
|
||||||
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
|
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None)
|
cifar100_splits = load_config(
|
||||||
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [
|
root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None
|
||||||
|
)
|
||||||
|
assert cifar100_splits.xvalid[:10] == [
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
10,
|
||||||
|
13,
|
||||||
|
14,
|
||||||
|
15,
|
||||||
|
16,
|
||||||
|
] and cifar100_splits.xtest[:10] == [
|
||||||
0,
|
0,
|
||||||
2,
|
2,
|
||||||
6,
|
6,
|
||||||
@ -255,7 +314,11 @@ def get_nas_bench_loaders(workers):
|
|||||||
24,
|
24,
|
||||||
]
|
]
|
||||||
train_cifar100_loader = torch.utils.data.DataLoader(
|
train_cifar100_loader = torch.utils.data.DataLoader(
|
||||||
TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
|
TRAIN_CIFAR100,
|
||||||
|
batch_size=cifar_config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=workers,
|
||||||
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
valid_cifar100_loader = torch.utils.data.DataLoader(
|
valid_cifar100_loader = torch.utils.data.DataLoader(
|
||||||
VALID_CIFAR100,
|
VALID_CIFAR100,
|
||||||
@ -271,9 +334,15 @@ def get_nas_bench_loaders(workers):
|
|||||||
num_workers=workers,
|
num_workers=workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader)))
|
print(
|
||||||
print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)))
|
"CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))
|
||||||
print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader)))
|
)
|
||||||
|
print(
|
||||||
|
"CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))
|
||||||
|
)
|
||||||
print(break_line)
|
print(break_line)
|
||||||
|
|
||||||
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
|
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
|
||||||
@ -286,8 +355,23 @@ def get_nas_bench_loaders(workers):
|
|||||||
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
|
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None)
|
imagenet_splits = load_config(
|
||||||
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [
|
root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt",
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert imagenet_splits.xvalid[:10] == [
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
3,
|
||||||
|
6,
|
||||||
|
7,
|
||||||
|
8,
|
||||||
|
9,
|
||||||
|
12,
|
||||||
|
16,
|
||||||
|
18,
|
||||||
|
] and imagenet_splits.xtest[:10] == [
|
||||||
0,
|
0,
|
||||||
4,
|
4,
|
||||||
5,
|
5,
|
||||||
|
@ -14,7 +14,9 @@ class _LRScheduler(object):
|
|||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group.setdefault("initial_lr", group["lr"])
|
group.setdefault("initial_lr", group["lr"])
|
||||||
self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups))
|
self.base_lrs = list(
|
||||||
|
map(lambda group: group["initial_lr"], optimizer.param_groups)
|
||||||
|
)
|
||||||
self.max_epochs = epochs
|
self.max_epochs = epochs
|
||||||
self.warmup_epochs = warmup_epochs
|
self.warmup_epochs = warmup_epochs
|
||||||
self.current_epoch = 0
|
self.current_epoch = 0
|
||||||
@ -31,7 +33,9 @@ class _LRScheduler(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
|
return {
|
||||||
|
key: value for key, value in self.__dict__.items() if key != "optimizer"
|
||||||
|
}
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
def load_state_dict(self, state_dict):
|
||||||
self.__dict__.update(state_dict)
|
self.__dict__.update(state_dict)
|
||||||
@ -50,10 +54,14 @@ class _LRScheduler(object):
|
|||||||
|
|
||||||
def update(self, cur_epoch, cur_iter):
|
def update(self, cur_epoch, cur_iter):
|
||||||
if cur_epoch is not None:
|
if cur_epoch is not None:
|
||||||
assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch)
|
assert (
|
||||||
|
isinstance(cur_epoch, int) and cur_epoch >= 0
|
||||||
|
), "invalid cur-epoch : {:}".format(cur_epoch)
|
||||||
self.current_epoch = cur_epoch
|
self.current_epoch = cur_epoch
|
||||||
if cur_iter is not None:
|
if cur_iter is not None:
|
||||||
assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter)
|
assert (
|
||||||
|
isinstance(cur_iter, float) and cur_iter >= 0
|
||||||
|
), "invalid cur-iter : {:}".format(cur_iter)
|
||||||
self.current_iter = cur_iter
|
self.current_iter = cur_iter
|
||||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||||
param_group["lr"] = lr
|
param_group["lr"] = lr
|
||||||
@ -66,29 +74,44 @@ class CosineAnnealingLR(_LRScheduler):
|
|||||||
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min)
|
return "type={:}, T-max={:}, eta-min={:}".format(
|
||||||
|
"cosine", self.T_max, self.eta_min
|
||||||
|
)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lrs = []
|
lrs = []
|
||||||
for base_lr in self.base_lrs:
|
for base_lr in self.base_lrs:
|
||||||
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
|
if (
|
||||||
|
self.current_epoch >= self.warmup_epochs
|
||||||
|
and self.current_epoch < self.max_epochs
|
||||||
|
):
|
||||||
last_epoch = self.current_epoch - self.warmup_epochs
|
last_epoch = self.current_epoch - self.warmup_epochs
|
||||||
# if last_epoch < self.T_max:
|
# if last_epoch < self.T_max:
|
||||||
# if last_epoch < self.max_epochs:
|
# if last_epoch < self.max_epochs:
|
||||||
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
|
lr = (
|
||||||
|
self.eta_min
|
||||||
|
+ (base_lr - self.eta_min)
|
||||||
|
* (1 + math.cos(math.pi * last_epoch / self.T_max))
|
||||||
|
/ 2
|
||||||
|
)
|
||||||
# else:
|
# else:
|
||||||
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
|
||||||
elif self.current_epoch >= self.max_epochs:
|
elif self.current_epoch >= self.max_epochs:
|
||||||
lr = self.eta_min
|
lr = self.eta_min
|
||||||
else:
|
else:
|
||||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
lr = (
|
||||||
|
self.current_epoch / self.warmup_epochs
|
||||||
|
+ self.current_iter / self.warmup_epochs
|
||||||
|
) * base_lr
|
||||||
lrs.append(lr)
|
lrs.append(lr)
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
|
|
||||||
class MultiStepLR(_LRScheduler):
|
class MultiStepLR(_LRScheduler):
|
||||||
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
|
||||||
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas))
|
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
|
||||||
|
len(milestones), len(gammas)
|
||||||
|
)
|
||||||
self.milestones = milestones
|
self.milestones = milestones
|
||||||
self.gammas = gammas
|
self.gammas = gammas
|
||||||
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||||
@ -108,7 +131,10 @@ class MultiStepLR(_LRScheduler):
|
|||||||
for x in self.gammas[:idx]:
|
for x in self.gammas[:idx]:
|
||||||
lr *= x
|
lr *= x
|
||||||
else:
|
else:
|
||||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
lr = (
|
||||||
|
self.current_epoch / self.warmup_epochs
|
||||||
|
+ self.current_iter / self.warmup_epochs
|
||||||
|
) * base_lr
|
||||||
lrs.append(lr)
|
lrs.append(lr)
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
@ -119,7 +145,9 @@ class ExponentialLR(_LRScheduler):
|
|||||||
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
|
||||||
|
|
||||||
def extra_repr(self):
|
def extra_repr(self):
|
||||||
return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs)
|
return "type={:}, gamma={:}, base-lrs={:}".format(
|
||||||
|
"exponential", self.gamma, self.base_lrs
|
||||||
|
)
|
||||||
|
|
||||||
def get_lr(self):
|
def get_lr(self):
|
||||||
lrs = []
|
lrs = []
|
||||||
@ -129,7 +157,10 @@ class ExponentialLR(_LRScheduler):
|
|||||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||||
lr = base_lr * (self.gamma ** last_epoch)
|
lr = base_lr * (self.gamma ** last_epoch)
|
||||||
else:
|
else:
|
||||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
lr = (
|
||||||
|
self.current_epoch / self.warmup_epochs
|
||||||
|
+ self.current_iter / self.warmup_epochs
|
||||||
|
) * base_lr
|
||||||
lrs.append(lr)
|
lrs.append(lr)
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
@ -151,10 +182,18 @@ class LinearLR(_LRScheduler):
|
|||||||
if self.current_epoch >= self.warmup_epochs:
|
if self.current_epoch >= self.warmup_epochs:
|
||||||
last_epoch = self.current_epoch - self.warmup_epochs
|
last_epoch = self.current_epoch - self.warmup_epochs
|
||||||
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
|
||||||
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
|
ratio = (
|
||||||
|
(self.max_LR - self.min_LR)
|
||||||
|
* last_epoch
|
||||||
|
/ self.max_epochs
|
||||||
|
/ self.max_LR
|
||||||
|
)
|
||||||
lr = base_lr * (1 - ratio)
|
lr = base_lr * (1 - ratio)
|
||||||
else:
|
else:
|
||||||
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
|
lr = (
|
||||||
|
self.current_epoch / self.warmup_epochs
|
||||||
|
+ self.current_iter / self.warmup_epochs
|
||||||
|
) * base_lr
|
||||||
lrs.append(lr)
|
lrs.append(lr)
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
@ -176,26 +215,42 @@ class CrossEntropyLabelSmooth(nn.Module):
|
|||||||
|
|
||||||
def get_optim_scheduler(parameters, config):
|
def get_optim_scheduler(parameters, config):
|
||||||
assert (
|
assert (
|
||||||
hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion")
|
hasattr(config, "optim")
|
||||||
), "config must have optim / scheduler / criterion keys instead of {:}".format(config)
|
and hasattr(config, "scheduler")
|
||||||
|
and hasattr(config, "criterion")
|
||||||
|
), "config must have optim / scheduler / criterion keys instead of {:}".format(
|
||||||
|
config
|
||||||
|
)
|
||||||
if config.optim == "SGD":
|
if config.optim == "SGD":
|
||||||
optim = torch.optim.SGD(
|
optim = torch.optim.SGD(
|
||||||
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov
|
parameters,
|
||||||
|
config.LR,
|
||||||
|
momentum=config.momentum,
|
||||||
|
weight_decay=config.decay,
|
||||||
|
nesterov=config.nesterov,
|
||||||
)
|
)
|
||||||
elif config.optim == "RMSprop":
|
elif config.optim == "RMSprop":
|
||||||
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
|
optim = torch.optim.RMSprop(
|
||||||
|
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid optim : {:}".format(config.optim))
|
raise ValueError("invalid optim : {:}".format(config.optim))
|
||||||
|
|
||||||
if config.scheduler == "cos":
|
if config.scheduler == "cos":
|
||||||
T_max = getattr(config, "T_max", config.epochs)
|
T_max = getattr(config, "T_max", config.epochs)
|
||||||
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
|
scheduler = CosineAnnealingLR(
|
||||||
|
optim, config.warmup, config.epochs, T_max, config.eta_min
|
||||||
|
)
|
||||||
elif config.scheduler == "multistep":
|
elif config.scheduler == "multistep":
|
||||||
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
|
scheduler = MultiStepLR(
|
||||||
|
optim, config.warmup, config.epochs, config.milestones, config.gammas
|
||||||
|
)
|
||||||
elif config.scheduler == "exponential":
|
elif config.scheduler == "exponential":
|
||||||
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
|
||||||
elif config.scheduler == "linear":
|
elif config.scheduler == "linear":
|
||||||
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
|
scheduler = LinearLR(
|
||||||
|
optim, config.warmup, config.epochs, config.LR, config.LR_min
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
raise ValueError("invalid scheduler : {:}".format(config.scheduler))
|
||||||
|
|
||||||
|
@ -41,7 +41,10 @@ def update_gpu(config, gpu):
|
|||||||
if "task" in config and "model" in config["task"]:
|
if "task" in config and "model" in config["task"]:
|
||||||
if "GPU" in config["task"]["model"]:
|
if "GPU" in config["task"]["model"]:
|
||||||
config["task"]["model"]["GPU"] = gpu
|
config["task"]["model"]["GPU"] = gpu
|
||||||
elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]:
|
elif (
|
||||||
|
"kwargs" in config["task"]["model"]
|
||||||
|
and "GPU" in config["task"]["model"]["kwargs"]
|
||||||
|
):
|
||||||
config["task"]["model"]["kwargs"]["GPU"] = gpu
|
config["task"]["model"]["kwargs"]["GPU"] = gpu
|
||||||
elif "model" in config:
|
elif "model" in config:
|
||||||
if "GPU" in config["model"]:
|
if "GPU" in config["model"]:
|
||||||
@ -68,7 +71,12 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
|
|||||||
model_fit_kwargs = dict(dataset=dataset)
|
model_fit_kwargs = dict(dataset=dataset)
|
||||||
|
|
||||||
# Let's start the experiment.
|
# Let's start the experiment.
|
||||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri, resume=True):
|
with R.start(
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
recorder_name=recorder_name,
|
||||||
|
uri=uri,
|
||||||
|
resume=True,
|
||||||
|
):
|
||||||
# Setup log
|
# Setup log
|
||||||
recorder_root_dir = R.get_recorder().get_local_dir()
|
recorder_root_dir = R.get_recorder().get_local_dir()
|
||||||
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
|
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))
|
||||||
|
@ -36,7 +36,12 @@ def search_train(
|
|||||||
logger,
|
logger,
|
||||||
):
|
):
|
||||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
base_losses, arch_losses, top1, top5 = (
|
||||||
|
AverageMeter(),
|
||||||
|
AverageMeter(),
|
||||||
|
AverageMeter(),
|
||||||
|
AverageMeter(),
|
||||||
|
)
|
||||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||||
extra_info["epoch-str"],
|
extra_info["epoch-str"],
|
||||||
@ -46,10 +51,16 @@ def search_train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
network.train()
|
network.train()
|
||||||
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
|
logger.log(
|
||||||
|
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||||
|
epoch_str, flop_need, flop_weight
|
||||||
|
)
|
||||||
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
network.apply(change_key("search_mode", "search"))
|
network.apply(change_key("search_mode", "search"))
|
||||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||||
|
search_loader
|
||||||
|
):
|
||||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||||
# calculate prediction and loss
|
# calculate prediction and loss
|
||||||
base_targets = base_targets.cuda(non_blocking=True)
|
base_targets = base_targets.cuda(non_blocking=True)
|
||||||
@ -75,7 +86,9 @@ def search_train(
|
|||||||
arch_optimizer.zero_grad()
|
arch_optimizer.zero_grad()
|
||||||
logits, expected_flop = network(arch_inputs)
|
logits, expected_flop = network(arch_inputs)
|
||||||
flop_cur = network.module.get_flop("genotype", None, None)
|
flop_cur = network.module.get_flop("genotype", None, None)
|
||||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
flop_loss, flop_loss_scale = get_flop_loss(
|
||||||
|
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||||
|
)
|
||||||
acls_loss = criterion(logits, arch_targets)
|
acls_loss = criterion(logits, arch_targets)
|
||||||
arch_loss = acls_loss + flop_loss * flop_weight
|
arch_loss = acls_loss + flop_loss * flop_weight
|
||||||
arch_loss.backward()
|
arch_loss.backward()
|
||||||
@ -90,7 +103,11 @@ def search_train(
|
|||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||||
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
Sstr = (
|
||||||
|
"**TRAIN** "
|
||||||
|
+ time_string()
|
||||||
|
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||||
|
)
|
||||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||||
batch_time=batch_time, data_time=data_time
|
batch_time=batch_time, data_time=data_time
|
||||||
)
|
)
|
||||||
@ -153,7 +170,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
|||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
if i % print_freq == 0 or (i + 1) == len(xloader):
|
if i % print_freq == 0 or (i + 1) == len(xloader):
|
||||||
Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
Sstr = (
|
||||||
|
"**VALID** "
|
||||||
|
+ time_string()
|
||||||
|
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
|
||||||
|
)
|
||||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||||
batch_time=batch_time, data_time=data_time
|
batch_time=batch_time, data_time=data_time
|
||||||
)
|
)
|
||||||
@ -165,7 +186,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
|
|||||||
|
|
||||||
logger.log(
|
logger.log(
|
||||||
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||||
top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
top1=top1,
|
||||||
|
top5=top5,
|
||||||
|
error1=100 - top1.avg,
|
||||||
|
error5=100 - top5.avg,
|
||||||
|
loss=losses.avg,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,7 +36,12 @@ def search_train_v2(
|
|||||||
logger,
|
logger,
|
||||||
):
|
):
|
||||||
data_time, batch_time = AverageMeter(), AverageMeter()
|
data_time, batch_time = AverageMeter(), AverageMeter()
|
||||||
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
|
base_losses, arch_losses, top1, top5 = (
|
||||||
|
AverageMeter(),
|
||||||
|
AverageMeter(),
|
||||||
|
AverageMeter(),
|
||||||
|
AverageMeter(),
|
||||||
|
)
|
||||||
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
|
||||||
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
epoch_str, flop_need, flop_weight, flop_tolerant = (
|
||||||
extra_info["epoch-str"],
|
extra_info["epoch-str"],
|
||||||
@ -46,10 +51,16 @@ def search_train_v2(
|
|||||||
)
|
)
|
||||||
|
|
||||||
network.train()
|
network.train()
|
||||||
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
|
logger.log(
|
||||||
|
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
|
||||||
|
epoch_str, flop_need, flop_weight
|
||||||
|
)
|
||||||
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
network.apply(change_key("search_mode", "search"))
|
network.apply(change_key("search_mode", "search"))
|
||||||
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
|
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
|
||||||
|
search_loader
|
||||||
|
):
|
||||||
scheduler.update(None, 1.0 * step / len(search_loader))
|
scheduler.update(None, 1.0 * step / len(search_loader))
|
||||||
# calculate prediction and loss
|
# calculate prediction and loss
|
||||||
base_targets = base_targets.cuda(non_blocking=True)
|
base_targets = base_targets.cuda(non_blocking=True)
|
||||||
@ -73,7 +84,9 @@ def search_train_v2(
|
|||||||
arch_optimizer.zero_grad()
|
arch_optimizer.zero_grad()
|
||||||
logits, expected_flop = network(arch_inputs)
|
logits, expected_flop = network(arch_inputs)
|
||||||
flop_cur = network.module.get_flop("genotype", None, None)
|
flop_cur = network.module.get_flop("genotype", None, None)
|
||||||
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
|
flop_loss, flop_loss_scale = get_flop_loss(
|
||||||
|
expected_flop, flop_cur, flop_need, flop_tolerant
|
||||||
|
)
|
||||||
acls_loss = criterion(logits, arch_targets)
|
acls_loss = criterion(logits, arch_targets)
|
||||||
arch_loss = acls_loss + flop_loss * flop_weight
|
arch_loss = acls_loss + flop_loss * flop_weight
|
||||||
arch_loss.backward()
|
arch_loss.backward()
|
||||||
@ -88,7 +101,11 @@ def search_train_v2(
|
|||||||
batch_time.update(time.time() - end)
|
batch_time.update(time.time() - end)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
if step % print_freq == 0 or (step + 1) == len(search_loader):
|
||||||
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
Sstr = (
|
||||||
|
"**TRAIN** "
|
||||||
|
+ time_string()
|
||||||
|
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
|
||||||
|
)
|
||||||
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
|
||||||
batch_time=batch_time, data_time=data_time
|
batch_time=batch_time, data_time=data_time
|
||||||
)
|
)
|
||||||
|
@ -10,7 +10,16 @@ from utils import obtain_accuracy
|
|||||||
|
|
||||||
|
|
||||||
def simple_KD_train(
|
def simple_KD_train(
|
||||||
xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger
|
xloader,
|
||||||
|
teacher,
|
||||||
|
network,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
optimizer,
|
||||||
|
optim_config,
|
||||||
|
extra_info,
|
||||||
|
print_freq,
|
||||||
|
logger,
|
||||||
):
|
):
|
||||||
loss, acc1, acc5 = procedure(
|
loss, acc1, acc5 = procedure(
|
||||||
xloader,
|
xloader,
|
||||||
@ -28,25 +37,58 @@ def simple_KD_train(
|
|||||||
return loss, acc1, acc5
|
return loss, acc1, acc5
|
||||||
|
|
||||||
|
|
||||||
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
|
def simple_KD_valid(
|
||||||
|
xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger
|
||||||
|
):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
loss, acc1, acc5 = procedure(
|
loss, acc1, acc5 = procedure(
|
||||||
xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger
|
xloader,
|
||||||
|
teacher,
|
||||||
|
network,
|
||||||
|
criterion,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
"valid",
|
||||||
|
optim_config,
|
||||||
|
extra_info,
|
||||||
|
print_freq,
|
||||||
|
logger,
|
||||||
)
|
)
|
||||||
return loss, acc1, acc5
|
return loss, acc1, acc5
|
||||||
|
|
||||||
|
|
||||||
def loss_KD_fn(
|
def loss_KD_fn(
|
||||||
criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature
|
criterion,
|
||||||
|
student_logits,
|
||||||
|
teacher_logits,
|
||||||
|
studentFeatures,
|
||||||
|
teacherFeatures,
|
||||||
|
targets,
|
||||||
|
alpha,
|
||||||
|
temperature,
|
||||||
):
|
):
|
||||||
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
|
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
|
||||||
log_student = F.log_softmax(student_logits / temperature, dim=1)
|
log_student = F.log_softmax(student_logits / temperature, dim=1)
|
||||||
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
|
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
|
||||||
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature)
|
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (
|
||||||
|
alpha * temperature * temperature
|
||||||
|
)
|
||||||
return basic_loss + KD_loss
|
return basic_loss + KD_loss
|
||||||
|
|
||||||
|
|
||||||
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
|
def procedure(
|
||||||
|
xloader,
|
||||||
|
teacher,
|
||||||
|
network,
|
||||||
|
criterion,
|
||||||
|
scheduler,
|
||||||
|
optimizer,
|
||||||
|
mode,
|
||||||
|
config,
|
||||||
|
extra_info,
|
||||||
|
print_freq,
|
||||||
|
logger,
|
||||||
|
):
|
||||||
data_time, batch_time, losses, top1, top5 = (
|
data_time, batch_time, losses, top1, top5 = (
|
||||||
AverageMeter(),
|
AverageMeter(),
|
||||||
AverageMeter(),
|
AverageMeter(),
|
||||||
@ -65,7 +107,10 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
|||||||
|
|
||||||
logger.log(
|
logger.log(
|
||||||
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
|
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
|
||||||
mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature
|
mode,
|
||||||
|
config.auxiliary if hasattr(config, "auxiliary") else -1,
|
||||||
|
config.KD_alpha,
|
||||||
|
config.KD_temperature,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
@ -82,7 +127,9 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
|||||||
|
|
||||||
student_f, logits = network(inputs)
|
student_f, logits = network(inputs)
|
||||||
if isinstance(logits, list):
|
if isinstance(logits, list):
|
||||||
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
|
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
|
||||||
|
2, len(logits)
|
||||||
|
)
|
||||||
logits, logits_aux = logits
|
logits, logits_aux = logits
|
||||||
else:
|
else:
|
||||||
logits, logits_aux = logits, None
|
logits, logits_aux = logits, None
|
||||||
@ -90,7 +137,14 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
|||||||
teacher_f, teacher_logits = teacher(inputs)
|
teacher_f, teacher_logits = teacher(inputs)
|
||||||
|
|
||||||
loss = loss_KD_fn(
|
loss = loss_KD_fn(
|
||||||
criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature
|
criterion,
|
||||||
|
logits,
|
||||||
|
teacher_logits,
|
||||||
|
student_f,
|
||||||
|
teacher_f,
|
||||||
|
targets,
|
||||||
|
config.KD_alpha,
|
||||||
|
config.KD_temperature,
|
||||||
)
|
)
|
||||||
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
|
||||||
loss_aux = criterion(logits_aux, targets)
|
loss_aux = criterion(logits_aux, targets)
|
||||||
@ -139,7 +193,12 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
|
|||||||
)
|
)
|
||||||
logger.log(
|
logger.log(
|
||||||
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
|
||||||
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
|
mode=mode.upper(),
|
||||||
|
top1=top1,
|
||||||
|
top5=top5,
|
||||||
|
error1=100 - top1.avg,
|
||||||
|
error5=100 - top5.avg,
|
||||||
|
loss=losses.avg,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return losses.avg, top1.avg, top5.avg
|
return losses.avg, top1.avg, top5.avg
|
||||||
|
@ -31,7 +31,9 @@ def prepare_logger(xargs):
|
|||||||
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
|
||||||
logger.log(
|
logger.log(
|
||||||
"CUDA_VISIBLE_DEVICES : {:}".format(
|
"CUDA_VISIBLE_DEVICES : {:}".format(
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "None"
|
os.environ["CUDA_VISIBLE_DEVICES"]
|
||||||
|
if "CUDA_VISIBLE_DEVICES" in os.environ
|
||||||
|
else "None"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return logger
|
return logger
|
||||||
@ -54,10 +56,14 @@ def get_machine_info():
|
|||||||
def save_checkpoint(state, filename, logger):
|
def save_checkpoint(state, filename, logger):
|
||||||
if osp.isfile(filename):
|
if osp.isfile(filename):
|
||||||
if hasattr(logger, "log"):
|
if hasattr(logger, "log"):
|
||||||
logger.log("Find {:} exist, delete is at first before saving".format(filename))
|
logger.log(
|
||||||
|
"Find {:} exist, delete is at first before saving".format(filename)
|
||||||
|
)
|
||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
torch.save(state, filename)
|
torch.save(state, filename)
|
||||||
assert osp.isfile(filename), "save filename : {:} failed, which is not found.".format(filename)
|
assert osp.isfile(
|
||||||
|
filename
|
||||||
|
), "save filename : {:} failed, which is not found.".format(filename)
|
||||||
if hasattr(logger, "log"):
|
if hasattr(logger, "log"):
|
||||||
logger.log("save checkpoint into {:}".format(filename))
|
logger.log("save checkpoint into {:}".format(filename))
|
||||||
return filename
|
return filename
|
||||||
|
@ -50,6 +50,10 @@ class Space(metaclass=abc.ABCMeta):
|
|||||||
def clean_last_abstract(self):
|
def clean_last_abstract(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def clean_last(self):
|
||||||
|
self.clean_last_sample()
|
||||||
|
self.clean_last_abstract()
|
||||||
|
|
||||||
@abc.abstractproperty
|
@abc.abstractproperty
|
||||||
def determined(self) -> bool:
|
def determined(self) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -147,11 +147,18 @@ class SuperMLP(SuperModule):
|
|||||||
root_node.append("fc2", space_fc2)
|
root_node.append("fc2", space_fc2)
|
||||||
return root_node
|
return root_node
|
||||||
|
|
||||||
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
||||||
|
super(SuperMLP, self).apply_candidate(abstract_child)
|
||||||
|
if "fc1" in abstract_child:
|
||||||
|
self.fc1.apply_candidate(abstract_child["fc1"])
|
||||||
|
if "fc2" in abstract_child:
|
||||||
|
self.fc2.apply_candidate(abstract_child["fc2"])
|
||||||
|
|
||||||
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
return self._unified_forward(x)
|
return self._unified_forward(input)
|
||||||
|
|
||||||
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
return self._unified_forward(x)
|
return self._unified_forward(input)
|
||||||
|
|
||||||
def _unified_forward(self, x):
|
def _unified_forward(self, x):
|
||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
|
@ -32,7 +32,7 @@ class SuperModule(abc.ABC, nn.Module):
|
|||||||
|
|
||||||
self.apply(_reset_super_run)
|
self.apply(_reset_super_run)
|
||||||
|
|
||||||
def apply_candiate(self, abstract_child):
|
def apply_candidate(self, abstract_child):
|
||||||
if not isinstance(abstract_child, spaces.VirtualNode):
|
if not isinstance(abstract_child, spaces.VirtualNode):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Invalid abstract child program: {:}".format(abstract_child)
|
"Invalid abstract child program: {:}".format(abstract_child)
|
||||||
|
@ -30,32 +30,37 @@ class TestSuperLinear(unittest.TestCase):
|
|||||||
print(model.super_run_type)
|
print(model.super_run_type)
|
||||||
self.assertTrue(model.bias)
|
self.assertTrue(model.bias)
|
||||||
|
|
||||||
inputs = torch.rand(32, 10)
|
inputs = torch.rand(20, 10)
|
||||||
print("Input shape: {:}".format(inputs.shape))
|
print("Input shape: {:}".format(inputs.shape))
|
||||||
print("Weight shape: {:}".format(model._super_weight.shape))
|
print("Weight shape: {:}".format(model._super_weight.shape))
|
||||||
print("Bias shape: {:}".format(model._super_bias.shape))
|
print("Bias shape: {:}".format(model._super_bias.shape))
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
self.assertEqual(tuple(outputs.shape), (32, 36))
|
self.assertEqual(tuple(outputs.shape), (20, 36))
|
||||||
|
|
||||||
abstract_space = model.abstract_search_space
|
abstract_space = model.abstract_search_space
|
||||||
|
abstract_space.clean_last()
|
||||||
abstract_child = abstract_space.random()
|
abstract_child = abstract_space.random()
|
||||||
print("The abstract searc space:\n{:}".format(abstract_space))
|
print("The abstract searc space:\n{:}".format(abstract_space))
|
||||||
print("The abstract child program:\n{:}".format(abstract_child))
|
print("The abstract child program:\n{:}".format(abstract_child))
|
||||||
|
|
||||||
model.set_super_run_type(super_core.SuperRunMode.Candidate)
|
model.set_super_run_type(super_core.SuperRunMode.Candidate)
|
||||||
model.apply_candiate(abstract_child)
|
model.apply_candidate(abstract_child)
|
||||||
|
|
||||||
output_shape = (32, abstract_child["_out_features"].value)
|
output_shape = (20, abstract_child["_out_features"].value)
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
self.assertEqual(tuple(outputs.shape), output_shape)
|
self.assertEqual(tuple(outputs.shape), output_shape)
|
||||||
|
|
||||||
def test_super_mlp(self):
|
def test_super_mlp(self):
|
||||||
hidden_features = spaces.Categorical(12, 24, 36)
|
hidden_features = spaces.Categorical(12, 24, 36)
|
||||||
out_features = spaces.Categorical(12, 24, 36)
|
out_features = spaces.Categorical(24, 36, 48)
|
||||||
mlp = super_core.SuperMLP(10, hidden_features, out_features)
|
mlp = super_core.SuperMLP(10, hidden_features, out_features)
|
||||||
print(mlp)
|
print(mlp)
|
||||||
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
|
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
|
||||||
|
|
||||||
|
inputs = torch.rand(4, 10)
|
||||||
|
outputs = mlp(inputs)
|
||||||
|
self.assertEqual(tuple(outputs.shape), (4, 48))
|
||||||
|
|
||||||
abstract_space = mlp.abstract_search_space
|
abstract_space = mlp.abstract_search_space
|
||||||
print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space))
|
print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -67,10 +72,16 @@ class TestSuperLinear(unittest.TestCase):
|
|||||||
is abstract_space["fc2"]["_in_features"]
|
is abstract_space["fc2"]["_in_features"]
|
||||||
)
|
)
|
||||||
|
|
||||||
abstract_space.clean_last_sample()
|
abstract_space.clean_last()
|
||||||
abstract_child = abstract_space.random(reuse_last=True)
|
abstract_child = abstract_space.random(reuse_last=True)
|
||||||
print("The abstract child program is:\n{:}".format(abstract_child))
|
print("The abstract child program is:\n{:}".format(abstract_child))
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
abstract_child["fc1"]["_out_features"].value,
|
abstract_child["fc1"]["_out_features"].value,
|
||||||
abstract_child["fc2"]["_in_features"].value,
|
abstract_child["fc2"]["_in_features"].value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mlp.set_super_run_type(super_core.SuperRunMode.Candidate)
|
||||||
|
mlp.apply_candidate(abstract_child)
|
||||||
|
outputs = mlp(inputs)
|
||||||
|
output_shape = (4, abstract_child["fc2"]["_out_features"].value)
|
||||||
|
self.assertEqual(tuple(outputs.shape), output_shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user