Update SuperMLP

This commit is contained in:
D-X-Y 2021-03-19 23:57:23 +08:00
parent 31b8122cc1
commit 0c56a729ad
13 changed files with 412 additions and 85 deletions

View File

@ -35,6 +35,7 @@ jobs:
python -m black ./lib/xlayers -l 88 --check --diff --verbose
python -m black ./lib/spaces -l 88 --check --diff --verbose
python -m black ./lib/trade_models -l 88 --check --diff --verbose
python -m black ./lib/procedures -l 88 --check --diff --verbose
- name: Test Search Space
run: |

View File

@ -7,22 +7,63 @@ from log_utils import time_string
from utils import obtain_accuracy
def basic_train(xloader, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger):
def basic_train(
xloader,
network,
criterion,
scheduler,
optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
loss, acc1, acc5 = procedure(
xloader, network, criterion, scheduler, optimizer, "train", optim_config, extra_info, print_freq, logger
xloader,
network,
criterion,
scheduler,
optimizer,
"train",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def basic_valid(xloader, network, criterion, optim_config, extra_info, print_freq, logger):
def basic_valid(
xloader, network, criterion, optim_config, extra_info, print_freq, logger
):
with torch.no_grad():
loss, acc1, acc5 = procedure(
xloader, network, criterion, None, None, "valid", None, extra_info, print_freq, logger
xloader,
network,
criterion,
None,
None,
"valid",
None,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
def procedure(
xloader,
network,
criterion,
scheduler,
optimizer,
mode,
config,
extra_info,
print_freq,
logger,
):
data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
@ -39,7 +80,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
# logger.log('[{:5s}] config :: auxiliary={:}, message={:}'.format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1, network.module.get_message()))
logger.log(
"[{:5s}] config :: auxiliary={:}".format(mode, config.auxiliary if hasattr(config, "auxiliary") else -1)
"[{:5s}] config :: auxiliary={:}".format(
mode, config.auxiliary if hasattr(config, "auxiliary") else -1
)
)
end = time.time()
for i, (inputs, targets) in enumerate(xloader):
@ -55,7 +98,9 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
features, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
2, len(logits)
)
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
@ -97,7 +142,12 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode, config, e
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
mode=mode.upper(),
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)
return losses.avg, top1.avg, top5.avg

View File

@ -78,23 +78,42 @@ def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
return losses.avg, top1.avg, top5.avg, batch_time.sum
def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed: int, logger):
def evaluate_for_seed(
arch_config, opt_config, train_loader, valid_loaders, seed: int, logger
):
prepare_seed(seed) # random seed
net = get_cell_based_tiny_net(arch_config)
# net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
flop, param = get_model_infos(net, opt_config.xshape)
logger.log("Network : {:}".format(net.get_message()), False)
logger.log("{:} Seed-------------------------- {:} --------------------------".format(time_string(), seed))
logger.log(
"{:} Seed-------------------------- {:} --------------------------".format(
time_string(), seed
)
)
logger.log("FLOP = {:} MB, Param = {:} MB".format(flop, param))
# train and valid
optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), opt_config)
default_device = torch.cuda.current_device()
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(device=default_device)
network = torch.nn.DataParallel(net, device_ids=[default_device]).cuda(
device=default_device
)
criterion = criterion.cuda(device=default_device)
# start training
start_time, epoch_time, total_epoch = time.time(), AverageMeter(), opt_config.epochs + opt_config.warmup
train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
start_time, epoch_time, total_epoch = (
time.time(),
AverageMeter(),
opt_config.epochs + opt_config.warmup,
)
(
train_losses,
train_acc1es,
train_acc5es,
valid_losses,
valid_acc1es,
valid_acc5es,
) = ({}, {}, {}, {}, {}, {})
train_times, valid_times, lrs = {}, {}, {}
for epoch in range(total_epoch):
scheduler.update(epoch, 0.0)
@ -120,7 +139,9 @@ def evaluate_for_seed(arch_config, opt_config, train_loader, valid_loaders, seed
# measure elapsed time
epoch_time.update(time.time() - start_time)
start_time = time.time()
need_time = "Time Left: {:}".format(convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True))
need_time = "Time Left: {:}".format(
convert_secs2time(epoch_time.avg * (total_epoch - epoch - 1), True)
)
logger.log(
"{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%], lr={:}".format(
time_string(),
@ -171,14 +192,29 @@ def get_nas_bench_loaders(workers):
break_line = "-" * 150
print("{:} Create data-loader for all datasets".format(time_string()))
print(break_line)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets("cifar10", str(torch_dir / "cifar.python"), -1)
TRAIN_CIFAR10, VALID_CIFAR10, xshape, class_num = get_datasets(
"cifar10", str(torch_dir / "cifar.python"), -1
)
print(
"original CIFAR-10 : {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_CIFAR10), len(VALID_CIFAR10), xshape, class_num
)
)
cifar10_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None)
assert cifar10_splits.train[:10] == [0, 5, 7, 11, 13, 15, 16, 17, 20, 24] and cifar10_splits.valid[:10] == [
cifar10_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "cifar-split.txt", None, None
)
assert cifar10_splits.train[:10] == [
0,
5,
7,
11,
13,
15,
16,
17,
20,
24,
] and cifar10_splits.valid[:10] == [
1,
2,
3,
@ -194,7 +230,11 @@ def get_nas_bench_loaders(workers):
temp_dataset.transform = VALID_CIFAR10.transform
# data loader
trainval_cifar10_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR10, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
TRAIN_CIFAR10,
batch_size=cifar_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
train_cifar10_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR10,
@ -211,7 +251,11 @@ def get_nas_bench_loaders(workers):
pin_memory=True,
)
test__cifar10_loader = torch.utils.data.DataLoader(
VALID_CIFAR10, batch_size=cifar_config.batch_size, shuffle=False, num_workers=workers, pin_memory=True
VALID_CIFAR10,
batch_size=cifar_config.batch_size,
shuffle=False,
num_workers=workers,
pin_memory=True,
)
print(
"CIFAR-10 : trval-loader has {:3d} batch with {:} per batch".format(
@ -235,14 +279,29 @@ def get_nas_bench_loaders(workers):
)
print(break_line)
# CIFAR-100
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets("cifar100", str(torch_dir / "cifar.python"), -1)
TRAIN_CIFAR100, VALID_CIFAR100, xshape, class_num = get_datasets(
"cifar100", str(torch_dir / "cifar.python"), -1
)
print(
"original CIFAR-100: {:} training images and {:} test images : {:} input shape : {:} number of classes".format(
len(TRAIN_CIFAR100), len(VALID_CIFAR100), xshape, class_num
)
)
cifar100_splits = load_config(root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None)
assert cifar100_splits.xvalid[:10] == [1, 3, 4, 5, 8, 10, 13, 14, 15, 16] and cifar100_splits.xtest[:10] == [
cifar100_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "cifar100-test-split.txt", None, None
)
assert cifar100_splits.xvalid[:10] == [
1,
3,
4,
5,
8,
10,
13,
14,
15,
16,
] and cifar100_splits.xtest[:10] == [
0,
2,
6,
@ -255,7 +314,11 @@ def get_nas_bench_loaders(workers):
24,
]
train_cifar100_loader = torch.utils.data.DataLoader(
TRAIN_CIFAR100, batch_size=cifar_config.batch_size, shuffle=True, num_workers=workers, pin_memory=True
TRAIN_CIFAR100,
batch_size=cifar_config.batch_size,
shuffle=True,
num_workers=workers,
pin_memory=True,
)
valid_cifar100_loader = torch.utils.data.DataLoader(
VALID_CIFAR100,
@ -271,9 +334,15 @@ def get_nas_bench_loaders(workers):
num_workers=workers,
pin_memory=True,
)
print("CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader)))
print("CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader)))
print("CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader)))
print(
"CIFAR-100 : train-loader has {:3d} batch".format(len(train_cifar100_loader))
)
print(
"CIFAR-100 : valid-loader has {:3d} batch".format(len(valid_cifar100_loader))
)
print(
"CIFAR-100 : test--loader has {:3d} batch".format(len(test__cifar100_loader))
)
print(break_line)
imagenet16_config_path = "configs/nas-benchmark/ImageNet-16.config"
@ -286,8 +355,23 @@ def get_nas_bench_loaders(workers):
len(TRAIN_ImageNet16_120), len(VALID_ImageNet16_120), xshape, class_num
)
)
imagenet_splits = load_config(root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt", None, None)
assert imagenet_splits.xvalid[:10] == [1, 2, 3, 6, 7, 8, 9, 12, 16, 18] and imagenet_splits.xtest[:10] == [
imagenet_splits = load_config(
root_dir / "configs" / "nas-benchmark" / "imagenet-16-120-test-split.txt",
None,
None,
)
assert imagenet_splits.xvalid[:10] == [
1,
2,
3,
6,
7,
8,
9,
12,
16,
18,
] and imagenet_splits.xtest[:10] == [
0,
4,
5,

View File

@ -14,7 +14,9 @@ class _LRScheduler(object):
self.optimizer = optimizer
for group in optimizer.param_groups:
group.setdefault("initial_lr", group["lr"])
self.base_lrs = list(map(lambda group: group["initial_lr"], optimizer.param_groups))
self.base_lrs = list(
map(lambda group: group["initial_lr"], optimizer.param_groups)
)
self.max_epochs = epochs
self.warmup_epochs = warmup_epochs
self.current_epoch = 0
@ -31,7 +33,9 @@ class _LRScheduler(object):
)
def state_dict(self):
return {key: value for key, value in self.__dict__.items() if key != "optimizer"}
return {
key: value for key, value in self.__dict__.items() if key != "optimizer"
}
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
@ -50,10 +54,14 @@ class _LRScheduler(object):
def update(self, cur_epoch, cur_iter):
if cur_epoch is not None:
assert isinstance(cur_epoch, int) and cur_epoch >= 0, "invalid cur-epoch : {:}".format(cur_epoch)
assert (
isinstance(cur_epoch, int) and cur_epoch >= 0
), "invalid cur-epoch : {:}".format(cur_epoch)
self.current_epoch = cur_epoch
if cur_iter is not None:
assert isinstance(cur_iter, float) and cur_iter >= 0, "invalid cur-iter : {:}".format(cur_iter)
assert (
isinstance(cur_iter, float) and cur_iter >= 0
), "invalid cur-iter : {:}".format(cur_iter)
self.current_iter = cur_iter
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group["lr"] = lr
@ -66,29 +74,44 @@ class CosineAnnealingLR(_LRScheduler):
super(CosineAnnealingLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return "type={:}, T-max={:}, eta-min={:}".format("cosine", self.T_max, self.eta_min)
return "type={:}, T-max={:}, eta-min={:}".format(
"cosine", self.T_max, self.eta_min
)
def get_lr(self):
lrs = []
for base_lr in self.base_lrs:
if self.current_epoch >= self.warmup_epochs and self.current_epoch < self.max_epochs:
if (
self.current_epoch >= self.warmup_epochs
and self.current_epoch < self.max_epochs
):
last_epoch = self.current_epoch - self.warmup_epochs
# if last_epoch < self.T_max:
# if last_epoch < self.max_epochs:
lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * last_epoch / self.T_max)) / 2
lr = (
self.eta_min
+ (base_lr - self.eta_min)
* (1 + math.cos(math.pi * last_epoch / self.T_max))
/ 2
)
# else:
# lr = self.eta_min + (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.T_max-1.0) / self.T_max)) / 2
elif self.current_epoch >= self.max_epochs:
lr = self.eta_min
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
class MultiStepLR(_LRScheduler):
def __init__(self, optimizer, warmup_epochs, epochs, milestones, gammas):
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(len(milestones), len(gammas))
assert len(milestones) == len(gammas), "invalid {:} vs {:}".format(
len(milestones), len(gammas)
)
self.milestones = milestones
self.gammas = gammas
super(MultiStepLR, self).__init__(optimizer, warmup_epochs, epochs)
@ -108,7 +131,10 @@ class MultiStepLR(_LRScheduler):
for x in self.gammas[:idx]:
lr *= x
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
@ -119,7 +145,9 @@ class ExponentialLR(_LRScheduler):
super(ExponentialLR, self).__init__(optimizer, warmup_epochs, epochs)
def extra_repr(self):
return "type={:}, gamma={:}, base-lrs={:}".format("exponential", self.gamma, self.base_lrs)
return "type={:}, gamma={:}, base-lrs={:}".format(
"exponential", self.gamma, self.base_lrs
)
def get_lr(self):
lrs = []
@ -129,7 +157,10 @@ class ExponentialLR(_LRScheduler):
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
lr = base_lr * (self.gamma ** last_epoch)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
@ -151,10 +182,18 @@ class LinearLR(_LRScheduler):
if self.current_epoch >= self.warmup_epochs:
last_epoch = self.current_epoch - self.warmup_epochs
assert last_epoch >= 0, "invalid last_epoch : {:}".format(last_epoch)
ratio = (self.max_LR - self.min_LR) * last_epoch / self.max_epochs / self.max_LR
ratio = (
(self.max_LR - self.min_LR)
* last_epoch
/ self.max_epochs
/ self.max_LR
)
lr = base_lr * (1 - ratio)
else:
lr = (self.current_epoch / self.warmup_epochs + self.current_iter / self.warmup_epochs) * base_lr
lr = (
self.current_epoch / self.warmup_epochs
+ self.current_iter / self.warmup_epochs
) * base_lr
lrs.append(lr)
return lrs
@ -176,26 +215,42 @@ class CrossEntropyLabelSmooth(nn.Module):
def get_optim_scheduler(parameters, config):
assert (
hasattr(config, "optim") and hasattr(config, "scheduler") and hasattr(config, "criterion")
), "config must have optim / scheduler / criterion keys instead of {:}".format(config)
hasattr(config, "optim")
and hasattr(config, "scheduler")
and hasattr(config, "criterion")
), "config must have optim / scheduler / criterion keys instead of {:}".format(
config
)
if config.optim == "SGD":
optim = torch.optim.SGD(
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay, nesterov=config.nesterov
parameters,
config.LR,
momentum=config.momentum,
weight_decay=config.decay,
nesterov=config.nesterov,
)
elif config.optim == "RMSprop":
optim = torch.optim.RMSprop(parameters, config.LR, momentum=config.momentum, weight_decay=config.decay)
optim = torch.optim.RMSprop(
parameters, config.LR, momentum=config.momentum, weight_decay=config.decay
)
else:
raise ValueError("invalid optim : {:}".format(config.optim))
if config.scheduler == "cos":
T_max = getattr(config, "T_max", config.epochs)
scheduler = CosineAnnealingLR(optim, config.warmup, config.epochs, T_max, config.eta_min)
scheduler = CosineAnnealingLR(
optim, config.warmup, config.epochs, T_max, config.eta_min
)
elif config.scheduler == "multistep":
scheduler = MultiStepLR(optim, config.warmup, config.epochs, config.milestones, config.gammas)
scheduler = MultiStepLR(
optim, config.warmup, config.epochs, config.milestones, config.gammas
)
elif config.scheduler == "exponential":
scheduler = ExponentialLR(optim, config.warmup, config.epochs, config.gamma)
elif config.scheduler == "linear":
scheduler = LinearLR(optim, config.warmup, config.epochs, config.LR, config.LR_min)
scheduler = LinearLR(
optim, config.warmup, config.epochs, config.LR, config.LR_min
)
else:
raise ValueError("invalid scheduler : {:}".format(config.scheduler))

View File

@ -41,7 +41,10 @@ def update_gpu(config, gpu):
if "task" in config and "model" in config["task"]:
if "GPU" in config["task"]["model"]:
config["task"]["model"]["GPU"] = gpu
elif "kwargs" in config["task"]["model"] and "GPU" in config["task"]["model"]["kwargs"]:
elif (
"kwargs" in config["task"]["model"]
and "GPU" in config["task"]["model"]["kwargs"]
):
config["task"]["model"]["kwargs"]["GPU"] = gpu
elif "model" in config:
if "GPU" in config["model"]:
@ -68,7 +71,12 @@ def run_exp(task_config, dataset, experiment_name, recorder_name, uri):
model_fit_kwargs = dict(dataset=dataset)
# Let's start the experiment.
with R.start(experiment_name=experiment_name, recorder_name=recorder_name, uri=uri, resume=True):
with R.start(
experiment_name=experiment_name,
recorder_name=recorder_name,
uri=uri,
resume=True,
):
# Setup log
recorder_root_dir = R.get_recorder().get_local_dir()
log_file = os.path.join(recorder_root_dir, "{:}.log".format(experiment_name))

View File

@ -36,7 +36,12 @@ def search_train(
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = (
extra_info["epoch-str"],
@ -46,10 +51,16 @@ def search_train(
)
network.train()
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
logger.log(
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
epoch_str, flop_need, flop_weight
)
)
end = time.time()
network.apply(change_key("search_mode", "search"))
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
search_loader
):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
@ -75,7 +86,9 @@ def search_train(
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop("genotype", None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
flop_loss, flop_loss_scale = get_flop_loss(
expected_flop, flop_cur, flop_need, flop_tolerant
)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
@ -90,7 +103,11 @@ def search_train(
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step + 1) == len(search_loader):
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
Sstr = (
"**TRAIN** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
@ -153,7 +170,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
end = time.time()
if i % print_freq == 0 or (i + 1) == len(xloader):
Sstr = "**VALID** " + time_string() + " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
Sstr = (
"**VALID** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(extra_info, i, len(xloader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)
@ -165,7 +186,11 @@ def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
logger.log(
" **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)

View File

@ -36,7 +36,12 @@ def search_train_v2(
logger,
):
data_time, batch_time = AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
base_losses, arch_losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
AverageMeter(),
AverageMeter(),
)
arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
epoch_str, flop_need, flop_weight, flop_tolerant = (
extra_info["epoch-str"],
@ -46,10 +51,16 @@ def search_train_v2(
)
network.train()
logger.log("[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(epoch_str, flop_need, flop_weight))
logger.log(
"[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}".format(
epoch_str, flop_need, flop_weight
)
)
end = time.time()
network.apply(change_key("search_mode", "search"))
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(
search_loader
):
scheduler.update(None, 1.0 * step / len(search_loader))
# calculate prediction and loss
base_targets = base_targets.cuda(non_blocking=True)
@ -73,7 +84,9 @@ def search_train_v2(
arch_optimizer.zero_grad()
logits, expected_flop = network(arch_inputs)
flop_cur = network.module.get_flop("genotype", None, None)
flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant)
flop_loss, flop_loss_scale = get_flop_loss(
expected_flop, flop_cur, flop_need, flop_tolerant
)
acls_loss = criterion(logits, arch_targets)
arch_loss = acls_loss + flop_loss * flop_weight
arch_loss.backward()
@ -88,7 +101,11 @@ def search_train_v2(
batch_time.update(time.time() - end)
end = time.time()
if step % print_freq == 0 or (step + 1) == len(search_loader):
Sstr = "**TRAIN** " + time_string() + " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
Sstr = (
"**TRAIN** "
+ time_string()
+ " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(search_loader))
)
Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
batch_time=batch_time, data_time=data_time
)

View File

@ -10,7 +10,16 @@ from utils import obtain_accuracy
def simple_KD_train(
xloader, teacher, network, criterion, scheduler, optimizer, optim_config, extra_info, print_freq, logger
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
optim_config,
extra_info,
print_freq,
logger,
):
loss, acc1, acc5 = procedure(
xloader,
@ -28,25 +37,58 @@ def simple_KD_train(
return loss, acc1, acc5
def simple_KD_valid(xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger):
def simple_KD_valid(
xloader, teacher, network, criterion, optim_config, extra_info, print_freq, logger
):
with torch.no_grad():
loss, acc1, acc5 = procedure(
xloader, teacher, network, criterion, None, None, "valid", optim_config, extra_info, print_freq, logger
xloader,
teacher,
network,
criterion,
None,
None,
"valid",
optim_config,
extra_info,
print_freq,
logger,
)
return loss, acc1, acc5
def loss_KD_fn(
criterion, student_logits, teacher_logits, studentFeatures, teacherFeatures, targets, alpha, temperature
criterion,
student_logits,
teacher_logits,
studentFeatures,
teacherFeatures,
targets,
alpha,
temperature,
):
basic_loss = criterion(student_logits, targets) * (1.0 - alpha)
log_student = F.log_softmax(student_logits / temperature, dim=1)
sof_teacher = F.softmax(teacher_logits / temperature, dim=1)
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (alpha * temperature * temperature)
KD_loss = F.kl_div(log_student, sof_teacher, reduction="batchmean") * (
alpha * temperature * temperature
)
return basic_loss + KD_loss
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode, config, extra_info, print_freq, logger):
def procedure(
xloader,
teacher,
network,
criterion,
scheduler,
optimizer,
mode,
config,
extra_info,
print_freq,
logger,
):
data_time, batch_time, losses, top1, top5 = (
AverageMeter(),
AverageMeter(),
@ -65,7 +107,10 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
logger.log(
"[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]".format(
mode, config.auxiliary if hasattr(config, "auxiliary") else -1, config.KD_alpha, config.KD_temperature
mode,
config.auxiliary if hasattr(config, "auxiliary") else -1,
config.KD_alpha,
config.KD_temperature,
)
)
end = time.time()
@ -82,7 +127,9 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
student_f, logits = network(inputs)
if isinstance(logits, list):
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(2, len(logits))
assert len(logits) == 2, "logits must has {:} items instead of {:}".format(
2, len(logits)
)
logits, logits_aux = logits
else:
logits, logits_aux = logits, None
@ -90,7 +137,14 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
teacher_f, teacher_logits = teacher(inputs)
loss = loss_KD_fn(
criterion, logits, teacher_logits, student_f, teacher_f, targets, config.KD_alpha, config.KD_temperature
criterion,
logits,
teacher_logits,
student_f,
teacher_f,
targets,
config.KD_alpha,
config.KD_temperature,
)
if config is not None and hasattr(config, "auxiliary") and config.auxiliary > 0:
loss_aux = criterion(logits_aux, targets)
@ -139,7 +193,12 @@ def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
)
logger.log(
" **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}".format(
mode=mode.upper(), top1=top1, top5=top5, error1=100 - top1.avg, error5=100 - top5.avg, loss=losses.avg
mode=mode.upper(),
top1=top1,
top5=top5,
error1=100 - top1.avg,
error5=100 - top5.avg,
loss=losses.avg,
)
)
return losses.avg, top1.avg, top5.avg

View File

@ -31,7 +31,9 @@ def prepare_logger(xargs):
logger.log("CUDA GPU numbers : {:}".format(torch.cuda.device_count()))
logger.log(
"CUDA_VISIBLE_DEVICES : {:}".format(
os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "None"
os.environ["CUDA_VISIBLE_DEVICES"]
if "CUDA_VISIBLE_DEVICES" in os.environ
else "None"
)
)
return logger
@ -54,10 +56,14 @@ def get_machine_info():
def save_checkpoint(state, filename, logger):
if osp.isfile(filename):
if hasattr(logger, "log"):
logger.log("Find {:} exist, delete is at first before saving".format(filename))
logger.log(
"Find {:} exist, delete is at first before saving".format(filename)
)
os.remove(filename)
torch.save(state, filename)
assert osp.isfile(filename), "save filename : {:} failed, which is not found.".format(filename)
assert osp.isfile(
filename
), "save filename : {:} failed, which is not found.".format(filename)
if hasattr(logger, "log"):
logger.log("save checkpoint into {:}".format(filename))
return filename

View File

@ -50,6 +50,10 @@ class Space(metaclass=abc.ABCMeta):
def clean_last_abstract(self):
raise NotImplementedError
def clean_last(self):
self.clean_last_sample()
self.clean_last_abstract()
@abc.abstractproperty
def determined(self) -> bool:
raise NotImplementedError

View File

@ -147,11 +147,18 @@ class SuperMLP(SuperModule):
root_node.append("fc2", space_fc2)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperMLP, self).apply_candidate(abstract_child)
if "fc1" in abstract_child:
self.fc1.apply_candidate(abstract_child["fc1"])
if "fc2" in abstract_child:
self.fc2.apply_candidate(abstract_child["fc2"])
def forward_candidate(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
return self._unified_forward(input)
def forward_raw(self, input: torch.Tensor) -> torch.Tensor:
return self._unified_forward(x)
return self._unified_forward(input)
def _unified_forward(self, x):
x = self.fc1(x)

View File

@ -32,7 +32,7 @@ class SuperModule(abc.ABC, nn.Module):
self.apply(_reset_super_run)
def apply_candiate(self, abstract_child):
def apply_candidate(self, abstract_child):
if not isinstance(abstract_child, spaces.VirtualNode):
raise ValueError(
"Invalid abstract child program: {:}".format(abstract_child)

View File

@ -30,32 +30,37 @@ class TestSuperLinear(unittest.TestCase):
print(model.super_run_type)
self.assertTrue(model.bias)
inputs = torch.rand(32, 10)
inputs = torch.rand(20, 10)
print("Input shape: {:}".format(inputs.shape))
print("Weight shape: {:}".format(model._super_weight.shape))
print("Bias shape: {:}".format(model._super_bias.shape))
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), (32, 36))
self.assertEqual(tuple(outputs.shape), (20, 36))
abstract_space = model.abstract_search_space
abstract_space.clean_last()
abstract_child = abstract_space.random()
print("The abstract searc space:\n{:}".format(abstract_space))
print("The abstract child program:\n{:}".format(abstract_child))
model.set_super_run_type(super_core.SuperRunMode.Candidate)
model.apply_candiate(abstract_child)
model.apply_candidate(abstract_child)
output_shape = (32, abstract_child["_out_features"].value)
output_shape = (20, abstract_child["_out_features"].value)
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), output_shape)
def test_super_mlp(self):
hidden_features = spaces.Categorical(12, 24, 36)
out_features = spaces.Categorical(12, 24, 36)
out_features = spaces.Categorical(24, 36, 48)
mlp = super_core.SuperMLP(10, hidden_features, out_features)
print(mlp)
self.assertTrue(mlp.fc1._out_features, mlp.fc2._in_features)
inputs = torch.rand(4, 10)
outputs = mlp(inputs)
self.assertEqual(tuple(outputs.shape), (4, 48))
abstract_space = mlp.abstract_search_space
print("The abstract search space for SuperMLP is:\n{:}".format(abstract_space))
self.assertEqual(
@ -67,10 +72,16 @@ class TestSuperLinear(unittest.TestCase):
is abstract_space["fc2"]["_in_features"]
)
abstract_space.clean_last_sample()
abstract_space.clean_last()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract child program is:\n{:}".format(abstract_child))
self.assertEqual(
abstract_child["fc1"]["_out_features"].value,
abstract_child["fc2"]["_in_features"].value,
)
mlp.set_super_run_type(super_core.SuperRunMode.Candidate)
mlp.apply_candidate(abstract_child)
outputs = mlp(inputs)
output_shape = (4, abstract_child["fc2"]["_out_features"].value)
self.assertEqual(tuple(outputs.shape), output_shape)