Fix small bugs
This commit is contained in:
parent
3b1d8f1e4b
commit
36bb07ef1a
4
.gitignore
vendored
4
.gitignore
vendored
@ -103,3 +103,7 @@ main_main.py
|
|||||||
scripts-nas/.nfs00*
|
scripts-nas/.nfs00*
|
||||||
*/.nfs00*
|
*/.nfs00*
|
||||||
*.DS_Store
|
*.DS_Store
|
||||||
|
|
||||||
|
# logs and snapshots
|
||||||
|
output
|
||||||
|
logs
|
||||||
|
@ -108,11 +108,11 @@ def main_procedure_imagenet(config, data_path, args, genotype, init_channels, la
|
|||||||
for epoch in range(start_epoch, config.epochs):
|
for epoch in range(start_epoch, config.epochs):
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
need_time = convert_secs2time(epoch_time.val * (config.epochs-epoch), True)
|
|
||||||
print_log("\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} LR={:6.4f} ~ {:6.4f}, Batch={:d}".format(time_string(), epoch, config.epochs, need_time, min(scheduler.get_lr()), max(scheduler.get_lr()), config.batch_size), log)
|
|
||||||
|
|
||||||
basemodel.update_drop_path(config.drop_path_prob * epoch / config.epochs)
|
basemodel.update_drop_path(config.drop_path_prob * epoch / config.epochs)
|
||||||
|
|
||||||
|
need_time = convert_secs2time(epoch_time.val * (config.epochs-epoch), True)
|
||||||
|
print_log("\n==>>{:s} [Epoch={:03d}/{:03d}] {:s} LR={:6.4f} ~ {:6.4f}, Batch={:d}, Drop-Path-Prob={:}".format(time_string(), epoch, config.epochs, need_time, min(scheduler.get_lr()), max(scheduler.get_lr()), config.batch_size, basemodel.get_drop_path()), log)
|
||||||
|
|
||||||
train_acc1, train_acc5, train_los = _train(train_queue, model, criterion_smooth, optimizer, 'train', epoch, config, args.print_freq, log)
|
train_acc1, train_acc5, train_los = _train(train_queue, model, criterion_smooth, optimizer, 'train', epoch, config, args.print_freq, log)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -60,14 +60,14 @@ def get_datasets(name, root, cutout):
|
|||||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
if name == 'cifar10':
|
if name == 'cifar10':
|
||||||
train_data = dset.CIFAR10(root, train=True , transform=train_transform, download=True)
|
train_data = dset.CIFAR10 (root, train=True , transform=train_transform, download=True)
|
||||||
test_data = dset.CIFAR10(root, train=False, transform=test_transform , download=True)
|
test_data = dset.CIFAR10 (root, train=False, transform=test_transform , download=True)
|
||||||
elif name == 'cifar100':
|
elif name == 'cifar100':
|
||||||
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
train_data = dset.CIFAR100(root, train=True , transform=train_transform, download=True)
|
||||||
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
test_data = dset.CIFAR100(root, train=False, transform=test_transform , download=True)
|
||||||
elif name == 'imagenet-1k' or name == 'imagenet-100':
|
elif name == 'imagenet-1k' or name == 'imagenet-100':
|
||||||
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
|
||||||
test_data = dset.ImageFolder(osp.join(root, 'val'), train_transform)
|
test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
|
||||||
else: raise TypeError("Unknow dataset : {:}".format(name))
|
else: raise TypeError("Unknow dataset : {:}".format(name))
|
||||||
|
|
||||||
class_num = Dataset2Class[name]
|
class_num = Dataset2Class[name]
|
||||||
|
@ -80,6 +80,9 @@ class NetworkImageNet(nn.Module):
|
|||||||
def update_drop_path(self, drop_path_prob):
|
def update_drop_path(self, drop_path_prob):
|
||||||
self.drop_path_prob = drop_path_prob
|
self.drop_path_prob = drop_path_prob
|
||||||
|
|
||||||
|
def get_drop_path(self):
|
||||||
|
return self.drop_path_prob
|
||||||
|
|
||||||
def auxiliary_param(self):
|
def auxiliary_param(self):
|
||||||
if self.auxiliary_head is None: return []
|
if self.auxiliary_head is None: return []
|
||||||
else: return list( self.auxiliary_head.parameters() )
|
else: return list( self.auxiliary_head.parameters() )
|
||||||
|
Loading…
Reference in New Issue
Block a user