first commit
This commit is contained in:
100
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPHelpers.py
Normal file
100
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPHelpers.py
Normal file
@@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Jul 6 14:02:53 2021
|
||||
|
||||
@author: hsjomaa
|
||||
"""
|
||||
import numpy as np
|
||||
from scipy.stats import norm
|
||||
import pandas as pd
|
||||
from torch import autograd as ag
|
||||
import torch
|
||||
from sklearn.preprocessing import PowerTransformer
|
||||
|
||||
|
||||
def regret(output,response):
|
||||
incumbent = output[0]
|
||||
best_output = []
|
||||
for _ in output:
|
||||
incumbent = _ if _ > incumbent else incumbent
|
||||
best_output.append(incumbent)
|
||||
opt = max(response)
|
||||
orde = list(np.sort(np.unique(response))[::-1])
|
||||
tmp = pd.DataFrame(best_output,columns=['regret_validation'])
|
||||
|
||||
tmp['rank_valid'] = tmp['regret_validation'].map(lambda x : orde.index(x))
|
||||
tmp['regret_validation'] = opt - tmp['regret_validation']
|
||||
return tmp
|
||||
|
||||
def EI(incumbent, model_fn,support,queries,return_variance, return_score=False):
|
||||
mu, stddev = model_fn(queries)
|
||||
mu = mu.reshape(-1,)
|
||||
stddev = stddev.reshape(-1,)
|
||||
if return_variance:
|
||||
stddev = np.sqrt(stddev)
|
||||
with np.errstate(divide='warn'):
|
||||
imp = mu - incumbent
|
||||
Z = imp / stddev
|
||||
score = imp * norm.cdf(Z) + stddev * norm.pdf(Z)
|
||||
if not return_score:
|
||||
score[support] = 0
|
||||
return np.argmax(score)
|
||||
else:
|
||||
return score
|
||||
|
||||
|
||||
class Metric(object):
|
||||
def __init__(self,prefix='train: '):
|
||||
self.reset()
|
||||
self.message=prefix + "loss: {loss:.2f} - noise: {log_var:.2f} - mse: {mse:.2f}"
|
||||
|
||||
def update(self,loss,noise,mse):
|
||||
self.loss.append(np.asscalar(loss))
|
||||
self.noise.append(np.asscalar(noise))
|
||||
self.mse.append(np.asscalar(mse))
|
||||
|
||||
def reset(self,):
|
||||
self.loss = []
|
||||
self.noise = []
|
||||
self.mse = []
|
||||
|
||||
def report(self):
|
||||
return self.message.format(loss=np.mean(self.loss),
|
||||
log_var=np.mean(self.noise),
|
||||
mse=np.mean(self.mse))
|
||||
|
||||
def get(self):
|
||||
return {"loss":np.mean(self.loss),
|
||||
"noise":np.mean(self.noise),
|
||||
"mse":np.mean(self.mse)}
|
||||
|
||||
def totorch(x,device):
|
||||
if type(x) is tuple:
|
||||
return tuple([ag.Variable(torch.Tensor(e)).to(device) for e in x])
|
||||
return torch.Tensor(x).to(device)
|
||||
|
||||
|
||||
def prepare_data(indexes, support, Lambda, response, metafeatures=None, output_transform=False):
|
||||
# Generate indexes of the batch
|
||||
X,E,Z,y,r = [],[],[],[],[]
|
||||
#### get support data
|
||||
for dim in indexes:
|
||||
if metafeatures is not None:
|
||||
Z.append(metafeatures)
|
||||
E.append(Lambda[support])
|
||||
X.append(Lambda[dim])
|
||||
r_ = response[support,np.newaxis]
|
||||
y_ = response[dim]
|
||||
if output_transform:
|
||||
power = PowerTransformer(method="yeo-johnson")
|
||||
r_ = power.fit_transform(r_)
|
||||
y_ = power.transform(y_.reshape(-1,1)).reshape(-1,)
|
||||
r.append(r_)
|
||||
y.append(y_)
|
||||
X = np.array(X)
|
||||
E = np.array(E)
|
||||
Z = np.array(Z)
|
||||
y = np.array(y)
|
||||
r = np.array(r)
|
||||
return (np.expand_dims(E, axis=-1), r, np.expand_dims(X, axis=-1), Z), y
|
||||
581
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPModules.py
Normal file
581
MobileNetV3/main_exp/transfer_nag_lib/DeepKernelGPModules.py
Normal file
@@ -0,0 +1,581 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Created on Tue Jul 6 14:03:42 2021
|
||||
|
||||
@author: hsjomaa
|
||||
"""
|
||||
## Original packages
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
import copy
|
||||
import numpy as np
|
||||
import os
|
||||
# from torch.utils.tensorboard import SummaryWriter
|
||||
import json
|
||||
import time
|
||||
## Our packages
|
||||
import gpytorch
|
||||
import logging
|
||||
from transfer_nag_lib.DeepKernelGPHelpers import totorch,prepare_data, Metric, EI
|
||||
from transfer_nag_lib.MetaD2A_nas_bench_201.generator import Generator
|
||||
from transfer_nag_lib.MetaD2A_nas_bench_201.main import get_parser
|
||||
np.random.seed(1203)
|
||||
RandomQueryGenerator= np.random.RandomState(413)
|
||||
RandomSupportGenerator= np.random.RandomState(413)
|
||||
RandomTaskGenerator = np.random.RandomState(413)
|
||||
|
||||
|
||||
class DeepKernelGP(nn.Module):
|
||||
|
||||
def __init__(self,X,Y,Z,kernel,backbone_fn, config, support,log_dir,seed):
|
||||
super(DeepKernelGP, self).__init__()
|
||||
torch.manual_seed(seed)
|
||||
## GP parameters
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.X,self.Y,self.Z = X,Y,Z
|
||||
self.feature_extractor = backbone_fn().to(self.device)
|
||||
self.config=config
|
||||
self.get_model_likelihood_mll(len(support),kernel,backbone_fn)
|
||||
|
||||
logging.basicConfig(filename=log_dir, level=logging.DEBUG)
|
||||
|
||||
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
|
||||
|
||||
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
|
||||
train_y=torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,
|
||||
dims=self.feature_extractor.out_features)
|
||||
self.model = model.to(self.device)
|
||||
self.likelihood = likelihood.to(self.device)
|
||||
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
|
||||
|
||||
def set_forward(self, x, is_feature=False):
|
||||
pass
|
||||
|
||||
def set_forward_loss(self, x):
|
||||
pass
|
||||
|
||||
def train(self, support, load_model,optimizer, checkpoint=None,epochs=1000, verbose = False):
|
||||
|
||||
if load_model:
|
||||
assert(checkpoint is not None)
|
||||
print("KEYS MATCHED")
|
||||
self.load_checkpoint(os.path.join(checkpoint,"weights"))
|
||||
|
||||
inputs,labels = prepare_data(support,support,self.X,self.Y,self.Z)
|
||||
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
|
||||
losses = [np.inf]
|
||||
best_loss = np.inf
|
||||
starttime = time.time()
|
||||
initial_weights = copy.deepcopy(self.state_dict())
|
||||
patience=0
|
||||
max_patience = self.config["patience"]
|
||||
for _ in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
z = self.feature_extractor(inputs)
|
||||
self.model.set_train_data(inputs=z, targets=labels)
|
||||
predictions = self.model(z)
|
||||
try:
|
||||
loss = -self.mll(predictions, self.model.train_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
except Exception as ada:
|
||||
logging.info(f"Exception {ada}")
|
||||
break
|
||||
|
||||
if verbose:
|
||||
print("Iter {iter}/{epochs} - Loss: {loss:.5f} noise: {noise:.5f}".format(
|
||||
iter=_+1,epochs=epochs,loss=loss.item(),noise=self.likelihood.noise.item()))
|
||||
losses.append(loss.detach().to("cpu").item())
|
||||
if best_loss>losses[-1]:
|
||||
best_loss = losses[-1]
|
||||
weights = copy.deepcopy(self.state_dict())
|
||||
if np.allclose(losses[-1],losses[-2],atol=self.config["loss_tol"]):
|
||||
patience+=1
|
||||
else:
|
||||
patience=0
|
||||
if patience>max_patience:
|
||||
break
|
||||
self.load_state_dict(weights)
|
||||
logging.info(f"Current Iteration: {len(support)} | Incumbent {max(self.Y[support])} | Duration {np.round(time.time()-starttime)} | Epochs {_} | Noise {self.likelihood.noise.item()}")
|
||||
return losses,weights,initial_weights
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
ckpt = torch.load(checkpoint,map_location=torch.device(self.device))
|
||||
self.model.load_state_dict(ckpt['gp'],strict=False)
|
||||
self.likelihood.load_state_dict(ckpt['likelihood'],strict=False)
|
||||
self.feature_extractor.load_state_dict(ckpt['net'],strict=False)
|
||||
|
||||
|
||||
def predict(self,support, query_range=None, noise_fn=None):
|
||||
|
||||
card = len(self.Y)
|
||||
if noise_fn:
|
||||
self.Y = noise_fn(self.Y)
|
||||
x_support,y_support = prepare_data(support,support,
|
||||
self.X,self.Y,self.Z)
|
||||
if query_range is None:
|
||||
x_query,_ = prepare_data(np.arange(card),support,
|
||||
self.X,self.Y,self.Z)
|
||||
else:
|
||||
x_query,_ = prepare_data(query_range,support,
|
||||
self.X,self.Y,self.Z)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
z_support = self.feature_extractor(totorch(x_support,self.device)).detach()
|
||||
self.model.set_train_data(inputs=z_support, targets=totorch(y_support.reshape(-1,),self.device), strict=False)
|
||||
|
||||
with torch.no_grad():
|
||||
z_query = self.feature_extractor(totorch(x_query,self.device)).detach()
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
|
||||
|
||||
mu = pred.mean.detach().to("cpu").numpy().reshape(-1,)
|
||||
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1,)
|
||||
|
||||
return mu,stddev
|
||||
|
||||
class DKT(nn.Module):
|
||||
def __init__(self, train_data,valid_data, kernel,backbone_fn, config):
|
||||
super(DKT, self).__init__()
|
||||
## GP parameters
|
||||
self.train_data = train_data
|
||||
self.valid_data = valid_data
|
||||
self.fixed_context_size = config["fixed_context_size"]
|
||||
self.minibatch_size = config["minibatch_size"]
|
||||
self.n_inner_steps = config["n_inner_steps"]
|
||||
self.checkpoint_path = config["checkpoint_path"]
|
||||
os.makedirs(self.checkpoint_path,exist_ok=False)
|
||||
json.dump(config, open(os.path.join(self.checkpoint_path,"configuration.json"),"w"))
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.basicConfig(filename=os.path.join(self.checkpoint_path,"log.txt"), level=logging.DEBUG)
|
||||
self.feature_extractor = backbone_fn().to(self.device)
|
||||
self.config=config
|
||||
self.get_model_likelihood_mll(self.fixed_context_size,kernel,backbone_fn)
|
||||
self.mse = nn.MSELoss()
|
||||
self.curr_valid_loss = np.inf
|
||||
self.get_tasks()
|
||||
self.setup_writers()
|
||||
|
||||
self.train_metrics = Metric()
|
||||
self.valid_metrics = Metric(prefix="valid: ")
|
||||
print(self)
|
||||
|
||||
|
||||
def setup_writers(self,):
|
||||
train_log_dir = os.path.join(self.checkpoint_path,"train")
|
||||
os.makedirs(train_log_dir,exist_ok=True)
|
||||
self.train_summary_writer = SummaryWriter(train_log_dir)
|
||||
|
||||
valid_log_dir = os.path.join(self.checkpoint_path,"valid")
|
||||
os.makedirs(valid_log_dir,exist_ok=True)
|
||||
self.valid_summary_writer = SummaryWriter(valid_log_dir)
|
||||
|
||||
def get_tasks(self,):
|
||||
pairs = []
|
||||
for space in self.train_data.keys():
|
||||
for task in self.train_data[space].keys():
|
||||
pairs.append([space,task])
|
||||
self.tasks = pairs
|
||||
##########
|
||||
pairs = []
|
||||
for space in self.valid_data.keys():
|
||||
for task in self.valid_data[space].keys():
|
||||
pairs.append([space,task])
|
||||
self.valid_tasks = pairs
|
||||
|
||||
|
||||
def get_model_likelihood_mll(self, train_size,kernel,backbone_fn):
|
||||
|
||||
train_x=torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
|
||||
train_y=torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
model = ExactGPLayer(train_x=train_x, train_y=train_y, likelihood=likelihood, config=self.config,dims = self.feature_extractor.out_features)
|
||||
self.model = model.to(self.device)
|
||||
self.likelihood = likelihood.to(self.device)
|
||||
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
|
||||
|
||||
def set_forward(self, x, is_feature=False):
|
||||
pass
|
||||
|
||||
def set_forward_loss(self, x):
|
||||
pass
|
||||
|
||||
def epoch_end(self):
|
||||
RandomTaskGenerator.shuffle(self.tasks)
|
||||
|
||||
def train_loop(self, epoch, optimizer, scheduler_fn=None):
|
||||
if scheduler_fn:
|
||||
scheduler = scheduler_fn(optimizer,len(self.tasks))
|
||||
self.epoch_end()
|
||||
assert(self.training)
|
||||
for task in self.tasks:
|
||||
inputs, labels = self.get_batch(task)
|
||||
for _ in range(self.n_inner_steps):
|
||||
optimizer.zero_grad()
|
||||
z = self.feature_extractor(inputs)
|
||||
self.model.set_train_data(inputs=z, targets=labels, strict=False)
|
||||
predictions = self.model(z)
|
||||
loss = -self.mll(predictions, self.model.train_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mse = self.mse(predictions.mean, labels)
|
||||
self.train_metrics.update(loss,self.model.likelihood.noise,mse)
|
||||
if scheduler_fn:
|
||||
scheduler.step()
|
||||
|
||||
training_results = self.train_metrics.get()
|
||||
for k,v in training_results.items():
|
||||
self.train_summary_writer.add_scalar(k, v, epoch)
|
||||
for task in self.valid_tasks:
|
||||
mse,loss = self.test_loop(task,train=False)
|
||||
self.valid_metrics.update(loss,np.array(0),mse,)
|
||||
|
||||
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
|
||||
validation_results = self.valid_metrics.get()
|
||||
for k,v in validation_results.items():
|
||||
self.valid_summary_writer.add_scalar(k, v, epoch)
|
||||
self.feature_extractor.train()
|
||||
self.likelihood.train()
|
||||
self.model.train()
|
||||
|
||||
if validation_results["loss"] < self.curr_valid_loss:
|
||||
self.save_checkpoint(os.path.join(self.checkpoint_path,"weights"))
|
||||
self.curr_valid_loss = validation_results["loss"]
|
||||
self.valid_metrics.reset()
|
||||
self.train_metrics.reset()
|
||||
|
||||
def test_loop(self, task, train, optimizer=None): # no optimizer needed for GP
|
||||
(x_support, y_support),(x_query,y_query) = self.get_support_and_queries(task,train)
|
||||
z_support = self.feature_extractor(x_support).detach()
|
||||
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
z_query = self.feature_extractor(x_query).detach()
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
loss = -self.mll(pred, y_query)
|
||||
lower, upper = pred.confidence_region() #2 standard deviations above and below the mean
|
||||
|
||||
mse = self.mse(pred.mean, y_query)
|
||||
|
||||
return mse,loss
|
||||
|
||||
def get_batch(self,task):
|
||||
# we want to fit the gp given context info to new observations
|
||||
# task is an algorithm/dataset pair
|
||||
space,task = task
|
||||
Lambda,response = np.array(self.train_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(self.train_data[space][task]["y"])).reshape(-1,)
|
||||
|
||||
card, dim = Lambda.shape
|
||||
|
||||
support = RandomSupportGenerator.choice(np.arange(card),
|
||||
replace=False,size=self.fixed_context_size)
|
||||
remaining = np.setdiff1d(np.arange(card),support)
|
||||
indexes = RandomQueryGenerator.choice(
|
||||
remaining,replace=False,size=self.minibatch_size if len(remaining)>self.minibatch_size else len(remaining))
|
||||
|
||||
inputs,labels = prepare_data(support,indexes,Lambda,response,np.zeros(32))
|
||||
inputs,labels = totorch(inputs,device=self.device), totorch(labels.reshape(-1,),device=self.device)
|
||||
return inputs, labels
|
||||
|
||||
def get_support_and_queries(self,task, train=False):
|
||||
|
||||
# task is an algorithm/dataset pair
|
||||
space,task = task
|
||||
|
||||
hpo_data = self.valid_data if not train else self.train_data
|
||||
Lambda,response = np.array(hpo_data[space][task]["X"]), MinMaxScaler().fit_transform(np.array(hpo_data[space][task]["y"])).reshape(-1,)
|
||||
card, dim = Lambda.shape
|
||||
|
||||
support = RandomSupportGenerator.choice(np.arange(card),
|
||||
replace=False,size=self.fixed_context_size)
|
||||
indexes = RandomQueryGenerator.choice(
|
||||
np.setdiff1d(np.arange(card),support),replace=False,size=self.minibatch_size)
|
||||
|
||||
support_x,support_y = prepare_data(support,support,Lambda,response,np.zeros(32))
|
||||
query_x,query_y = prepare_data(support,indexes,Lambda,response,np.zeros(32))
|
||||
|
||||
return (totorch(support_x,self.device),totorch(support_y.reshape(-1,),self.device)),\
|
||||
(totorch(query_x,self.device),totorch(query_y.reshape(-1,),self.device))
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
# save state
|
||||
gp_state_dict = self.model.state_dict()
|
||||
likelihood_state_dict = self.likelihood.state_dict()
|
||||
nn_state_dict = self.feature_extractor.state_dict()
|
||||
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net':nn_state_dict}, checkpoint)
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
ckpt = torch.load(checkpoint)
|
||||
self.model.load_state_dict(ckpt['gp'])
|
||||
self.likelihood.load_state_dict(ckpt['likelihood'])
|
||||
self.feature_extractor.load_state_dict(ckpt['net'])
|
||||
|
||||
class ExactGPLayer(gpytorch.models.ExactGP):
|
||||
def __init__(self, train_x, train_y, likelihood,config,dims ):
|
||||
super(ExactGPLayer, self).__init__(train_x, train_y, likelihood)
|
||||
self.mean_module = gpytorch.means.ConstantMean()
|
||||
|
||||
## RBF kernel
|
||||
if(config["kernel"]=='rbf' or config["kernel"]=='RBF'):
|
||||
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(ard_num_dims=dims if config["ard"] else None))
|
||||
elif(config["kernel"]=='52'):
|
||||
self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.MaternKernel(nu=config["nu"],ard_num_dims=dims if config["ard"] else None))
|
||||
## Spectral kernel
|
||||
else:
|
||||
raise ValueError("[ERROR] the kernel '" + str(config["kernel"]) + "' is not supported for regression, use 'rbf' or 'spectral'.")
|
||||
|
||||
def forward(self, x):
|
||||
mean_x = self.mean_module(x)
|
||||
covar_x = self.covar_module(x)
|
||||
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
|
||||
|
||||
|
||||
class batch_mlp(nn.Module):
|
||||
def __init__(self, d_in, output_sizes, nonlinearity="relu",dropout=0.0):
|
||||
|
||||
super(batch_mlp, self).__init__()
|
||||
assert(nonlinearity=="relu")
|
||||
self.nonlinearity = nn.ReLU()
|
||||
|
||||
self.fc = nn.ModuleList([nn.Linear(in_features=d_in, out_features=output_sizes[0])])
|
||||
for d_out in output_sizes[1:]:
|
||||
self.fc.append(nn.Linear(in_features=self.fc[-1].out_features, out_features=d_out))
|
||||
self.out_features = output_sizes[-1]
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
def forward(self,x):
|
||||
|
||||
for fc in self.fc[:-1]:
|
||||
x = fc(x)
|
||||
x = self.dropout(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.fc[-1](x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class StandardDeepGP(nn.Module):
|
||||
def __init__(self, configuration):
|
||||
|
||||
super(StandardDeepGP, self).__init__()
|
||||
self.A = batch_mlp(configuration["dim"], configuration["output_size_A"],dropout=configuration["dropout"])
|
||||
self.out_features = configuration["output_size_A"][-1]
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
# e,r,x,z = x
|
||||
hidden = self.A(x.squeeze(dim=-1)) ### NxA
|
||||
return hidden
|
||||
|
||||
|
||||
class DKTNAS(nn.Module):
|
||||
def __init__(self, kernel, backbone_fn, config, pretrained_encoder=True, GP_only=False):
|
||||
super(DKTNAS, self).__init__()
|
||||
## GP parameters
|
||||
|
||||
self.fixed_context_size = config["fixed_context_size"]
|
||||
self.minibatch_size = config["minibatch_size"]
|
||||
self.n_inner_steps = config["n_inner_steps"]
|
||||
self.set_encoder_args = get_parser()
|
||||
if not os.path.exists(self.set_encoder_args.save_path):
|
||||
os.makedirs(self.set_encoder_args.save_path)
|
||||
self.set_encoder_args.model_path = os.path.join(self.set_encoder_args.save_path,
|
||||
self.set_encoder_args.model_name, 'model')
|
||||
if not os.path.exists(self.set_encoder_args.model_path):
|
||||
os.makedirs(self.set_encoder_args.model_path)
|
||||
self.set_encoder = Generator(self.set_encoder_args)
|
||||
if pretrained_encoder:
|
||||
self.dataset_enc, self.arch, self.acc = self.set_encoder.train_dgp(encode=False)
|
||||
self.dataset_enc_val, self.acc_val = self.set_encoder.test_dgp(data_name='cifar100', encode=False)
|
||||
else: # In case we want to train the set-encoder from scratch
|
||||
self.dataset_enc = np.load("train_data_path.npy")
|
||||
self.acc = np.load("train_acc.npy")
|
||||
self.dataset_enc_val = np.load("cifar100_data_path.npy")
|
||||
self.acc_val = np.load("cifar100_acc.npy")
|
||||
self.valid_data = self.dataset_enc_val
|
||||
self.checkpoint_path = config["checkpoint_path"]
|
||||
os.makedirs(self.checkpoint_path, exist_ok=False)
|
||||
json.dump(config, open(os.path.join(self.checkpoint_path, "configuration.json"), "w"))
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.basicConfig(filename=os.path.join(self.checkpoint_path, "log.txt"), level=logging.DEBUG)
|
||||
self.feature_extractor = backbone_fn().to(self.device)
|
||||
self.config = config
|
||||
self.GP_only = GP_only
|
||||
self.get_model_likelihood_mll(self.fixed_context_size, kernel, backbone_fn)
|
||||
self.mse = nn.MSELoss()
|
||||
self.curr_valid_loss = np.inf
|
||||
# self.get_tasks()
|
||||
self.setup_writers()
|
||||
|
||||
self.train_metrics = Metric()
|
||||
self.valid_metrics = Metric(prefix="valid: ")
|
||||
self.tasks = len(self.dataset_enc)
|
||||
|
||||
print(self)
|
||||
|
||||
def setup_writers(self, ):
|
||||
train_log_dir = os.path.join(self.checkpoint_path, "train")
|
||||
os.makedirs(train_log_dir, exist_ok=True)
|
||||
# self.train_summary_writer = SummaryWriter(train_log_dir)
|
||||
|
||||
valid_log_dir = os.path.join(self.checkpoint_path, "valid")
|
||||
os.makedirs(valid_log_dir, exist_ok=True)
|
||||
# self.valid_summary_writer = SummaryWriter(valid_log_dir)
|
||||
|
||||
|
||||
def get_model_likelihood_mll(self, train_size, kernel, backbone_fn):
|
||||
if not self.GP_only:
|
||||
train_x = torch.ones(train_size, self.feature_extractor.out_features).to(self.device)
|
||||
train_y = torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
|
||||
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
|
||||
dims=self.feature_extractor.out_features)
|
||||
else:
|
||||
train_x = torch.ones(train_size, self.fixed_context_size).to(self.device)
|
||||
train_y = torch.ones(train_size).to(self.device)
|
||||
|
||||
likelihood = gpytorch.likelihoods.GaussianLikelihood()
|
||||
|
||||
model = ExactGPLayer(train_x=None, train_y=None, likelihood=likelihood, config=self.config,
|
||||
dims=self.fixed_context_size)
|
||||
self.model = model.to(self.device)
|
||||
self.likelihood = likelihood.to(self.device)
|
||||
self.mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model).to(self.device)
|
||||
|
||||
def set_forward(self, x, is_feature=False):
|
||||
pass
|
||||
|
||||
def set_forward_loss(self, x):
|
||||
pass
|
||||
|
||||
def epoch_end(self):
|
||||
RandomTaskGenerator.shuffle([1])
|
||||
|
||||
def train_loop(self, epoch, optimizer, scheduler_fn=None):
|
||||
if scheduler_fn:
|
||||
scheduler = scheduler_fn(optimizer, 1)
|
||||
self.epoch_end()
|
||||
assert (self.training)
|
||||
for task in range(self.tasks):
|
||||
inputs, labels = self.get_batch(task)
|
||||
for _ in range(self.n_inner_steps):
|
||||
optimizer.zero_grad()
|
||||
z = self.feature_extractor(inputs)
|
||||
self.model.set_train_data(inputs=z, targets=labels, strict=False)
|
||||
predictions = self.model(z)
|
||||
loss = -self.mll(predictions, self.model.train_targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
mse = self.mse(predictions.mean, labels)
|
||||
self.train_metrics.update(loss, self.model.likelihood.noise, mse)
|
||||
if scheduler_fn:
|
||||
scheduler.step()
|
||||
|
||||
training_results = self.train_metrics.get()
|
||||
for k, v in training_results.items():
|
||||
self.train_summary_writer.add_scalar(k, v, epoch)
|
||||
mse, loss = self.test_loop(train=False)
|
||||
self.valid_metrics.update(loss, np.array(0), mse, )
|
||||
|
||||
logging.info(self.train_metrics.report() + " " + self.valid_metrics.report())
|
||||
validation_results = self.valid_metrics.get()
|
||||
for k, v in validation_results.items():
|
||||
self.valid_summary_writer.add_scalar(k, v, epoch)
|
||||
self.feature_extractor.train()
|
||||
self.likelihood.train()
|
||||
self.model.train()
|
||||
|
||||
if validation_results["loss"] < self.curr_valid_loss:
|
||||
self.save_checkpoint(os.path.join(self.checkpoint_path, "weights"))
|
||||
self.curr_valid_loss = validation_results["loss"]
|
||||
self.valid_metrics.reset()
|
||||
self.train_metrics.reset()
|
||||
|
||||
def test_loop(self, train=None, optimizer=None): # no optimizer needed for GP
|
||||
(x_support, y_support), (x_query, y_query) = self.get_support_and_queries(train)
|
||||
z_support = self.feature_extractor(x_support).detach()
|
||||
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
z_query = self.feature_extractor(x_query).detach()
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
loss = -self.mll(pred, y_query)
|
||||
lower, upper = pred.confidence_region() # 2 standard deviations above and below the mean
|
||||
|
||||
mse = self.mse(pred.mean, y_query)
|
||||
|
||||
return mse, loss
|
||||
|
||||
def get_batch(self, task, valid=False):
|
||||
|
||||
# we want to fit the gp given context info to new observations
|
||||
#TODO: scale the response as in FSBO(needed for train)
|
||||
Lambda, response = np.array(self.dataset_enc), np.array(self.acc)
|
||||
|
||||
inputs, labels = Lambda[task], response[task]
|
||||
inputs, labels = totorch([inputs], device=self.device), totorch([labels], device=self.device)
|
||||
return inputs, labels
|
||||
|
||||
def get_support_and_queries(self, task, train=False):
|
||||
|
||||
# TODO: scale the response as in FSBO(not necessary for test)
|
||||
Lambda, response = np.array(self.dataset_enc_val), np.array(self.acc_val)
|
||||
card, dim = Lambda.shape
|
||||
|
||||
support = RandomSupportGenerator.choice(np.arange(card),
|
||||
replace=False, size=self.fixed_context_size)
|
||||
indexes = RandomQueryGenerator.choice(
|
||||
np.setdiff1d(np.arange(card), support), replace=False, size=self.minibatch_size)
|
||||
|
||||
support_x, support_y = Lambda[support], response[support]
|
||||
query_x, query_y = Lambda[indexes], response[indexes]
|
||||
|
||||
return (totorch(support_x, self.device), totorch(support_y.reshape(-1, ), self.device)), \
|
||||
(totorch(query_x, self.device), totorch(query_y.reshape(-1, ), self.device))
|
||||
|
||||
def save_checkpoint(self, checkpoint):
|
||||
# save state
|
||||
gp_state_dict = self.model.state_dict()
|
||||
likelihood_state_dict = self.likelihood.state_dict()
|
||||
nn_state_dict = self.feature_extractor.state_dict()
|
||||
torch.save({'gp': gp_state_dict, 'likelihood': likelihood_state_dict, 'net': nn_state_dict}, checkpoint)
|
||||
|
||||
def load_checkpoint(self, checkpoint):
|
||||
ckpt = torch.load(checkpoint)
|
||||
self.model.load_state_dict(ckpt['gp'])
|
||||
self.likelihood.load_state_dict(ckpt['likelihood'])
|
||||
self.feature_extractor.load_state_dict(ckpt['net'])
|
||||
|
||||
def predict(self, x_support, y_support, x_query, y_query, GP_only=False):
|
||||
if not GP_only:
|
||||
z_support = self.feature_extractor(x_support).detach()
|
||||
else:
|
||||
z_support = x_support
|
||||
self.model.set_train_data(inputs=z_support, targets=y_support, strict=False)
|
||||
self.model.eval()
|
||||
self.feature_extractor.eval()
|
||||
self.likelihood.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if not GP_only:
|
||||
z_query = self.feature_extractor(x_query).detach()
|
||||
else:
|
||||
z_query = x_query
|
||||
pred = self.likelihood(self.model(z_query))
|
||||
mu = pred.mean.detach().to("cpu").numpy().reshape(-1, )
|
||||
stddev = pred.stddev.detach().to("cpu").numpy().reshape(-1, )
|
||||
return mu, stddev
|
||||
@@ -0,0 +1,168 @@
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets
|
||||
This code is for MobileNetV3 Search Space experiments
|
||||
|
||||
|
||||
## Prerequisites
|
||||
- Python 3.6 (Anaconda)
|
||||
- PyTorch 1.6.0
|
||||
- CUDA 10.2
|
||||
- python-igraph==0.8.2
|
||||
- tqdm==4.50.2
|
||||
- torchvision==0.7.0
|
||||
- python-igraph==0.8.2
|
||||
- scipy==1.5.2
|
||||
- ofa==0.0.4-2007200808
|
||||
|
||||
|
||||
## MobileNetV3 Search Space
|
||||
Go to the folder for MobileNetV3 experiments (i.e. ```MetaD2A_mobilenetV3```)
|
||||
|
||||
The overall flow is summarized as follows:
|
||||
- Building database for Predictor
|
||||
- Meta-Training Predictor
|
||||
- Building database for Generator with trained Predictor
|
||||
- Meta-Training Generator
|
||||
- Meta-Testing (Searching)
|
||||
- Evaluating the Searched architecture
|
||||
|
||||
|
||||
## Data Preparation
|
||||
To download preprocessed data files, run ```get_files/get_preprocessed_data.py```:
|
||||
```shell script
|
||||
$ python get_files/get_preprocessed_data.py
|
||||
```
|
||||
It will take some time to download and preprocess each dataset.
|
||||
|
||||
|
||||
## Meta Test and Evaluation
|
||||
### Meta-Test
|
||||
|
||||
You can download trained checkpoint files for generator and predictor
|
||||
```shell script
|
||||
$ python get_files/get_generator_checkpoint.py
|
||||
$ python get_files/get_predictor_checkpoint.py
|
||||
```
|
||||
|
||||
If you want to meta-test with your own dataset, please first make your own preprocessed data,
|
||||
by modifying ```process_dataset.py``` .
|
||||
```shell script
|
||||
$ process_dataset.py
|
||||
```
|
||||
|
||||
This code automatically generates neural architecturess and then
|
||||
selects high-performing architectures among the candidates.
|
||||
By setting ```--data-name``` as the name of dataset (i.e. ```cifar10```, ```cifar100```, ```aircraft100```, ```pets```),
|
||||
you can evaluate the specific dataset.
|
||||
|
||||
```shell script
|
||||
# Meta-testing
|
||||
$ python main.py --gpu 0 --model generator --hs 56 --nz 56 --test --load-epoch 120 --num-gen-arch 200 --data-name {DATASET_NAME}
|
||||
```
|
||||
|
||||
### Arhictecture Evaluation (MetaD2A vs NSGANetV2)
|
||||
##### Dataset Preparation
|
||||
You need to download Oxford-IIIT Pet dataset to evaluate on ```--data-name pets```
|
||||
```shell script
|
||||
$ python get_files/get_pets.py
|
||||
```
|
||||
Every others ```cifar10```, ```cifar100```, ```aircraft100``` will be downloaded automatically.
|
||||
|
||||
##### evaluation
|
||||
You can run the searched architecture by running ```evaluation/main```. Codes are based on NSGANetV2.
|
||||
|
||||
Go to the evaluation folder (i.e. ```evaluation```)
|
||||
```shell script
|
||||
$ cd evaluation
|
||||
```
|
||||
|
||||
This automatically run the top 1 predicted architecture derived by MetaD2A.
|
||||
```shell script
|
||||
python main.py --data-name cifar10 --num-gen-arch 200
|
||||
```
|
||||
You can also give flop constraint by using ```bound``` option.
|
||||
```shell script
|
||||
python main.py --data-name cifar10 --num-gen-arch 200 --bound 300
|
||||
```
|
||||
|
||||
You can compare MetaD2A with NSGANetV2
|
||||
but you need to download some files provided
|
||||
by [NSGANetV2](https://github.com/human-analysis/nsganetv2)
|
||||
|
||||
```shell script
|
||||
python main.py --data-name cifar10 --num-gen-arch 200 --model-config flops@232
|
||||
```
|
||||
|
||||
|
||||
## Meta-Training MetaD2A Model
|
||||
To build database for Meta-training, you need to set ```IMGNET_PATH```, which is a directory of ILSVRC2021.
|
||||
|
||||
### Database Building for Predictor
|
||||
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
|
||||
You need to set ```IMGNET_PATH``` in the shell script.
|
||||
```shell script
|
||||
# Examples
|
||||
bash create_database.sh 0,1,2,3 0 49 predictor
|
||||
bash create_database.sh all 50 99 predictor
|
||||
...
|
||||
```
|
||||
After enough dataset is gathered, run ```build_database.py``` to collect them as one file.
|
||||
```shell script
|
||||
python build_database.py --model_name predictor --collect
|
||||
```
|
||||
|
||||
We also provide the database we use. To download database, run ```get_files/get_predictor_database.py```:
|
||||
```shell script
|
||||
$ python get_files/get_predictor_database.py
|
||||
```
|
||||
|
||||
### Meta-Train Predictor
|
||||
You can train the predictor as follows
|
||||
```shell script
|
||||
# Meta-training for predictor
|
||||
$ python main.py --gpu 0 --model predictor --hs 512 --nz 56
|
||||
```
|
||||
### Database Building for Generator
|
||||
We recommend you to run the multiple ```create_database.sh``` simultaneously to build fast.
|
||||
```shell script
|
||||
# Examples
|
||||
bash create_database.sh 4,5,6,7 0 49 generator
|
||||
bash create_database.sh all 50 99 generator
|
||||
...
|
||||
```
|
||||
After enough dataset is gathered, run ```build_database.py``` to collect them as one.
|
||||
```shell script
|
||||
python build_database.py --model_name generator --collect
|
||||
```
|
||||
|
||||
We also provide the database we use. To download database, run ```get_files/get_generator_database.py```
|
||||
```shell script
|
||||
$ python get_files/get_generator_database.py
|
||||
```
|
||||
|
||||
|
||||
### Meta-Train Generator
|
||||
You can train the generator as follows
|
||||
```shell script
|
||||
# Meta-training for generator
|
||||
$ python main.py --gpu 0 --model generator --hs 56 --nz 56
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Citation
|
||||
If you found the provided code useful, please cite our work.
|
||||
```
|
||||
@inproceedings{
|
||||
lee2021rapid,
|
||||
title={Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets},
|
||||
author={Hayeon Lee and Eunyoung Hyung and Sung Ju Hwang},
|
||||
booktitle={ICLR},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Reference
|
||||
- [Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks (ICML2019)](https://github.com/juho-lee/set_transformer)
|
||||
- [D-VAE: A Variational Autoencoder for Directed Acyclic Graphs, Advances in Neural Information Processing Systems (NeurIPS2019)](https://github.com/muhanzhang/D-VAE)
|
||||
- [Once for All: Train One Network and Specialize it for Efficient Deployment (ICLR2020)](https://github.com/mit-han-lab/once-for-all)
|
||||
- [NSGANetV2: Evolutionary Multi-Objective Surrogate-Assisted Neural Architecture Search (ECCV2020)](https://github.com/human-analysis/nsganetv2)
|
||||
@@ -0,0 +1,49 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from parser import get_parser
|
||||
from predictor import PredictorModel
|
||||
from database import DatabaseOFA
|
||||
from utils import load_graph_config
|
||||
|
||||
def main():
|
||||
args = get_parser()
|
||||
|
||||
if args.gpu == 'all':
|
||||
device_list = range(torch.cuda.device_count())
|
||||
args.gpu = ','.join(str(_) for _ in device_list)
|
||||
else:
|
||||
device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
||||
args.device = torch.device("cuda:0")
|
||||
args.batch_size = args.batch_size * max(len(device_list), 1)
|
||||
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
|
||||
|
||||
if args.model_name == 'generator':
|
||||
graph_config = load_graph_config(
|
||||
args.graph_data_name, args.nvt, args.data_path)
|
||||
model = PredictorModel(args, graph_config)
|
||||
d = DatabaseOFA(args, model)
|
||||
else:
|
||||
d = DatabaseOFA(args)
|
||||
|
||||
if args.collect:
|
||||
d.collect_db()
|
||||
else:
|
||||
assert args.index is not None
|
||||
assert args.imgnet is not None
|
||||
d.make_db()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,15 @@
|
||||
#bash create_database.sh all predictor 0 49
|
||||
|
||||
IMGNET_PATH='/w14/dataset/ILSVRC2012' # PUT YOUR ILSVRC2012 DIR
|
||||
|
||||
for ((ind=$2;ind<=$3;ind++))
|
||||
do
|
||||
python build_database.py --gpu $1 \
|
||||
--model_name $4 \
|
||||
--index $ind \
|
||||
--imgnet $IMGNET_PATH \
|
||||
--hs 512 \
|
||||
--nz 56
|
||||
done
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .db_ofa import DatabaseOFA
|
||||
@@ -0,0 +1,57 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
__all__ = ['DataProvider']
|
||||
|
||||
|
||||
class DataProvider:
|
||||
SUB_SEED = 937162211 # random seed for sampling subset
|
||||
VALID_SEED = 2147483647 # random seed for the validation set
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
""" Return name of the dataset """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
""" Return shape as python list of one data entry """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
""" Return `int` of num classes """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
""" local path to save the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
""" link to download the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def random_sample_valid_set(train_size, valid_size):
|
||||
assert train_size > valid_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
|
||||
rand_indexes = torch.randperm(train_size, generator=g).tolist()
|
||||
|
||||
valid_indexes = rand_indexes[:valid_size]
|
||||
train_indexes = rand_indexes[valid_size:]
|
||||
return train_indexes, valid_indexes
|
||||
|
||||
@staticmethod
|
||||
def labels_to_one_hot(n_classes, labels):
|
||||
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
|
||||
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
|
||||
return new_labels
|
||||
@@ -0,0 +1,107 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import copy
|
||||
import glob
|
||||
from .imagenet import ImagenetDataProvider
|
||||
from .imagenet_loader import ImagenetRunConfig
|
||||
from .run_manager import RunManager
|
||||
from ofa.model_zoo import ofa_net
|
||||
|
||||
|
||||
class DatabaseOFA:
|
||||
def __init__(self, args, predictor=None):
|
||||
self.path = f'{args.data_path}/{args.model_name}'
|
||||
self.model_name = args.model_name
|
||||
self.index = args.index
|
||||
self.args = args
|
||||
self.predictor = predictor
|
||||
ImagenetDataProvider.DEFAULT_PATH = args.imgnet
|
||||
|
||||
if not os.path.exists(self.path):
|
||||
os.makedirs(self.path)
|
||||
|
||||
def make_db(self):
|
||||
self.ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)
|
||||
self.run_config = ImagenetRunConfig(test_batch_size=self.args.batch_size,
|
||||
n_worker=20)
|
||||
database = []
|
||||
st_time = time.time()
|
||||
f = open(f'{self.path}/txt_{self.index}.txt', 'w')
|
||||
for dn in range(10000):
|
||||
best_pp = -1
|
||||
best_info = None
|
||||
dls = None
|
||||
with torch.no_grad():
|
||||
if self.model_name == 'generator':
|
||||
for i in range(10):
|
||||
net_setting = self.ofa_network.sample_active_subnet()
|
||||
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
|
||||
if i == 0:
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
|
||||
self.run_config, init=False, pp=self.predictor)
|
||||
self.run_config.data_provider.assign_active_img_size(224)
|
||||
dls = {j: copy.deepcopy(run_manager.data_loader) for j in range(1, 10)}
|
||||
else:
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet,
|
||||
self.run_config,
|
||||
init=False, data_loader=dls[i], pp=self.predictor)
|
||||
run_manager.reset_running_statistics(net=subnet)
|
||||
|
||||
loss, (top1, top5), pred_acc \
|
||||
= run_manager.validate(net=subnet, net_setting=net_setting)
|
||||
|
||||
if best_pp < pred_acc:
|
||||
best_pp = pred_acc
|
||||
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
|
||||
dn, len(run_manager.cls_lst), loss, top1, top5))
|
||||
info_dict = {'loss': loss,
|
||||
'top1': top1,
|
||||
'top5': top5,
|
||||
'net': net_setting,
|
||||
'class': run_manager.cls_lst,
|
||||
'params': run_manager.net_info['params'],
|
||||
'flops': run_manager.net_info['flops'],
|
||||
'test_transform': run_manager.test_transform
|
||||
}
|
||||
best_info = info_dict
|
||||
elif self.model_name == 'predictor':
|
||||
net_setting = self.ofa_network.sample_active_subnet()
|
||||
subnet = self.ofa_network.get_active_subnet(preserve_weight=True)
|
||||
run_manager = RunManager('.tmp/eval_subnet', self.args, subnet, self.run_config, init=False)
|
||||
self.run_config.data_provider.assign_active_img_size(224)
|
||||
run_manager.reset_running_statistics(net=subnet)
|
||||
|
||||
loss, (top1, top5), _ = run_manager.validate(net=subnet)
|
||||
print('[%d] class=%d,\t loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (
|
||||
dn, len(run_manager.cls_lst), loss, top1, top5))
|
||||
best_info = {'loss': loss,
|
||||
'top1': top1,
|
||||
'top5': top5,
|
||||
'net': net_setting,
|
||||
'class': run_manager.cls_lst,
|
||||
'params': run_manager.net_info['params'],
|
||||
'flops': run_manager.net_info['flops'],
|
||||
'test_transform': run_manager.test_transform
|
||||
}
|
||||
database.append(best_info)
|
||||
if (len(database)) % 10 == 0:
|
||||
msg = f'{(time.time() - st_time) / 60.0:0.2f}(min) save {len(database)} database, {self.index} id'
|
||||
print(msg)
|
||||
f.write(msg + '\n')
|
||||
f.flush()
|
||||
torch.save(database, f'{self.path}/database_{self.index}.pt')
|
||||
|
||||
def collect_db(self):
|
||||
if not os.path.exists(self.path + f'/processed'):
|
||||
os.makedirs(self.path + f'/processed')
|
||||
|
||||
database = []
|
||||
dlst = glob.glob(self.path + '/*.pt')
|
||||
for filepath in dlst:
|
||||
database += torch.load(filepath)
|
||||
|
||||
assert len(database) != 0
|
||||
|
||||
print(f'The number of database: {len(database)}')
|
||||
torch.save(database, self.path + f'/processed/collected_database.pt')
|
||||
@@ -0,0 +1,240 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import warnings
|
||||
import os
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa_local.imagenet_classification.data_providers.base_provider import DataProvider
|
||||
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
||||
from .metaloader import MetaImageNetDataset, EpisodeSampler, MetaDataLoader
|
||||
|
||||
|
||||
__all__ = ['ImagenetDataProvider']
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
DEFAULT_PATH = '/dataset/imagenet'
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = 'None' if distort_color is None else distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
from ofa.utils.my_dataloader import MyDataLoader
|
||||
assert isinstance(self.image_size, list)
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size) # active resolution for test
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
|
||||
########################## modification ########################
|
||||
train_dataset = self.train_dataset(self.build_train_transform())
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
# test_dataset = self.test_dataset(valid_transforms)
|
||||
test_dataset = self.meta_test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
# self.test = torch.utils.data.DataLoader(
|
||||
# test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
# )
|
||||
sampler = EpisodeSampler(
|
||||
max_way=1000, query=10, ylst=test_dataset.ylst)
|
||||
self.test = MetaDataLoader(dataset=test_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=test_batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = self.DEFAULT_PATH
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.train_path, _transforms)
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.valid_path, _transforms)
|
||||
|
||||
def meta_test_dataset(self, _transforms):
|
||||
return MetaImageNetDataset('val', max_way=1000, query=10,
|
||||
dpath='/w14/dataset/ILSVRC2012', transform=_transforms)
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
# random_resize_crop -> random_horizontal_flip
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
# color augmentation (optional)
|
||||
color_transform = None
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting BN running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,40 @@
|
||||
from .imagenet import ImagenetDataProvider
|
||||
from ofa_local.imagenet_classification.run_manager import RunConfig
|
||||
|
||||
|
||||
__all__ = ['ImagenetRunConfig']
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
@@ -0,0 +1,210 @@
|
||||
from torch.utils.data.sampler import Sampler
|
||||
import os
|
||||
import random
|
||||
from PIL import Image
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import glob
|
||||
|
||||
|
||||
class RandCycleIter:
|
||||
'''
|
||||
Return data_list per class
|
||||
Shuffle the returning order after one epoch
|
||||
'''
|
||||
def __init__ (self, data, shuffle=True):
|
||||
self.data_list = list(data)
|
||||
self.length = len(self.data_list)
|
||||
self.i = self.length - 1
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__ (self):
|
||||
return self
|
||||
|
||||
def __next__ (self):
|
||||
self.i += 1
|
||||
|
||||
if self.i == self.length:
|
||||
self.i = 0
|
||||
if self.shuffle:
|
||||
random.shuffle(self.data_list)
|
||||
|
||||
return self.data_list[self.i]
|
||||
|
||||
|
||||
class EpisodeSampler(Sampler):
|
||||
def __init__(self, max_way, query, ylst):
|
||||
self.max_way = max_way
|
||||
self.query = query
|
||||
self.ylst = ylst
|
||||
# self.n_epi = n_epi
|
||||
|
||||
clswise_xidx = defaultdict(list)
|
||||
for i, y in enumerate(ylst):
|
||||
clswise_xidx[y].append(i)
|
||||
self.cws_xidx_iter = [RandCycleIter(cxidx, shuffle=True)
|
||||
for cxidx in clswise_xidx.values()]
|
||||
self.n_cls = len(clswise_xidx)
|
||||
|
||||
self.create_episode()
|
||||
|
||||
|
||||
def __iter__ (self):
|
||||
return self.get_index()
|
||||
|
||||
def __len__ (self):
|
||||
return self.get_len()
|
||||
|
||||
def create_episode(self):
|
||||
self.way = torch.randperm(int(self.max_way/10.0)-1)[0] * 10 + 10
|
||||
cls_lst = torch.sort(torch.randperm(self.max_way)[:self.way])[0]
|
||||
self.cls_itr = iter(cls_lst)
|
||||
self.cls_lst = cls_lst
|
||||
|
||||
def get_len(self):
|
||||
return self.way * self.query
|
||||
|
||||
def get_index(self):
|
||||
x_itr = self.cws_xidx_iter
|
||||
|
||||
i, j = 0, 0
|
||||
while i < self.query * self.way:
|
||||
if j >= self.query:
|
||||
j = 0
|
||||
if j == 0:
|
||||
cls_idx = next(self.cls_itr).item()
|
||||
bb = [x_itr[cls_idx]] * self.query
|
||||
didx = next(zip(*bb))
|
||||
yield didx[j]
|
||||
# yield (didx[j], self.way)
|
||||
|
||||
i += 1; j += 1
|
||||
|
||||
|
||||
class MetaImageNetDataset(Dataset):
|
||||
def __init__(self, mode='val',
|
||||
max_way=1000, query=10,
|
||||
dpath='/w14/dataset/ILSVRC2012', transform=None):
|
||||
self.dpath = dpath
|
||||
self.transform = transform
|
||||
self.mode = mode
|
||||
|
||||
self.max_way = max_way
|
||||
self.query = query
|
||||
classes, class_to_idx = self._find_classes(dpath+'/'+mode)
|
||||
self.classes, self.class_to_idx = classes, class_to_idx
|
||||
# self.class_folder_lst = \
|
||||
# glob.glob(dpath+'/'+mode+'/*')
|
||||
# ## sorting alphabetically
|
||||
# self.class_folder_lst = sorted(self.class_folder_lst)
|
||||
self.file_path_lst, self.ylst = [], []
|
||||
for cls in classes:
|
||||
xlst = glob.glob(dpath+'/'+mode+'/'+cls+'/*')
|
||||
self.file_path_lst += xlst[:self.query]
|
||||
y = class_to_idx[cls]
|
||||
self.ylst += [y] * len(xlst[:self.query])
|
||||
|
||||
# for y, cls in enumerate(self.class_folder_lst):
|
||||
# xlst = glob.glob(cls+'/*')
|
||||
# self.file_path_lst += xlst[:self.query]
|
||||
# self.ylst += [y] * len(xlst[:self.query])
|
||||
# # self.file_path_lst += [xlst[_] for _ in
|
||||
# # torch.randperm(len(xlst))[:self.query]]
|
||||
# # self.ylst += [cls.split('/')[-1]] * len(xlst)
|
||||
|
||||
self.way_idx = 0
|
||||
self.x_idx = 0
|
||||
self.way = 2
|
||||
self.cls_lst = None
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.way * self.query
|
||||
|
||||
def __getitem__(self, index):
|
||||
# if self.way != index[1]:
|
||||
# self.way = index[1]
|
||||
# index = index[0]
|
||||
|
||||
x = Image.open(
|
||||
self.file_path_lst[index]).convert('RGB')
|
||||
|
||||
if self.transform is not None:
|
||||
x = self.transform(x)
|
||||
cls_name = self.ylst[index]
|
||||
y = self.cls_lst.index(cls_name)
|
||||
# y = self.way_idx
|
||||
# self.x_idx += 1
|
||||
# if self.x_idx == self.query:
|
||||
# self.way_idx += 1
|
||||
# self.x_idx = 0
|
||||
# if self.way_idx == self.way:
|
||||
# self.way_idx = 0
|
||||
# self.x_idx = 0
|
||||
return x, y #, cls_name # y # cls_name #y
|
||||
|
||||
def _find_classes(self, dir: str):
|
||||
"""
|
||||
Finds the class folders in a dataset.
|
||||
|
||||
Args:
|
||||
dir (string): Root directory path.
|
||||
|
||||
Returns:
|
||||
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
|
||||
|
||||
Ensures:
|
||||
No class is a subdirectory of another.
|
||||
"""
|
||||
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
|
||||
classes.sort()
|
||||
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
|
||||
return classes, class_to_idx
|
||||
|
||||
|
||||
class MetaDataLoader(DataLoader):
|
||||
def __init__(self,
|
||||
dataset, sampler, batch_size, shuffle, num_workers):
|
||||
super(MetaDataLoader, self).__init__(
|
||||
dataset=dataset,
|
||||
sampler=sampler,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
num_workers=num_workers)
|
||||
|
||||
|
||||
def create_episode(self):
|
||||
self.sampler.create_episode()
|
||||
self.dataset.way = self.sampler.way
|
||||
self.dataset.cls_lst = self.sampler.cls_lst.tolist()
|
||||
|
||||
|
||||
def get_cls_idx(self):
|
||||
return self.sampler.cls_lst
|
||||
|
||||
|
||||
def get_loader(mode='val', way=10, query=10,
|
||||
n_epi=100, dpath='/w14/dataset/ILSVRC2012',
|
||||
transform=None):
|
||||
trans = get_transforms(mode)
|
||||
dataset = MetaImageNetDataset(mode, way, query, dpath, trans)
|
||||
sampler = EpisodeSampler(
|
||||
way, query, n_epi, dataset.ylst)
|
||||
dataset.way = sampler.way
|
||||
dataset.cls_lst = sampler.cls_lst
|
||||
loader = MetaDataLoader(dataset=dataset,
|
||||
sampler=sampler,
|
||||
batch_size=10,
|
||||
shuffle=False,
|
||||
num_workers=4)
|
||||
return loader
|
||||
|
||||
# trloader = get_loader()
|
||||
|
||||
# trloader.create_episode()
|
||||
# print(len(trloader))
|
||||
# print(trloader.dataset.way)
|
||||
# print(trloader.sampler.way)
|
||||
# for i, episode in enumerate(trloader, start=1):
|
||||
# print(episode[2])
|
||||
@@ -0,0 +1,302 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import os
|
||||
import json
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim
|
||||
from tqdm import tqdm
|
||||
from utils import decode_ofa_mbv3_to_igraph
|
||||
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
|
||||
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
|
||||
|
||||
__all__ = ['RunManager']
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
class RunManager:
|
||||
|
||||
def __init__(self, path, args, net, run_config, init=True, measure_latency=None,
|
||||
no_gpu=False, data_loader=None, pp=None):
|
||||
self.path = path
|
||||
self.mode = args.model_name
|
||||
self.net = net
|
||||
self.run_config = run_config
|
||||
|
||||
self.best_acc = 0
|
||||
self.start_epoch = 0
|
||||
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
# dataloader
|
||||
if data_loader is not None:
|
||||
self.data_loader = data_loader
|
||||
cls_lst = self.data_loader.get_cls_idx()
|
||||
self.cls_lst = cls_lst
|
||||
else:
|
||||
self.data_loader = self.run_config.valid_loader
|
||||
self.data_loader.create_episode()
|
||||
cls_lst = self.data_loader.get_cls_idx()
|
||||
self.cls_lst = cls_lst
|
||||
|
||||
state_dict = self.net.classifier.state_dict()
|
||||
new_state_dict = {'weight': state_dict['linear.weight'][cls_lst],
|
||||
'bias': state_dict['linear.bias'][cls_lst]}
|
||||
|
||||
self.net.classifier = nn.Linear(1280, len(cls_lst), bias=True)
|
||||
self.net.classifier.load_state_dict(new_state_dict)
|
||||
|
||||
# move network to GPU if available
|
||||
if torch.cuda.is_available() and (not no_gpu):
|
||||
self.device = torch.device('cuda:0')
|
||||
self.net = self.net.to(self.device)
|
||||
cudnn.benchmark = True
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
|
||||
# net info
|
||||
net_info = get_net_info(
|
||||
self.net, self.run_config.data_provider.data_shape, measure_latency, False)
|
||||
self.net_info = net_info
|
||||
self.test_transform = self.run_config.data_provider.test.dataset.transform
|
||||
|
||||
# criterion
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
self.train_criterion = cross_entropy_loss_with_soft_target
|
||||
elif self.run_config.label_smoothing > 0:
|
||||
self.train_criterion = \
|
||||
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
|
||||
else:
|
||||
self.train_criterion = nn.CrossEntropyLoss()
|
||||
self.test_criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
if self.run_config.no_decay_keys:
|
||||
keys = self.run_config.no_decay_keys.split('#')
|
||||
net_params = [
|
||||
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
|
||||
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
|
||||
]
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
net_params = self.network.weight_parameters()
|
||||
except Exception:
|
||||
net_params = []
|
||||
for param in self.network.parameters():
|
||||
if param.requires_grad:
|
||||
net_params.append(param)
|
||||
self.optimizer = self.run_config.build_optimizer(net_params)
|
||||
|
||||
self.net = torch.nn.DataParallel(self.net)
|
||||
|
||||
if self.mode == 'generator':
|
||||
# PP
|
||||
save_dir = f'{args.save_path}/predictor/model/ckpt_max_corr.pt'
|
||||
|
||||
self.acc_predictor = pp.to('cuda')
|
||||
self.acc_predictor.load_state_dict(torch.load(save_dir))
|
||||
self.acc_predictor = torch.nn.DataParallel(self.acc_predictor)
|
||||
model = models.resnet18(pretrained=True).eval()
|
||||
feature_extractor = torch.nn.Sequential(*list(model.children())[:-1]).to(self.device)
|
||||
self.feature_extractor = torch.nn.DataParallel(feature_extractor)
|
||||
|
||||
""" save path and log path """
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self.__dict__.get('_save_path', None) is None:
|
||||
save_path = os.path.join(self.path, 'checkpoint')
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
self.__dict__['_save_path'] = save_path
|
||||
return self.__dict__['_save_path']
|
||||
|
||||
@property
|
||||
def logs_path(self):
|
||||
if self.__dict__.get('_logs_path', None) is None:
|
||||
logs_path = os.path.join(self.path, 'logs')
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
self.__dict__['_logs_path'] = logs_path
|
||||
return self.__dict__['_logs_path']
|
||||
|
||||
@property
|
||||
def network(self):
|
||||
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
|
||||
|
||||
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
||||
|
||||
""" save and load models """
|
||||
|
||||
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
||||
if checkpoint is None:
|
||||
checkpoint = {'state_dict': self.network.state_dict()}
|
||||
|
||||
if model_name is None:
|
||||
model_name = 'checkpoint.pth.tar'
|
||||
|
||||
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_path + '\n')
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
|
||||
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
|
||||
|
||||
def load_model(self, model_fname=None):
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
if model_fname is None and os.path.exists(latest_fname):
|
||||
with open(latest_fname, 'r') as fin:
|
||||
model_fname = fin.readline()
|
||||
if model_fname[-1] == '\n':
|
||||
model_fname = model_fname[:-1]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if model_fname is None or not os.path.exists(model_fname):
|
||||
model_fname = '%s/checkpoint.pth.tar' % self.save_path
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_fname + '\n')
|
||||
print("=> loading checkpoint '{}'".format(model_fname))
|
||||
checkpoint = torch.load(model_fname, map_location='cpu')
|
||||
except Exception:
|
||||
print('fail to load checkpoint from %s' % self.save_path)
|
||||
return {}
|
||||
|
||||
self.network.load_state_dict(checkpoint['state_dict'])
|
||||
if 'epoch' in checkpoint:
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'best_acc' in checkpoint:
|
||||
self.best_acc = checkpoint['best_acc']
|
||||
if 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
print("=> loaded checkpoint '{}'".format(model_fname))
|
||||
return checkpoint
|
||||
|
||||
def save_config(self, extra_run_config=None, extra_net_config=None):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
run_save_path = os.path.join(self.path, 'run.config')
|
||||
if not os.path.isfile(run_save_path):
|
||||
run_config = self.run_config.config
|
||||
if extra_run_config is not None:
|
||||
run_config.update(extra_run_config)
|
||||
json.dump(run_config, open(run_save_path, 'w'), indent=4)
|
||||
print('Run configs dump to %s' % run_save_path)
|
||||
|
||||
try:
|
||||
net_save_path = os.path.join(self.path, 'net.config')
|
||||
net_config = self.network.config
|
||||
if extra_net_config is not None:
|
||||
net_config.update(extra_net_config)
|
||||
json.dump(net_config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
except Exception:
|
||||
print('%s do not support net config' % type(self.network))
|
||||
|
||||
""" metric related """
|
||||
|
||||
def get_metric_dict(self):
|
||||
return {
|
||||
'top1': AverageMeter(),
|
||||
'top5': AverageMeter(),
|
||||
}
|
||||
|
||||
def update_metric(self, metric_dict, output, labels):
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
metric_dict['top1'].update(acc1[0].item(), output.size(0))
|
||||
metric_dict['top5'].update(acc5[0].item(), output.size(0))
|
||||
|
||||
def get_metric_vals(self, metric_dict, return_dict=False):
|
||||
if return_dict:
|
||||
return {
|
||||
key: metric_dict[key].avg for key in metric_dict
|
||||
}
|
||||
else:
|
||||
return [metric_dict[key].avg for key in metric_dict]
|
||||
|
||||
def get_metric_names(self):
|
||||
return 'top1', 'top5'
|
||||
|
||||
""" train and test """
|
||||
def validate(self, epoch=0, is_test=False, run_str='', net=None,
|
||||
data_loader=None, no_logs=False, train_mode=False, net_setting=None):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if not isinstance(net, nn.DataParallel):
|
||||
net = nn.DataParallel(net)
|
||||
|
||||
if data_loader is not None:
|
||||
self.data_loader = data_loader
|
||||
|
||||
if train_mode:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
losses = AverageMeter()
|
||||
metric_dict = self.get_metric_dict()
|
||||
|
||||
features_stack = []
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(self.data_loader),
|
||||
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
|
||||
for i, (images, labels) in enumerate(self.data_loader):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
if self.mode == 'generator':
|
||||
features = self.feature_extractor(images).squeeze()
|
||||
features_stack.append(features)
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = self.test_criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
self.update_metric(metric_dict, output, labels)
|
||||
|
||||
losses.update(loss.item(), images.size(0))
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
|
||||
if self.mode == 'generator':
|
||||
features_stack = torch.cat(features_stack)
|
||||
igraph_g = decode_ofa_mbv3_to_igraph(net_setting)[0]
|
||||
D_mu = self.acc_predictor.module.set_encode(features_stack.unsqueeze(0).to('cuda'))
|
||||
G_mu = self.acc_predictor.module.graph_encode(igraph_g)
|
||||
pred_acc = self.acc_predictor.module.predict(D_mu.unsqueeze(0), G_mu).item()
|
||||
|
||||
return losses.avg, self.get_metric_vals(metric_dict), \
|
||||
pred_acc if self.mode == 'generator' else None
|
||||
|
||||
|
||||
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
||||
if net is None:
|
||||
net = self.network
|
||||
if isinstance(self.run_config.data_provider.image_size, list):
|
||||
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
|
||||
for img_size in self.run_config.data_provider.image_size:
|
||||
img_size_list.append(img_size)
|
||||
self.run_config.data_provider.assign_active_img_size(img_size)
|
||||
self.reset_running_statistics(net=net)
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
loss_list.append(loss)
|
||||
top1_list.append(top1)
|
||||
top5_list.append(top5)
|
||||
return img_size_list, loss_list, top1_list, top5_list
|
||||
else:
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
|
||||
|
||||
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
|
||||
from ofa_local.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
if net is None:
|
||||
net = self.network
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
|
||||
set_running_statistics(net, data_loader)
|
||||
@@ -0,0 +1,4 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
@@ -0,0 +1,401 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
from torchvision.datasets.folder import default_loader
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert(len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images',
|
||||
'%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class FGVCAircraft(torch.utils.data.Dataset):
|
||||
"""`FGVC-Aircraft <http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory path to dataset.
|
||||
class_type (string, optional): The level of FGVC-Aircraft fine-grain classification
|
||||
to label data with (i.e., ``variant``, ``family``, or ``manufacturer``).
|
||||
transform (callable, optional): A function/transform that takes in a PIL image
|
||||
and returns a transformed version. E.g. ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
loader (callable, optional): A function to load an image given its path.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in the root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
|
||||
class_types = ('variant', 'family', 'manufacturer')
|
||||
splits = ('train', 'val', 'trainval', 'test')
|
||||
|
||||
def __init__(self, root, class_type='variant', split='train', transform=None,
|
||||
target_transform=None, loader=default_loader, download=False):
|
||||
if split not in self.splits:
|
||||
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
|
||||
split, ', '.join(self.splits),
|
||||
))
|
||||
if class_type not in self.class_types:
|
||||
raise ValueError('Class type "{}" not found. Valid class types are: {}'.format(
|
||||
class_type, ', '.join(self.class_types),
|
||||
))
|
||||
self.root = os.path.expanduser(root)
|
||||
self.class_type = class_type
|
||||
self.split = split
|
||||
self.classes_file = os.path.join(self.root, 'data',
|
||||
'images_%s_%s.txt' % (self.class_type, self.split))
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
(image_ids, targets, classes, class_to_idx) = find_classes(self.classes_file)
|
||||
samples = make_dataset(self.root, image_ids, targets)
|
||||
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
self.loader = loader
|
||||
|
||||
self.samples = samples
|
||||
self.classes = classes
|
||||
self.class_to_idx = class_to_idx
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (sample, target) where target is class_index of the target class.
|
||||
"""
|
||||
|
||||
path, target = self.samples[index]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return sample, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __repr__(self):
|
||||
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
|
||||
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
|
||||
fmt_str += ' Root Location: {}\n'.format(self.root)
|
||||
tmp = ' Transforms (if any): '
|
||||
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
tmp = ' Target Transforms (if any): '
|
||||
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
|
||||
return fmt_str
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, 'data', 'images')) and \
|
||||
os.path.exists(self.classes_file)
|
||||
|
||||
def download(self):
|
||||
"""Download the FGVC-Aircraft data if it doesn't exist already."""
|
||||
from six.moves import urllib
|
||||
import tarfile
|
||||
|
||||
if self._check_exists():
|
||||
return
|
||||
|
||||
# prepare to download data to PARENT_DIR/fgvc-aircraft-2013.tar.gz
|
||||
print('Downloading %s ... (may take a few minutes)' % self.url)
|
||||
|
||||
parent_dir = os.path.abspath(os.path.join(self.root, os.pardir))
|
||||
tar_name = self.url.rpartition('/')[-1]
|
||||
tar_path = os.path.join(parent_dir, tar_name)
|
||||
data = urllib.request.urlopen(self.url)
|
||||
|
||||
# download .tar.gz file
|
||||
with open(tar_path, 'wb') as f:
|
||||
f.write(data.read())
|
||||
|
||||
# extract .tar.gz to PARENT_DIR/fgvc-aircraft-2013b
|
||||
data_folder = tar_path.strip('.tar.gz')
|
||||
print('Extracting %s to %s ... (may take a few minutes)' % (tar_path, data_folder))
|
||||
tar = tarfile.open(tar_path)
|
||||
tar.extractall(parent_dir)
|
||||
|
||||
# if necessary, rename data folder to self.root
|
||||
if not os.path.samefile(data_folder, self.root):
|
||||
print('Renaming %s to %s ...' % (data_folder, self.root))
|
||||
os.rename(data_folder, self.root)
|
||||
|
||||
# delete .tar.gz file
|
||||
print('Deleting %s ...' % tar_path)
|
||||
os.remove(tar_path)
|
||||
|
||||
print('Done!')
|
||||
|
||||
|
||||
class FGVCAircraftDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'aircraft'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/Aircraft' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/Aircraft' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = FGVCAircraft(
|
||||
root=self.train_path, split='trainval', download=True, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = FGVCAircraft(
|
||||
root=self.valid_path, split='test', download=True, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.48933587508932375, 0.5183537408957618, 0.5387914411673883],
|
||||
std=[0.22388883112804625, 0.21641635409388751, 0.24615605842636115])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.48933587508932375, 0.5183537408957618,
|
||||
0.5387914411673883]]),
|
||||
)
|
||||
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = FGVCAircraft(root='/mnt/datastore/Aircraft',
|
||||
split='trainval', download=True)
|
||||
print(len(data.classes))
|
||||
print(len(data.samples))
|
||||
@@ -0,0 +1,238 @@
|
||||
"""
|
||||
Taken from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
|
||||
"""
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class ImageNetPolicy(object):
|
||||
""" Randomly choose one of the best 24 Sub-policies on ImageNet.
|
||||
|
||||
Example:
|
||||
>>> policy = ImageNetPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> ImageNetPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class CIFAR10Policy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on CIFAR10.
|
||||
|
||||
Example:
|
||||
>>> policy = CIFAR10Policy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> CIFAR10Policy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class SVHNPolicy(object):
|
||||
""" Randomly choose one of the best 25 Sub-policies on SVHN.
|
||||
|
||||
Example:
|
||||
>>> policy = SVHNPolicy()
|
||||
>>> transformed = policy(image)
|
||||
|
||||
Example as a PyTorch Transform:
|
||||
>>> transform=transforms.Compose([
|
||||
>>> transforms.Resize(256),
|
||||
>>> SVHNPolicy(),
|
||||
>>> transforms.ToTensor()])
|
||||
"""
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||
]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
# from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1: img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2: img = self.operation2(img, self.magnitude2)
|
||||
return img
|
||||
@@ -0,0 +1,657 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torchvision
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class CIFAR10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cifar10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR10(
|
||||
root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR10(
|
||||
root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
class CIFAR100DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cifar100'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CIFAR' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR100(
|
||||
root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.CIFAR100(
|
||||
root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.49139968, 0.48215827, 0.44653124], std=[0.24703233, 0.24348505, 0.26158768])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
|
||||
|
||||
class CINIC10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'cinic10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/CINIC10' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/CINIC10' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = torchvision.datasets.ImageFolder(self.train_path, transform=_transforms)
|
||||
# dataset = torchvision.datasets.CIFAR10(
|
||||
# root=self.valid_path, train=True, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = torchvision.datasets.ImageFolder(self.valid_path, transform=_transforms)
|
||||
# dataset = torchvision.datasets.CIFAR10(
|
||||
# root=self.valid_path, train=False, download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train_and_valid')
|
||||
# return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'test')
|
||||
# return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.47889522, 0.47227842, 0.43047404], std=[0.24205776, 0.23828046, 0.25874835])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class DTDDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'dtd'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 47
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/dtd' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/dtd' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'valid')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.5329876098715876, 0.474260843249454, 0.42627281899380676],
|
||||
std=[0.26549755708788914, 0.25473554309855373, 0.2631728035662832])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.5329876098715876, 0.474260843249454,
|
||||
0.42627281899380676]]),
|
||||
)
|
||||
aa_params['interpolation'] = _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
# transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.Resize((image_size, image_size), interpolation=3),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,241 @@
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import PIL
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class Flowers102DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
# warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
weights = self.make_weights_for_balanced_classes(
|
||||
train_dataset.imgs, self.n_classes)
|
||||
weights = torch.DoubleTensor(weights)
|
||||
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
|
||||
|
||||
if valid_size is not None:
|
||||
raise NotImplementedError("validation dataset not yet implemented")
|
||||
# valid_dataset = self.valid_dataset(valid_transforms)
|
||||
|
||||
# self.train = train_loader_class(
|
||||
# train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
# num_workers=n_worker, pin_memory=True)
|
||||
# self.valid = torch.utils.data.DataLoader(
|
||||
# valid_dataset, batch_size=test_batch_size,
|
||||
# num_workers=n_worker, pin_memory=True)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'flowers102'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 102
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
|
||||
self._save_path = '/mnt/datastore/Flowers102' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
# self._save_path = '/mnt/datastore/Oxford102Flowers' # home server
|
||||
self._save_path = '/mnt/datastore/Flowers102' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
# def valid_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
# return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.test_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
# @property
|
||||
# def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def test_path(self):
|
||||
return os.path.join(self.save_path, 'test')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.5178361839861569, 0.4106749456881299, 0.32864167836880803],
|
||||
std=[0.2972239085211309, 0.24976049135203868, 0.28533308036347665])
|
||||
|
||||
@staticmethod
|
||||
def make_weights_for_balanced_classes(images, nclasses):
|
||||
count = [0] * nclasses
|
||||
|
||||
# Counts per label
|
||||
for item in images:
|
||||
count[item[1]] += 1
|
||||
|
||||
weight_per_class = [0.] * nclasses
|
||||
|
||||
# Total number of images.
|
||||
N = float(sum(count))
|
||||
|
||||
# super-sample the smaller classes.
|
||||
for i in range(nclasses):
|
||||
weight_per_class[i] = N / float(count[i])
|
||||
|
||||
weight = [0] * len(images)
|
||||
|
||||
# Calculate a weight per image.
|
||||
for idx, val in enumerate(images):
|
||||
weight[idx] = weight_per_class[val[1]]
|
||||
|
||||
return weight
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
transforms.RandomAffine(
|
||||
45, translate=(0.4, 0.4), scale=(0.75, 1.5), shear=None, resample=PIL.Image.BILINEAR, fillcolor=0),
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,225 @@
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
# self._save_path = '/dataset/imagenet'
|
||||
# self._save_path = '/usr/local/soft/temp-datastore/ILSVRC2012' # servers
|
||||
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
# self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
# self._save_path = os.path.expanduser('/usr/local/soft/temp-datastore/ILSVRC2012')
|
||||
self._save_path = '/mnt/datastore/ILSVRC2012' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,237 @@
|
||||
import os
|
||||
import math
|
||||
import warnings
|
||||
import numpy as np
|
||||
|
||||
# from timm.data.transforms import _pil_interp
|
||||
from timm.data.auto_augment import rand_augment_transform
|
||||
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class OxfordIIITPetsDataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=32, test_batch_size=200, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.samples) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.samples), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'pets'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 37
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/Oxford-IIITPets' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'valid')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.4828895122298728, 0.4448394893850807, 0.39566558230789783],
|
||||
std=[0.25925664613996574, 0.2532760018681693, 0.25981017205097917])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True, auto_augment='rand-m9-mstd0.5'):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
# if print_log:
|
||||
# print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
# (self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
# if self.distort_color == 'torch':
|
||||
# color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
# elif self.distort_color == 'tf':
|
||||
# color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
# else:
|
||||
# color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
img_size_min = min(image_size)
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
img_size_min = image_size
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
aa_params = dict(
|
||||
translate_const=int(img_size_min * 0.45),
|
||||
img_mean=tuple([min(255, round(255 * x)) for x in [0.4828895122298728, 0.4448394893850807,
|
||||
0.39566558230789783]]),
|
||||
)
|
||||
aa_params['interpolation'] = transforms.Resize(image_size) # _pil_interp('bicubic')
|
||||
train_transforms += [rand_augment_transform(auto_augment, aa_params)]
|
||||
|
||||
# if color_transform is not None:
|
||||
# train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.samples)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from glob import glob
|
||||
from torch.utils.data.dataset import Dataset
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_image(filename):
|
||||
img = Image.open(filename)
|
||||
img = img.convert('RGB')
|
||||
return img
|
||||
|
||||
|
||||
class PetDataset(Dataset):
|
||||
def __init__(self, root, train=True, num_cl=37, val_split=0.15, transforms=None):
|
||||
pt_name = os.path.join(root, '{}{}.pth'.format('train' if train else 'test',
|
||||
int(100 * (1 - val_split)) if train else int(
|
||||
100 * val_split)))
|
||||
if not os.path.exists(pt_name):
|
||||
filenames = glob(os.path.join(root, 'images') + '/*.jpg')
|
||||
classes = set()
|
||||
|
||||
data = []
|
||||
labels = []
|
||||
|
||||
for image in filenames:
|
||||
class_name = image.rsplit("/", 1)[1].rsplit('_', 1)[0]
|
||||
classes.add(class_name)
|
||||
img = load_image(image)
|
||||
|
||||
data.append(img)
|
||||
labels.append(class_name)
|
||||
|
||||
# convert classnames to indices
|
||||
class2idx = {cl: idx for idx, cl in enumerate(classes)}
|
||||
labels = torch.Tensor(list(map(lambda x: class2idx[x], labels))).long()
|
||||
data = list(zip(data, labels))
|
||||
|
||||
class_values = [[] for x in range(num_cl)]
|
||||
|
||||
# create arrays for each class type
|
||||
for d in data:
|
||||
class_values[d[1].item()].append(d)
|
||||
|
||||
train_data = []
|
||||
val_data = []
|
||||
|
||||
for class_dp in class_values:
|
||||
split_idx = int(len(class_dp) * (1 - val_split))
|
||||
train_data += class_dp[:split_idx]
|
||||
val_data += class_dp[split_idx:]
|
||||
torch.save(train_data, os.path.join(root, 'train{}.pth'.format(int(100 * (1 - val_split)))))
|
||||
torch.save(val_data, os.path.join(root, 'test{}.pth'.format(int(100 * val_split))))
|
||||
|
||||
self.data = torch.load(pt_name)
|
||||
self.len = len(self.data)
|
||||
self.transform = transforms
|
||||
|
||||
def __getitem__(self, index):
|
||||
img, label = self.data[index]
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, label
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torchvision
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from ofa.imagenet_codebase.data_providers.base_provider import DataProvider, MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
|
||||
class STL10DataProvider(DataProvider):
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224, num_replicas=None, rank=None):
|
||||
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
assert isinstance(self.image_size, list)
|
||||
from ofa.imagenet_codebase.data_providers.my_data_loader import MyDataLoader
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size)
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_transforms = self.build_train_transform()
|
||||
train_dataset = self.train_dataset(train_transforms)
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset.data) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset.data), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'stl10'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 10
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = '/mnt/datastore/STL10' # home server
|
||||
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = '/mnt/datastore/STL10' # home server
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.train_path, _transforms)
|
||||
dataset = torchvision.datasets.STL10(
|
||||
root=self.valid_path, split='train', download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
# dataset = datasets.ImageFolder(self.valid_path, _transforms)
|
||||
dataset = torchvision.datasets.STL10(
|
||||
root=self.valid_path, split='test', download=False, transform=_transforms)
|
||||
return dataset
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
# return os.path.join(self.save_path, 'train')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
# return os.path.join(self.save_path, 'val')
|
||||
return self.save_path
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(
|
||||
mean=[0.44671097, 0.4398105, 0.4066468],
|
||||
std=[0.2603405, 0.25657743, 0.27126738])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
else:
|
||||
color_transform = None
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset.data)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,4 @@
|
||||
from ofa.imagenet_codebase.networks.proxyless_nets import ProxylessNASNets, proxyless_base, MobileNetV2
|
||||
from ofa.imagenet_codebase.networks.mobilenet_v3 import MobileNetV3, MobileNetV3Large
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks.nsganetv2 import NSGANetV2
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
from timm.models.layers import drop_path
|
||||
from ofa.imagenet_codebase.modules.layers import *
|
||||
from ofa.imagenet_codebase.networks import MobileNetV3
|
||||
|
||||
|
||||
class MobileInvertedResidualBlock(MyModule):
|
||||
"""
|
||||
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
|
||||
imagenet_codebase/networks/proxyless_nets.py to include drop path in training
|
||||
|
||||
"""
|
||||
def __init__(self, mobile_inverted_conv, shortcut, drop_connect_rate=0.0):
|
||||
super(MobileInvertedResidualBlock, self).__init__()
|
||||
|
||||
self.mobile_inverted_conv = mobile_inverted_conv
|
||||
self.shortcut = shortcut
|
||||
self.drop_connect_rate = drop_connect_rate
|
||||
|
||||
def forward(self, x):
|
||||
if self.mobile_inverted_conv is None or isinstance(self.mobile_inverted_conv, ZeroLayer):
|
||||
res = x
|
||||
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
|
||||
res = self.mobile_inverted_conv(x)
|
||||
else:
|
||||
# res = self.mobile_inverted_conv(x) + self.shortcut(x)
|
||||
res = self.mobile_inverted_conv(x)
|
||||
|
||||
if self.drop_connect_rate > 0.:
|
||||
res = drop_path(res, drop_prob=self.drop_connect_rate, training=self.training)
|
||||
|
||||
res += self.shortcut(x)
|
||||
|
||||
return res
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '(%s, %s)' % (
|
||||
self.mobile_inverted_conv.module_str if self.mobile_inverted_conv is not None else None,
|
||||
self.shortcut.module_str if self.shortcut is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': MobileInvertedResidualBlock.__name__,
|
||||
'mobile_inverted_conv': self.mobile_inverted_conv.config if self.mobile_inverted_conv is not None else None,
|
||||
'shortcut': self.shortcut.config if self.shortcut is not None else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
mobile_inverted_conv = set_layer_from_config(config['mobile_inverted_conv'])
|
||||
shortcut = set_layer_from_config(config['shortcut'])
|
||||
return MobileInvertedResidualBlock(
|
||||
mobile_inverted_conv, shortcut, drop_connect_rate=config['drop_connect_rate'])
|
||||
|
||||
|
||||
class NSGANetV2(MobileNetV3):
|
||||
"""
|
||||
Modified from https://github.com/mit-han-lab/once-for-all/blob/master/ofa/
|
||||
imagenet_codebase/networks/mobilenet_v3.py to include drop path in training
|
||||
and option to reset classification layer
|
||||
"""
|
||||
@staticmethod
|
||||
def build_from_config(config, drop_connect_rate=0.0):
|
||||
first_conv = set_layer_from_config(config['first_conv'])
|
||||
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
|
||||
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
blocks = []
|
||||
for block_idx, block_config in enumerate(config['blocks']):
|
||||
block_config['drop_connect_rate'] = drop_connect_rate * block_idx / len(config['blocks'])
|
||||
blocks.append(MobileInvertedResidualBlock.build_from_config(block_config))
|
||||
|
||||
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-3)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, MobileInvertedResidualBlock):
|
||||
if isinstance(m.mobile_inverted_conv, MBInvertedConvLayer) and isinstance(m.shortcut, IdentityLayer):
|
||||
m.mobile_inverted_conv.point_linear.bn.weight.data.zero_()
|
||||
|
||||
@staticmethod
|
||||
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
|
||||
)
|
||||
# build mobile blocks
|
||||
feature_dim = input_channel
|
||||
blocks = []
|
||||
for stage_id, block_config_list in cfg.items():
|
||||
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
|
||||
mb_conv = MBInvertedConvLayer(
|
||||
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
|
||||
)
|
||||
if stride == 1 and out_channel == feature_dim:
|
||||
shortcut = IdentityLayer(out_channel, out_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(MobileInvertedResidualBlock(mb_conv, shortcut))
|
||||
feature_dim = out_channel
|
||||
# final expand layer
|
||||
final_expand_layer = ConvLayer(
|
||||
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
|
||||
)
|
||||
feature_dim = feature_dim * 6
|
||||
# feature mix layer
|
||||
feature_mix_layer = ConvLayer(
|
||||
feature_dim, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
|
||||
)
|
||||
# classifier
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
||||
|
||||
@staticmethod
|
||||
def reset_classifier(model, last_channel, n_classes, dropout_rate=0.0):
|
||||
model.classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
@@ -0,0 +1,309 @@
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.imagenet import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.cifar import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets import *
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import *
|
||||
|
||||
from ofa.imagenet_codebase.run_manager.run_manager import *
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=1e-4, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=128, test_batch_size=512, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/ILSVRC2012',
|
||||
**kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.imagenet_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.imagenet_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class CIFARRunConfig(RunConfig):
|
||||
def __init__(self, n_epochs=5, init_lr=0.01, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='cifar10', train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=2, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/CIFAR',
|
||||
**kwargs):
|
||||
super(CIFARRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.cifar_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == CIFAR10DataProvider.name():
|
||||
DataProviderClass = CIFAR10DataProvider
|
||||
elif self.dataset == CIFAR100DataProvider.name():
|
||||
DataProviderClass = CIFAR100DataProvider
|
||||
elif self.dataset == CINIC10DataProvider.name():
|
||||
DataProviderClass = CINIC10DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.cifar_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class Flowers102RunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=3, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='flowers102', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/Flowers102',
|
||||
**kwargs):
|
||||
super(Flowers102RunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.flowers102_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == Flowers102DataProvider.name():
|
||||
DataProviderClass = Flowers102DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.flowers102_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class STL10RunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=5, init_lr=1e-2, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='stl10', train_batch_size=96, test_batch_size=256, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=4, resize_scale=0.08, distort_color=None, image_size=224,
|
||||
data_path='/mnt/datastore/STL10',
|
||||
**kwargs):
|
||||
super(STL10RunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.stl10_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == STL10DataProvider.name():
|
||||
DataProviderClass = STL10DataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.stl10_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class DTDRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='dtd', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/dtd',
|
||||
**kwargs):
|
||||
super(DTDRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == DTDDataProvider.name():
|
||||
DataProviderClass = DTDDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class PetsRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='pets', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/Oxford-IIITPets',
|
||||
**kwargs):
|
||||
super(PetsRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.imagenet_data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == OxfordIIITPetsDataProvider.name():
|
||||
DataProviderClass = OxfordIIITPetsDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.imagenet_data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class AircraftRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=1, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='aircraft', train_batch_size=32, test_batch_size=250, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.0, no_decay_keys=None,
|
||||
mixup_alpha=None,
|
||||
model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
data_path='/mnt/datastore/Aircraft',
|
||||
**kwargs):
|
||||
super(AircraftRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
self.data_path = data_path
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == FGVCAircraftDataProvider.name():
|
||||
DataProviderClass = FGVCAircraftDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
save_path=self.data_path,
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
def get_run_config(**kwargs):
|
||||
if kwargs['dataset'] == 'imagenet':
|
||||
run_config = ImagenetRunConfig(**kwargs)
|
||||
elif kwargs['dataset'].startswith('cifar') or kwargs['dataset'].startswith('cinic'):
|
||||
run_config = CIFARRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'flowers102':
|
||||
run_config = Flowers102RunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'stl10':
|
||||
run_config = STL10RunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'dtd':
|
||||
run_config = DTDRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'pets':
|
||||
run_config = PetsRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'aircraft':
|
||||
run_config = AircraftRunConfig(**kwargs)
|
||||
elif kwargs['dataset'] == 'aircraft100':
|
||||
run_config = AircraftRunConfig(**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return run_config
|
||||
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from PIL import Image
|
||||
import torchvision.utils
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.aircraft import FGVCAircraft
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.pets2 import PetDataset
|
||||
import torch.utils.data as Data
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.data_providers.autoaugment import CIFAR10Policy
|
||||
|
||||
|
||||
def get_dataset(data_name, batch_size, data_path, num_workers,
|
||||
img_size, autoaugment, cutout, cutout_length):
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'aircraft': 100,
|
||||
'svhn': 10,
|
||||
'pets': 37
|
||||
}
|
||||
# 'aircraft30': 30,
|
||||
# 'aircraft100': 100,
|
||||
|
||||
train_transform, valid_transform = _data_transforms(
|
||||
data_name, img_size, autoaugment, cutout, cutout_length)
|
||||
if data_name == 'cifar100':
|
||||
train_data = torchvision.datasets.CIFAR100(
|
||||
root=data_path, train=True, download=True, transform=train_transform)
|
||||
valid_data = torchvision.datasets.CIFAR100(
|
||||
root=data_path, train=False, download=True, transform=valid_transform)
|
||||
elif data_name == 'cifar10':
|
||||
train_data = torchvision.datasets.CIFAR10(
|
||||
root=data_path, train=True, download=True, transform=train_transform)
|
||||
valid_data = torchvision.datasets.CIFAR10(
|
||||
root=data_path, train=False, download=True, transform=valid_transform)
|
||||
elif data_name.startswith('aircraft'):
|
||||
print(data_path)
|
||||
if 'aircraft100' in data_path:
|
||||
data_path = data_path.replace('aircraft100', 'aircraft/fgvc-aircraft-2013b')
|
||||
else:
|
||||
data_path = data_path.replace('aircraft', 'aircraft/fgvc-aircraft-2013b')
|
||||
train_data = FGVCAircraft(data_path, class_type='variant', split='trainval',
|
||||
transform=train_transform, download=True)
|
||||
valid_data = FGVCAircraft(data_path, class_type='variant', split='test',
|
||||
transform=valid_transform, download=True)
|
||||
elif data_name.startswith('pets'):
|
||||
train_data = PetDataset(data_path, train=True, num_cl=37,
|
||||
val_split=0.15, transforms=train_transform)
|
||||
valid_data = PetDataset(data_path, train=False, num_cl=37,
|
||||
val_split=0.15, transforms=valid_transform)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
train_queue = torch.utils.data.DataLoader(
|
||||
train_data, batch_size=batch_size, shuffle=True, pin_memory=True,
|
||||
num_workers=num_workers)
|
||||
|
||||
valid_queue = torch.utils.data.DataLoader(
|
||||
valid_data, batch_size=200, shuffle=False, pin_memory=True,
|
||||
num_workers=num_workers)
|
||||
|
||||
return train_queue, valid_queue, num_class_dict[data_name]
|
||||
|
||||
|
||||
|
||||
class Cutout(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
|
||||
def __call__(self, img):
|
||||
h, w = img.size(1), img.size(2)
|
||||
mask = np.ones((h, w), np.float32)
|
||||
y = np.random.randint(h)
|
||||
x = np.random.randint(w)
|
||||
|
||||
y1 = np.clip(y - self.length // 2, 0, h)
|
||||
y2 = np.clip(y + self.length // 2, 0, h)
|
||||
x1 = np.clip(x - self.length // 2, 0, w)
|
||||
x2 = np.clip(x + self.length // 2, 0, w)
|
||||
|
||||
mask[y1: y2, x1: x2] = 0.
|
||||
mask = torch.from_numpy(mask)
|
||||
mask = mask.expand_as(img)
|
||||
img *= mask
|
||||
return img
|
||||
|
||||
|
||||
def _data_transforms(data_name, img_size, autoaugment, cutout, cutout_length):
|
||||
if 'cifar' in data_name:
|
||||
norm_mean = [0.49139968, 0.48215827, 0.44653124]
|
||||
norm_std = [0.24703233, 0.24348505, 0.26158768]
|
||||
elif 'aircraft' in data_name:
|
||||
norm_mean = [0.48933587508932375, 0.5183537408957618, 0.5387914411673883]
|
||||
norm_std = [0.22388883112804625, 0.21641635409388751, 0.24615605842636115]
|
||||
elif 'pets' in data_name:
|
||||
norm_mean = [0.4828895122298728, 0.4448394893850807, 0.39566558230789783]
|
||||
norm_std = [0.25925664613996574, 0.2532760018681693, 0.25981017205097917]
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
train_transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
|
||||
transforms.RandomHorizontalFlip(),
|
||||
])
|
||||
|
||||
if autoaugment:
|
||||
train_transform.transforms.append(CIFAR10Policy())
|
||||
|
||||
train_transform.transforms.append(transforms.ToTensor())
|
||||
|
||||
if cutout:
|
||||
train_transform.transforms.append(Cutout(cutout_length))
|
||||
|
||||
train_transform.transforms.append(transforms.Normalize(norm_mean, norm_std))
|
||||
|
||||
valid_transform = transforms.Compose([
|
||||
transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC), # BICUBIC interpolation
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(norm_mean, norm_std),
|
||||
])
|
||||
return train_transform, valid_transform
|
||||
@@ -0,0 +1,233 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import random
|
||||
import sys
|
||||
import transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.run_manager import get_run_config
|
||||
from ofa.elastic_nn.networks import OFAMobileNetV3
|
||||
from ofa.imagenet_codebase.run_manager import RunManager
|
||||
from ofa.elastic_nn.modules.dynamic_op import DynamicSeparableConv2d
|
||||
from torchprofile import profile_macs
|
||||
import copy
|
||||
import json
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
DynamicSeparableConv2d.KERNEL_TRANSFORM_MODE = 1
|
||||
|
||||
|
||||
class ArchManager:
|
||||
def __init__(self):
|
||||
self.num_blocks = 20
|
||||
self.num_stages = 5
|
||||
self.kernel_sizes = [3, 5, 7]
|
||||
self.expand_ratios = [3, 4, 6]
|
||||
self.depths = [2, 3, 4]
|
||||
self.resolutions = [160, 176, 192, 208, 224]
|
||||
|
||||
def random_sample(self):
|
||||
sample = {}
|
||||
d = []
|
||||
e = []
|
||||
ks = []
|
||||
for i in range(self.num_stages):
|
||||
d.append(random.choice(self.depths))
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
e.append(random.choice(self.expand_ratios))
|
||||
ks.append(random.choice(self.kernel_sizes))
|
||||
|
||||
sample = {
|
||||
'wid': None,
|
||||
'ks': ks,
|
||||
'e': e,
|
||||
'd': d,
|
||||
'r': [random.choice(self.resolutions)]
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def random_resample(self, sample, i):
|
||||
assert i >= 0 and i < self.num_blocks
|
||||
sample['ks'][i] = random.choice(self.kernel_sizes)
|
||||
sample['e'][i] = random.choice(self.expand_ratios)
|
||||
|
||||
def random_resample_depth(self, sample, i):
|
||||
assert i >= 0 and i < self.num_stages
|
||||
sample['d'][i] = random.choice(self.depths)
|
||||
|
||||
def random_resample_resolution(self, sample):
|
||||
sample['r'][0] = random.choice(self.resolutions)
|
||||
|
||||
|
||||
def parse_string_list(string):
|
||||
if isinstance(string, str):
|
||||
# convert '[5 5 5 7 7 7 3 3 7 7 7 3 3]' to [5, 5, 5, 7, 7, 7, 3, 3, 7, 7, 7, 3, 3]
|
||||
return list(map(int, string[1:-1].split()))
|
||||
else:
|
||||
return string
|
||||
|
||||
|
||||
def pad_none(x, depth, max_depth):
|
||||
new_x, counter = [], 0
|
||||
for d in depth:
|
||||
for _ in range(d):
|
||||
new_x.append(x[counter])
|
||||
counter += 1
|
||||
if d < max_depth:
|
||||
new_x += [None] * (max_depth - d)
|
||||
return new_x
|
||||
|
||||
|
||||
def get_net_info(net, data_shape, measure_latency=None, print_info=True, clean=False, lut=None):
|
||||
net_info = eval_utils.get_net_info(
|
||||
net, data_shape, measure_latency, print_info=print_info, clean=clean, lut=lut)
|
||||
|
||||
gpu_latency, cpu_latency = None, None
|
||||
for k in net_info.keys():
|
||||
if 'gpu' in k:
|
||||
gpu_latency = np.round(net_info[k]['val'], 2)
|
||||
if 'cpu' in k:
|
||||
cpu_latency = np.round(net_info[k]['val'], 2)
|
||||
|
||||
return {
|
||||
'params': np.round(net_info['params'] / 1e6, 2),
|
||||
'flops': np.round(net_info['flops'] / 1e6, 2),
|
||||
'gpu': gpu_latency, 'cpu': cpu_latency
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config, max_depth=4):
|
||||
kernel_size, exp_ratio, depth = config['ks'], config['e'], config['d']
|
||||
|
||||
if isinstance(kernel_size, str): kernel_size = parse_string_list(kernel_size)
|
||||
if isinstance(exp_ratio, str): exp_ratio = parse_string_list(exp_ratio)
|
||||
if isinstance(depth, str): depth = parse_string_list(depth)
|
||||
|
||||
assert (isinstance(kernel_size, list) or isinstance(kernel_size, int))
|
||||
assert (isinstance(exp_ratio, list) or isinstance(exp_ratio, int))
|
||||
assert isinstance(depth, list)
|
||||
|
||||
if len(kernel_size) < len(depth) * max_depth:
|
||||
kernel_size = pad_none(kernel_size, depth, max_depth)
|
||||
if len(exp_ratio) < len(depth) * max_depth:
|
||||
exp_ratio = pad_none(exp_ratio, depth, max_depth)
|
||||
|
||||
# return {'ks': kernel_size, 'e': exp_ratio, 'd': depth, 'w': config['w']}
|
||||
return {'ks': kernel_size, 'e': exp_ratio, 'd': depth}
|
||||
|
||||
|
||||
def set_nas_test_dataset(path, test_data_name, max_img):
|
||||
if not test_data_name in ['mnist', 'svhn', 'cifar10',
|
||||
'cifar100', 'aircraft', 'pets']: raise ValueError(test_data_name)
|
||||
|
||||
dpath = path
|
||||
num_cls = 10 # mnist, svhn, cifar10
|
||||
if test_data_name in ['cifar100', 'aircraft']:
|
||||
num_cls = 100
|
||||
elif test_data_name == 'pets':
|
||||
num_cls = 37
|
||||
|
||||
x = torch.load(dpath + f'/{test_data_name}bylabel')
|
||||
img_per_cls = min(int(max_img / num_cls), 20)
|
||||
return x, img_per_cls, num_cls
|
||||
|
||||
|
||||
class OFAEvaluator:
|
||||
""" based on OnceForAll supernet taken from https://github.com/mit-han-lab/once-for-all """
|
||||
|
||||
def __init__(self, num_gen_arch, img_size, drop_path,
|
||||
n_classes=1000,
|
||||
model_path=None,
|
||||
kernel_size=None, exp_ratio=None, depth=None):
|
||||
# default configurations
|
||||
self.kernel_size = [3, 5, 7] if kernel_size is None else kernel_size # depth-wise conv kernel size
|
||||
self.exp_ratio = [3, 4, 6] if exp_ratio is None else exp_ratio # expansion rate
|
||||
self.depth = [2, 3, 4] if depth is None else depth # number of MB block repetition
|
||||
|
||||
if 'w1.0' in model_path:
|
||||
self.width_mult = 1.0
|
||||
elif 'w1.2' in model_path:
|
||||
self.width_mult = 1.2
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
self.engine = OFAMobileNetV3(
|
||||
n_classes=n_classes,
|
||||
dropout_rate=0, width_mult_list=self.width_mult, ks_list=self.kernel_size,
|
||||
expand_ratio_list=self.exp_ratio, depth_list=self.depth)
|
||||
|
||||
|
||||
init = torch.load(model_path, map_location='cpu')['state_dict']
|
||||
self.engine.load_weights_from_net(init)
|
||||
print(f'load {model_path}...')
|
||||
|
||||
## metad2a
|
||||
self.arch_manager = ArchManager()
|
||||
self.num_gen_arch = num_gen_arch
|
||||
|
||||
|
||||
def sample_random_architecture(self):
|
||||
sampled_architecture = self.arch_manager.random_sample()
|
||||
return sampled_architecture
|
||||
|
||||
def get_architecture(self, bound=None):
|
||||
g_lst, pred_acc_lst, x_lst = [], [], []
|
||||
searched_g, max_pred_acc = None, 0
|
||||
|
||||
with torch.no_grad():
|
||||
for n in range(self.num_gen_arch):
|
||||
file_acc = self.lines[n].split()[0]
|
||||
g_dict = ' '.join(self.lines[n].split())
|
||||
g = json.loads(g_dict.replace("'", "\""))
|
||||
|
||||
if bound is not None:
|
||||
subnet, config = self.sample(config=g)
|
||||
net = NSGANetV2.build_from_config(subnet.config,
|
||||
drop_connect_rate=self.drop_path)
|
||||
inputs = torch.randn(1, 3, self.img_size, self.img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
if flops <= bound:
|
||||
searched_g = g
|
||||
break
|
||||
else:
|
||||
searched_g = g
|
||||
pred_acc_lst.append(file_acc)
|
||||
break
|
||||
|
||||
if searched_g is None:
|
||||
raise ValueError(searched_g)
|
||||
return searched_g, pred_acc_lst
|
||||
|
||||
|
||||
def sample(self, config=None):
|
||||
""" randomly sample a sub-network """
|
||||
if config is not None:
|
||||
config = validate_config(config)
|
||||
self.engine.set_active_subnet(ks=config['ks'], e=config['e'], d=config['d'])
|
||||
else:
|
||||
config = self.engine.sample_active_subnet()
|
||||
|
||||
subnet = self.engine.get_active_subnet(preserve_weight=True)
|
||||
return subnet, config
|
||||
|
||||
@staticmethod
|
||||
def save_net_config(path, net, config_name='net.config'):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
net_save_path = os.path.join(path, config_name)
|
||||
json.dump(net.config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
|
||||
@staticmethod
|
||||
def save_net(path, net, model_name):
|
||||
""" dump net weight as checkpoint """
|
||||
if isinstance(net, torch.nn.DataParallel):
|
||||
checkpoint = {'state_dict': net.module.state_dict()}
|
||||
else:
|
||||
checkpoint = {'state_dict': net.state_dict()}
|
||||
model_path = os.path.join(path, model_name)
|
||||
torch.save(checkpoint, model_path)
|
||||
print('Network model dump to %s' % model_path)
|
||||
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import torch.optim as optim
|
||||
from evaluator import OFAEvaluator
|
||||
from torchprofile import profile_macs
|
||||
from codebase.networks import NSGANetV2
|
||||
from parser import get_parse
|
||||
from eval_utils import get_dataset
|
||||
|
||||
|
||||
args = get_parse()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
args.n_gpus = len(device_list)
|
||||
args.device = torch.device("cuda:0")
|
||||
|
||||
if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
|
||||
evaluator = OFAEvaluator(args,
|
||||
model_path='../.torch/ofa_nets/ofa_mbv3_d234_e346_k357_w1.0')
|
||||
|
||||
args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
|
||||
if args.model_config.startswith('flops@'):
|
||||
args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
|
||||
else:
|
||||
args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
|
||||
args.data_path = os.path.join(args.data_path, args.data_name)
|
||||
|
||||
log_format = '%(asctime)s %(message)s'
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
|
||||
format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
|
||||
fh.setFormatter(logging.Formatter(log_format))
|
||||
logging.getLogger().addHandler(fh)
|
||||
if not torch.cuda.is_available():
|
||||
logging.info('no gpu self.args.device available')
|
||||
sys.exit(1)
|
||||
logging.info("args = %s", args)
|
||||
|
||||
|
||||
|
||||
def set_architecture(n_cls):
|
||||
if args.model_config.startswith('flops@'):
|
||||
names = {'cifar10': 'CIFAR-10', 'cifar100': 'CIFAR-100',
|
||||
'aircraft100': 'Aircraft', 'pets': 'Pets'}
|
||||
p = os.path.join('./searched-architectures/{}/net-{}/net.subnet'.
|
||||
format(names[args.data_name], args.model_config))
|
||||
g = json.load(open(p))
|
||||
else:
|
||||
g, acc = evaluator.get_architecture(args)
|
||||
|
||||
subnet, config = evaluator.sample(g)
|
||||
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=args.drop_path)
|
||||
net.load_state_dict(subnet.state_dict())
|
||||
|
||||
NSGANetV2.reset_classifier(
|
||||
net, last_channel=net.classifier.in_features,
|
||||
n_classes=n_cls, dropout_rate=args.drop)
|
||||
# calculate #Paramaters and #FLOPS
|
||||
inputs = torch.randn(1, 3, args.img_size, args.img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
|
||||
net_name = "net_flops@{:.0f}".format(flops)
|
||||
logging.info('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
|
||||
OFAEvaluator.save_net_config(args.save_path, net, net_name + '.config')
|
||||
if args.n_gpus > 1:
|
||||
net = nn.DataParallel(net) # data parallel in case more than 1 gpu available
|
||||
net = net.to(args.device)
|
||||
|
||||
return net, net_name
|
||||
|
||||
|
||||
def train(train_queue, net, criterion, optimizer):
|
||||
net.train()
|
||||
train_loss, correct, total = 0, 0, 0
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
# upsample by bicubic to match imagenet training size
|
||||
inputs, targets = inputs.to(args.device), targets.to(args.device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), args.grad_clip)
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('train %03d %e %f', step, train_loss / total, 100. * correct / total)
|
||||
logging.info('train acc %f', 100. * correct / total)
|
||||
return train_loss / total, 100. * correct / total
|
||||
|
||||
|
||||
def infer(valid_queue, net, criterion, early_stop=False):
|
||||
net.eval()
|
||||
test_loss, correct, total = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
inputs, targets = inputs.to(args.device), targets.to(args.device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % args.report_freq == 0:
|
||||
logging.info('valid %03d %e %f', step, test_loss / total, 100. * correct / total)
|
||||
if early_stop and step == 10:
|
||||
break
|
||||
acc = 100. * correct / total
|
||||
logging.info('valid acc %f', 100. * correct / total)
|
||||
|
||||
return test_loss / total, acc
|
||||
|
||||
|
||||
def main():
|
||||
best_acc, top_checkpoints = 0, []
|
||||
|
||||
train_queue, valid_queue, n_cls = get_dataset(args)
|
||||
net, net_name = set_architecture(n_cls)
|
||||
parameters = filter(lambda p: p.requires_grad, net.parameters())
|
||||
optimizer = optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
criterion = nn.CrossEntropyLoss().to(args.device)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
|
||||
|
||||
train(train_queue, net, criterion, optimizer)
|
||||
_, valid_acc = infer(valid_queue, net, criterion)
|
||||
# checkpoint saving
|
||||
|
||||
if len(top_checkpoints) < args.topk:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
|
||||
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
|
||||
else:
|
||||
idx = np.argmin([x[1] for x in top_checkpoints])
|
||||
if valid_acc > top_checkpoints[idx][1]:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.ckpt{}'.format(epoch))
|
||||
top_checkpoints.append((os.path.join(args.save_path, net_name + '.ckpt{}'.format(epoch)), valid_acc))
|
||||
# remove the idx
|
||||
os.remove(top_checkpoints[idx][0])
|
||||
top_checkpoints.pop(idx)
|
||||
print(top_checkpoints)
|
||||
if valid_acc > best_acc:
|
||||
OFAEvaluator.save_net(args.save_path, net, net_name + '.best')
|
||||
best_acc = valid_acc
|
||||
scheduler.step()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,43 @@
|
||||
import argparse
|
||||
|
||||
def get_parse():
|
||||
parser = argparse.ArgumentParser(description='MetaD2A vs NSGANETv2')
|
||||
parser.add_argument('--save-path', type=str, default='../results', help='the path of save directory')
|
||||
parser.add_argument('--data-path', type=str, default='../data', help='the path of save directory')
|
||||
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
|
||||
parser.add_argument('--num-gen-arch', type=int, default=200,
|
||||
help='the number of candidate architectures generated by the generator')
|
||||
parser.add_argument('--bound', type=int, default=None)
|
||||
|
||||
# original setting
|
||||
parser.add_argument('--seed', type=int, default=-1, help='random seed')
|
||||
parser.add_argument('--batch-size', type=int, default=96, help='batch size')
|
||||
parser.add_argument('--num_workers', type=int, default=2, help='number of workers for data loading')
|
||||
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
|
||||
parser.add_argument('--lr', type=float, default=0.01, help='init learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
|
||||
parser.add_argument('--weight_decay', type=float, default=4e-5, help='weight decay')
|
||||
parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
|
||||
parser.add_argument('--epochs', type=int, default=150, help='num of training epochs')
|
||||
parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
|
||||
parser.add_argument('--cutout', action='store_true', default=True, help='use cutout')
|
||||
parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
|
||||
parser.add_argument('--autoaugment', action='store_true', default=True, help='use auto augmentation')
|
||||
|
||||
parser.add_argument('--topk', type=int, default=10, help='top k checkpoints to save')
|
||||
parser.add_argument('--evaluate', action='store_true', default=False, help='evaluate a pretrained model')
|
||||
# model related
|
||||
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "countception"')
|
||||
parser.add_argument('--model-config', type=str, default='search',
|
||||
help='location of a json file of specific model declaration')
|
||||
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
||||
help='Initialize model from this checkpoint (default: none)')
|
||||
parser.add_argument('--drop', type=float, default=0.2,
|
||||
help='dropout rate')
|
||||
parser.add_argument('--drop-path', type=float, default=0.2, metavar='PCT',
|
||||
help='Drop path rate (default: None)')
|
||||
parser.add_argument('--img-size', type=int, default=224,
|
||||
help='input resolution (192 -> 256)')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@@ -0,0 +1,261 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import torch.optim as optim
|
||||
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.evaluator import OFAEvaluator
|
||||
from torchprofile import profile_macs
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.codebase.networks import NSGANetV2
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.parser import get_parse
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.evaluation.eval_utils import get_dataset
|
||||
from transfer_nag_lib.MetaD2A_nas_bench_201.metad2a_utils import reset_seed
|
||||
from transfer_nag_lib.ofa_net import OFASubNet
|
||||
|
||||
|
||||
# os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
# device_list = [int(_) for _ in args.gpu.split(',')]
|
||||
# args.n_gpus = len(device_list)
|
||||
# args.device = torch.device("cuda:0")
|
||||
|
||||
# if args.seed is None or args.seed < 0: args.seed = random.randint(1, 100000)
|
||||
# torch.cuda.manual_seed(args.seed)
|
||||
# torch.manual_seed(args.seed)
|
||||
# np.random.seed(args.seed)
|
||||
# random.seed(args.seed)
|
||||
|
||||
|
||||
|
||||
# args.save_path = os.path.join(args.save_path, f'evaluation/{args.data_name}')
|
||||
# if args.model_config.startswith('flops@'):
|
||||
# args.save_path += f'-nsganetV2-{args.model_config}-{args.seed}'
|
||||
# else:
|
||||
# args.save_path += f'-metaD2A-{args.bound}-{args.seed}'
|
||||
# if not os.path.exists(args.save_path):
|
||||
# os.makedirs(args.save_path)
|
||||
|
||||
# args.data_path = os.path.join(args.data_path, args.data_name)
|
||||
|
||||
# log_format = '%(asctime)s %(message)s'
|
||||
# logging.basicConfig(stream=sys.stdout, level=print,
|
||||
# format=log_format, datefmt='%m/%d %I:%M:%S %p')
|
||||
# fh = logging.FileHandler(os.path.join(args.save_path, 'log.txt'))
|
||||
# fh.setFormatter(logging.Formatter(log_format))
|
||||
# logging.getLogger().addHandler(fh)
|
||||
# if not torch.cuda.is_available():
|
||||
# print('no gpu self.args.device available')
|
||||
# sys.exit(1)
|
||||
# print("args = %s", args)
|
||||
|
||||
|
||||
|
||||
def set_architecture(n_cls, evaluator, drop_path, drop, img_size, n_gpus, device, save_path, model_str):
|
||||
# g, acc = evaluator.get_architecture(model_str)
|
||||
g = OFASubNet(model_str).get_op_dict()
|
||||
subnet, config = evaluator.sample(g)
|
||||
net = NSGANetV2.build_from_config(subnet.config, drop_connect_rate=drop_path)
|
||||
net.load_state_dict(subnet.state_dict())
|
||||
|
||||
NSGANetV2.reset_classifier(
|
||||
net, last_channel=net.classifier.in_features,
|
||||
n_classes=n_cls, dropout_rate=drop)
|
||||
# calculate #Paramaters and #FLOPS
|
||||
inputs = torch.randn(1, 3, img_size, img_size)
|
||||
flops = profile_macs(copy.deepcopy(net), inputs) / 1e6
|
||||
params = sum(p.numel() for p in net.parameters() if p.requires_grad) / 1e6
|
||||
net_name = "net_flops@{:.0f}".format(flops)
|
||||
print('#params {:.2f}M, #flops {:.0f}M'.format(params, flops))
|
||||
# OFAEvaluator.save_net_config(save_path, net, net_name + '.config')
|
||||
if torch.cuda.device_count() > 1:
|
||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||
net = nn.DataParallel(net)
|
||||
net = net.to(device)
|
||||
|
||||
return net, net_name, params, flops
|
||||
|
||||
|
||||
def train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq):
|
||||
net.train()
|
||||
train_loss, correct, total = 0, 0, 0
|
||||
for step, (inputs, targets) in enumerate(train_queue):
|
||||
# upsample by bicubic to match imagenet training size
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
|
||||
optimizer.step()
|
||||
train_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % report_freq == 0:
|
||||
print(f'train step {step:03d} loss {train_loss / total:.4f} train acc {100. * correct / total:.4f}')
|
||||
print(f'train acc {100. * correct / total:.4f}')
|
||||
return train_loss / total, 100. * correct / total
|
||||
|
||||
|
||||
def infer(valid_queue, net, criterion, device, report_freq, early_stop=False):
|
||||
net.eval()
|
||||
test_loss, correct, total = 0, 0, 0
|
||||
with torch.no_grad():
|
||||
for step, (inputs, targets) in enumerate(valid_queue):
|
||||
inputs, targets = inputs.to(device), targets.to(device)
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, targets)
|
||||
test_loss += loss.item()
|
||||
_, predicted = outputs.max(1)
|
||||
total += targets.size(0)
|
||||
correct += predicted.eq(targets).sum().item()
|
||||
if step % report_freq == 0:
|
||||
print(f'valid {step:03d} {test_loss / total:.4f} {100. * correct / total:.4f}')
|
||||
if early_stop and step == 10:
|
||||
break
|
||||
acc = 100. * correct / total
|
||||
print('valid acc {:.4f}'.format(100. * correct / total))
|
||||
|
||||
return test_loss / total, acc
|
||||
|
||||
|
||||
def train_single_model(save_path, workers, datasets, xpaths, splits, use_less,
|
||||
seed, model_str, device,
|
||||
lr=0.01,
|
||||
momentum=0.9,
|
||||
weight_decay=4e-5,
|
||||
report_freq=50,
|
||||
epochs=150,
|
||||
grad_clip=5,
|
||||
cutout=True,
|
||||
cutout_length=16,
|
||||
autoaugment=True,
|
||||
drop=0.2,
|
||||
drop_path=0.2,
|
||||
img_size=224,
|
||||
batch_size=96,
|
||||
):
|
||||
assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.deterministic = True
|
||||
reset_seed(seed)
|
||||
# save_dir = Path(save_dir)
|
||||
# logger = Logger(str(save_dir), 0, False)
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
to_save_name = save_path + '/seed-{:04d}.pth'.format(seed)
|
||||
print(to_save_name)
|
||||
# args = get_parse()
|
||||
num_gen_arch = None
|
||||
evaluator = OFAEvaluator(num_gen_arch, img_size, drop_path,
|
||||
model_path='/home/data/GTAD/checkpoints/ofa/ofa_net/ofa_mbv3_d234_e346_k357_w1.0')
|
||||
|
||||
train_queue, valid_queue, n_cls = get_dataset(datasets, batch_size,
|
||||
xpaths, workers, img_size, autoaugment, cutout, cutout_length)
|
||||
net, net_name, params, flops = set_architecture(n_cls, evaluator,
|
||||
drop_path, drop, img_size, n_gpus=1, device=device, save_path=save_path, model_str=model_str)
|
||||
|
||||
|
||||
# net.to(device)
|
||||
|
||||
parameters = filter(lambda p: p.requires_grad, net.parameters())
|
||||
optimizer = optim.SGD(parameters, lr=lr, momentum=momentum, weight_decay=weight_decay)
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
|
||||
|
||||
# assert epochs == 1
|
||||
max_valid_acc = 0
|
||||
max_epoch = 0
|
||||
for epoch in range(epochs):
|
||||
print('epoch {:d} lr {:.4f}'.format(epoch, scheduler.get_lr()[0]))
|
||||
|
||||
train(train_queue, net, criterion, optimizer, grad_clip, device, report_freq)
|
||||
_, valid_acc = infer(valid_queue, net, criterion, device, report_freq)
|
||||
torch.save(valid_acc, to_save_name)
|
||||
print(f'seed {seed:04d} last acc {valid_acc:.4f} max acc {max_valid_acc:.4f}')
|
||||
if max_valid_acc < valid_acc:
|
||||
max_valid_acc = valid_acc
|
||||
max_epoch = epoch
|
||||
# parent_path = os.path.abspath(os.path.join(save_path, os.pardir))
|
||||
# with open(parent_path + '/accuracy.txt', 'a+') as f:
|
||||
# f.write(f'{model_str} seed {seed:04d} {valid_acc:.4f}\n')
|
||||
|
||||
return valid_acc, max_valid_acc, params, flops
|
||||
|
||||
|
||||
################ NAS BENCH 201 #####################
|
||||
# def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less,
|
||||
# seeds, model_str, arch_config):
|
||||
# assert torch.cuda.is_available(), 'CUDA is not available.'
|
||||
# torch.backends.cudnn.enabled = True
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.set_num_threads(workers)
|
||||
|
||||
# save_dir = Path(save_dir)
|
||||
# logger = Logger(str(save_dir), 0, False)
|
||||
|
||||
# if model_str in CellArchitectures:
|
||||
# arch = CellArchitectures[model_str]
|
||||
# logger.log(
|
||||
# 'The model string is found in pre-defined architecture dict : {:}'.format(model_str))
|
||||
# else:
|
||||
# try:
|
||||
# arch = CellStructure.str2structure(model_str)
|
||||
# except:
|
||||
# raise ValueError(
|
||||
# 'Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
|
||||
|
||||
# assert arch.check_valid_op(get_search_spaces(
|
||||
# 'cell', 'nas-bench-201')), '{:} has the invalid op.'.format(arch)
|
||||
# # assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
|
||||
# logger.log('Start train-evaluate {:}'.format(arch.tostr()))
|
||||
# logger.log('arch_config : {:}'.format(arch_config))
|
||||
|
||||
# start_time, seed_time = time.time(), AverageMeter()
|
||||
# for _is, seed in enumerate(seeds):
|
||||
# logger.log(
|
||||
# '\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds),
|
||||
# seed))
|
||||
# to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
|
||||
# if to_save_name.exists():
|
||||
# logger.log(
|
||||
# 'Find the existing file {:}, directly load!'.format(to_save_name))
|
||||
# checkpoint = torch.load(to_save_name)
|
||||
# else:
|
||||
# logger.log(
|
||||
# 'Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
|
||||
# checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less,
|
||||
# seed, arch_config, workers, logger)
|
||||
# torch.save(checkpoint, to_save_name)
|
||||
# # log information
|
||||
# logger.log('{:}'.format(checkpoint['info']))
|
||||
# all_dataset_keys = checkpoint['all_dataset_keys']
|
||||
# for dataset_key in all_dataset_keys:
|
||||
# logger.log('\n{:} dataset : {:} {:}'.format(
|
||||
# '-' * 15, dataset_key, '-' * 15))
|
||||
# dataset_info = checkpoint[dataset_key]
|
||||
# # logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
|
||||
# logger.log('Flops = {:} MB, Params = {:} MB'.format(
|
||||
# dataset_info['flop'], dataset_info['param']))
|
||||
# logger.log('config : {:}'.format(dataset_info['config']))
|
||||
# logger.log('Training State (finish) = {:}'.format(
|
||||
# dataset_info['finish-train']))
|
||||
# last_epoch = dataset_info['total_epoch'] - 1
|
||||
# train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
|
||||
# valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
|
||||
# # measure elapsed time
|
||||
# seed_time.update(time.time() - start_time)
|
||||
# start_time = time.time()
|
||||
# need_time = 'Time Left: {:}'.format(convert_secs2time(
|
||||
# seed_time.avg * (len(seeds) - _is - 1), True))
|
||||
# logger.log(
|
||||
# '\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed,
|
||||
# need_time))
|
||||
# logger.close()
|
||||
# ###################
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_single_model()
|
||||
@@ -0,0 +1,5 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .generator import Generator
|
||||
@@ -0,0 +1,204 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
||||
from utils import load_graph_config, decode_ofa_mbv3_to_igraph, decode_igraph_to_ofa_mbv3
|
||||
from utils import Accumulator, Log
|
||||
from utils import load_model, save_model
|
||||
from loader import get_meta_train_loader, get_meta_test_loader
|
||||
|
||||
from .generator_model import GeneratorModel
|
||||
|
||||
|
||||
class Generator:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.batch_size = args.batch_size
|
||||
self.data_path = args.data_path
|
||||
self.num_sample = args.num_sample
|
||||
self.max_epoch = args.max_epoch
|
||||
self.save_epoch = args.save_epoch
|
||||
self.model_path = args.model_path
|
||||
self.save_path = args.save_path
|
||||
self.model_name = args.model_name
|
||||
self.test = args.test
|
||||
self.device = args.device
|
||||
|
||||
graph_config = load_graph_config(
|
||||
args.graph_data_name, args.nvt, args.data_path)
|
||||
self.model = GeneratorModel(args, graph_config)
|
||||
self.model.to(self.device)
|
||||
|
||||
if self.test:
|
||||
self.data_name = args.data_name
|
||||
self.num_class = args.num_class
|
||||
self.load_epoch = args.load_epoch
|
||||
self.num_gen_arch = args.num_gen_arch
|
||||
load_model(self.model, self.model_path, self.load_epoch)
|
||||
|
||||
else:
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
|
||||
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
self.mtrloader = get_meta_train_loader(
|
||||
self.batch_size, self.data_path, self.num_sample)
|
||||
self.mtrlog = Log(self.args, open(os.path.join(
|
||||
self.save_path, self.model_name, 'meta_train_generator.log'), 'w'))
|
||||
self.mtrlog.print_args()
|
||||
self.mtrlogger = Accumulator('loss', 'recon_loss', 'kld')
|
||||
self.mvallogger = Accumulator('loss', 'recon_loss', 'kld')
|
||||
|
||||
def meta_train(self):
|
||||
sttime = time.time()
|
||||
for epoch in range(1, self.max_epoch + 1):
|
||||
self.mtrlog.ep_sttime = time.time()
|
||||
loss = self.meta_train_epoch(epoch)
|
||||
self.scheduler.step(loss)
|
||||
self.mtrlog.print(self.mtrlogger, epoch, tag='train')
|
||||
|
||||
self.meta_validation()
|
||||
self.mtrlog.print(self.mvallogger, epoch, tag='valid')
|
||||
|
||||
if epoch % self.save_epoch == 0:
|
||||
save_model(epoch, self.model, self.model_path)
|
||||
|
||||
self.mtrlog.save_time_log()
|
||||
|
||||
def meta_train_epoch(self, epoch):
|
||||
self.model.to(self.device)
|
||||
self.model.train()
|
||||
|
||||
self.mtrloader.dataset.set_mode('train')
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
for batch in pbar:
|
||||
for x, g, acc in batch:
|
||||
self.optimizer.zero_grad()
|
||||
g = decode_ofa_mbv3_to_igraph(g)[0]
|
||||
x_ = x.unsqueeze(0).to(self.device)
|
||||
mu, logvar = self.model.set_encode(x_)
|
||||
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
cnt = len(x)
|
||||
self.mtrlogger.accum([loss.item() / cnt,
|
||||
recon.item() / cnt,
|
||||
kld.item() / cnt])
|
||||
|
||||
return self.mtrlogger.get('loss')
|
||||
|
||||
|
||||
def meta_validation(self):
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
self.mtrloader.dataset.set_mode('valid')
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
for batch in pbar:
|
||||
for x, g, acc in batch:
|
||||
with torch.no_grad():
|
||||
g = decode_ofa_mbv3_to_igraph(g)[0]
|
||||
x_ = x.unsqueeze(0).to(self.device)
|
||||
mu, logvar = self.model.set_encode(x_)
|
||||
loss, recon, kld = self.model.loss(mu.unsqueeze(0), logvar.unsqueeze(0), [g])
|
||||
|
||||
cnt = len(x)
|
||||
self.mvallogger.accum([loss.item() / cnt,
|
||||
recon.item() / cnt,
|
||||
kld.item() / cnt])
|
||||
|
||||
return self.mvallogger.get('loss')
|
||||
|
||||
|
||||
def meta_test(self, predictor):
|
||||
if self.data_name == 'all':
|
||||
for data_name in ['cifar100', 'cifar10', 'mnist', 'svhn', 'aircraft30', 'aircraft100', 'pets']:
|
||||
self.meta_test_per_dataset(data_name, predictor)
|
||||
else:
|
||||
self.meta_test_per_dataset(self.data_name, predictor)
|
||||
|
||||
def meta_test_per_dataset(self, data_name, predictor):
|
||||
# meta_test_path = os.path.join(
|
||||
# self.save_path, 'meta_test', data_name, 'generated_arch')
|
||||
meta_test_path = os.path.join(
|
||||
self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'generated_arch')
|
||||
if not os.path.exists(meta_test_path):
|
||||
os.makedirs(meta_test_path)
|
||||
|
||||
meta_test_loader = get_meta_test_loader(
|
||||
self.data_path, data_name, self.num_sample, self.num_class)
|
||||
|
||||
print(f'==> generate architectures for {data_name}')
|
||||
runs = 10 if data_name in ['cifar10', 'cifar100'] else 1
|
||||
# num_gen_arch = 500 if data_name in ['cifar100'] else self.num_gen_arch
|
||||
elasped_time = []
|
||||
for run in range(1, runs + 1):
|
||||
print(f'==> run {run}/{runs}')
|
||||
elasped_time.append(self.generate_architectures(
|
||||
meta_test_loader, data_name,
|
||||
meta_test_path, run, self.num_gen_arch, predictor))
|
||||
print(f'==> done\n')
|
||||
|
||||
# time_path = os.path.join(self.save_path, 'meta_test', data_name, 'time.txt')
|
||||
time_path = os.path.join(self.save_path, 'meta_test', data_name, f'{self.num_gen_arch}', 'time.txt')
|
||||
with open(time_path, 'w') as f_time:
|
||||
msg = f'generator elasped time {np.mean(elasped_time):.2f}s'
|
||||
print(f'==> save time in {time_path}')
|
||||
f_time.write(msg + '\n');
|
||||
print(msg)
|
||||
|
||||
def generate_architectures(self, meta_test_loader, data_name,
|
||||
meta_test_path, run, num_gen_arch, predictor):
|
||||
self.model.eval()
|
||||
self.model.to(self.device)
|
||||
|
||||
architecture_string_lst, pred_acc_lst = [], []
|
||||
total_cnt, valid_cnt = 0, 0
|
||||
flag = False
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
for x in meta_test_loader:
|
||||
x_ = x.unsqueeze(0).to(self.device)
|
||||
mu, logvar = self.model.set_encode(x_)
|
||||
z = self.model.reparameterize(mu.unsqueeze(0), logvar.unsqueeze(0))
|
||||
g_recon = self.model.graph_decode(z)
|
||||
pred_acc = predictor.forward(x_, g_recon)
|
||||
architecture_string = decode_igraph_to_ofa_mbv3(g_recon[0])
|
||||
total_cnt += 1
|
||||
if architecture_string is not None:
|
||||
if not architecture_string in architecture_string_lst:
|
||||
valid_cnt += 1
|
||||
architecture_string_lst.append(architecture_string)
|
||||
pred_acc_lst.append(pred_acc.item())
|
||||
if valid_cnt == num_gen_arch:
|
||||
flag = True
|
||||
break
|
||||
if flag:
|
||||
break
|
||||
elapsed = time.time() - start
|
||||
pred_acc_lst, architecture_string_lst = zip(*sorted(zip(pred_acc_lst,
|
||||
architecture_string_lst),
|
||||
key=lambda x: x[0], reverse=True))
|
||||
|
||||
spath = os.path.join(meta_test_path, f"run_{run}.txt")
|
||||
with open(spath, 'w') as f:
|
||||
print(f'==> save generated architectures in {spath}')
|
||||
msg = f'elapsed time: {elapsed:6.2f}s '
|
||||
print(msg);
|
||||
f.write(msg + '\n')
|
||||
for i, architecture_string in enumerate(architecture_string_lst):
|
||||
f.write(f"{architecture_string}\n")
|
||||
return elapsed
|
||||
@@ -0,0 +1,396 @@
|
||||
######################################################################################
|
||||
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import numpy as np
|
||||
import igraph
|
||||
from set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
class GeneratorModel(nn.Module):
|
||||
def __init__(self, args, graph_config):
|
||||
super(GeneratorModel, self).__init__()
|
||||
self.max_n = graph_config['max_n'] # maximum number of vertices
|
||||
self.nvt = graph_config['num_vertex_type'] # number of vertex types
|
||||
self.START_TYPE = graph_config['START_TYPE']
|
||||
self.END_TYPE = graph_config['END_TYPE']
|
||||
self.hs = args.hs # hidden state size of each vertex
|
||||
self.nz = args.nz # size of latent representation z
|
||||
self.gs = args.hs # size of graph state
|
||||
self.bidir = True # whether to use bidirectional encoding
|
||||
self.vid = True
|
||||
self.device = None
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
if self.vid:
|
||||
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
|
||||
else:
|
||||
self.vs = self.hs
|
||||
|
||||
# 0. encoding-related
|
||||
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
|
||||
self.grue_backward = nn.GRUCell(self.nvt, self.hs) # backward encoder GRU
|
||||
self.enc_g_mu = nn.Linear(self.gs, self.nz) # latent mean
|
||||
self.enc_g_var = nn.Linear(self.gs, self.nz) # latent var
|
||||
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
|
||||
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
|
||||
|
||||
# 1. decoding-related
|
||||
self.grud = nn.GRUCell(self.nvt, self.hs) # decoder GRU
|
||||
self.fc3 = nn.Linear(self.nz, self.hs) # from latent z to initial hidden state h0
|
||||
self.add_vertex = nn.Sequential(
|
||||
nn.Linear(self.hs, self.hs * 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.hs * 2, self.nvt)
|
||||
) # which type of new vertex to add f(h0, hg)
|
||||
self.add_edge = nn.Sequential(
|
||||
nn.Linear(self.hs * 2, self.hs * 4),
|
||||
nn.ReLU(),
|
||||
nn.Linear(self.hs * 4, 1)
|
||||
) # whether to add edge between v_i and v_new, f(hvi, hnew)
|
||||
self.decoding_gate = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.decoding_mapper = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
) # disable bias to ensure padded zeros also mapped to zeros
|
||||
|
||||
# 2. gate-related
|
||||
self.gate_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.gate_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.mapper_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
) # disable bias to ensure padded zeros also mapped to zeros
|
||||
self.mapper_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
)
|
||||
|
||||
# 3. bidir-related, to unify sizes
|
||||
if self.bidir:
|
||||
self.hv_unify = nn.Sequential(
|
||||
nn.Linear(self.hs * 2, self.hs),
|
||||
)
|
||||
self.hg_unify = nn.Sequential(
|
||||
nn.Linear(self.gs * 2, self.gs),
|
||||
)
|
||||
|
||||
# 4. other
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.tanh = nn.Tanh()
|
||||
self.logsoftmax1 = nn.LogSoftmax(1)
|
||||
|
||||
# 6. predictor
|
||||
np = self.gs
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
def get_device(self):
|
||||
if self.device is None:
|
||||
self.device = next(self.parameters()).device
|
||||
return self.device
|
||||
|
||||
def _get_zeros(self, n, length):
|
||||
return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state
|
||||
|
||||
def _get_zero_hidden(self, n=1):
|
||||
return self._get_zeros(n, self.hs) # get a zero hidden state
|
||||
|
||||
def _one_hot(self, idx, length):
|
||||
if type(idx) in [list, range]:
|
||||
if idx == []:
|
||||
return None
|
||||
idx = torch.LongTensor(idx).unsqueeze(0).t()
|
||||
x = torch.zeros((len(idx), length)
|
||||
).scatter_(1, idx, 1).to(self.get_device())
|
||||
else:
|
||||
idx = torch.LongTensor([idx]).unsqueeze(0)
|
||||
x = torch.zeros((1, length)
|
||||
).scatter_(1, idx, 1).to(self.get_device())
|
||||
return x
|
||||
|
||||
def _gated(self, h, gate, mapper):
|
||||
return gate(h) * mapper(h)
|
||||
|
||||
def _collate_fn(self, G):
|
||||
return [g.copy() for g in G]
|
||||
|
||||
def _propagate_to(self, G, v, propagator,
|
||||
H=None, reverse=False, gate=None, mapper=None):
|
||||
# propagate messages to vertex index v for all graphs in G
|
||||
# return the new messages (states) at v
|
||||
G = [g for g in G if g.vcount() > v]
|
||||
if len(G) == 0:
|
||||
return
|
||||
if H is not None:
|
||||
idx = [i for i, g in enumerate(G) if g.vcount() > v]
|
||||
H = H[idx]
|
||||
v_types = [g.vs[v]['type'] for g in G]
|
||||
X = self._one_hot(v_types, self.nvt)
|
||||
H_name = 'H_forward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
|
||||
if reverse:
|
||||
H_name = 'H_backward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
|
||||
gate, mapper = self.gate_backward, self.mapper_backward
|
||||
else:
|
||||
H_name = 'H_forward' # name of the hidden states attribute
|
||||
H_pred = [
|
||||
[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [
|
||||
self._one_hot(g.predecessors(v), self.max_n) for g in G]
|
||||
if gate is None:
|
||||
gate, mapper = self.gate_forward, self.mapper_forward
|
||||
if self.vid:
|
||||
H_pred = [[torch.cat(
|
||||
[x[i], y[i:i + 1]], 1) for i in range(len(x))
|
||||
] for x, y in zip(H_pred, vids)]
|
||||
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
|
||||
if H is None:
|
||||
max_n_pred = max([len(x) for x in H_pred]) # maximum number of predecessors
|
||||
if max_n_pred == 0:
|
||||
H = self._get_zero_hidden(len(G))
|
||||
else:
|
||||
H_pred = [torch.cat(h_pred +
|
||||
[self._get_zeros(max_n_pred - len(h_pred),
|
||||
self.vs)], 0).unsqueeze(0)
|
||||
for h_pred in H_pred] # pad all to same length
|
||||
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
|
||||
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
|
||||
Hv = propagator(X, H)
|
||||
for i, g in enumerate(G):
|
||||
g.vs[v][H_name] = Hv[i:i + 1]
|
||||
return Hv
|
||||
|
||||
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
|
||||
# perform a series of propagation_to steps starting from v following a topo order
|
||||
# assume the original vertex indices are in a topological order
|
||||
if reverse:
|
||||
prop_order = range(v, -1, -1)
|
||||
else:
|
||||
prop_order = range(v, self.max_n)
|
||||
Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse) # the initial vertex
|
||||
for v_ in prop_order[1:]:
|
||||
self._propagate_to(G, v_, propagator, reverse=reverse)
|
||||
return Hv
|
||||
|
||||
def _update_v(self, G, v, H0=None):
|
||||
# perform a forward propagation step at v when decoding to update v's state
|
||||
# self._propagate_to(G, v, self.grud, H0, reverse=False)
|
||||
self._propagate_to(G, v, self.grud, H0,
|
||||
reverse=False, gate=self.decoding_gate,
|
||||
mapper=self.decoding_mapper)
|
||||
return
|
||||
|
||||
def _get_vertex_state(self, G, v):
|
||||
# get the vertex states at v
|
||||
Hv = []
|
||||
for g in G:
|
||||
if v >= g.vcount():
|
||||
hv = self._get_zero_hidden()
|
||||
else:
|
||||
hv = g.vs[v]['H_forward']
|
||||
Hv.append(hv)
|
||||
Hv = torch.cat(Hv, 0)
|
||||
return Hv
|
||||
|
||||
def _get_graph_state(self, G, decode=False):
|
||||
# get the graph states
|
||||
# when decoding, use the last generated vertex's state as the graph state
|
||||
# when encoding, use the ending vertex state or unify the starting and ending vertex states
|
||||
Hg = []
|
||||
for g in G:
|
||||
hg = g.vs[g.vcount() - 1]['H_forward']
|
||||
if self.bidir and not decode: # decoding never uses backward propagation
|
||||
hg_b = g.vs[0]['H_backward']
|
||||
hg = torch.cat([hg, hg_b], 1)
|
||||
Hg.append(hg)
|
||||
Hg = torch.cat(Hg, 0)
|
||||
if self.bidir and not decode:
|
||||
Hg = self.hg_unify(Hg)
|
||||
return Hg
|
||||
|
||||
def graph_encode(self, G):
|
||||
# encode graphs G into latent vectors
|
||||
if type(G) != list:
|
||||
G = [G]
|
||||
self._propagate_from(G, 0, self.grue_forward,
|
||||
H0=self._get_zero_hidden(len(G)), reverse=False)
|
||||
if self.bidir:
|
||||
self._propagate_from(G, self.max_n - 1, self.grue_backward,
|
||||
H0=self._get_zero_hidden(len(G)), reverse=True)
|
||||
Hg = self._get_graph_state(G)
|
||||
mu, logvar = self.enc_g_mu(Hg), self.enc_g_var(Hg)
|
||||
return mu, logvar
|
||||
|
||||
def set_encode(self, X):
|
||||
proto_batch = []
|
||||
for x in X: # X.shape: [32, 400, 512]
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
mu, logvar = self.fc1(v), self.fc2(v)
|
||||
return mu, logvar
|
||||
|
||||
def reparameterize(self, mu, logvar, eps_scale=0.01):
|
||||
# return z ~ N(mu, std)
|
||||
if self.training:
|
||||
std = logvar.mul(0.5).exp_()
|
||||
eps = torch.randn_like(std) * eps_scale
|
||||
return eps.mul(std).add_(mu)
|
||||
else:
|
||||
return mu
|
||||
|
||||
def _get_edge_score(self, Hvi, H, H0):
|
||||
# compute scores for edges from vi based on Hvi, H (current vertex) and H0
|
||||
# in most cases, H0 need not be explicitly included since Hvi and H contain its information
|
||||
return self.sigmoid(self.add_edge(torch.cat([Hvi, H], -1)))
|
||||
|
||||
def graph_decode(self, z, stochastic=True):
|
||||
# decode latent vectors z back to graphs
|
||||
# if stochastic=True, stochastically sample each action from the predicted distribution;
|
||||
# otherwise, select argmax action deterministically.
|
||||
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
|
||||
G = [igraph.Graph(directed=True) for _ in range(len(z))]
|
||||
for g in G:
|
||||
g.add_vertex(type=self.START_TYPE)
|
||||
self._update_v(G, 0, H0)
|
||||
finished = [False] * len(G)
|
||||
for idx in range(1, self.max_n):
|
||||
# decide the type of the next added vertex
|
||||
if idx == self.max_n - 1: # force the last node to be end_type
|
||||
new_types = [self.END_TYPE] * len(G)
|
||||
else:
|
||||
Hg = self._get_graph_state(G, decode=True)
|
||||
type_scores = self.add_vertex(Hg)
|
||||
if stochastic:
|
||||
type_probs = F.softmax(type_scores, 1
|
||||
).cpu().detach().numpy()
|
||||
new_types = [np.random.choice(range(self.nvt),
|
||||
p=type_probs[i]) for i in range(len(G))]
|
||||
else:
|
||||
new_types = torch.argmax(type_scores, 1)
|
||||
new_types = new_types.flatten().tolist()
|
||||
for i, g in enumerate(G):
|
||||
if not finished[i]:
|
||||
g.add_vertex(type=new_types[i])
|
||||
self._update_v(G, idx)
|
||||
|
||||
# decide connections
|
||||
edge_scores = []
|
||||
for vi in range(idx - 1, -1, -1):
|
||||
Hvi = self._get_vertex_state(G, vi)
|
||||
H = self._get_vertex_state(G, idx)
|
||||
ei_score = self._get_edge_score(Hvi, H, H0)
|
||||
if stochastic:
|
||||
random_score = torch.rand_like(ei_score)
|
||||
decisions = random_score < ei_score
|
||||
else:
|
||||
decisions = ei_score > 0.5
|
||||
for i, g in enumerate(G):
|
||||
if finished[i]:
|
||||
continue
|
||||
if new_types[i] == self.END_TYPE:
|
||||
# if new node is end_type, connect it to all loose-end vertices (out_degree==0)
|
||||
end_vertices = set([
|
||||
v.index for v in g.vs.select(_outdegree_eq=0)
|
||||
if v.index != g.vcount() - 1])
|
||||
for v in end_vertices:
|
||||
g.add_edge(v, g.vcount() - 1)
|
||||
finished[i] = True
|
||||
continue
|
||||
if decisions[i, 0]:
|
||||
g.add_edge(vi, g.vcount() - 1)
|
||||
self._update_v(G, idx)
|
||||
|
||||
for g in G:
|
||||
del g.vs['H_forward'] # delete hidden states to save GPU memory
|
||||
return G
|
||||
|
||||
def loss(self, mu, logvar, G_true, beta=0.005):
|
||||
# compute the loss of decoding mu and logvar to true graphs using teacher forcing
|
||||
# ensure when computing the loss of step i, steps 0 to i-1 are correct
|
||||
z = self.reparameterize(mu, logvar)
|
||||
H0 = self.tanh(self.fc3(z)) # or relu activation, similar performance
|
||||
G = [igraph.Graph(directed=True) for _ in range(len(z))]
|
||||
for g in G:
|
||||
g.add_vertex(type=self.START_TYPE)
|
||||
self._update_v(G, 0, H0)
|
||||
res = 0 # log likelihood
|
||||
for v_true in range(1, self.max_n):
|
||||
# calculate the likelihood of adding true types of nodes
|
||||
# use start type to denote padding vertices since start type only appears for vertex 0
|
||||
# and will never be a true type for later vertices, thus it's free to use
|
||||
true_types = [g_true.vs[v_true]['type']
|
||||
if v_true < g_true.vcount()
|
||||
else self.START_TYPE for g_true in G_true]
|
||||
Hg = self._get_graph_state(G, decode=True)
|
||||
type_scores = self.add_vertex(Hg)
|
||||
# vertex log likelihood
|
||||
vll = self.logsoftmax1(type_scores)[
|
||||
np.arange(len(G)), true_types].sum()
|
||||
res = res + vll
|
||||
for i, g in enumerate(G):
|
||||
if true_types[i] != self.START_TYPE:
|
||||
g.add_vertex(type=true_types[i])
|
||||
self._update_v(G, v_true)
|
||||
|
||||
# calculate the likelihood of adding true edges
|
||||
true_edges = []
|
||||
for i, g_true in enumerate(G_true):
|
||||
true_edges.append(g_true.get_adjlist(igraph.IN)[v_true]
|
||||
if v_true < g_true.vcount() else [])
|
||||
edge_scores = []
|
||||
for vi in range(v_true - 1, -1, -1):
|
||||
Hvi = self._get_vertex_state(G, vi)
|
||||
H = self._get_vertex_state(G, v_true)
|
||||
ei_score = self._get_edge_score(Hvi, H, H0)
|
||||
edge_scores.append(ei_score)
|
||||
for i, g in enumerate(G):
|
||||
if vi in true_edges[i]:
|
||||
g.add_edge(vi, v_true)
|
||||
self._update_v(G, v_true)
|
||||
edge_scores = torch.cat(edge_scores[::-1], 1)
|
||||
|
||||
ground_truth = torch.zeros_like(edge_scores)
|
||||
idx1 = [i for i, x in enumerate(true_edges)
|
||||
for _ in range(len(x))]
|
||||
idx2 = [xx for x in true_edges for xx in x]
|
||||
ground_truth[idx1, idx2] = 1.0
|
||||
|
||||
# edges log-likelihood
|
||||
ell = - F.binary_cross_entropy(
|
||||
edge_scores, ground_truth, reduction='sum')
|
||||
res = res + ell
|
||||
|
||||
res = -res # convert likelihood to loss
|
||||
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
|
||||
return res + beta * kld, res, kld
|
||||
@@ -0,0 +1,37 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
file_name = 'ckpt_120.pt'
|
||||
dir_path = 'results/generator/model'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/zss9yt034hen45h/ckpt_120.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
file_name = 'collected_database.pt'
|
||||
dir_path = 'data/generator/processed'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading generator {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/zgip4aq0w2pkj49/generator_collected_database.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
dir_path = 'data/pets'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
full_name = os.path.join(dir_path, 'test15.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file('https://www.dropbox.com/s/kzmrwyyk5iaugv0/test15.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
full_name = os.path.join(dir_path, 'train85.pth')
|
||||
if not os.path.exists(full_name):
|
||||
print(f"Downloading {full_name}\n")
|
||||
download_file('https://www.dropbox.com/s/w7mikpztkamnw9s/train85.pth?dl=1', full_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{full_name} has already been downloaded. Did not download twice.\n")
|
||||
@@ -0,0 +1,35 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
file_name = 'ckpt_max_corr.pt'
|
||||
dir_path = 'results/predictor/model'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
@@ -0,0 +1,38 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
|
||||
file_name = 'collected_database.pt'
|
||||
dir_path = 'data/predictor/processed'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
file_name = os.path.join(dir_path, file_name)
|
||||
if not os.path.exists(file_name):
|
||||
print(f"Downloading predictor {file_name}\n")
|
||||
download_file('https://www.dropbox.com/s/ycm4jaojgswp0zm/ckpt_max_corr.pt?dl=1', file_name)
|
||||
print("Downloading done.\n")
|
||||
else:
|
||||
print(f"{file_name} has already been downloaded. Did not download twice.\n")
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import requests
|
||||
import zipfile
|
||||
|
||||
def download_file(url, filename):
|
||||
"""
|
||||
Helper method handling downloading large files from `url`
|
||||
to `filename`. Returns a pointer to `filename`.
|
||||
"""
|
||||
chunkSize = 1024
|
||||
r = requests.get(url, stream=True)
|
||||
with open(filename, 'wb') as f:
|
||||
pbar = tqdm( unit="B", total=int( r.headers['Content-Length'] ) )
|
||||
for chunk in r.iter_content(chunk_size=chunkSize):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
pbar.update (len(chunk))
|
||||
f.write(chunk)
|
||||
return filename
|
||||
|
||||
dir_path = 'data'
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
def get_preprocessed_data(file_name, url):
|
||||
print(f"Downloading {file_name} datasets\n")
|
||||
full_name = os.path.join(dir_path, file_name)
|
||||
download_file(url, full_name)
|
||||
print("Downloading done.\n")
|
||||
|
||||
|
||||
for file_name, url in [
|
||||
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
|
||||
('aircraft100bylabel.pt', 'https://www.dropbox.com/s/nn6mlrk1jijg108/aircraft100bylabel.pt?dl=1'),
|
||||
('cifar100bylabel.pt', 'https://www.dropbox.com/s/y0xahxgzj29kffk/cifar100bylabel.pt?dl=1'),
|
||||
('cifar10bylabel.pt', 'https://www.dropbox.com/s/wt1pcwi991xyhwr/cifar10bylabel.pt?dl=1'),
|
||||
('imgnet32bylabel.pt', 'https://www.dropbox.com/s/7r3hpugql8qgi9d/imgnet32bylabel.pt?dl=1'),
|
||||
('petsbylabel.pt', 'https://www.dropbox.com/s/mxh6qz3grhy7wcn/petsbylabel.pt?dl=1'),
|
||||
('mnistbylabel.pt', 'https://www.dropbox.com/s/86rbuic7a7y34e4/mnistbylabel.pt?dl=1'),
|
||||
('svhnbylabel.pt', 'https://www.dropbox.com/s/yywaelhrsl6egvd/svhnbylabel.pt?dl=1')
|
||||
]:
|
||||
|
||||
get_preprocessed_data(file_name, url)
|
||||
@@ -0,0 +1,149 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def get_meta_train_loader(batch_size, data_path, num_sample, is_pred=False):
|
||||
dataset = MetaTrainDatabase(data_path, num_sample, is_pred)
|
||||
print(f'==> The number of tasks for meta-training: {len(dataset)}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
collate_fn=collate_fn)
|
||||
return loader
|
||||
|
||||
|
||||
def get_meta_test_loader(data_path, data_name, num_class=None, is_pred=False):
|
||||
dataset = MetaTestDataset(data_path, data_name, num_class)
|
||||
print(f'==> Meta-Test dataset {data_name}')
|
||||
|
||||
loader = DataLoader(dataset=dataset,
|
||||
batch_size=100,
|
||||
shuffle=False,
|
||||
num_workers=1)
|
||||
return loader
|
||||
|
||||
|
||||
class MetaTrainDatabase(Dataset):
|
||||
def __init__(self, data_path, num_sample, is_pred=False):
|
||||
self.mode = 'train'
|
||||
self.acc_norm = True
|
||||
self.num_sample = num_sample
|
||||
self.x = torch.load(os.path.join(data_path, 'imgnet32bylabel.pt'))
|
||||
|
||||
self.dpath = '{}/{}/processed/'.format(data_path, 'predictor' if is_pred else 'generator')
|
||||
self.dname = f'database_219152_14.0K'
|
||||
|
||||
if not os.path.exists(self.dpath + f'{self.dname}_train.pt'):
|
||||
raise ValueError('')
|
||||
database = torch.load(self.dpath + f'{self.dname}.pt')
|
||||
|
||||
rand_idx = torch.randperm(len(database))
|
||||
test_len = int(len(database) * 0.15)
|
||||
idxlst = {'test': rand_idx[:test_len],
|
||||
'valid': rand_idx[test_len:2 * test_len],
|
||||
'train': rand_idx[2 * test_len:]}
|
||||
|
||||
for m in ['train', 'valid', 'test']:
|
||||
acc, graph, cls, net, flops = [], [], [], [], []
|
||||
for idx in tqdm(idxlst[m].tolist(), desc=f'data-{m}'):
|
||||
acc.append(database[idx]['top1'])
|
||||
net.append(database[idx]['net'])
|
||||
cls.append(database[idx]['class'])
|
||||
flops.append(database[idx]['flops'])
|
||||
if m == 'train':
|
||||
mean = torch.mean(torch.tensor(acc)).item()
|
||||
std = torch.std(torch.tensor(acc)).item()
|
||||
torch.save({'acc': acc,
|
||||
'class': cls,
|
||||
'net': net,
|
||||
'flops': flops,
|
||||
'mean': mean,
|
||||
'std': std},
|
||||
self.dpath + f'{self.dname}_{m}.pt')
|
||||
|
||||
self.set_mode(self.mode)
|
||||
|
||||
def set_mode(self, mode):
|
||||
self.mode = mode
|
||||
data = torch.load(self.dpath + f'{self.dname}_{self.mode}.pt')
|
||||
self.acc = data['acc']
|
||||
self.cls = data['class']
|
||||
self.net = data['net']
|
||||
self.flops = data['flops']
|
||||
self.mean = data['mean']
|
||||
self.std = data['std']
|
||||
|
||||
def __len__(self):
|
||||
return len(self.acc)
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = self.cls[index]
|
||||
acc = self.acc[index]
|
||||
graph = self.net[index]
|
||||
|
||||
for i, cls in enumerate(classes):
|
||||
cx = self.x[cls.item()][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
if self.acc_norm:
|
||||
acc = ((acc - self.mean) / self.std) / 100.0
|
||||
else:
|
||||
acc = acc / 100.0
|
||||
return x, graph, torch.tensor(acc).view(1, 1)
|
||||
|
||||
|
||||
class MetaTestDataset(Dataset):
|
||||
def __init__(self, data_path, data_name, num_sample, num_class=None):
|
||||
self.num_sample = num_sample
|
||||
self.data_name = data_name
|
||||
if data_name == 'aircraft':
|
||||
data_name = 'aircraft100'
|
||||
num_class_dict = {
|
||||
'cifar100': 100,
|
||||
'cifar10': 10,
|
||||
'mnist': 10,
|
||||
'aircraft100': 30,
|
||||
'svhn': 10,
|
||||
'pets': 37
|
||||
}
|
||||
# 'aircraft30': 30,
|
||||
# 'aircraft100': 100,
|
||||
|
||||
if num_class is not None:
|
||||
self.num_class = num_class
|
||||
else:
|
||||
self.num_class = num_class_dict[data_name]
|
||||
|
||||
self.x = torch.load(os.path.join(data_path, f'{data_name}bylabel.pt'))
|
||||
|
||||
def __len__(self):
|
||||
return 1000000
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
classes = list(range(self.num_class))
|
||||
for cls in classes:
|
||||
cx = self.x[cls][0]
|
||||
ridx = torch.randperm(len(cx))
|
||||
data.append(cx[ridx[:self.num_sample]])
|
||||
x = torch.cat(data)
|
||||
return x
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
# x = torch.stack([item[0] for item in batch])
|
||||
# graph = [item[1] for item in batch]
|
||||
# acc = torch.stack([item[2] for item in batch])
|
||||
return batch
|
||||
@@ -0,0 +1,48 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from parser import get_parser
|
||||
from generator import Generator
|
||||
from predictor import Predictor
|
||||
|
||||
def main():
|
||||
args = get_parser()
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
||||
args.device = torch.device("cuda:0")
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
random.seed(args.seed)
|
||||
|
||||
if not os.path.exists(args.save_path):
|
||||
os.makedirs(args.save_path)
|
||||
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
|
||||
if not os.path.exists(args.model_path):
|
||||
os.makedirs(args.model_path)
|
||||
|
||||
if args.model_name == 'generator':
|
||||
g = Generator(args)
|
||||
if args.test:
|
||||
args.model_path = os.path.join(args.save_path, 'predictor', 'model')
|
||||
hs = args.hs
|
||||
args.hs = 512
|
||||
p = Predictor(args)
|
||||
args.model_path = os.path.join(args.save_path, args.model_name, 'model')
|
||||
args.hs = hs
|
||||
g.meta_test(p)
|
||||
else:
|
||||
g.meta_train()
|
||||
elif args.model_name == 'predictor':
|
||||
p = Predictor(args)
|
||||
p.meta_train()
|
||||
else:
|
||||
raise ValueError('You should select generator|predictor|train_arch')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@@ -0,0 +1,344 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import igraph
|
||||
import random
|
||||
import numpy as np
|
||||
import scipy.stats
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
|
||||
def load_graph_config(graph_data_name, nvt, data_path):
|
||||
max_n=20
|
||||
graph_config = {}
|
||||
graph_config['num_vertex_type'] = nvt + 2 # original types + start/end types
|
||||
graph_config['max_n'] = max_n + 2 # maximum number of nodes
|
||||
graph_config['START_TYPE'] = 0 # predefined start vertex type
|
||||
graph_config['END_TYPE'] = 1 # predefined end vertex type
|
||||
|
||||
return graph_config
|
||||
|
||||
|
||||
type_dict = {'2-3-3': 0, '2-3-4': 1, '2-3-6': 2,
|
||||
'2-5-3': 3, '2-5-4': 4, '2-5-6': 5,
|
||||
'2-7-3': 6, '2-7-4': 7, '2-7-6': 8,
|
||||
'3-3-3': 9, '3-3-4': 10, '3-3-6': 11,
|
||||
'3-5-3': 12, '3-5-4': 13, '3-5-6': 14,
|
||||
'3-7-3': 15, '3-7-4': 16, '3-7-6': 17,
|
||||
'4-3-3': 18, '4-3-4': 19, '4-3-6': 20,
|
||||
'4-5-3': 21, '4-5-4': 22, '4-5-6': 23,
|
||||
'4-7-3': 24, '4-7-4': 25, '4-7-6': 26}
|
||||
|
||||
edge_dict = {2: (2, 3, 3), 3: (2, 3, 4), 4: (2, 3, 6),
|
||||
5: (2, 5, 3), 6: (2, 5, 4), 7: (2, 5, 6),
|
||||
8: (2, 7, 3), 9: (2, 7, 4), 10: (2, 7, 6),
|
||||
11: (3, 3, 3), 12: (3, 3, 4), 13: (3, 3, 6),
|
||||
14: (3, 5, 3), 15: (3, 5, 4), 16: (3, 5, 6),
|
||||
17: (3, 7, 3), 18: (3, 7, 4), 19: (3, 7, 6),
|
||||
20: (4, 3, 3), 21: (4, 3, 4), 22: (4, 3, 6),
|
||||
23: (4, 5, 3), 24: (4, 5, 4), 25: (4, 5, 6),
|
||||
26: (4, 7, 3), 27: (4, 7, 4), 28: (4, 7, 6)}
|
||||
|
||||
|
||||
def decode_ofa_mbv3_to_igraph(matrix):
|
||||
# 5 stages, 4 layers for each stage
|
||||
# d: 2, 3, 4
|
||||
# e: 3, 4, 6
|
||||
# k: 3, 5, 7
|
||||
|
||||
# stage_depth to one hot
|
||||
num_stage = 5
|
||||
num_layer = 4
|
||||
|
||||
node_types = torch.zeros(num_stage * num_layer)
|
||||
|
||||
d = []
|
||||
for i in range(num_stage):
|
||||
for j in range(num_layer):
|
||||
d.append(matrix['d'][i])
|
||||
for i, (ks, e, d) in enumerate(zip(
|
||||
matrix['ks'], matrix['e'], d)):
|
||||
node_types[i] = type_dict[f'{d}-{ks}-{e}']
|
||||
|
||||
n = num_stage * num_layer
|
||||
g = igraph.Graph(directed=True)
|
||||
g.add_vertices(n + 2) # + in/out nodes
|
||||
g.vs[0]['type'] = 0
|
||||
for i, v in enumerate(node_types):
|
||||
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
|
||||
g.add_edge(i, i + 1)
|
||||
g.vs[n + 1]['type'] = 1
|
||||
g.add_edge(n, n + 1)
|
||||
return g, n + 2
|
||||
|
||||
|
||||
def decode_ofa_mbv3_str_to_igraph(gen_str):
|
||||
# 5 stages, 4 layers for each stage
|
||||
# d: 2, 3, 4
|
||||
# e: 3, 4, 6
|
||||
# k: 3, 5, 7
|
||||
|
||||
# stage_depth to one hot
|
||||
num_stage = 5
|
||||
num_layer = 4
|
||||
|
||||
node_types = torch.zeros(num_stage * num_layer)
|
||||
|
||||
d = []
|
||||
split_str = gen_str.split('_')
|
||||
for i, s in enumerate(split_str):
|
||||
if s == '0-0-0':
|
||||
node_types[i] = random.randint(0, 26)
|
||||
else:
|
||||
node_types[i] = type_dict[s]
|
||||
|
||||
n = num_stage * num_layer
|
||||
g = igraph.Graph(directed=True)
|
||||
g.add_vertices(n + 2) # + in/out nodes
|
||||
g.vs[0]['type'] = 0
|
||||
for i, v in enumerate(node_types):
|
||||
g.vs[i + 1]['type'] = v + 2 # in node: 0, out node: 1
|
||||
g.add_edge(i, i + 1)
|
||||
g.vs[n + 1]['type'] = 1
|
||||
g.add_edge(n, n + 1)
|
||||
return g
|
||||
|
||||
|
||||
def is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
|
||||
# first need to be a valid DAG computation graph
|
||||
msg = ''
|
||||
res = is_valid_DAG(g, START_TYPE, END_TYPE)
|
||||
# in addition, node i must connect to node i+1
|
||||
res = res and len(g.vs['type']) == 22
|
||||
if not res:
|
||||
return res
|
||||
msg += '{} ({}) '.format(g.vs['type'][1:-1], len(g.vs['type']))
|
||||
|
||||
for i in range(5):
|
||||
if ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 0:
|
||||
for j in range(1, 4):
|
||||
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 0
|
||||
|
||||
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 1:
|
||||
for j in range(1, 4):
|
||||
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 1
|
||||
|
||||
elif ((g.vs['type'][1:-1][i * 4]) - 2) // 9 == 2:
|
||||
for j in range(1, 4):
|
||||
res = res and ((g.vs['type'][1:-1][i * 4 + j]) - 2) // 9 == 2
|
||||
else:
|
||||
raise ValueError
|
||||
return res
|
||||
|
||||
|
||||
def is_valid_DAG(g, START_TYPE=0, END_TYPE=1):
|
||||
res = g.is_dag()
|
||||
n_start, n_end = 0, 0
|
||||
for v in g.vs:
|
||||
if v['type'] == START_TYPE:
|
||||
n_start += 1
|
||||
elif v['type'] == END_TYPE:
|
||||
n_end += 1
|
||||
if v.indegree() == 0 and v['type'] != START_TYPE:
|
||||
return False
|
||||
if v.outdegree() == 0 and v['type'] != END_TYPE:
|
||||
return False
|
||||
return res and n_start == 1 and n_end == 1
|
||||
|
||||
|
||||
def decode_igraph_to_ofa_mbv3(g):
|
||||
if not is_valid_ofa_mbv3(g, START_TYPE=0, END_TYPE=1):
|
||||
return None
|
||||
|
||||
graph = {'ks': [], 'e': [], 'd': [4, 4, 4, 4, 4]}
|
||||
for i, edge_type in enumerate(g.vs['type'][1:-1]):
|
||||
edge_type = int(edge_type)
|
||||
d, ks, e = edge_dict[edge_type]
|
||||
graph['ks'].append(ks)
|
||||
graph['e'].append(e)
|
||||
graph['d'][i // 4] = d
|
||||
return graph
|
||||
|
||||
|
||||
class Accumulator():
|
||||
def __init__(self, *args):
|
||||
self.args = args
|
||||
self.argdict = {}
|
||||
for i, arg in enumerate(args):
|
||||
self.argdict[arg] = i
|
||||
self.sums = [0] * len(args)
|
||||
self.cnt = 0
|
||||
|
||||
def accum(self, val):
|
||||
val = [val] if type(val) is not list else val
|
||||
val = [v for v in val if v is not None]
|
||||
assert (len(val) == len(self.args))
|
||||
for i in range(len(val)):
|
||||
if torch.is_tensor(val[i]):
|
||||
val[i] = val[i].item()
|
||||
self.sums[i] += val[i]
|
||||
self.cnt += 1
|
||||
|
||||
def clear(self):
|
||||
self.sums = [0] * len(self.args)
|
||||
self.cnt = 0
|
||||
|
||||
def get(self, arg, avg=True):
|
||||
i = self.argdict.get(arg, -1)
|
||||
assert (i is not -1)
|
||||
if avg:
|
||||
return self.sums[i] / (self.cnt + 1e-8)
|
||||
else:
|
||||
return self.sums[i]
|
||||
|
||||
def print_(self, header=None, time=None,
|
||||
logfile=None, do_not_print=[], as_int=[],
|
||||
avg=True):
|
||||
msg = '' if header is None else header + ': '
|
||||
if time is not None:
|
||||
msg += ('(%.3f secs), ' % time)
|
||||
|
||||
args = [arg for arg in self.args if arg not in do_not_print]
|
||||
arg = []
|
||||
for arg in args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
if arg in as_int:
|
||||
msg += ('%s %d, ' % (arg, int(val)))
|
||||
else:
|
||||
msg += ('%s %.4f, ' % (arg, val))
|
||||
print(msg)
|
||||
|
||||
if logfile is not None:
|
||||
logfile.write(msg + '\n')
|
||||
logfile.flush()
|
||||
|
||||
def add_scalars(self, summary, header=None, tag_scalar=None,
|
||||
step=None, avg=True, args=None):
|
||||
for arg in self.args:
|
||||
val = self.sums[self.argdict[arg]]
|
||||
if avg:
|
||||
val /= (self.cnt + 1e-8)
|
||||
else:
|
||||
val = val
|
||||
tag = f'{header}/{arg}' if header is not None else arg
|
||||
if tag_scalar is not None:
|
||||
summary.add_scalars(main_tag=tag,
|
||||
tag_scalar_dict={tag_scalar: val},
|
||||
global_step=step)
|
||||
else:
|
||||
summary.add_scalar(tag=tag,
|
||||
scalar_value=val,
|
||||
global_step=step)
|
||||
|
||||
|
||||
class Log:
|
||||
def __init__(self, args, logf, summary=None):
|
||||
self.args = args
|
||||
self.logf = logf
|
||||
self.summary = summary
|
||||
self.stime = time.time()
|
||||
self.ep_sttime = None
|
||||
|
||||
def print(self, logger, epoch, tag=None, avg=True):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
print(msg)
|
||||
self.logf.write(msg + '\n')
|
||||
logger.print_(header=tag, logfile=self.logf, avg=avg)
|
||||
|
||||
if self.summary is not None:
|
||||
logger.add_scalars(
|
||||
self.summary, header=tag, step=epoch, avg=avg)
|
||||
logger.clear()
|
||||
|
||||
def print_args(self):
|
||||
argdict = vars(self.args)
|
||||
print(argdict)
|
||||
for k, v in argdict.items():
|
||||
self.logf.write(k + ': ' + str(v) + '\n')
|
||||
self.logf.write('\n')
|
||||
|
||||
def set_time(self):
|
||||
self.stime = time.time()
|
||||
|
||||
def save_time_log(self):
|
||||
ct = time.time() - self.stime
|
||||
msg = f'({ct:6.2f}s) meta-training phase done'
|
||||
print(msg)
|
||||
self.logf.write(msg + '\n')
|
||||
|
||||
def print_pred_log(self, loss, corr, tag, epoch=None, max_corr_dict=None):
|
||||
if tag == 'train':
|
||||
ct = time.time() - self.ep_sttime
|
||||
tt = time.time() - self.stime
|
||||
msg = f'[total {tt:6.2f}s (ep {ct:6.2f}s)] epoch {epoch:3d}'
|
||||
self.logf.write(msg + '\n');
|
||||
print(msg);
|
||||
self.logf.flush()
|
||||
# msg = f'ep {epoch:3d} ep time {time.time() - ep_sttime:8.2f} '
|
||||
# msg += f'time {time.time() - sttime:6.2f} '
|
||||
if max_corr_dict is not None:
|
||||
max_corr = max_corr_dict['corr']
|
||||
max_loss = max_corr_dict['loss']
|
||||
msg = f'{tag}: loss {loss:.6f} ({max_loss:.6f}) '
|
||||
msg += f'corr {corr:.4f} ({max_corr:.4f})'
|
||||
else:
|
||||
msg = f'{tag}: loss {loss:.6f} corr {corr:.4f}'
|
||||
self.logf.write(msg + '\n');
|
||||
print(msg);
|
||||
self.logf.flush()
|
||||
|
||||
def max_corr_log(self, max_corr_dict):
|
||||
corr = max_corr_dict['corr']
|
||||
loss = max_corr_dict['loss']
|
||||
epoch = max_corr_dict['epoch']
|
||||
msg = f'[epoch {epoch}] max correlation: {corr:.4f}, loss: {loss:.6f}'
|
||||
self.logf.write(msg + '\n');
|
||||
print(msg);
|
||||
self.logf.flush()
|
||||
|
||||
|
||||
def get_log(epoch, loss, y_pred, y, acc_std, acc_mean, tag='train'):
|
||||
msg = f'[{tag}] Ep {epoch} loss {loss.item() / len(y):0.4f} '
|
||||
msg += f'pacc {y_pred[0]:0.4f}'
|
||||
msg += f'({y_pred[0] * 100.0 * acc_std + acc_mean:0.4f}) '
|
||||
msg += f'acc {y[0]:0.4f}({y[0] * 100 * acc_std + acc_mean:0.4f})'
|
||||
return msg
|
||||
|
||||
|
||||
def load_model(model, model_path, load_epoch=None, load_max_pt=None):
|
||||
if load_max_pt is not None:
|
||||
ckpt_path = os.path.join(model_path, load_max_pt)
|
||||
else:
|
||||
ckpt_path = os.path.join(model_path, f'ckpt_{load_epoch}.pt')
|
||||
|
||||
print(f"==> load checkpoint for MetaD2A predictor: {ckpt_path} ...")
|
||||
model.cpu()
|
||||
model.load_state_dict(torch.load(ckpt_path))
|
||||
|
||||
|
||||
def save_model(epoch, model, model_path, max_corr=None):
|
||||
print("==> save current model...")
|
||||
if max_corr is not None:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, 'ckpt_max_corr.pt'))
|
||||
else:
|
||||
torch.save(model.cpu().state_dict(),
|
||||
os.path.join(model_path, f'ckpt_{epoch}.pt'))
|
||||
|
||||
|
||||
def mean_confidence_interval(data, confidence=0.95):
|
||||
a = 1.0 * np.array(data)
|
||||
n = len(a)
|
||||
m, se = np.mean(a), scipy.stats.sem(a)
|
||||
h = se * scipy.stats.t.ppf((1 + confidence) / 2., n - 1)
|
||||
return m, h
|
||||
@@ -0,0 +1,5 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .imagenet import *
|
||||
@@ -0,0 +1,56 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
__all__ = ['DataProvider']
|
||||
|
||||
|
||||
class DataProvider:
|
||||
SUB_SEED = 937162211 # random seed for sampling subset
|
||||
VALID_SEED = 2147483647 # random seed for the validation set
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
""" Return name of the dataset """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
""" Return shape as python list of one data entry """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
""" Return `int` of num classes """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
""" local path to save the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
""" link to download the data """
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def random_sample_valid_set(train_size, valid_size):
|
||||
assert train_size > valid_size
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.VALID_SEED) # set random seed before sampling validation set
|
||||
rand_indexes = torch.randperm(train_size, generator=g).tolist()
|
||||
|
||||
valid_indexes = rand_indexes[:valid_size]
|
||||
train_indexes = rand_indexes[valid_size:]
|
||||
return train_indexes, valid_indexes
|
||||
|
||||
@staticmethod
|
||||
def labels_to_one_hot(n_classes, labels):
|
||||
new_labels = np.zeros((labels.shape[0], n_classes), dtype=np.float32)
|
||||
new_labels[range(labels.shape[0]), labels] = np.ones(labels.shape)
|
||||
return new_labels
|
||||
@@ -0,0 +1,225 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import warnings
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
import torch.utils.data
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
|
||||
from .base_provider import DataProvider
|
||||
from ofa_local.utils.my_dataloader import MyRandomResizedCrop, MyDistributedSampler
|
||||
|
||||
__all__ = ['ImagenetDataProvider']
|
||||
|
||||
|
||||
class ImagenetDataProvider(DataProvider):
|
||||
DEFAULT_PATH = '/dataset/imagenet'
|
||||
|
||||
def __init__(self, save_path=None, train_batch_size=256, test_batch_size=512, valid_size=None, n_worker=32,
|
||||
resize_scale=0.08, distort_color=None, image_size=224,
|
||||
num_replicas=None, rank=None):
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
self._save_path = save_path
|
||||
|
||||
self.image_size = image_size # int or list of int
|
||||
self.distort_color = 'None' if distort_color is None else distort_color
|
||||
self.resize_scale = resize_scale
|
||||
|
||||
self._valid_transform_dict = {}
|
||||
if not isinstance(self.image_size, int):
|
||||
from ofa.utils.my_dataloader import MyDataLoader
|
||||
assert isinstance(self.image_size, list)
|
||||
self.image_size.sort() # e.g., 160 -> 224
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST = self.image_size.copy()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = max(self.image_size)
|
||||
|
||||
for img_size in self.image_size:
|
||||
self._valid_transform_dict[img_size] = self.build_valid_transform(img_size)
|
||||
self.active_img_size = max(self.image_size) # active resolution for test
|
||||
valid_transforms = self._valid_transform_dict[self.active_img_size]
|
||||
train_loader_class = MyDataLoader # randomly sample image size for each batch of training image
|
||||
else:
|
||||
self.active_img_size = self.image_size
|
||||
valid_transforms = self.build_valid_transform()
|
||||
train_loader_class = torch.utils.data.DataLoader
|
||||
|
||||
train_dataset = self.train_dataset(self.build_train_transform())
|
||||
|
||||
if valid_size is not None:
|
||||
if not isinstance(valid_size, int):
|
||||
assert isinstance(valid_size, float) and 0 < valid_size < 1
|
||||
valid_size = int(len(train_dataset) * valid_size)
|
||||
|
||||
valid_dataset = self.train_dataset(valid_transforms)
|
||||
train_indexes, valid_indexes = self.random_sample_valid_set(len(train_dataset), valid_size)
|
||||
|
||||
if num_replicas is not None:
|
||||
train_sampler = MyDistributedSampler(train_dataset, num_replicas, rank, True, np.array(train_indexes))
|
||||
valid_sampler = MyDistributedSampler(valid_dataset, num_replicas, rank, True, np.array(valid_indexes))
|
||||
else:
|
||||
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(train_indexes)
|
||||
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(valid_indexes)
|
||||
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = torch.utils.data.DataLoader(
|
||||
valid_dataset, batch_size=test_batch_size, sampler=valid_sampler,
|
||||
num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
if num_replicas is not None:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas, rank)
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, sampler=train_sampler,
|
||||
num_workers=n_worker, pin_memory=True
|
||||
)
|
||||
else:
|
||||
self.train = train_loader_class(
|
||||
train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
self.valid = None
|
||||
|
||||
test_dataset = self.test_dataset(valid_transforms)
|
||||
if num_replicas is not None:
|
||||
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, num_replicas, rank)
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, sampler=test_sampler, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.test = torch.utils.data.DataLoader(
|
||||
test_dataset, batch_size=test_batch_size, shuffle=True, num_workers=n_worker, pin_memory=True,
|
||||
)
|
||||
|
||||
if self.valid is None:
|
||||
self.valid = self.test
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'imagenet'
|
||||
|
||||
@property
|
||||
def data_shape(self):
|
||||
return 3, self.active_img_size, self.active_img_size # C, H, W
|
||||
|
||||
@property
|
||||
def n_classes(self):
|
||||
return 1000
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self._save_path is None:
|
||||
self._save_path = self.DEFAULT_PATH
|
||||
if not os.path.exists(self._save_path):
|
||||
self._save_path = os.path.expanduser('~/dataset/imagenet')
|
||||
return self._save_path
|
||||
|
||||
@property
|
||||
def data_url(self):
|
||||
raise ValueError('unable to download %s' % self.name())
|
||||
|
||||
def train_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.train_path, _transforms)
|
||||
|
||||
def test_dataset(self, _transforms):
|
||||
return datasets.ImageFolder(self.valid_path, _transforms)
|
||||
|
||||
@property
|
||||
def train_path(self):
|
||||
return os.path.join(self.save_path, 'train')
|
||||
|
||||
@property
|
||||
def valid_path(self):
|
||||
return os.path.join(self.save_path, 'val')
|
||||
|
||||
@property
|
||||
def normalize(self):
|
||||
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def build_train_transform(self, image_size=None, print_log=True):
|
||||
if image_size is None:
|
||||
image_size = self.image_size
|
||||
if print_log:
|
||||
print('Color jitter: %s, resize_scale: %s, img_size: %s' %
|
||||
(self.distort_color, self.resize_scale, image_size))
|
||||
|
||||
if isinstance(image_size, list):
|
||||
resize_transform_class = MyRandomResizedCrop
|
||||
print('Use MyRandomResizedCrop: %s, \t %s' % MyRandomResizedCrop.get_candidate_image_size(),
|
||||
'sync=%s, continuous=%s' % (MyRandomResizedCrop.SYNC_DISTRIBUTED, MyRandomResizedCrop.CONTINUOUS))
|
||||
else:
|
||||
resize_transform_class = transforms.RandomResizedCrop
|
||||
|
||||
# random_resize_crop -> random_horizontal_flip
|
||||
train_transforms = [
|
||||
resize_transform_class(image_size, scale=(self.resize_scale, 1.0)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
]
|
||||
|
||||
# color augmentation (optional)
|
||||
color_transform = None
|
||||
if self.distort_color == 'torch':
|
||||
color_transform = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
|
||||
elif self.distort_color == 'tf':
|
||||
color_transform = transforms.ColorJitter(brightness=32. / 255., saturation=0.5)
|
||||
if color_transform is not None:
|
||||
train_transforms.append(color_transform)
|
||||
|
||||
train_transforms += [
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
|
||||
train_transforms = transforms.Compose(train_transforms)
|
||||
return train_transforms
|
||||
|
||||
def build_valid_transform(self, image_size=None):
|
||||
if image_size is None:
|
||||
image_size = self.active_img_size
|
||||
return transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
])
|
||||
|
||||
def assign_active_img_size(self, new_img_size):
|
||||
self.active_img_size = new_img_size
|
||||
if self.active_img_size not in self._valid_transform_dict:
|
||||
self._valid_transform_dict[self.active_img_size] = self.build_valid_transform()
|
||||
# change the transform of the valid and test set
|
||||
self.valid.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
self.test.dataset.transform = self._valid_transform_dict[self.active_img_size]
|
||||
|
||||
def build_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
# used for resetting BN running statistics
|
||||
if self.__dict__.get('sub_train_%d' % self.active_img_size, None) is None:
|
||||
if num_worker is None:
|
||||
num_worker = self.train.num_workers
|
||||
|
||||
n_samples = len(self.train.dataset)
|
||||
g = torch.Generator()
|
||||
g.manual_seed(DataProvider.SUB_SEED)
|
||||
rand_indexes = torch.randperm(n_samples, generator=g).tolist()
|
||||
|
||||
new_train_dataset = self.train_dataset(
|
||||
self.build_train_transform(image_size=self.active_img_size, print_log=False))
|
||||
chosen_indexes = rand_indexes[:n_images]
|
||||
if num_replicas is not None:
|
||||
sub_sampler = MyDistributedSampler(new_train_dataset, num_replicas, rank, True, np.array(chosen_indexes))
|
||||
else:
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
sub_data_loader = torch.utils.data.DataLoader(
|
||||
new_train_dataset, batch_size=batch_size, sampler=sub_sampler,
|
||||
num_workers=num_worker, pin_memory=True,
|
||||
)
|
||||
self.__dict__['sub_train_%d' % self.active_img_size] = []
|
||||
for images, labels in sub_data_loader:
|
||||
self.__dict__['sub_train_%d' % self.active_img_size].append((images, labels))
|
||||
return self.__dict__['sub_train_%d' % self.active_img_size]
|
||||
@@ -0,0 +1,6 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .dynamic_layers import *
|
||||
from .dynamic_op import *
|
||||
@@ -0,0 +1,632 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict
|
||||
|
||||
from ofa_local.utils.layers import MBConvLayer, ConvLayer, IdentityLayer, set_layer_from_config
|
||||
from ofa_local.utils.layers import ResNetBottleneckBlock, LinearLayer
|
||||
from ofa_local.utils import MyModule, val2list, get_net_device, build_activation, make_divisible, SEModule, MyNetwork
|
||||
from .dynamic_op import DynamicSeparableConv2d, DynamicConv2d, DynamicBatchNorm2d, DynamicSE, DynamicGroupNorm
|
||||
from .dynamic_op import DynamicLinear
|
||||
|
||||
__all__ = [
|
||||
'adjust_bn_according_to_idx', 'copy_bn',
|
||||
'DynamicMBConvLayer', 'DynamicConvLayer', 'DynamicLinearLayer', 'DynamicResNetBottleneckBlock'
|
||||
]
|
||||
|
||||
|
||||
def adjust_bn_according_to_idx(bn, idx):
|
||||
bn.weight.data = torch.index_select(bn.weight.data, 0, idx)
|
||||
bn.bias.data = torch.index_select(bn.bias.data, 0, idx)
|
||||
if type(bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
|
||||
bn.running_mean.data = torch.index_select(bn.running_mean.data, 0, idx)
|
||||
bn.running_var.data = torch.index_select(bn.running_var.data, 0, idx)
|
||||
|
||||
|
||||
def copy_bn(target_bn, src_bn):
|
||||
feature_dim = target_bn.num_channels if isinstance(target_bn, nn.GroupNorm) else target_bn.num_features
|
||||
|
||||
target_bn.weight.data.copy_(src_bn.weight.data[:feature_dim])
|
||||
target_bn.bias.data.copy_(src_bn.bias.data[:feature_dim])
|
||||
if type(src_bn) in [nn.BatchNorm1d, nn.BatchNorm2d]:
|
||||
target_bn.running_mean.data.copy_(src_bn.running_mean.data[:feature_dim])
|
||||
target_bn.running_var.data.copy_(src_bn.running_var.data[:feature_dim])
|
||||
|
||||
|
||||
class DynamicLinearLayer(MyModule):
|
||||
|
||||
def __init__(self, in_features_list, out_features, bias=True, dropout_rate=0):
|
||||
super(DynamicLinearLayer, self).__init__()
|
||||
|
||||
self.in_features_list = in_features_list
|
||||
self.out_features = out_features
|
||||
self.bias = bias
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
if self.dropout_rate > 0:
|
||||
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
|
||||
else:
|
||||
self.dropout = None
|
||||
self.linear = DynamicLinear(
|
||||
max_in_features=max(self.in_features_list), max_out_features=self.out_features, bias=self.bias
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
return self.linear(x)
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return 'DyLinear(%d, %d)' % (max(self.in_features_list), self.out_features)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicLinear.__name__,
|
||||
'in_features_list': self.in_features_list,
|
||||
'out_features': self.out_features,
|
||||
'bias': self.bias,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicLinearLayer(**config)
|
||||
|
||||
def get_active_subnet(self, in_features, preserve_weight=True):
|
||||
sub_layer = LinearLayer(in_features, self.out_features, self.bias, dropout_rate=self.dropout_rate)
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
sub_layer.linear.weight.data.copy_(
|
||||
self.linear.get_active_weight(self.out_features, in_features).data
|
||||
)
|
||||
if self.bias:
|
||||
sub_layer.linear.bias.data.copy_(
|
||||
self.linear.get_active_bias(self.out_features).data
|
||||
)
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_features):
|
||||
return {
|
||||
'name': LinearLayer.__name__,
|
||||
'in_features': in_features,
|
||||
'out_features': self.out_features,
|
||||
'bias': self.bias,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
}
|
||||
|
||||
|
||||
class DynamicMBConvLayer(MyModule):
|
||||
|
||||
def __init__(self, in_channel_list, out_channel_list,
|
||||
kernel_size_list=3, expand_ratio_list=6, stride=1, act_func='relu6', use_se=False):
|
||||
super(DynamicMBConvLayer, self).__init__()
|
||||
|
||||
self.in_channel_list = in_channel_list
|
||||
self.out_channel_list = out_channel_list
|
||||
|
||||
self.kernel_size_list = val2list(kernel_size_list)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list)
|
||||
|
||||
self.stride = stride
|
||||
self.act_func = act_func
|
||||
self.use_se = use_se
|
||||
|
||||
# build modules
|
||||
max_middle_channel = make_divisible(
|
||||
round(max(self.in_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
if max(self.expand_ratio_list) == 1:
|
||||
self.inverted_bottleneck = None
|
||||
else:
|
||||
self.inverted_bottleneck = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func)),
|
||||
]))
|
||||
|
||||
self.depth_conv = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicSeparableConv2d(max_middle_channel, self.kernel_size_list, self.stride)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func))
|
||||
]))
|
||||
if self.use_se:
|
||||
self.depth_conv.add_module('se', DynamicSE(max_middle_channel))
|
||||
|
||||
self.point_linear = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
|
||||
self.active_kernel_size = max(self.kernel_size_list)
|
||||
self.active_expand_ratio = max(self.expand_ratio_list)
|
||||
self.active_out_channel = max(self.out_channel_list)
|
||||
|
||||
def forward(self, x):
|
||||
in_channel = x.size(1)
|
||||
|
||||
if self.inverted_bottleneck is not None:
|
||||
self.inverted_bottleneck.conv.active_out_channel = \
|
||||
make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
self.depth_conv.conv.active_kernel_size = self.active_kernel_size
|
||||
self.point_linear.conv.active_out_channel = self.active_out_channel
|
||||
|
||||
if self.inverted_bottleneck is not None:
|
||||
x = self.inverted_bottleneck(x)
|
||||
x = self.depth_conv(x)
|
||||
x = self.point_linear(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
if self.use_se:
|
||||
return 'SE(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
|
||||
else:
|
||||
return '(O%d, E%.1f, K%d)' % (self.active_out_channel, self.active_expand_ratio, self.active_kernel_size)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicMBConvLayer.__name__,
|
||||
'in_channel_list': self.in_channel_list,
|
||||
'out_channel_list': self.out_channel_list,
|
||||
'kernel_size_list': self.kernel_size_list,
|
||||
'expand_ratio_list': self.expand_ratio_list,
|
||||
'stride': self.stride,
|
||||
'act_func': self.act_func,
|
||||
'use_se': self.use_se,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicMBConvLayer(**config)
|
||||
|
||||
############################################################################################
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return max(self.in_channel_list)
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return max(self.out_channel_list)
|
||||
|
||||
def active_middle_channel(self, in_channel):
|
||||
return make_divisible(round(in_channel * self.active_expand_ratio), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
############################################################################################
|
||||
|
||||
def get_active_subnet(self, in_channel, preserve_weight=True):
|
||||
# build the new layer
|
||||
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
middle_channel = self.active_middle_channel(in_channel)
|
||||
# copy weight from current layer
|
||||
if sub_layer.inverted_bottleneck is not None:
|
||||
sub_layer.inverted_bottleneck.conv.weight.data.copy_(
|
||||
self.inverted_bottleneck.conv.get_active_filter(middle_channel, in_channel).data,
|
||||
)
|
||||
copy_bn(sub_layer.inverted_bottleneck.bn, self.inverted_bottleneck.bn.bn)
|
||||
|
||||
sub_layer.depth_conv.conv.weight.data.copy_(
|
||||
self.depth_conv.conv.get_active_filter(middle_channel, self.active_kernel_size).data
|
||||
)
|
||||
copy_bn(sub_layer.depth_conv.bn, self.depth_conv.bn.bn)
|
||||
|
||||
if self.use_se:
|
||||
se_mid = make_divisible(middle_channel // SEModule.REDUCTION, divisor=MyNetwork.CHANNEL_DIVISIBLE)
|
||||
sub_layer.depth_conv.se.fc.reduce.weight.data.copy_(
|
||||
self.depth_conv.se.get_active_reduce_weight(se_mid, middle_channel).data
|
||||
)
|
||||
sub_layer.depth_conv.se.fc.reduce.bias.data.copy_(
|
||||
self.depth_conv.se.get_active_reduce_bias(se_mid).data
|
||||
)
|
||||
|
||||
sub_layer.depth_conv.se.fc.expand.weight.data.copy_(
|
||||
self.depth_conv.se.get_active_expand_weight(se_mid, middle_channel).data
|
||||
)
|
||||
sub_layer.depth_conv.se.fc.expand.bias.data.copy_(
|
||||
self.depth_conv.se.get_active_expand_bias(middle_channel).data
|
||||
)
|
||||
|
||||
sub_layer.point_linear.conv.weight.data.copy_(
|
||||
self.point_linear.conv.get_active_filter(self.active_out_channel, middle_channel).data
|
||||
)
|
||||
copy_bn(sub_layer.point_linear.bn, self.point_linear.bn.bn)
|
||||
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_channel):
|
||||
return {
|
||||
'name': MBConvLayer.__name__,
|
||||
'in_channels': in_channel,
|
||||
'out_channels': self.active_out_channel,
|
||||
'kernel_size': self.active_kernel_size,
|
||||
'stride': self.stride,
|
||||
'expand_ratio': self.active_expand_ratio,
|
||||
'mid_channels': self.active_middle_channel(in_channel),
|
||||
'act_func': self.act_func,
|
||||
'use_se': self.use_se,
|
||||
}
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
importance = torch.sum(torch.abs(self.point_linear.conv.conv.weight.data), dim=(0, 2, 3))
|
||||
if isinstance(self.depth_conv.bn, DynamicGroupNorm):
|
||||
channel_per_group = self.depth_conv.bn.channel_per_group
|
||||
importance_chunks = torch.split(importance, channel_per_group)
|
||||
for chunk in importance_chunks:
|
||||
chunk.data.fill_(torch.mean(chunk))
|
||||
importance = torch.cat(importance_chunks, dim=0)
|
||||
if expand_ratio_stage > 0:
|
||||
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
||||
sorted_expand_list.sort(reverse=True)
|
||||
target_width_list = [
|
||||
make_divisible(round(max(self.in_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for expand in sorted_expand_list
|
||||
]
|
||||
|
||||
right = len(importance)
|
||||
base = - len(target_width_list) * 1e5
|
||||
for i in range(expand_ratio_stage + 1):
|
||||
left = target_width_list[i]
|
||||
importance[left:right] += base
|
||||
base += 1e5
|
||||
right = left
|
||||
|
||||
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
||||
self.point_linear.conv.conv.weight.data = torch.index_select(
|
||||
self.point_linear.conv.conv.weight.data, 1, sorted_idx
|
||||
)
|
||||
|
||||
adjust_bn_according_to_idx(self.depth_conv.bn.bn, sorted_idx)
|
||||
self.depth_conv.conv.conv.weight.data = torch.index_select(
|
||||
self.depth_conv.conv.conv.weight.data, 0, sorted_idx
|
||||
)
|
||||
|
||||
if self.use_se:
|
||||
# se expand: output dim 0 reorganize
|
||||
se_expand = self.depth_conv.se.fc.expand
|
||||
se_expand.weight.data = torch.index_select(se_expand.weight.data, 0, sorted_idx)
|
||||
se_expand.bias.data = torch.index_select(se_expand.bias.data, 0, sorted_idx)
|
||||
# se reduce: input dim 1 reorganize
|
||||
se_reduce = self.depth_conv.se.fc.reduce
|
||||
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 1, sorted_idx)
|
||||
# middle weight reorganize
|
||||
se_importance = torch.sum(torch.abs(se_expand.weight.data), dim=(0, 2, 3))
|
||||
se_importance, se_idx = torch.sort(se_importance, dim=0, descending=True)
|
||||
|
||||
se_expand.weight.data = torch.index_select(se_expand.weight.data, 1, se_idx)
|
||||
se_reduce.weight.data = torch.index_select(se_reduce.weight.data, 0, se_idx)
|
||||
se_reduce.bias.data = torch.index_select(se_reduce.bias.data, 0, se_idx)
|
||||
|
||||
if self.inverted_bottleneck is not None:
|
||||
adjust_bn_according_to_idx(self.inverted_bottleneck.bn.bn, sorted_idx)
|
||||
self.inverted_bottleneck.conv.conv.weight.data = torch.index_select(
|
||||
self.inverted_bottleneck.conv.conv.weight.data, 0, sorted_idx
|
||||
)
|
||||
return None
|
||||
else:
|
||||
return sorted_idx
|
||||
|
||||
|
||||
class DynamicConvLayer(MyModule):
|
||||
|
||||
def __init__(self, in_channel_list, out_channel_list, kernel_size=3, stride=1, dilation=1,
|
||||
use_bn=True, act_func='relu6'):
|
||||
super(DynamicConvLayer, self).__init__()
|
||||
|
||||
self.in_channel_list = in_channel_list
|
||||
self.out_channel_list = out_channel_list
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.use_bn = use_bn
|
||||
self.act_func = act_func
|
||||
|
||||
self.conv = DynamicConv2d(
|
||||
max_in_channels=max(self.in_channel_list), max_out_channels=max(self.out_channel_list),
|
||||
kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation,
|
||||
)
|
||||
if self.use_bn:
|
||||
self.bn = DynamicBatchNorm2d(max(self.out_channel_list))
|
||||
self.act = build_activation(self.act_func)
|
||||
|
||||
self.active_out_channel = max(self.out_channel_list)
|
||||
|
||||
def forward(self, x):
|
||||
self.conv.active_out_channel = self.active_out_channel
|
||||
|
||||
x = self.conv(x)
|
||||
if self.use_bn:
|
||||
x = self.bn(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return 'DyConv(O%d, K%d, S%d)' % (self.active_out_channel, self.kernel_size, self.stride)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicConvLayer.__name__,
|
||||
'in_channel_list': self.in_channel_list,
|
||||
'out_channel_list': self.out_channel_list,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'dilation': self.dilation,
|
||||
'use_bn': self.use_bn,
|
||||
'act_func': self.act_func,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicConvLayer(**config)
|
||||
|
||||
############################################################################################
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return max(self.in_channel_list)
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return max(self.out_channel_list)
|
||||
|
||||
############################################################################################
|
||||
|
||||
def get_active_subnet(self, in_channel, preserve_weight=True):
|
||||
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
sub_layer.conv.weight.data.copy_(self.conv.get_active_filter(self.active_out_channel, in_channel).data)
|
||||
if self.use_bn:
|
||||
copy_bn(sub_layer.bn, self.bn.bn)
|
||||
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_channel):
|
||||
return {
|
||||
'name': ConvLayer.__name__,
|
||||
'in_channels': in_channel,
|
||||
'out_channels': self.active_out_channel,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'dilation': self.dilation,
|
||||
'use_bn': self.use_bn,
|
||||
'act_func': self.act_func,
|
||||
}
|
||||
|
||||
|
||||
class DynamicResNetBottleneckBlock(MyModule):
|
||||
|
||||
def __init__(self, in_channel_list, out_channel_list, expand_ratio_list=0.25,
|
||||
kernel_size=3, stride=1, act_func='relu', downsample_mode='avgpool_conv'):
|
||||
super(DynamicResNetBottleneckBlock, self).__init__()
|
||||
|
||||
self.in_channel_list = in_channel_list
|
||||
self.out_channel_list = out_channel_list
|
||||
self.expand_ratio_list = val2list(expand_ratio_list)
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.act_func = act_func
|
||||
self.downsample_mode = downsample_mode
|
||||
|
||||
# build modules
|
||||
max_middle_channel = make_divisible(
|
||||
round(max(self.out_channel_list) * max(self.expand_ratio_list)), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
self.conv1 = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max_middle_channel)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func, inplace=True)),
|
||||
]))
|
||||
|
||||
self.conv2 = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max_middle_channel, max_middle_channel, kernel_size, stride)),
|
||||
('bn', DynamicBatchNorm2d(max_middle_channel)),
|
||||
('act', build_activation(self.act_func, inplace=True))
|
||||
]))
|
||||
|
||||
self.conv3 = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max_middle_channel, max(self.out_channel_list))),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
|
||||
if self.stride == 1 and self.in_channel_list == self.out_channel_list:
|
||||
self.downsample = IdentityLayer(max(self.in_channel_list), max(self.out_channel_list))
|
||||
elif self.downsample_mode == 'conv':
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list), stride=stride)),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
elif self.downsample_mode == 'avgpool_conv':
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
('avg_pool', nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)),
|
||||
('conv', DynamicConv2d(max(self.in_channel_list), max(self.out_channel_list))),
|
||||
('bn', DynamicBatchNorm2d(max(self.out_channel_list))),
|
||||
]))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.final_act = build_activation(self.act_func, inplace=True)
|
||||
|
||||
self.active_expand_ratio = max(self.expand_ratio_list)
|
||||
self.active_out_channel = max(self.out_channel_list)
|
||||
|
||||
def forward(self, x):
|
||||
feature_dim = self.active_middle_channels
|
||||
|
||||
self.conv1.conv.active_out_channel = feature_dim
|
||||
self.conv2.conv.active_out_channel = feature_dim
|
||||
self.conv3.conv.active_out_channel = self.active_out_channel
|
||||
if not isinstance(self.downsample, IdentityLayer):
|
||||
self.downsample.conv.active_out_channel = self.active_out_channel
|
||||
|
||||
residual = self.downsample(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
|
||||
x = x + residual
|
||||
x = self.final_act(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '(%s, %s)' % (
|
||||
'%dx%d_BottleneckConv_in->%d->%d_S%d' % (
|
||||
self.kernel_size, self.kernel_size, self.active_middle_channels, self.active_out_channel, self.stride
|
||||
),
|
||||
'Identity' if isinstance(self.downsample, IdentityLayer) else self.downsample_mode,
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': DynamicResNetBottleneckBlock.__name__,
|
||||
'in_channel_list': self.in_channel_list,
|
||||
'out_channel_list': self.out_channel_list,
|
||||
'expand_ratio_list': self.expand_ratio_list,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'act_func': self.act_func,
|
||||
'downsample_mode': self.downsample_mode,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return DynamicResNetBottleneckBlock(**config)
|
||||
|
||||
############################################################################################
|
||||
|
||||
@property
|
||||
def in_channels(self):
|
||||
return max(self.in_channel_list)
|
||||
|
||||
@property
|
||||
def out_channels(self):
|
||||
return max(self.out_channel_list)
|
||||
|
||||
@property
|
||||
def active_middle_channels(self):
|
||||
feature_dim = round(self.active_out_channel * self.active_expand_ratio)
|
||||
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
return feature_dim
|
||||
|
||||
############################################################################################
|
||||
|
||||
def get_active_subnet(self, in_channel, preserve_weight=True):
|
||||
# build the new layer
|
||||
sub_layer = set_layer_from_config(self.get_active_subnet_config(in_channel))
|
||||
sub_layer = sub_layer.to(get_net_device(self))
|
||||
if not preserve_weight:
|
||||
return sub_layer
|
||||
|
||||
# copy weight from current layer
|
||||
sub_layer.conv1.conv.weight.data.copy_(
|
||||
self.conv1.conv.get_active_filter(self.active_middle_channels, in_channel).data)
|
||||
copy_bn(sub_layer.conv1.bn, self.conv1.bn.bn)
|
||||
|
||||
sub_layer.conv2.conv.weight.data.copy_(
|
||||
self.conv2.conv.get_active_filter(self.active_middle_channels, self.active_middle_channels).data)
|
||||
copy_bn(sub_layer.conv2.bn, self.conv2.bn.bn)
|
||||
|
||||
sub_layer.conv3.conv.weight.data.copy_(
|
||||
self.conv3.conv.get_active_filter(self.active_out_channel, self.active_middle_channels).data)
|
||||
copy_bn(sub_layer.conv3.bn, self.conv3.bn.bn)
|
||||
|
||||
if not isinstance(self.downsample, IdentityLayer):
|
||||
sub_layer.downsample.conv.weight.data.copy_(
|
||||
self.downsample.conv.get_active_filter(self.active_out_channel, in_channel).data)
|
||||
copy_bn(sub_layer.downsample.bn, self.downsample.bn.bn)
|
||||
|
||||
return sub_layer
|
||||
|
||||
def get_active_subnet_config(self, in_channel):
|
||||
return {
|
||||
'name': ResNetBottleneckBlock.__name__,
|
||||
'in_channels': in_channel,
|
||||
'out_channels': self.active_out_channel,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'expand_ratio': self.active_expand_ratio,
|
||||
'mid_channels': self.active_middle_channels,
|
||||
'act_func': self.act_func,
|
||||
'groups': 1,
|
||||
'downsample_mode': self.downsample_mode,
|
||||
}
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
# conv3 -> conv2
|
||||
importance = torch.sum(torch.abs(self.conv3.conv.conv.weight.data), dim=(0, 2, 3))
|
||||
if isinstance(self.conv2.bn, DynamicGroupNorm):
|
||||
channel_per_group = self.conv2.bn.channel_per_group
|
||||
importance_chunks = torch.split(importance, channel_per_group)
|
||||
for chunk in importance_chunks:
|
||||
chunk.data.fill_(torch.mean(chunk))
|
||||
importance = torch.cat(importance_chunks, dim=0)
|
||||
if expand_ratio_stage > 0:
|
||||
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
||||
sorted_expand_list.sort(reverse=True)
|
||||
target_width_list = [
|
||||
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for expand in sorted_expand_list
|
||||
]
|
||||
right = len(importance)
|
||||
base = - len(target_width_list) * 1e5
|
||||
for i in range(expand_ratio_stage + 1):
|
||||
left = target_width_list[i]
|
||||
importance[left:right] += base
|
||||
base += 1e5
|
||||
right = left
|
||||
|
||||
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
||||
self.conv3.conv.conv.weight.data = torch.index_select(self.conv3.conv.conv.weight.data, 1, sorted_idx)
|
||||
adjust_bn_according_to_idx(self.conv2.bn.bn, sorted_idx)
|
||||
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 0, sorted_idx)
|
||||
|
||||
# conv2 -> conv1
|
||||
importance = torch.sum(torch.abs(self.conv2.conv.conv.weight.data), dim=(0, 2, 3))
|
||||
if isinstance(self.conv1.bn, DynamicGroupNorm):
|
||||
channel_per_group = self.conv1.bn.channel_per_group
|
||||
importance_chunks = torch.split(importance, channel_per_group)
|
||||
for chunk in importance_chunks:
|
||||
chunk.data.fill_(torch.mean(chunk))
|
||||
importance = torch.cat(importance_chunks, dim=0)
|
||||
if expand_ratio_stage > 0:
|
||||
sorted_expand_list = copy.deepcopy(self.expand_ratio_list)
|
||||
sorted_expand_list.sort(reverse=True)
|
||||
target_width_list = [
|
||||
make_divisible(round(max(self.out_channel_list) * expand), MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for expand in sorted_expand_list
|
||||
]
|
||||
right = len(importance)
|
||||
base = - len(target_width_list) * 1e5
|
||||
for i in range(expand_ratio_stage + 1):
|
||||
left = target_width_list[i]
|
||||
importance[left:right] += base
|
||||
base += 1e5
|
||||
right = left
|
||||
sorted_importance, sorted_idx = torch.sort(importance, dim=0, descending=True)
|
||||
|
||||
self.conv2.conv.conv.weight.data = torch.index_select(self.conv2.conv.conv.weight.data, 1, sorted_idx)
|
||||
adjust_bn_according_to_idx(self.conv1.bn.bn, sorted_idx)
|
||||
self.conv1.conv.conv.weight.data = torch.index_select(self.conv1.conv.conv.weight.data, 0, sorted_idx)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,314 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from ofa_local.utils import get_same_padding, sub_filter_start_end, make_divisible, SEModule, MyNetwork, MyConv2d
|
||||
|
||||
__all__ = ['DynamicSeparableConv2d', 'DynamicConv2d', 'DynamicGroupConv2d',
|
||||
'DynamicBatchNorm2d', 'DynamicGroupNorm', 'DynamicSE', 'DynamicLinear']
|
||||
|
||||
|
||||
class DynamicSeparableConv2d(nn.Module):
|
||||
KERNEL_TRANSFORM_MODE = 1 # None or 1
|
||||
|
||||
def __init__(self, max_in_channels, kernel_size_list, stride=1, dilation=1):
|
||||
super(DynamicSeparableConv2d, self).__init__()
|
||||
|
||||
self.max_in_channels = max_in_channels
|
||||
self.kernel_size_list = kernel_size_list
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
self.max_in_channels, self.max_in_channels, max(self.kernel_size_list), self.stride,
|
||||
groups=self.max_in_channels, bias=False,
|
||||
)
|
||||
|
||||
self._ks_set = list(set(self.kernel_size_list))
|
||||
self._ks_set.sort() # e.g., [3, 5, 7]
|
||||
if self.KERNEL_TRANSFORM_MODE is not None:
|
||||
# register scaling parameters
|
||||
# 7to5_matrix, 5to3_matrix
|
||||
scale_params = {}
|
||||
for i in range(len(self._ks_set) - 1):
|
||||
ks_small = self._ks_set[i]
|
||||
ks_larger = self._ks_set[i + 1]
|
||||
param_name = '%dto%d' % (ks_larger, ks_small)
|
||||
# noinspection PyArgumentList
|
||||
scale_params['%s_matrix' % param_name] = Parameter(torch.eye(ks_small ** 2))
|
||||
for name, param in scale_params.items():
|
||||
self.register_parameter(name, param)
|
||||
|
||||
self.active_kernel_size = max(self.kernel_size_list)
|
||||
|
||||
def get_active_filter(self, in_channel, kernel_size):
|
||||
out_channel = in_channel
|
||||
max_kernel_size = max(self.kernel_size_list)
|
||||
|
||||
start, end = sub_filter_start_end(max_kernel_size, kernel_size)
|
||||
filters = self.conv.weight[:out_channel, :in_channel, start:end, start:end]
|
||||
if self.KERNEL_TRANSFORM_MODE is not None and kernel_size < max_kernel_size:
|
||||
start_filter = self.conv.weight[:out_channel, :in_channel, :, :] # start with max kernel
|
||||
for i in range(len(self._ks_set) - 1, 0, -1):
|
||||
src_ks = self._ks_set[i]
|
||||
if src_ks <= kernel_size:
|
||||
break
|
||||
target_ks = self._ks_set[i - 1]
|
||||
start, end = sub_filter_start_end(src_ks, target_ks)
|
||||
_input_filter = start_filter[:, :, start:end, start:end]
|
||||
_input_filter = _input_filter.contiguous()
|
||||
_input_filter = _input_filter.view(_input_filter.size(0), _input_filter.size(1), -1)
|
||||
_input_filter = _input_filter.view(-1, _input_filter.size(2))
|
||||
_input_filter = F.linear(
|
||||
_input_filter, self.__getattr__('%dto%d_matrix' % (src_ks, target_ks)),
|
||||
)
|
||||
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks ** 2)
|
||||
_input_filter = _input_filter.view(filters.size(0), filters.size(1), target_ks, target_ks)
|
||||
start_filter = _input_filter
|
||||
filters = start_filter
|
||||
return filters
|
||||
|
||||
def forward(self, x, kernel_size=None):
|
||||
if kernel_size is None:
|
||||
kernel_size = self.active_kernel_size
|
||||
in_channel = x.size(1)
|
||||
|
||||
filters = self.get_active_filter(in_channel, kernel_size).contiguous()
|
||||
|
||||
padding = get_same_padding(kernel_size)
|
||||
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
|
||||
y = F.conv2d(
|
||||
x, filters, None, self.stride, padding, self.dilation, in_channel
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, max_in_channels, max_out_channels, kernel_size=1, stride=1, dilation=1):
|
||||
super(DynamicConv2d, self).__init__()
|
||||
|
||||
self.max_in_channels = max_in_channels
|
||||
self.max_out_channels = max_out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
self.max_in_channels, self.max_out_channels, self.kernel_size, stride=self.stride, bias=False,
|
||||
)
|
||||
|
||||
self.active_out_channel = self.max_out_channels
|
||||
|
||||
def get_active_filter(self, out_channel, in_channel):
|
||||
return self.conv.weight[:out_channel, :in_channel, :, :]
|
||||
|
||||
def forward(self, x, out_channel=None):
|
||||
if out_channel is None:
|
||||
out_channel = self.active_out_channel
|
||||
in_channel = x.size(1)
|
||||
filters = self.get_active_filter(out_channel, in_channel).contiguous()
|
||||
|
||||
padding = get_same_padding(self.kernel_size)
|
||||
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
|
||||
y = F.conv2d(x, filters, None, self.stride, padding, self.dilation, 1)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicGroupConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size_list, groups_list, stride=1, dilation=1):
|
||||
super(DynamicGroupConv2d, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size_list = kernel_size_list
|
||||
self.groups_list = groups_list
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
|
||||
self.conv = nn.Conv2d(
|
||||
self.in_channels, self.out_channels, max(self.kernel_size_list), self.stride,
|
||||
groups=min(self.groups_list), bias=False,
|
||||
)
|
||||
|
||||
self.active_kernel_size = max(self.kernel_size_list)
|
||||
self.active_groups = min(self.groups_list)
|
||||
|
||||
def get_active_filter(self, kernel_size, groups):
|
||||
start, end = sub_filter_start_end(max(self.kernel_size_list), kernel_size)
|
||||
filters = self.conv.weight[:, :, start:end, start:end]
|
||||
|
||||
sub_filters = torch.chunk(filters, groups, dim=0)
|
||||
sub_in_channels = self.in_channels // groups
|
||||
sub_ratio = filters.size(1) // sub_in_channels
|
||||
|
||||
filter_crops = []
|
||||
for i, sub_filter in enumerate(sub_filters):
|
||||
part_id = i % sub_ratio
|
||||
start = part_id * sub_in_channels
|
||||
filter_crops.append(sub_filter[:, start:start + sub_in_channels, :, :])
|
||||
filters = torch.cat(filter_crops, dim=0)
|
||||
return filters
|
||||
|
||||
def forward(self, x, kernel_size=None, groups=None):
|
||||
if kernel_size is None:
|
||||
kernel_size = self.active_kernel_size
|
||||
if groups is None:
|
||||
groups = self.active_groups
|
||||
|
||||
filters = self.get_active_filter(kernel_size, groups).contiguous()
|
||||
padding = get_same_padding(kernel_size)
|
||||
filters = self.conv.weight_standardization(filters) if isinstance(self.conv, MyConv2d) else filters
|
||||
y = F.conv2d(
|
||||
x, filters, None, self.stride, padding, self.dilation, groups,
|
||||
)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicBatchNorm2d(nn.Module):
|
||||
SET_RUNNING_STATISTICS = False
|
||||
|
||||
def __init__(self, max_feature_dim):
|
||||
super(DynamicBatchNorm2d, self).__init__()
|
||||
|
||||
self.max_feature_dim = max_feature_dim
|
||||
self.bn = nn.BatchNorm2d(self.max_feature_dim)
|
||||
|
||||
@staticmethod
|
||||
def bn_forward(x, bn: nn.BatchNorm2d, feature_dim):
|
||||
if bn.num_features == feature_dim or DynamicBatchNorm2d.SET_RUNNING_STATISTICS:
|
||||
return bn(x)
|
||||
else:
|
||||
exponential_average_factor = 0.0
|
||||
|
||||
if bn.training and bn.track_running_stats:
|
||||
if bn.num_batches_tracked is not None:
|
||||
bn.num_batches_tracked += 1
|
||||
if bn.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(bn.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = bn.momentum
|
||||
return F.batch_norm(
|
||||
x, bn.running_mean[:feature_dim], bn.running_var[:feature_dim], bn.weight[:feature_dim],
|
||||
bn.bias[:feature_dim], bn.training or not bn.track_running_stats,
|
||||
exponential_average_factor, bn.eps,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
feature_dim = x.size(1)
|
||||
y = self.bn_forward(x, self.bn, feature_dim)
|
||||
return y
|
||||
|
||||
|
||||
class DynamicGroupNorm(nn.GroupNorm):
|
||||
|
||||
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, channel_per_group=None):
|
||||
super(DynamicGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
|
||||
self.channel_per_group = channel_per_group
|
||||
|
||||
def forward(self, x):
|
||||
n_channels = x.size(1)
|
||||
n_groups = n_channels // self.channel_per_group
|
||||
return F.group_norm(x, n_groups, self.weight[:n_channels], self.bias[:n_channels], self.eps)
|
||||
|
||||
@property
|
||||
def bn(self):
|
||||
return self
|
||||
|
||||
|
||||
class DynamicSE(SEModule):
|
||||
|
||||
def __init__(self, max_channel):
|
||||
super(DynamicSE, self).__init__(max_channel)
|
||||
|
||||
def get_active_reduce_weight(self, num_mid, in_channel, groups=None):
|
||||
if groups is None or groups == 1:
|
||||
return self.fc.reduce.weight[:num_mid, :in_channel, :, :]
|
||||
else:
|
||||
assert in_channel % groups == 0
|
||||
sub_in_channels = in_channel // groups
|
||||
sub_filters = torch.chunk(self.fc.reduce.weight[:num_mid, :, :, :], groups, dim=1)
|
||||
return torch.cat([
|
||||
sub_filter[:, :sub_in_channels, :, :] for sub_filter in sub_filters
|
||||
], dim=1)
|
||||
|
||||
def get_active_reduce_bias(self, num_mid):
|
||||
return self.fc.reduce.bias[:num_mid] if self.fc.reduce.bias is not None else None
|
||||
|
||||
def get_active_expand_weight(self, num_mid, in_channel, groups=None):
|
||||
if groups is None or groups == 1:
|
||||
return self.fc.expand.weight[:in_channel, :num_mid, :, :]
|
||||
else:
|
||||
assert in_channel % groups == 0
|
||||
sub_in_channels = in_channel // groups
|
||||
sub_filters = torch.chunk(self.fc.expand.weight[:, :num_mid, :, :], groups, dim=0)
|
||||
return torch.cat([
|
||||
sub_filter[:sub_in_channels, :, :, :] for sub_filter in sub_filters
|
||||
], dim=0)
|
||||
|
||||
def get_active_expand_bias(self, in_channel, groups=None):
|
||||
if groups is None or groups == 1:
|
||||
return self.fc.expand.bias[:in_channel] if self.fc.expand.bias is not None else None
|
||||
else:
|
||||
assert in_channel % groups == 0
|
||||
sub_in_channels = in_channel // groups
|
||||
sub_bias_list = torch.chunk(self.fc.expand.bias, groups, dim=0)
|
||||
return torch.cat([
|
||||
sub_bias[:sub_in_channels] for sub_bias in sub_bias_list
|
||||
], dim=0)
|
||||
|
||||
def forward(self, x, groups=None):
|
||||
in_channel = x.size(1)
|
||||
num_mid = make_divisible(in_channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
|
||||
# reduce
|
||||
reduce_filter = self.get_active_reduce_weight(num_mid, in_channel, groups=groups).contiguous()
|
||||
reduce_bias = self.get_active_reduce_bias(num_mid)
|
||||
y = F.conv2d(y, reduce_filter, reduce_bias, 1, 0, 1, 1)
|
||||
# relu
|
||||
y = self.fc.relu(y)
|
||||
# expand
|
||||
expand_filter = self.get_active_expand_weight(num_mid, in_channel, groups=groups).contiguous()
|
||||
expand_bias = self.get_active_expand_bias(in_channel, groups=groups)
|
||||
y = F.conv2d(y, expand_filter, expand_bias, 1, 0, 1, 1)
|
||||
# hard sigmoid
|
||||
y = self.fc.h_sigmoid(y)
|
||||
|
||||
return x * y
|
||||
|
||||
|
||||
class DynamicLinear(nn.Module):
|
||||
|
||||
def __init__(self, max_in_features, max_out_features, bias=True):
|
||||
super(DynamicLinear, self).__init__()
|
||||
|
||||
self.max_in_features = max_in_features
|
||||
self.max_out_features = max_out_features
|
||||
self.bias = bias
|
||||
|
||||
self.linear = nn.Linear(self.max_in_features, self.max_out_features, self.bias)
|
||||
|
||||
self.active_out_features = self.max_out_features
|
||||
|
||||
def get_active_weight(self, out_features, in_features):
|
||||
return self.linear.weight[:out_features, :in_features]
|
||||
|
||||
def get_active_bias(self, out_features):
|
||||
return self.linear.bias[:out_features] if self.bias else None
|
||||
|
||||
def forward(self, x, out_features=None):
|
||||
if out_features is None:
|
||||
out_features = self.active_out_features
|
||||
|
||||
in_features = x.size(1)
|
||||
weight = self.get_active_weight(out_features, in_features).contiguous()
|
||||
bias = self.get_active_bias(out_features)
|
||||
y = F.linear(x, weight, bias)
|
||||
return y
|
||||
@@ -0,0 +1,7 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .ofa_proxyless import OFAProxylessNASNets
|
||||
from .ofa_mbv3 import OFAMobileNetV3
|
||||
from .ofa_resnets import OFAResNets
|
||||
@@ -0,0 +1,336 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import random
|
||||
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicMBConvLayer
|
||||
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
|
||||
from ofa_local.imagenet_classification.networks import MobileNetV3
|
||||
from ofa_local.utils import make_divisible, val2list, MyNetwork
|
||||
from ofa_local.utils.layers import set_layer_from_config
|
||||
import gin
|
||||
|
||||
__all__ = ['OFAMobileNetV3']
|
||||
|
||||
@gin.configurable
|
||||
class OFAMobileNetV3(MobileNetV3):
|
||||
|
||||
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
|
||||
ks_list=3, expand_ratio_list=6, depth_list=4, dropblock=False, block_size=0):
|
||||
|
||||
self.width_mult = width_mult
|
||||
self.ks_list = val2list(ks_list, 1)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
||||
self.depth_list = val2list(depth_list, 1)
|
||||
|
||||
self.ks_list.sort()
|
||||
self.expand_ratio_list.sort()
|
||||
self.depth_list.sort()
|
||||
|
||||
base_stage_width = [16, 16, 24, 40, 80, 112, 160, 960, 1280]
|
||||
|
||||
final_expand_width = make_divisible(base_stage_width[-2] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
stride_stages = [1, 2, 2, 2, 1, 2]
|
||||
act_stages = ['relu', 'relu', 'relu', 'h_swish', 'h_swish', 'h_swish']
|
||||
se_stages = [False, False, True, False, True, True]
|
||||
n_block_list = [1] + [max(self.depth_list)] * 5
|
||||
width_list = []
|
||||
for base_width in base_stage_width[:-2]:
|
||||
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
width_list.append(width)
|
||||
|
||||
input_channel, first_block_dim = width_list[0], width_list[1]
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(3, input_channel, kernel_size=3, stride=2, act_func='h_swish')
|
||||
first_block_conv = MBConvLayer(
|
||||
in_channels=input_channel, out_channels=first_block_dim, kernel_size=3, stride=stride_stages[0],
|
||||
expand_ratio=1, act_func=act_stages[0], use_se=se_stages[0],
|
||||
)
|
||||
first_block = ResidualBlock(
|
||||
first_block_conv,
|
||||
IdentityLayer(first_block_dim, first_block_dim) if input_channel == first_block_dim else None,
|
||||
dropout_rate, dropblock, block_size
|
||||
)
|
||||
|
||||
# inverted residual blocks
|
||||
self.block_group_info = []
|
||||
blocks = [first_block]
|
||||
_block_index = 1
|
||||
feature_dim = first_block_dim
|
||||
|
||||
for width, n_block, s, act_func, use_se in zip(width_list[2:], n_block_list[1:],
|
||||
stride_stages[1:], act_stages[1:], se_stages[1:]):
|
||||
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
||||
_block_index += n_block
|
||||
|
||||
output_channel = width
|
||||
for i in range(n_block):
|
||||
if i == 0:
|
||||
stride = s
|
||||
else:
|
||||
stride = 1
|
||||
mobile_inverted_conv = DynamicMBConvLayer(
|
||||
in_channel_list=val2list(feature_dim), out_channel_list=val2list(output_channel),
|
||||
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list,
|
||||
stride=stride, act_func=act_func, use_se=use_se,
|
||||
)
|
||||
if stride == 1 and feature_dim == output_channel:
|
||||
shortcut = IdentityLayer(feature_dim, feature_dim)
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(ResidualBlock(mobile_inverted_conv, shortcut,
|
||||
dropout_rate, dropblock, block_size))
|
||||
feature_dim = output_channel
|
||||
# final expand layer, feature mix layer & classifier
|
||||
final_expand_layer = ConvLayer(feature_dim, final_expand_width, kernel_size=1, act_func='h_swish')
|
||||
feature_mix_layer = ConvLayer(
|
||||
final_expand_width, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
|
||||
)
|
||||
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(OFAMobileNetV3, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
||||
|
||||
# runtime_depth
|
||||
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
||||
|
||||
""" MyNetwork required methods """
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'OFAMobileNetV3'
|
||||
|
||||
def forward(self, x):
|
||||
# first conv
|
||||
x = self.first_conv(x)
|
||||
# first block
|
||||
x = self.blocks[0](x)
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
x = self.blocks[idx](x)
|
||||
x = self.final_expand_layer(x)
|
||||
x = x.mean(3, keepdim=True).mean(2, keepdim=True) # global average pooling
|
||||
x = self.feature_mix_layer(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
_str += self.blocks[0].module_str + '\n'
|
||||
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
_str += self.blocks[idx].module_str + '\n'
|
||||
|
||||
_str += self.final_expand_layer.module_str + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.classifier.module_str + '\n'
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': OFAMobileNetV3.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'final_expand_layer': self.final_expand_layer.config,
|
||||
'feature_mix_layer': self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise ValueError('do not support this function')
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
return self.block_group_info
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
model_dict = self.state_dict()
|
||||
for key in state_dict:
|
||||
if '.mobile_inverted_conv.' in key:
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
if new_key in model_dict:
|
||||
pass
|
||||
elif '.bn.bn.' in new_key:
|
||||
new_key = new_key.replace('.bn.bn.', '.bn.')
|
||||
elif '.conv.conv.weight' in new_key:
|
||||
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
|
||||
elif '.linear.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.linear.', '.linear.')
|
||||
##############################################################################
|
||||
elif '.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.', '.linear.linear.')
|
||||
elif 'bn.' in new_key:
|
||||
new_key = new_key.replace('bn.', 'bn.bn.')
|
||||
elif 'conv.weight' in new_key:
|
||||
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
|
||||
else:
|
||||
raise ValueError(new_key)
|
||||
assert new_key in model_dict, '%s' % new_key
|
||||
model_dict[new_key] = state_dict[key]
|
||||
super(OFAMobileNetV3, self).load_state_dict(model_dict)
|
||||
|
||||
""" set, sample and get active sub-networks """
|
||||
|
||||
def set_max_net(self):
|
||||
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
|
||||
|
||||
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
||||
ks = val2list(ks, len(self.blocks) - 1)
|
||||
expand_ratio = val2list(e, len(self.blocks) - 1)
|
||||
depth = val2list(d, len(self.block_group_info))
|
||||
|
||||
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
||||
if k is not None:
|
||||
block.conv.active_kernel_size = k
|
||||
if e is not None:
|
||||
block.conv.active_expand_ratio = e
|
||||
|
||||
for i, d in enumerate(depth):
|
||||
if d is not None:
|
||||
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
||||
|
||||
def set_constraint(self, include_list, constraint_type='depth'):
|
||||
if constraint_type == 'depth':
|
||||
self.__dict__['_depth_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'expand_ratio':
|
||||
self.__dict__['_expand_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'kernel_size':
|
||||
self.__dict__['_ks_include_list'] = include_list.copy()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_constraint(self):
|
||||
self.__dict__['_depth_include_list'] = None
|
||||
self.__dict__['_expand_include_list'] = None
|
||||
self.__dict__['_ks_include_list'] = None
|
||||
|
||||
def sample_active_subnet(self):
|
||||
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
|
||||
else self.__dict__['_ks_include_list']
|
||||
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
|
||||
else self.__dict__['_expand_include_list']
|
||||
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
|
||||
self.__dict__['_depth_include_list']
|
||||
|
||||
# sample kernel size
|
||||
ks_setting = []
|
||||
if not isinstance(ks_candidates[0], list):
|
||||
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for k_set in ks_candidates:
|
||||
k = random.choice(k_set)
|
||||
ks_setting.append(k)
|
||||
|
||||
# sample expand ratio
|
||||
expand_setting = []
|
||||
if not isinstance(expand_candidates[0], list):
|
||||
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for e_set in expand_candidates:
|
||||
e = random.choice(e_set)
|
||||
expand_setting.append(e)
|
||||
|
||||
# sample depth
|
||||
depth_setting = []
|
||||
if not isinstance(depth_candidates[0], list):
|
||||
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
|
||||
for d_set in depth_candidates:
|
||||
d = random.choice(d_set)
|
||||
depth_setting.append(d)
|
||||
|
||||
import pdb; pdb.set_trace()
|
||||
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
||||
|
||||
return {
|
||||
'ks': ks_setting,
|
||||
'e': expand_setting,
|
||||
'd': depth_setting,
|
||||
}
|
||||
|
||||
def get_active_subnet(self, preserve_weight=True):
|
||||
first_conv = copy.deepcopy(self.first_conv)
|
||||
blocks = [copy.deepcopy(self.blocks[0])]
|
||||
|
||||
final_expand_layer = copy.deepcopy(self.final_expand_layer)
|
||||
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
||||
classifier = copy.deepcopy(self.classifier)
|
||||
|
||||
input_channel = blocks[0].conv.out_channels
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append(ResidualBlock(
|
||||
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
|
||||
copy.deepcopy(self.blocks[idx].shortcut),
|
||||
copy.deepcopy(self.blocks[idx].dropout_rate),
|
||||
copy.deepcopy(self.blocks[idx].dropblock),
|
||||
copy.deepcopy(self.blocks[idx].block_size),
|
||||
))
|
||||
input_channel = stage_blocks[-1].conv.out_channels
|
||||
blocks += stage_blocks
|
||||
|
||||
_subnet = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
_subnet.set_bn_param(**self.get_bn_param())
|
||||
return _subnet
|
||||
|
||||
def get_active_net_config(self):
|
||||
# first conv
|
||||
first_conv_config = self.first_conv.config
|
||||
first_block_config = self.blocks[0].config
|
||||
final_expand_config = self.final_expand_layer.config
|
||||
feature_mix_layer_config = self.feature_mix_layer.config
|
||||
classifier_config = self.classifier.config
|
||||
|
||||
block_config_list = [first_block_config]
|
||||
input_channel = first_block_config['conv']['out_channels']
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append({
|
||||
'name': ResidualBlock.__name__,
|
||||
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
|
||||
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
|
||||
})
|
||||
input_channel = self.blocks[idx].conv.active_out_channel
|
||||
block_config_list += stage_blocks
|
||||
|
||||
return {
|
||||
'name': MobileNetV3.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': first_conv_config,
|
||||
'blocks': block_config_list,
|
||||
'final_expand_layer': final_expand_config,
|
||||
'feature_mix_layer': feature_mix_layer_config,
|
||||
'classifier': classifier_config,
|
||||
}
|
||||
|
||||
""" Width Related Methods """
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
for block in self.blocks[1:]:
|
||||
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
||||
@@ -0,0 +1,331 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import random
|
||||
|
||||
from ofa_local.utils import make_divisible, val2list, MyNetwork
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules import DynamicMBConvLayer
|
||||
from ofa_local.utils.layers import ConvLayer, IdentityLayer, LinearLayer, MBConvLayer, ResidualBlock
|
||||
from ofa_local.imagenet_classification.networks.proxyless_nets import ProxylessNASNets
|
||||
|
||||
__all__ = ['OFAProxylessNASNets']
|
||||
|
||||
|
||||
class OFAProxylessNASNets(ProxylessNASNets):
|
||||
|
||||
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-3), dropout_rate=0.1, base_stage_width=None, width_mult=1.0,
|
||||
ks_list=3, expand_ratio_list=6, depth_list=4):
|
||||
|
||||
self.width_mult = width_mult
|
||||
self.ks_list = val2list(ks_list, 1)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list, 1)
|
||||
self.depth_list = val2list(depth_list, 1)
|
||||
|
||||
self.ks_list.sort()
|
||||
self.expand_ratio_list.sort()
|
||||
self.depth_list.sort()
|
||||
|
||||
if base_stage_width == 'google':
|
||||
# MobileNetV2 Stage Width
|
||||
base_stage_width = [32, 16, 24, 32, 64, 96, 160, 320, 1280]
|
||||
else:
|
||||
# ProxylessNAS Stage Width
|
||||
base_stage_width = [32, 16, 24, 40, 80, 96, 192, 320, 1280]
|
||||
|
||||
input_channel = make_divisible(base_stage_width[0] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
first_block_width = make_divisible(base_stage_width[1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(base_stage_width[-1] * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
|
||||
)
|
||||
# first block
|
||||
first_block_conv = MBConvLayer(
|
||||
in_channels=input_channel, out_channels=first_block_width, kernel_size=3, stride=1,
|
||||
expand_ratio=1, act_func='relu6',
|
||||
)
|
||||
first_block = ResidualBlock(first_block_conv, None)
|
||||
|
||||
input_channel = first_block_width
|
||||
# inverted residual blocks
|
||||
self.block_group_info = []
|
||||
blocks = [first_block]
|
||||
_block_index = 1
|
||||
|
||||
stride_stages = [2, 2, 2, 1, 2, 1]
|
||||
n_block_list = [max(self.depth_list)] * 5 + [1]
|
||||
|
||||
width_list = []
|
||||
for base_width in base_stage_width[2:-1]:
|
||||
width = make_divisible(base_width * self.width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
width_list.append(width)
|
||||
|
||||
for width, n_block, s in zip(width_list, n_block_list, stride_stages):
|
||||
self.block_group_info.append([_block_index + i for i in range(n_block)])
|
||||
_block_index += n_block
|
||||
|
||||
output_channel = width
|
||||
for i in range(n_block):
|
||||
if i == 0:
|
||||
stride = s
|
||||
else:
|
||||
stride = 1
|
||||
|
||||
mobile_inverted_conv = DynamicMBConvLayer(
|
||||
in_channel_list=val2list(input_channel, 1), out_channel_list=val2list(output_channel, 1),
|
||||
kernel_size_list=ks_list, expand_ratio_list=expand_ratio_list, stride=stride, act_func='relu6',
|
||||
)
|
||||
|
||||
if stride == 1 and input_channel == output_channel:
|
||||
shortcut = IdentityLayer(input_channel, input_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
|
||||
mb_inverted_block = ResidualBlock(mobile_inverted_conv, shortcut)
|
||||
|
||||
blocks.append(mb_inverted_block)
|
||||
input_channel = output_channel
|
||||
# 1x1_conv before global average pooling
|
||||
feature_mix_layer = ConvLayer(
|
||||
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6',
|
||||
)
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(OFAProxylessNASNets, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])
|
||||
|
||||
# runtime_depth
|
||||
self.runtime_depth = [len(block_idx) for block_idx in self.block_group_info]
|
||||
|
||||
""" MyNetwork required methods """
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'OFAProxylessNASNets'
|
||||
|
||||
def forward(self, x):
|
||||
# first conv
|
||||
x = self.first_conv(x)
|
||||
# first block
|
||||
x = self.blocks[0](x)
|
||||
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
x = self.blocks[idx](x)
|
||||
|
||||
# feature_mix_layer
|
||||
x = self.feature_mix_layer(x)
|
||||
x = x.mean(3).mean(2)
|
||||
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
_str += self.blocks[0].module_str + '\n'
|
||||
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
for idx in active_idx:
|
||||
_str += self.blocks[idx].module_str + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.classifier.module_str + '\n'
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': OFAProxylessNASNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise ValueError('do not support this function')
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
return self.block_group_info
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
model_dict = self.state_dict()
|
||||
for key in state_dict:
|
||||
if '.mobile_inverted_conv.' in key:
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
if new_key in model_dict:
|
||||
pass
|
||||
elif '.bn.bn.' in new_key:
|
||||
new_key = new_key.replace('.bn.bn.', '.bn.')
|
||||
elif '.conv.conv.weight' in new_key:
|
||||
new_key = new_key.replace('.conv.conv.weight', '.conv.weight')
|
||||
elif '.linear.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.linear.', '.linear.')
|
||||
##############################################################################
|
||||
elif '.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.', '.linear.linear.')
|
||||
elif 'bn.' in new_key:
|
||||
new_key = new_key.replace('bn.', 'bn.bn.')
|
||||
elif 'conv.weight' in new_key:
|
||||
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
|
||||
else:
|
||||
raise ValueError(new_key)
|
||||
assert new_key in model_dict, '%s' % new_key
|
||||
model_dict[new_key] = state_dict[key]
|
||||
super(OFAProxylessNASNets, self).load_state_dict(model_dict)
|
||||
|
||||
""" set, sample and get active sub-networks """
|
||||
|
||||
def set_max_net(self):
|
||||
self.set_active_subnet(ks=max(self.ks_list), e=max(self.expand_ratio_list), d=max(self.depth_list))
|
||||
|
||||
def set_active_subnet(self, ks=None, e=None, d=None, **kwargs):
|
||||
ks = val2list(ks, len(self.blocks) - 1)
|
||||
expand_ratio = val2list(e, len(self.blocks) - 1)
|
||||
depth = val2list(d, len(self.block_group_info))
|
||||
|
||||
for block, k, e in zip(self.blocks[1:], ks, expand_ratio):
|
||||
if k is not None:
|
||||
block.conv.active_kernel_size = k
|
||||
if e is not None:
|
||||
block.conv.active_expand_ratio = e
|
||||
|
||||
for i, d in enumerate(depth):
|
||||
if d is not None:
|
||||
self.runtime_depth[i] = min(len(self.block_group_info[i]), d)
|
||||
|
||||
def set_constraint(self, include_list, constraint_type='depth'):
|
||||
if constraint_type == 'depth':
|
||||
self.__dict__['_depth_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'expand_ratio':
|
||||
self.__dict__['_expand_include_list'] = include_list.copy()
|
||||
elif constraint_type == 'kernel_size':
|
||||
self.__dict__['_ks_include_list'] = include_list.copy()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_constraint(self):
|
||||
self.__dict__['_depth_include_list'] = None
|
||||
self.__dict__['_expand_include_list'] = None
|
||||
self.__dict__['_ks_include_list'] = None
|
||||
|
||||
def sample_active_subnet(self):
|
||||
ks_candidates = self.ks_list if self.__dict__.get('_ks_include_list', None) is None \
|
||||
else self.__dict__['_ks_include_list']
|
||||
expand_candidates = self.expand_ratio_list if self.__dict__.get('_expand_include_list', None) is None \
|
||||
else self.__dict__['_expand_include_list']
|
||||
depth_candidates = self.depth_list if self.__dict__.get('_depth_include_list', None) is None else \
|
||||
self.__dict__['_depth_include_list']
|
||||
|
||||
# sample kernel size
|
||||
ks_setting = []
|
||||
if not isinstance(ks_candidates[0], list):
|
||||
ks_candidates = [ks_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for k_set in ks_candidates:
|
||||
k = random.choice(k_set)
|
||||
ks_setting.append(k)
|
||||
|
||||
# sample expand ratio
|
||||
expand_setting = []
|
||||
if not isinstance(expand_candidates[0], list):
|
||||
expand_candidates = [expand_candidates for _ in range(len(self.blocks) - 1)]
|
||||
for e_set in expand_candidates:
|
||||
e = random.choice(e_set)
|
||||
expand_setting.append(e)
|
||||
|
||||
# sample depth
|
||||
depth_setting = []
|
||||
if not isinstance(depth_candidates[0], list):
|
||||
depth_candidates = [depth_candidates for _ in range(len(self.block_group_info))]
|
||||
for d_set in depth_candidates:
|
||||
d = random.choice(d_set)
|
||||
depth_setting.append(d)
|
||||
|
||||
depth_setting[-1] = 1
|
||||
self.set_active_subnet(ks_setting, expand_setting, depth_setting)
|
||||
|
||||
return {
|
||||
'ks': ks_setting,
|
||||
'e': expand_setting,
|
||||
'd': depth_setting,
|
||||
}
|
||||
|
||||
def get_active_subnet(self, preserve_weight=True):
|
||||
first_conv = copy.deepcopy(self.first_conv)
|
||||
blocks = [copy.deepcopy(self.blocks[0])]
|
||||
feature_mix_layer = copy.deepcopy(self.feature_mix_layer)
|
||||
classifier = copy.deepcopy(self.classifier)
|
||||
|
||||
input_channel = blocks[0].conv.out_channels
|
||||
# blocks
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append(ResidualBlock(
|
||||
self.blocks[idx].conv.get_active_subnet(input_channel, preserve_weight),
|
||||
copy.deepcopy(self.blocks[idx].shortcut)
|
||||
))
|
||||
input_channel = stage_blocks[-1].conv.out_channels
|
||||
blocks += stage_blocks
|
||||
|
||||
_subnet = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
|
||||
_subnet.set_bn_param(**self.get_bn_param())
|
||||
return _subnet
|
||||
|
||||
def get_active_net_config(self):
|
||||
first_conv_config = self.first_conv.config
|
||||
first_block_config = self.blocks[0].config
|
||||
feature_mix_layer_config = self.feature_mix_layer.config
|
||||
classifier_config = self.classifier.config
|
||||
|
||||
block_config_list = [first_block_config]
|
||||
input_channel = first_block_config['conv']['out_channels']
|
||||
for stage_id, block_idx in enumerate(self.block_group_info):
|
||||
depth = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:depth]
|
||||
stage_blocks = []
|
||||
for idx in active_idx:
|
||||
stage_blocks.append({
|
||||
'name': ResidualBlock.__name__,
|
||||
'conv': self.blocks[idx].conv.get_active_subnet_config(input_channel),
|
||||
'shortcut': self.blocks[idx].shortcut.config if self.blocks[idx].shortcut is not None else None,
|
||||
})
|
||||
try:
|
||||
input_channel = self.blocks[idx].conv.active_out_channel
|
||||
except Exception:
|
||||
input_channel = self.blocks[idx].conv.out_channels
|
||||
block_config_list += stage_blocks
|
||||
|
||||
return {
|
||||
'name': ProxylessNASNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': first_conv_config,
|
||||
'blocks': block_config_list,
|
||||
'feature_mix_layer': feature_mix_layer_config,
|
||||
'classifier': classifier_config,
|
||||
}
|
||||
|
||||
""" Width Related Methods """
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
for block in self.blocks[1:]:
|
||||
block.conv.re_organize_middle_weights(expand_ratio_stage)
|
||||
@@ -0,0 +1,267 @@
|
||||
import random
|
||||
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicConvLayer, DynamicLinearLayer
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_layers import DynamicResNetBottleneckBlock
|
||||
from ofa_local.utils.layers import IdentityLayer, ResidualBlock
|
||||
from ofa_local.imagenet_classification.networks import ResNets
|
||||
from ofa_local.utils import make_divisible, val2list, MyNetwork
|
||||
|
||||
__all__ = ['OFAResNets']
|
||||
|
||||
|
||||
class OFAResNets(ResNets):
|
||||
|
||||
def __init__(self, n_classes=1000, bn_param=(0.1, 1e-5), dropout_rate=0,
|
||||
depth_list=2, expand_ratio_list=0.25, width_mult_list=1.0):
|
||||
|
||||
self.depth_list = val2list(depth_list)
|
||||
self.expand_ratio_list = val2list(expand_ratio_list)
|
||||
self.width_mult_list = val2list(width_mult_list)
|
||||
# sort
|
||||
self.depth_list.sort()
|
||||
self.expand_ratio_list.sort()
|
||||
self.width_mult_list.sort()
|
||||
|
||||
input_channel = [
|
||||
make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
|
||||
]
|
||||
mid_input_channel = [
|
||||
make_divisible(channel // 2, MyNetwork.CHANNEL_DIVISIBLE) for channel in input_channel
|
||||
]
|
||||
|
||||
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
||||
for i, width in enumerate(stage_width_list):
|
||||
stage_width_list[i] = [
|
||||
make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE) for width_mult in self.width_mult_list
|
||||
]
|
||||
|
||||
n_block_list = [base_depth + max(self.depth_list) for base_depth in ResNets.BASE_DEPTH_LIST]
|
||||
stride_list = [1, 2, 2, 2]
|
||||
|
||||
# build input stem
|
||||
input_stem = [
|
||||
DynamicConvLayer(val2list(3), mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
|
||||
ResidualBlock(
|
||||
DynamicConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
|
||||
IdentityLayer(mid_input_channel, mid_input_channel)
|
||||
),
|
||||
DynamicConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
|
||||
]
|
||||
|
||||
# blocks
|
||||
blocks = []
|
||||
for d, width, s in zip(n_block_list, stage_width_list, stride_list):
|
||||
for i in range(d):
|
||||
stride = s if i == 0 else 1
|
||||
bottleneck_block = DynamicResNetBottleneckBlock(
|
||||
input_channel, width, expand_ratio_list=self.expand_ratio_list,
|
||||
kernel_size=3, stride=stride, act_func='relu', downsample_mode='avgpool_conv',
|
||||
)
|
||||
blocks.append(bottleneck_block)
|
||||
input_channel = width
|
||||
# classifier
|
||||
classifier = DynamicLinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(OFAResNets, self).__init__(input_stem, blocks, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
|
||||
# runtime_depth
|
||||
self.input_stem_skipping = 0
|
||||
self.runtime_depth = [0] * len(n_block_list)
|
||||
|
||||
@property
|
||||
def ks_list(self):
|
||||
return [3]
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
return 'OFAResNets'
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.input_stem:
|
||||
if self.input_stem_skipping > 0 \
|
||||
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
|
||||
pass
|
||||
else:
|
||||
x = layer(x)
|
||||
x = self.max_pooling(x)
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
x = self.blocks[idx](x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = ''
|
||||
for layer in self.input_stem:
|
||||
if self.input_stem_skipping > 0 \
|
||||
and isinstance(layer, ResidualBlock) and isinstance(layer.shortcut, IdentityLayer):
|
||||
pass
|
||||
else:
|
||||
_str += layer.module_str + '\n'
|
||||
_str += 'max_pooling(ks=3, stride=2)\n'
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
_str += self.blocks[idx].module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': OFAResNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'input_stem': [
|
||||
layer.config for layer in self.input_stem
|
||||
],
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise ValueError('do not support this function')
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
model_dict = self.state_dict()
|
||||
for key in state_dict:
|
||||
new_key = key
|
||||
if new_key in model_dict:
|
||||
pass
|
||||
elif '.linear.' in new_key:
|
||||
new_key = new_key.replace('.linear.', '.linear.linear.')
|
||||
elif 'bn.' in new_key:
|
||||
new_key = new_key.replace('bn.', 'bn.bn.')
|
||||
elif 'conv.weight' in new_key:
|
||||
new_key = new_key.replace('conv.weight', 'conv.conv.weight')
|
||||
else:
|
||||
raise ValueError(new_key)
|
||||
assert new_key in model_dict, '%s' % new_key
|
||||
model_dict[new_key] = state_dict[key]
|
||||
super(OFAResNets, self).load_state_dict(model_dict)
|
||||
|
||||
""" set, sample and get active sub-networks """
|
||||
|
||||
def set_max_net(self):
|
||||
self.set_active_subnet(d=max(self.depth_list), e=max(self.expand_ratio_list), w=len(self.width_mult_list) - 1)
|
||||
|
||||
def set_active_subnet(self, d=None, e=None, w=None, **kwargs):
|
||||
depth = val2list(d, len(ResNets.BASE_DEPTH_LIST) + 1)
|
||||
expand_ratio = val2list(e, len(self.blocks))
|
||||
width_mult = val2list(w, len(ResNets.BASE_DEPTH_LIST) + 2)
|
||||
|
||||
for block, e in zip(self.blocks, expand_ratio):
|
||||
if e is not None:
|
||||
block.active_expand_ratio = e
|
||||
|
||||
if width_mult[0] is not None:
|
||||
self.input_stem[1].conv.active_out_channel = self.input_stem[0].active_out_channel = \
|
||||
self.input_stem[0].out_channel_list[width_mult[0]]
|
||||
if width_mult[1] is not None:
|
||||
self.input_stem[2].active_out_channel = self.input_stem[2].out_channel_list[width_mult[1]]
|
||||
|
||||
if depth[0] is not None:
|
||||
self.input_stem_skipping = (depth[0] != max(self.depth_list))
|
||||
for stage_id, (block_idx, d, w) in enumerate(zip(self.grouped_block_index, depth[1:], width_mult[2:])):
|
||||
if d is not None:
|
||||
self.runtime_depth[stage_id] = max(self.depth_list) - d
|
||||
if w is not None:
|
||||
for idx in block_idx:
|
||||
self.blocks[idx].active_out_channel = self.blocks[idx].out_channel_list[w]
|
||||
|
||||
def sample_active_subnet(self):
|
||||
# sample expand ratio
|
||||
expand_setting = []
|
||||
for block in self.blocks:
|
||||
expand_setting.append(random.choice(block.expand_ratio_list))
|
||||
|
||||
# sample depth
|
||||
depth_setting = [random.choice([max(self.depth_list), min(self.depth_list)])]
|
||||
for stage_id in range(len(ResNets.BASE_DEPTH_LIST)):
|
||||
depth_setting.append(random.choice(self.depth_list))
|
||||
|
||||
# sample width_mult
|
||||
width_mult_setting = [
|
||||
random.choice(list(range(len(self.input_stem[0].out_channel_list)))),
|
||||
random.choice(list(range(len(self.input_stem[2].out_channel_list)))),
|
||||
]
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
stage_first_block = self.blocks[block_idx[0]]
|
||||
width_mult_setting.append(
|
||||
random.choice(list(range(len(stage_first_block.out_channel_list))))
|
||||
)
|
||||
|
||||
arch_config = {
|
||||
'd': depth_setting,
|
||||
'e': expand_setting,
|
||||
'w': width_mult_setting
|
||||
}
|
||||
self.set_active_subnet(**arch_config)
|
||||
return arch_config
|
||||
|
||||
def get_active_subnet(self, preserve_weight=True):
|
||||
input_stem = [self.input_stem[0].get_active_subnet(3, preserve_weight)]
|
||||
if self.input_stem_skipping <= 0:
|
||||
input_stem.append(ResidualBlock(
|
||||
self.input_stem[1].conv.get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight),
|
||||
IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel)
|
||||
))
|
||||
input_stem.append(self.input_stem[2].get_active_subnet(self.input_stem[0].active_out_channel, preserve_weight))
|
||||
input_channel = self.input_stem[2].active_out_channel
|
||||
|
||||
blocks = []
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
blocks.append(self.blocks[idx].get_active_subnet(input_channel, preserve_weight))
|
||||
input_channel = self.blocks[idx].active_out_channel
|
||||
classifier = self.classifier.get_active_subnet(input_channel, preserve_weight)
|
||||
subnet = ResNets(input_stem, blocks, classifier)
|
||||
|
||||
subnet.set_bn_param(**self.get_bn_param())
|
||||
return subnet
|
||||
|
||||
def get_active_net_config(self):
|
||||
input_stem_config = [self.input_stem[0].get_active_subnet_config(3)]
|
||||
if self.input_stem_skipping <= 0:
|
||||
input_stem_config.append({
|
||||
'name': ResidualBlock.__name__,
|
||||
'conv': self.input_stem[1].conv.get_active_subnet_config(self.input_stem[0].active_out_channel),
|
||||
'shortcut': IdentityLayer(self.input_stem[0].active_out_channel, self.input_stem[0].active_out_channel),
|
||||
})
|
||||
input_stem_config.append(self.input_stem[2].get_active_subnet_config(self.input_stem[0].active_out_channel))
|
||||
input_channel = self.input_stem[2].active_out_channel
|
||||
|
||||
blocks_config = []
|
||||
for stage_id, block_idx in enumerate(self.grouped_block_index):
|
||||
depth_param = self.runtime_depth[stage_id]
|
||||
active_idx = block_idx[:len(block_idx) - depth_param]
|
||||
for idx in active_idx:
|
||||
blocks_config.append(self.blocks[idx].get_active_subnet_config(input_channel))
|
||||
input_channel = self.blocks[idx].active_out_channel
|
||||
classifier_config = self.classifier.get_active_subnet_config(input_channel)
|
||||
return {
|
||||
'name': ResNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'input_stem': input_stem_config,
|
||||
'blocks': blocks_config,
|
||||
'classifier': classifier_config,
|
||||
}
|
||||
|
||||
""" Width Related Methods """
|
||||
|
||||
def re_organize_middle_weights(self, expand_ratio_stage=0):
|
||||
for block in self.blocks:
|
||||
block.re_organize_middle_weights(expand_ratio_stage)
|
||||
@@ -0,0 +1,5 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .progressive_shrinking import *
|
||||
@@ -0,0 +1,320 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import torch.nn as nn
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from ofa.utils import AverageMeter, cross_entropy_loss_with_soft_target
|
||||
from ofa.utils import DistributedMetric, list_mean, subset_mean, val2list, MyRandomResizedCrop
|
||||
from ofa.imagenet_classification.run_manager import DistributedRunManager
|
||||
|
||||
__all__ = [
|
||||
'validate', 'train_one_epoch', 'train', 'load_models',
|
||||
'train_elastic_depth', 'train_elastic_expand', 'train_elastic_width_mult',
|
||||
]
|
||||
|
||||
|
||||
def validate(run_manager, epoch=0, is_test=False, image_size_list=None,
|
||||
ks_list=None, expand_ratio_list=None, depth_list=None, width_mult_list=None, additional_setting=None):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
dynamic_net.eval()
|
||||
|
||||
if image_size_list is None:
|
||||
image_size_list = val2list(run_manager.run_config.data_provider.image_size, 1)
|
||||
if ks_list is None:
|
||||
ks_list = dynamic_net.ks_list
|
||||
if expand_ratio_list is None:
|
||||
expand_ratio_list = dynamic_net.expand_ratio_list
|
||||
if depth_list is None:
|
||||
depth_list = dynamic_net.depth_list
|
||||
if width_mult_list is None:
|
||||
if 'width_mult_list' in dynamic_net.__dict__:
|
||||
width_mult_list = list(range(len(dynamic_net.width_mult_list)))
|
||||
else:
|
||||
width_mult_list = [0]
|
||||
|
||||
subnet_settings = []
|
||||
for d in depth_list:
|
||||
for e in expand_ratio_list:
|
||||
for k in ks_list:
|
||||
for w in width_mult_list:
|
||||
for img_size in image_size_list:
|
||||
subnet_settings.append([{
|
||||
'image_size': img_size,
|
||||
'd': d,
|
||||
'e': e,
|
||||
'ks': k,
|
||||
'w': w,
|
||||
}, 'R%s-D%s-E%s-K%s-W%s' % (img_size, d, e, k, w)])
|
||||
if additional_setting is not None:
|
||||
subnet_settings += additional_setting
|
||||
|
||||
losses_of_subnets, top1_of_subnets, top5_of_subnets = [], [], []
|
||||
|
||||
valid_log = ''
|
||||
for setting, name in subnet_settings:
|
||||
run_manager.write_log('-' * 30 + ' Validate %s ' % name + '-' * 30, 'train', should_print=False)
|
||||
run_manager.run_config.data_provider.assign_active_img_size(setting.pop('image_size'))
|
||||
dynamic_net.set_active_subnet(**setting)
|
||||
run_manager.write_log(dynamic_net.module_str, 'train', should_print=False)
|
||||
|
||||
run_manager.reset_running_statistics(dynamic_net)
|
||||
loss, (top1, top5) = run_manager.validate(epoch=epoch, is_test=is_test, run_str=name, net=dynamic_net)
|
||||
losses_of_subnets.append(loss)
|
||||
top1_of_subnets.append(top1)
|
||||
top5_of_subnets.append(top5)
|
||||
valid_log += '%s (%.3f), ' % (name, top1)
|
||||
|
||||
return list_mean(losses_of_subnets), list_mean(top1_of_subnets), list_mean(top5_of_subnets), valid_log
|
||||
|
||||
|
||||
def train_one_epoch(run_manager, args, epoch, warmup_epochs=0, warmup_lr=0):
|
||||
dynamic_net = run_manager.network
|
||||
distributed = isinstance(run_manager, DistributedRunManager)
|
||||
|
||||
# switch to train mode
|
||||
dynamic_net.train()
|
||||
if distributed:
|
||||
run_manager.run_config.train_loader.sampler.set_epoch(epoch)
|
||||
MyRandomResizedCrop.EPOCH = epoch
|
||||
|
||||
nBatch = len(run_manager.run_config.train_loader)
|
||||
|
||||
data_time = AverageMeter()
|
||||
losses = DistributedMetric('train_loss') if distributed else AverageMeter()
|
||||
metric_dict = run_manager.get_metric_dict()
|
||||
|
||||
with tqdm(total=nBatch,
|
||||
desc='Train Epoch #{}'.format(epoch + 1),
|
||||
disable=distributed and not run_manager.is_root) as t:
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(run_manager.run_config.train_loader):
|
||||
MyRandomResizedCrop.BATCH = i
|
||||
data_time.update(time.time() - end)
|
||||
if epoch < warmup_epochs:
|
||||
new_lr = run_manager.run_config.warmup_adjust_learning_rate(
|
||||
run_manager.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
|
||||
)
|
||||
else:
|
||||
new_lr = run_manager.run_config.adjust_learning_rate(
|
||||
run_manager.optimizer, epoch - warmup_epochs, i, nBatch
|
||||
)
|
||||
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
target = labels
|
||||
|
||||
# soft target
|
||||
if args.kd_ratio > 0:
|
||||
args.teacher_model.train()
|
||||
with torch.no_grad():
|
||||
soft_logits = args.teacher_model(images).detach()
|
||||
soft_label = F.softmax(soft_logits, dim=1)
|
||||
|
||||
# clean gradients
|
||||
dynamic_net.zero_grad()
|
||||
|
||||
loss_of_subnets = []
|
||||
# compute output
|
||||
subnet_str = ''
|
||||
for _ in range(args.dynamic_batch_size):
|
||||
# set random seed before sampling
|
||||
subnet_seed = int('%d%.3d%.3d' % (epoch * nBatch + i, _, 0))
|
||||
random.seed(subnet_seed)
|
||||
subnet_settings = dynamic_net.sample_active_subnet()
|
||||
subnet_str += '%d: ' % _ + ','.join(['%s_%s' % (
|
||||
key, '%.1f' % subset_mean(val, 0) if isinstance(val, list) else val
|
||||
) for key, val in subnet_settings.items()]) + ' || '
|
||||
|
||||
output = run_manager.net(images)
|
||||
if args.kd_ratio == 0:
|
||||
loss = run_manager.train_criterion(output, labels)
|
||||
loss_type = 'ce'
|
||||
else:
|
||||
if args.kd_type == 'ce':
|
||||
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
else:
|
||||
kd_loss = F.mse_loss(output, soft_logits)
|
||||
loss = args.kd_ratio * kd_loss + run_manager.train_criterion(output, labels)
|
||||
loss_type = '%.1fkd-%s & ce' % (args.kd_ratio, args.kd_type)
|
||||
|
||||
# measure accuracy and record loss
|
||||
loss_of_subnets.append(loss)
|
||||
run_manager.update_metric(metric_dict, output, target)
|
||||
|
||||
loss.backward()
|
||||
run_manager.optimizer.step()
|
||||
|
||||
losses.update(list_mean(loss_of_subnets), images.size(0))
|
||||
|
||||
t.set_postfix({
|
||||
'loss': losses.avg.item(),
|
||||
**run_manager.get_metric_vals(metric_dict, return_dict=True),
|
||||
'R': images.size(2),
|
||||
'lr': new_lr,
|
||||
'loss_type': loss_type,
|
||||
'seed': str(subnet_seed),
|
||||
'str': subnet_str,
|
||||
'data_time': data_time.avg,
|
||||
})
|
||||
t.update(1)
|
||||
end = time.time()
|
||||
return losses.avg.item(), run_manager.get_metric_vals(metric_dict)
|
||||
|
||||
|
||||
def train(run_manager, args, validate_func=None):
|
||||
distributed = isinstance(run_manager, DistributedRunManager)
|
||||
if validate_func is None:
|
||||
validate_func = validate
|
||||
|
||||
for epoch in range(run_manager.start_epoch, run_manager.run_config.n_epochs + args.warmup_epochs):
|
||||
train_loss, (train_top1, train_top5) = train_one_epoch(
|
||||
run_manager, args, epoch, args.warmup_epochs, args.warmup_lr)
|
||||
|
||||
if (epoch + 1) % args.validation_frequency == 0:
|
||||
val_loss, val_acc, val_acc5, _val_log = validate_func(run_manager, epoch=epoch, is_test=False)
|
||||
# best_acc
|
||||
is_best = val_acc > run_manager.best_acc
|
||||
run_manager.best_acc = max(run_manager.best_acc, val_acc)
|
||||
if not distributed or run_manager.is_root:
|
||||
val_log = 'Valid [{0}/{1}] loss={2:.3f}, top-1={3:.3f} ({4:.3f})'. \
|
||||
format(epoch + 1 - args.warmup_epochs, run_manager.run_config.n_epochs, val_loss, val_acc,
|
||||
run_manager.best_acc)
|
||||
val_log += ', Train top-1 {top1:.3f}, Train loss {loss:.3f}\t'.format(top1=train_top1, loss=train_loss)
|
||||
val_log += _val_log
|
||||
run_manager.write_log(val_log, 'valid', should_print=False)
|
||||
|
||||
run_manager.save_model({
|
||||
'epoch': epoch,
|
||||
'best_acc': run_manager.best_acc,
|
||||
'optimizer': run_manager.optimizer.state_dict(),
|
||||
'state_dict': run_manager.network.state_dict(),
|
||||
}, is_best=is_best)
|
||||
|
||||
|
||||
def load_models(run_manager, dynamic_net, model_path=None):
|
||||
# specify init path
|
||||
init = torch.load(model_path, map_location='cpu')['state_dict']
|
||||
dynamic_net.load_state_dict(init)
|
||||
run_manager.write_log('Loaded init from %s' % model_path, 'valid')
|
||||
|
||||
|
||||
def train_elastic_depth(train_func, run_manager, args, validate_func_dict):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
depth_stage_list = dynamic_net.depth_list.copy()
|
||||
depth_stage_list.sort(reverse=True)
|
||||
n_stages = len(depth_stage_list) - 1
|
||||
current_stage = n_stages - 1
|
||||
|
||||
# load pretrained models
|
||||
if run_manager.start_epoch == 0 and not args.resume:
|
||||
validate_func_dict['depth_list'] = sorted(dynamic_net.depth_list)
|
||||
|
||||
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
|
||||
# validate after loading weights
|
||||
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
|
||||
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
|
||||
else:
|
||||
assert args.resume
|
||||
|
||||
run_manager.write_log(
|
||||
'-' * 30 + 'Supporting Elastic Depth: %s -> %s' %
|
||||
(depth_stage_list[:current_stage + 1], depth_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
|
||||
)
|
||||
# add depth list constraints
|
||||
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.expand_ratio_list)) == 1:
|
||||
validate_func_dict['depth_list'] = depth_stage_list
|
||||
else:
|
||||
validate_func_dict['depth_list'] = sorted({min(depth_stage_list), max(depth_stage_list)})
|
||||
|
||||
# train
|
||||
train_func(
|
||||
run_manager, args,
|
||||
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
|
||||
)
|
||||
|
||||
|
||||
def train_elastic_expand(train_func, run_manager, args, validate_func_dict):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
expand_stage_list = dynamic_net.expand_ratio_list.copy()
|
||||
expand_stage_list.sort(reverse=True)
|
||||
n_stages = len(expand_stage_list) - 1
|
||||
current_stage = n_stages - 1
|
||||
|
||||
# load pretrained models
|
||||
if run_manager.start_epoch == 0 and not args.resume:
|
||||
validate_func_dict['expand_ratio_list'] = sorted(dynamic_net.expand_ratio_list)
|
||||
|
||||
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
|
||||
dynamic_net.re_organize_middle_weights(expand_ratio_stage=current_stage)
|
||||
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
|
||||
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
|
||||
else:
|
||||
assert args.resume
|
||||
|
||||
run_manager.write_log(
|
||||
'-' * 30 + 'Supporting Elastic Expand Ratio: %s -> %s' %
|
||||
(expand_stage_list[:current_stage + 1], expand_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
|
||||
)
|
||||
if len(set(dynamic_net.ks_list)) == 1 and len(set(dynamic_net.depth_list)) == 1:
|
||||
validate_func_dict['expand_ratio_list'] = expand_stage_list
|
||||
else:
|
||||
validate_func_dict['expand_ratio_list'] = sorted({min(expand_stage_list), max(expand_stage_list)})
|
||||
|
||||
# train
|
||||
train_func(
|
||||
run_manager, args,
|
||||
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
|
||||
)
|
||||
|
||||
|
||||
def train_elastic_width_mult(train_func, run_manager, args, validate_func_dict):
|
||||
dynamic_net = run_manager.net
|
||||
if isinstance(dynamic_net, nn.DataParallel):
|
||||
dynamic_net = dynamic_net.module
|
||||
|
||||
width_stage_list = dynamic_net.width_mult_list.copy()
|
||||
width_stage_list.sort(reverse=True)
|
||||
n_stages = len(width_stage_list) - 1
|
||||
current_stage = n_stages - 1
|
||||
|
||||
if run_manager.start_epoch == 0 and not args.resume:
|
||||
load_models(run_manager, dynamic_net, model_path=args.ofa_checkpoint_path)
|
||||
if current_stage == 0:
|
||||
dynamic_net.re_organize_middle_weights(expand_ratio_stage=len(dynamic_net.expand_ratio_list) - 1)
|
||||
run_manager.write_log('reorganize_middle_weights (expand_ratio_stage=%d)'
|
||||
% (len(dynamic_net.expand_ratio_list) - 1), 'valid')
|
||||
try:
|
||||
dynamic_net.re_organize_outer_weights()
|
||||
run_manager.write_log('reorganize_outer_weights', 'valid')
|
||||
except Exception:
|
||||
pass
|
||||
run_manager.write_log('%.3f\t%.3f\t%.3f\t%s' %
|
||||
validate(run_manager, is_test=True, **validate_func_dict), 'valid')
|
||||
else:
|
||||
assert args.resume
|
||||
|
||||
run_manager.write_log(
|
||||
'-' * 30 + 'Supporting Elastic Width Mult: %s -> %s' %
|
||||
(width_stage_list[:current_stage + 1], width_stage_list[:current_stage + 2]) + '-' * 30, 'valid'
|
||||
)
|
||||
validate_func_dict['width_mult_list'] = sorted({0, len(width_stage_list) - 1})
|
||||
|
||||
# train
|
||||
train_func(
|
||||
run_manager, args,
|
||||
lambda _run_manager, epoch, is_test: validate(_run_manager, epoch, is_test, **validate_func_dict)
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
from ofa_local.utils import AverageMeter, get_net_device, DistributedTensor
|
||||
from ofa_local.imagenet_classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
|
||||
|
||||
__all__ = ['set_running_statistics']
|
||||
|
||||
|
||||
def set_running_statistics(model, data_loader, distributed=False):
|
||||
bn_mean = {}
|
||||
bn_var = {}
|
||||
|
||||
forward_model = copy.deepcopy(model)
|
||||
for name, m in forward_model.named_modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
if distributed:
|
||||
bn_mean[name] = DistributedTensor(name + '#mean')
|
||||
bn_var[name] = DistributedTensor(name + '#var')
|
||||
else:
|
||||
bn_mean[name] = AverageMeter()
|
||||
bn_var[name] = AverageMeter()
|
||||
|
||||
def new_forward(bn, mean_est, var_est):
|
||||
def lambda_forward(x):
|
||||
batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) # 1, C, 1, 1
|
||||
batch_var = (x - batch_mean) * (x - batch_mean)
|
||||
batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
|
||||
|
||||
batch_mean = torch.squeeze(batch_mean)
|
||||
batch_var = torch.squeeze(batch_var)
|
||||
|
||||
mean_est.update(batch_mean.data, x.size(0))
|
||||
var_est.update(batch_var.data, x.size(0))
|
||||
|
||||
# bn forward using calculated mean & var
|
||||
_feature_dim = batch_mean.size(0)
|
||||
return F.batch_norm(
|
||||
x, batch_mean, batch_var, bn.weight[:_feature_dim],
|
||||
bn.bias[:_feature_dim], False,
|
||||
0.0, bn.eps,
|
||||
)
|
||||
|
||||
return lambda_forward
|
||||
|
||||
m.forward = new_forward(m, bn_mean[name], bn_var[name])
|
||||
|
||||
if len(bn_mean) == 0:
|
||||
# skip if there is no batch normalization layers in the network
|
||||
return
|
||||
|
||||
with torch.no_grad():
|
||||
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
|
||||
for images, labels in data_loader:
|
||||
images = images.to(get_net_device(forward_model))
|
||||
forward_model(images)
|
||||
DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if name in bn_mean and bn_mean[name].count > 0:
|
||||
feature_dim = bn_mean[name].avg.size(0)
|
||||
assert isinstance(m, nn.BatchNorm2d)
|
||||
m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
|
||||
m.running_var.data[:feature_dim].copy_(bn_var[name].avg)
|
||||
@@ -0,0 +1,18 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .proxyless_nets import *
|
||||
from .mobilenet_v3 import *
|
||||
from .resnets import *
|
||||
|
||||
|
||||
def get_net_by_name(name):
|
||||
if name == ProxylessNASNets.__name__:
|
||||
return ProxylessNASNets
|
||||
elif name == MobileNetV3.__name__:
|
||||
return MobileNetV3
|
||||
elif name == ResNets.__name__:
|
||||
return ResNets
|
||||
else:
|
||||
raise ValueError('unrecognized type of network: %s' % name)
|
||||
@@ -0,0 +1,218 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import torch.nn as nn
|
||||
|
||||
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
|
||||
from ofa_local.utils import MyNetwork, make_divisible, MyGlobalAvgPool2d
|
||||
|
||||
__all__ = ['MobileNetV3', 'MobileNetV3Large']
|
||||
|
||||
|
||||
class MobileNetV3(MyNetwork):
|
||||
|
||||
def __init__(self, first_conv, blocks, final_expand_layer, feature_mix_layer, classifier):
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.first_conv = first_conv
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.final_expand_layer = final_expand_layer
|
||||
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=True)
|
||||
self.feature_mix_layer = feature_mix_layer
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_conv(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.final_expand_layer(x)
|
||||
x = self.global_avg_pool(x) # global average pooling
|
||||
x = self.feature_mix_layer(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
for block in self.blocks:
|
||||
_str += block.module_str + '\n'
|
||||
_str += self.final_expand_layer.module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': MobileNetV3.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'final_expand_layer': self.final_expand_layer.config,
|
||||
'feature_mix_layer': self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
first_conv = set_layer_from_config(config['first_conv'])
|
||||
final_expand_layer = set_layer_from_config(config['final_expand_layer'])
|
||||
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
blocks = []
|
||||
for block_config in config['blocks']:
|
||||
blocks.append(ResidualBlock.build_from_config(block_config))
|
||||
|
||||
net = MobileNetV3(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-5)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResidualBlock):
|
||||
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
|
||||
m.conv.point_linear.bn.weight.data.zero_()
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
info_list = []
|
||||
block_index_list = []
|
||||
for i, block in enumerate(self.blocks[1:], 1):
|
||||
if block.shortcut is None and len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
block_index_list = []
|
||||
block_index_list.append(i)
|
||||
if len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
return info_list
|
||||
|
||||
@staticmethod
|
||||
def build_net_via_cfg(cfg, input_channel, last_channel, n_classes, dropout_rate):
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='h_swish', ops_order='weight_bn_act'
|
||||
)
|
||||
# build mobile blocks
|
||||
feature_dim = input_channel
|
||||
blocks = []
|
||||
for stage_id, block_config_list in cfg.items():
|
||||
for k, mid_channel, out_channel, use_se, act_func, stride, expand_ratio in block_config_list:
|
||||
mb_conv = MBConvLayer(
|
||||
feature_dim, out_channel, k, stride, expand_ratio, mid_channel, act_func, use_se
|
||||
)
|
||||
if stride == 1 and out_channel == feature_dim:
|
||||
shortcut = IdentityLayer(out_channel, out_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(ResidualBlock(mb_conv, shortcut))
|
||||
feature_dim = out_channel
|
||||
# final expand layer
|
||||
final_expand_layer = ConvLayer(
|
||||
feature_dim, feature_dim * 6, kernel_size=1, use_bn=True, act_func='h_swish', ops_order='weight_bn_act',
|
||||
)
|
||||
# feature mix layer
|
||||
feature_mix_layer = ConvLayer(
|
||||
feature_dim * 6, last_channel, kernel_size=1, bias=False, use_bn=False, act_func='h_swish',
|
||||
)
|
||||
# classifier
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
return first_conv, blocks, final_expand_layer, feature_mix_layer, classifier
|
||||
|
||||
@staticmethod
|
||||
def adjust_cfg(cfg, ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
|
||||
for i, (stage_id, block_config_list) in enumerate(cfg.items()):
|
||||
for block_config in block_config_list:
|
||||
if ks is not None and stage_id != '0':
|
||||
block_config[0] = ks
|
||||
if expand_ratio is not None and stage_id != '0':
|
||||
block_config[-1] = expand_ratio
|
||||
block_config[1] = None
|
||||
if stage_width_list is not None:
|
||||
block_config[2] = stage_width_list[i]
|
||||
if depth_param is not None and stage_id != '0':
|
||||
new_block_config_list = [block_config_list[0]]
|
||||
new_block_config_list += [copy.deepcopy(block_config_list[-1]) for _ in range(depth_param - 1)]
|
||||
cfg[stage_id] = new_block_config_list
|
||||
return cfg
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
current_state_dict = self.state_dict()
|
||||
|
||||
for key in state_dict:
|
||||
if key not in current_state_dict:
|
||||
assert '.mobile_inverted_conv.' in key
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
current_state_dict[new_key] = state_dict[key]
|
||||
super(MobileNetV3, self).load_state_dict(current_state_dict)
|
||||
|
||||
|
||||
class MobileNetV3Large(MobileNetV3):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0.2,
|
||||
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
|
||||
input_channel = 16
|
||||
last_channel = 1280
|
||||
|
||||
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
|
||||
if width_mult > 1.0 else last_channel
|
||||
|
||||
cfg = {
|
||||
# k, exp, c, se, nl, s, e,
|
||||
'0': [
|
||||
[3, 16, 16, False, 'relu', 1, 1],
|
||||
],
|
||||
'1': [
|
||||
[3, 64, 24, False, 'relu', 2, None], # 4
|
||||
[3, 72, 24, False, 'relu', 1, None], # 3
|
||||
],
|
||||
'2': [
|
||||
[5, 72, 40, True, 'relu', 2, None], # 3
|
||||
[5, 120, 40, True, 'relu', 1, None], # 3
|
||||
[5, 120, 40, True, 'relu', 1, None], # 3
|
||||
],
|
||||
'3': [
|
||||
[3, 240, 80, False, 'h_swish', 2, None], # 6
|
||||
[3, 200, 80, False, 'h_swish', 1, None], # 2.5
|
||||
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
|
||||
[3, 184, 80, False, 'h_swish', 1, None], # 2.3
|
||||
],
|
||||
'4': [
|
||||
[3, 480, 112, True, 'h_swish', 1, None], # 6
|
||||
[3, 672, 112, True, 'h_swish', 1, None], # 6
|
||||
],
|
||||
'5': [
|
||||
[5, 672, 160, True, 'h_swish', 2, None], # 6
|
||||
[5, 960, 160, True, 'h_swish', 1, None], # 6
|
||||
[5, 960, 160, True, 'h_swish', 1, None], # 6
|
||||
]
|
||||
}
|
||||
|
||||
cfg = self.adjust_cfg(cfg, ks, expand_ratio, depth_param, stage_width_list)
|
||||
# width multiplier on mobile setting, change `exp: 1` and `c: 2`
|
||||
for stage_id, block_config_list in cfg.items():
|
||||
for block_config in block_config_list:
|
||||
if block_config[1] is not None:
|
||||
block_config[1] = make_divisible(block_config[1] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
block_config[2] = make_divisible(block_config[2] * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
first_conv, blocks, final_expand_layer, feature_mix_layer, classifier = self.build_net_via_cfg(
|
||||
cfg, input_channel, last_channel, n_classes, dropout_rate
|
||||
)
|
||||
super(MobileNetV3Large, self).__init__(first_conv, blocks, final_expand_layer, feature_mix_layer, classifier)
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
@@ -0,0 +1,210 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import json
|
||||
import torch.nn as nn
|
||||
|
||||
from ofa_local.utils.layers import set_layer_from_config, MBConvLayer, ConvLayer, IdentityLayer, LinearLayer, ResidualBlock
|
||||
from ofa_local.utils import download_url, make_divisible, val2list, MyNetwork, MyGlobalAvgPool2d
|
||||
|
||||
__all__ = ['proxyless_base', 'ProxylessNASNets', 'MobileNetV2']
|
||||
|
||||
|
||||
def proxyless_base(net_config=None, n_classes=None, bn_param=None, dropout_rate=None,
|
||||
local_path='~/.torch/proxylessnas/'):
|
||||
assert net_config is not None, 'Please input a network config'
|
||||
if 'http' in net_config:
|
||||
net_config_path = download_url(net_config, local_path)
|
||||
else:
|
||||
net_config_path = net_config
|
||||
net_config_json = json.load(open(net_config_path, 'r'))
|
||||
|
||||
if n_classes is not None:
|
||||
net_config_json['classifier']['out_features'] = n_classes
|
||||
if dropout_rate is not None:
|
||||
net_config_json['classifier']['dropout_rate'] = dropout_rate
|
||||
|
||||
net = ProxylessNASNets.build_from_config(net_config_json)
|
||||
if bn_param is not None:
|
||||
net.set_bn_param(*bn_param)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
class ProxylessNASNets(MyNetwork):
|
||||
|
||||
def __init__(self, first_conv, blocks, feature_mix_layer, classifier):
|
||||
super(ProxylessNASNets, self).__init__()
|
||||
|
||||
self.first_conv = first_conv
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.feature_mix_layer = feature_mix_layer
|
||||
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x):
|
||||
x = self.first_conv(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
if self.feature_mix_layer is not None:
|
||||
x = self.feature_mix_layer(x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = self.first_conv.module_str + '\n'
|
||||
for block in self.blocks:
|
||||
_str += block.module_str + '\n'
|
||||
_str += self.feature_mix_layer.module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ProxylessNASNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'first_conv': self.first_conv.config,
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'feature_mix_layer': None if self.feature_mix_layer is None else self.feature_mix_layer.config,
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
first_conv = set_layer_from_config(config['first_conv'])
|
||||
feature_mix_layer = set_layer_from_config(config['feature_mix_layer'])
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
blocks = []
|
||||
for block_config in config['blocks']:
|
||||
blocks.append(ResidualBlock.build_from_config(block_config))
|
||||
|
||||
net = ProxylessNASNets(first_conv, blocks, feature_mix_layer, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-3)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResidualBlock):
|
||||
if isinstance(m.conv, MBConvLayer) and isinstance(m.shortcut, IdentityLayer):
|
||||
m.conv.point_linear.bn.weight.data.zero_()
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
info_list = []
|
||||
block_index_list = []
|
||||
for i, block in enumerate(self.blocks[1:], 1):
|
||||
if block.shortcut is None and len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
block_index_list = []
|
||||
block_index_list.append(i)
|
||||
if len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
return info_list
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
current_state_dict = self.state_dict()
|
||||
|
||||
for key in state_dict:
|
||||
if key not in current_state_dict:
|
||||
assert '.mobile_inverted_conv.' in key
|
||||
new_key = key.replace('.mobile_inverted_conv.', '.conv.')
|
||||
else:
|
||||
new_key = key
|
||||
current_state_dict[new_key] = state_dict[key]
|
||||
super(ProxylessNASNets, self).load_state_dict(current_state_dict)
|
||||
|
||||
|
||||
class MobileNetV2(ProxylessNASNets):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-3), dropout_rate=0.2,
|
||||
ks=None, expand_ratio=None, depth_param=None, stage_width_list=None):
|
||||
|
||||
ks = 3 if ks is None else ks
|
||||
expand_ratio = 6 if expand_ratio is None else expand_ratio
|
||||
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
|
||||
input_channel = make_divisible(input_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
last_channel = make_divisible(last_channel * width_mult, MyNetwork.CHANNEL_DIVISIBLE) \
|
||||
if width_mult > 1.0 else last_channel
|
||||
|
||||
inverted_residual_setting = [
|
||||
# t, c, n, s
|
||||
[1, 16, 1, 1],
|
||||
[expand_ratio, 24, 2, 2],
|
||||
[expand_ratio, 32, 3, 2],
|
||||
[expand_ratio, 64, 4, 2],
|
||||
[expand_ratio, 96, 3, 1],
|
||||
[expand_ratio, 160, 3, 2],
|
||||
[expand_ratio, 320, 1, 1],
|
||||
]
|
||||
|
||||
if depth_param is not None:
|
||||
assert isinstance(depth_param, int)
|
||||
for i in range(1, len(inverted_residual_setting) - 1):
|
||||
inverted_residual_setting[i][2] = depth_param
|
||||
|
||||
if stage_width_list is not None:
|
||||
for i in range(len(inverted_residual_setting)):
|
||||
inverted_residual_setting[i][1] = stage_width_list[i]
|
||||
|
||||
ks = val2list(ks, sum([n for _, _, n, _ in inverted_residual_setting]) - 1)
|
||||
_pt = 0
|
||||
|
||||
# first conv layer
|
||||
first_conv = ConvLayer(
|
||||
3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act'
|
||||
)
|
||||
# inverted residual blocks
|
||||
blocks = []
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
output_channel = make_divisible(c * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
stride = s
|
||||
else:
|
||||
stride = 1
|
||||
if t == 1:
|
||||
kernel_size = 3
|
||||
else:
|
||||
kernel_size = ks[_pt]
|
||||
_pt += 1
|
||||
mobile_inverted_conv = MBConvLayer(
|
||||
in_channels=input_channel, out_channels=output_channel, kernel_size=kernel_size, stride=stride,
|
||||
expand_ratio=t,
|
||||
)
|
||||
if stride == 1:
|
||||
if input_channel == output_channel:
|
||||
shortcut = IdentityLayer(input_channel, input_channel)
|
||||
else:
|
||||
shortcut = None
|
||||
else:
|
||||
shortcut = None
|
||||
blocks.append(
|
||||
ResidualBlock(mobile_inverted_conv, shortcut)
|
||||
)
|
||||
input_channel = output_channel
|
||||
# 1x1_conv before global average pooling
|
||||
feature_mix_layer = ConvLayer(
|
||||
input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act',
|
||||
)
|
||||
|
||||
classifier = LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(MobileNetV2, self).__init__(first_conv, blocks, feature_mix_layer, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
@@ -0,0 +1,192 @@
|
||||
import torch.nn as nn
|
||||
|
||||
from ofa_local.utils.layers import set_layer_from_config, ConvLayer, IdentityLayer, LinearLayer
|
||||
from ofa_local.utils.layers import ResNetBottleneckBlock, ResidualBlock
|
||||
from ofa_local.utils import make_divisible, MyNetwork, MyGlobalAvgPool2d
|
||||
|
||||
__all__ = ['ResNets', 'ResNet50', 'ResNet50D']
|
||||
|
||||
|
||||
class ResNets(MyNetwork):
|
||||
|
||||
BASE_DEPTH_LIST = [2, 2, 4, 2]
|
||||
STAGE_WIDTH_LIST = [256, 512, 1024, 2048]
|
||||
|
||||
def __init__(self, input_stem, blocks, classifier):
|
||||
super(ResNets, self).__init__()
|
||||
|
||||
self.input_stem = nn.ModuleList(input_stem)
|
||||
self.max_pooling = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
self.global_avg_pool = MyGlobalAvgPool2d(keep_dim=False)
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.input_stem:
|
||||
x = layer(x)
|
||||
x = self.max_pooling(x)
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
x = self.global_avg_pool(x)
|
||||
x = self.classifier(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
_str = ''
|
||||
for layer in self.input_stem:
|
||||
_str += layer.module_str + '\n'
|
||||
_str += 'max_pooling(ks=3, stride=2)\n'
|
||||
for block in self.blocks:
|
||||
_str += block.module_str + '\n'
|
||||
_str += self.global_avg_pool.__repr__() + '\n'
|
||||
_str += self.classifier.module_str
|
||||
return _str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ResNets.__name__,
|
||||
'bn': self.get_bn_param(),
|
||||
'input_stem': [
|
||||
layer.config for layer in self.input_stem
|
||||
],
|
||||
'blocks': [
|
||||
block.config for block in self.blocks
|
||||
],
|
||||
'classifier': self.classifier.config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
classifier = set_layer_from_config(config['classifier'])
|
||||
|
||||
input_stem = []
|
||||
for layer_config in config['input_stem']:
|
||||
input_stem.append(set_layer_from_config(layer_config))
|
||||
blocks = []
|
||||
for block_config in config['blocks']:
|
||||
blocks.append(set_layer_from_config(block_config))
|
||||
|
||||
net = ResNets(input_stem, blocks, classifier)
|
||||
if 'bn' in config:
|
||||
net.set_bn_param(**config['bn'])
|
||||
else:
|
||||
net.set_bn_param(momentum=0.1, eps=1e-5)
|
||||
|
||||
return net
|
||||
|
||||
def zero_last_gamma(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, ResNetBottleneckBlock) and isinstance(m.downsample, IdentityLayer):
|
||||
m.conv3.bn.weight.data.zero_()
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
info_list = []
|
||||
block_index_list = []
|
||||
for i, block in enumerate(self.blocks):
|
||||
if not isinstance(block.downsample, IdentityLayer) and len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
block_index_list = []
|
||||
block_index_list.append(i)
|
||||
if len(block_index_list) > 0:
|
||||
info_list.append(block_index_list)
|
||||
return info_list
|
||||
|
||||
def load_state_dict(self, state_dict, **kwargs):
|
||||
super(ResNets, self).load_state_dict(state_dict)
|
||||
|
||||
|
||||
class ResNet50(ResNets):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
|
||||
expand_ratio=None, depth_param=None):
|
||||
|
||||
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
||||
|
||||
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
||||
for i, width in enumerate(stage_width_list):
|
||||
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
depth_list = [3, 4, 6, 3]
|
||||
if depth_param is not None:
|
||||
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
|
||||
depth_list[i] = depth + depth_param
|
||||
|
||||
stride_list = [1, 2, 2, 2]
|
||||
|
||||
# build input stem
|
||||
input_stem = [ConvLayer(
|
||||
3, input_channel, kernel_size=7, stride=2, use_bn=True, act_func='relu', ops_order='weight_bn_act',
|
||||
)]
|
||||
|
||||
# blocks
|
||||
blocks = []
|
||||
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
||||
for i in range(d):
|
||||
stride = s if i == 0 else 1
|
||||
bottleneck_block = ResNetBottleneckBlock(
|
||||
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
|
||||
act_func='relu', downsample_mode='conv',
|
||||
)
|
||||
blocks.append(bottleneck_block)
|
||||
input_channel = width
|
||||
# classifier
|
||||
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(ResNet50, self).__init__(input_stem, blocks, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
|
||||
|
||||
class ResNet50D(ResNets):
|
||||
|
||||
def __init__(self, n_classes=1000, width_mult=1.0, bn_param=(0.1, 1e-5), dropout_rate=0,
|
||||
expand_ratio=None, depth_param=None):
|
||||
|
||||
expand_ratio = 0.25 if expand_ratio is None else expand_ratio
|
||||
|
||||
input_channel = make_divisible(64 * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
mid_input_channel = make_divisible(input_channel // 2, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
stage_width_list = ResNets.STAGE_WIDTH_LIST.copy()
|
||||
for i, width in enumerate(stage_width_list):
|
||||
stage_width_list[i] = make_divisible(width * width_mult, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
depth_list = [3, 4, 6, 3]
|
||||
if depth_param is not None:
|
||||
for i, depth in enumerate(ResNets.BASE_DEPTH_LIST):
|
||||
depth_list[i] = depth + depth_param
|
||||
|
||||
stride_list = [1, 2, 2, 2]
|
||||
|
||||
# build input stem
|
||||
input_stem = [
|
||||
ConvLayer(3, mid_input_channel, 3, stride=2, use_bn=True, act_func='relu'),
|
||||
ResidualBlock(
|
||||
ConvLayer(mid_input_channel, mid_input_channel, 3, stride=1, use_bn=True, act_func='relu'),
|
||||
IdentityLayer(mid_input_channel, mid_input_channel)
|
||||
),
|
||||
ConvLayer(mid_input_channel, input_channel, 3, stride=1, use_bn=True, act_func='relu')
|
||||
]
|
||||
|
||||
# blocks
|
||||
blocks = []
|
||||
for d, width, s in zip(depth_list, stage_width_list, stride_list):
|
||||
for i in range(d):
|
||||
stride = s if i == 0 else 1
|
||||
bottleneck_block = ResNetBottleneckBlock(
|
||||
input_channel, width, kernel_size=3, stride=stride, expand_ratio=expand_ratio,
|
||||
act_func='relu', downsample_mode='avgpool_conv',
|
||||
)
|
||||
blocks.append(bottleneck_block)
|
||||
input_channel = width
|
||||
# classifier
|
||||
classifier = LinearLayer(input_channel, n_classes, dropout_rate=dropout_rate)
|
||||
|
||||
super(ResNet50D, self).__init__(input_stem, blocks, classifier)
|
||||
|
||||
# set bn param
|
||||
self.set_bn_param(*bn_param)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .run_config import *
|
||||
from .run_manager import *
|
||||
from .distributed_run_manager import *
|
||||
@@ -0,0 +1,381 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.backends.cudnn as cudnn
|
||||
from tqdm import tqdm
|
||||
|
||||
from ofa_local.utils import cross_entropy_with_label_smoothing, cross_entropy_loss_with_soft_target, write_log, init_models
|
||||
from ofa_local.utils import DistributedMetric, list_mean, get_net_info, accuracy, AverageMeter, mix_labels, mix_images
|
||||
from ofa_local.utils import MyRandomResizedCrop
|
||||
|
||||
__all__ = ['DistributedRunManager']
|
||||
|
||||
|
||||
class DistributedRunManager:
|
||||
|
||||
def __init__(self, path, net, run_config, hvd_compression, backward_steps=1, is_root=False, init=True):
|
||||
import horovod.torch as hvd
|
||||
|
||||
self.path = path
|
||||
self.net = net
|
||||
self.run_config = run_config
|
||||
self.is_root = is_root
|
||||
|
||||
self.best_acc = 0.0
|
||||
self.start_epoch = 0
|
||||
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
|
||||
self.net.cuda()
|
||||
cudnn.benchmark = True
|
||||
if init and self.is_root:
|
||||
init_models(self.net, self.run_config.model_init)
|
||||
if self.is_root:
|
||||
# print net info
|
||||
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape)
|
||||
with open('%s/net_info.txt' % self.path, 'w') as fout:
|
||||
fout.write(json.dumps(net_info, indent=4) + '\n')
|
||||
try:
|
||||
fout.write(self.net.module_str + '\n')
|
||||
except Exception:
|
||||
fout.write('%s do not support `module_str`' % type(self.net))
|
||||
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
|
||||
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
|
||||
fout.write('%s\n' % self.net)
|
||||
|
||||
# criterion
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
self.train_criterion = cross_entropy_loss_with_soft_target
|
||||
elif self.run_config.label_smoothing > 0:
|
||||
self.train_criterion = lambda pred, target: \
|
||||
cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
|
||||
else:
|
||||
self.train_criterion = nn.CrossEntropyLoss()
|
||||
self.test_criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
if self.run_config.no_decay_keys:
|
||||
keys = self.run_config.no_decay_keys.split('#')
|
||||
net_params = [
|
||||
self.net.get_parameters(keys, mode='exclude'), # parameters with weight decay
|
||||
self.net.get_parameters(keys, mode='include'), # parameters without weight decay
|
||||
]
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
net_params = self.network.weight_parameters()
|
||||
except Exception:
|
||||
net_params = []
|
||||
for param in self.network.parameters():
|
||||
if param.requires_grad:
|
||||
net_params.append(param)
|
||||
self.optimizer = self.run_config.build_optimizer(net_params)
|
||||
self.optimizer = hvd.DistributedOptimizer(
|
||||
self.optimizer, named_parameters=self.net.named_parameters(), compression=hvd_compression,
|
||||
backward_passes_per_step=backward_steps,
|
||||
)
|
||||
|
||||
""" save path and log path """
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self.__dict__.get('_save_path', None) is None:
|
||||
save_path = os.path.join(self.path, 'checkpoint')
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
self.__dict__['_save_path'] = save_path
|
||||
return self.__dict__['_save_path']
|
||||
|
||||
@property
|
||||
def logs_path(self):
|
||||
if self.__dict__.get('_logs_path', None) is None:
|
||||
logs_path = os.path.join(self.path, 'logs')
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
self.__dict__['_logs_path'] = logs_path
|
||||
return self.__dict__['_logs_path']
|
||||
|
||||
@property
|
||||
def network(self):
|
||||
return self.net
|
||||
|
||||
@network.setter
|
||||
def network(self, new_val):
|
||||
self.net = new_val
|
||||
|
||||
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
if self.is_root:
|
||||
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
||||
|
||||
""" save & load model & save_config & broadcast """
|
||||
|
||||
def save_config(self, extra_run_config=None, extra_net_config=None):
|
||||
if self.is_root:
|
||||
run_save_path = os.path.join(self.path, 'run.config')
|
||||
if not os.path.isfile(run_save_path):
|
||||
run_config = self.run_config.config
|
||||
if extra_run_config is not None:
|
||||
run_config.update(extra_run_config)
|
||||
json.dump(run_config, open(run_save_path, 'w'), indent=4)
|
||||
print('Run configs dump to %s' % run_save_path)
|
||||
|
||||
try:
|
||||
net_save_path = os.path.join(self.path, 'net.config')
|
||||
net_config = self.net.config
|
||||
if extra_net_config is not None:
|
||||
net_config.update(extra_net_config)
|
||||
json.dump(net_config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
except Exception:
|
||||
print('%s do not support net config' % type(self.net))
|
||||
|
||||
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
||||
if self.is_root:
|
||||
if checkpoint is None:
|
||||
checkpoint = {'state_dict': self.net.state_dict()}
|
||||
|
||||
if model_name is None:
|
||||
model_name = 'checkpoint.pth.tar'
|
||||
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
with open(latest_fname, 'w') as _fout:
|
||||
_fout.write(model_path + '\n')
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
|
||||
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
|
||||
|
||||
def load_model(self, model_fname=None):
|
||||
if self.is_root:
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
if model_fname is None and os.path.exists(latest_fname):
|
||||
with open(latest_fname, 'r') as fin:
|
||||
model_fname = fin.readline()
|
||||
if model_fname[-1] == '\n':
|
||||
model_fname = model_fname[:-1]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if model_fname is None or not os.path.exists(model_fname):
|
||||
model_fname = '%s/checkpoint.pth.tar' % self.save_path
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_fname + '\n')
|
||||
print("=> loading checkpoint '{}'".format(model_fname))
|
||||
checkpoint = torch.load(model_fname, map_location='cpu')
|
||||
except Exception:
|
||||
self.write_log('fail to load checkpoint from %s' % self.save_path, 'valid')
|
||||
return
|
||||
|
||||
self.net.load_state_dict(checkpoint['state_dict'])
|
||||
if 'epoch' in checkpoint:
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'best_acc' in checkpoint:
|
||||
self.best_acc = checkpoint['best_acc']
|
||||
if 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
self.write_log("=> loaded checkpoint '{}'".format(model_fname), 'valid')
|
||||
|
||||
# noinspection PyArgumentList
|
||||
def broadcast(self):
|
||||
import horovod.torch as hvd
|
||||
self.start_epoch = hvd.broadcast(torch.LongTensor(1).fill_(self.start_epoch)[0], 0, name='start_epoch').item()
|
||||
self.best_acc = hvd.broadcast(torch.Tensor(1).fill_(self.best_acc)[0], 0, name='best_acc').item()
|
||||
hvd.broadcast_parameters(self.net.state_dict(), 0)
|
||||
hvd.broadcast_optimizer_state(self.optimizer, 0)
|
||||
|
||||
""" metric related """
|
||||
|
||||
def get_metric_dict(self):
|
||||
return {
|
||||
'top1': DistributedMetric('top1'),
|
||||
'top5': DistributedMetric('top5'),
|
||||
}
|
||||
|
||||
def update_metric(self, metric_dict, output, labels):
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
metric_dict['top1'].update(acc1[0], output.size(0))
|
||||
metric_dict['top5'].update(acc5[0], output.size(0))
|
||||
|
||||
def get_metric_vals(self, metric_dict, return_dict=False):
|
||||
if return_dict:
|
||||
return {
|
||||
key: metric_dict[key].avg.item() for key in metric_dict
|
||||
}
|
||||
else:
|
||||
return [metric_dict[key].avg.item() for key in metric_dict]
|
||||
|
||||
def get_metric_names(self):
|
||||
return 'top1', 'top5'
|
||||
|
||||
""" train & validate """
|
||||
|
||||
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if data_loader is None:
|
||||
if is_test:
|
||||
data_loader = self.run_config.test_loader
|
||||
else:
|
||||
data_loader = self.run_config.valid_loader
|
||||
|
||||
net.eval()
|
||||
|
||||
losses = DistributedMetric('val_loss')
|
||||
metric_dict = self.get_metric_dict()
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(data_loader),
|
||||
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str),
|
||||
disable=no_logs or not self.is_root) as t:
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = self.test_criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss, images.size(0))
|
||||
self.update_metric(metric_dict, output, labels)
|
||||
t.set_postfix({
|
||||
'loss': losses.avg.item(),
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
return losses.avg.item(), self.get_metric_vals(metric_dict)
|
||||
|
||||
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if isinstance(self.run_config.data_provider.image_size, list):
|
||||
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
|
||||
for img_size in self.run_config.data_provider.image_size:
|
||||
img_size_list.append(img_size)
|
||||
self.run_config.data_provider.assign_active_img_size(img_size)
|
||||
self.reset_running_statistics(net=net)
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
loss_list.append(loss)
|
||||
top1_list.append(top1)
|
||||
top5_list.append(top5)
|
||||
return img_size_list, loss_list, top1_list, top5_list
|
||||
else:
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
|
||||
|
||||
def train_one_epoch(self, args, epoch, warmup_epochs=5, warmup_lr=0):
|
||||
self.net.train()
|
||||
self.run_config.train_loader.sampler.set_epoch(epoch) # required by distributed sampler
|
||||
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
|
||||
|
||||
nBatch = len(self.run_config.train_loader)
|
||||
|
||||
losses = DistributedMetric('train_loss')
|
||||
metric_dict = self.get_metric_dict()
|
||||
data_time = AverageMeter()
|
||||
|
||||
with tqdm(total=nBatch,
|
||||
desc='Train Epoch #{}'.format(epoch + 1),
|
||||
disable=not self.is_root) as t:
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(self.run_config.train_loader):
|
||||
MyRandomResizedCrop.BATCH = i
|
||||
data_time.update(time.time() - end)
|
||||
if epoch < warmup_epochs:
|
||||
new_lr = self.run_config.warmup_adjust_learning_rate(
|
||||
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
|
||||
)
|
||||
else:
|
||||
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
|
||||
|
||||
images, labels = images.cuda(), labels.cuda()
|
||||
target = labels
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
# transform data
|
||||
random.seed(int('%d%.3d' % (i, epoch)))
|
||||
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
|
||||
images = mix_images(images, lam)
|
||||
labels = mix_labels(
|
||||
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
|
||||
)
|
||||
|
||||
# soft target
|
||||
if args.teacher_model is not None:
|
||||
args.teacher_model.train()
|
||||
with torch.no_grad():
|
||||
soft_logits = args.teacher_model(images).detach()
|
||||
soft_label = F.softmax(soft_logits, dim=1)
|
||||
|
||||
# compute output
|
||||
output = self.net(images)
|
||||
|
||||
if args.teacher_model is None:
|
||||
loss = self.train_criterion(output, labels)
|
||||
loss_type = 'ce'
|
||||
else:
|
||||
if args.kd_type == 'ce':
|
||||
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
else:
|
||||
kd_loss = F.mse_loss(output, soft_logits)
|
||||
loss = args.kd_ratio * kd_loss + self.train_criterion(output, labels)
|
||||
loss_type = '%.1fkd+ce' % args.kd_ratio
|
||||
|
||||
# update
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss, images.size(0))
|
||||
self.update_metric(metric_dict, output, target)
|
||||
|
||||
t.set_postfix({
|
||||
'loss': losses.avg.item(),
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
'lr': new_lr,
|
||||
'loss_type': loss_type,
|
||||
'data_time': data_time.avg,
|
||||
})
|
||||
t.update(1)
|
||||
end = time.time()
|
||||
|
||||
return losses.avg.item(), self.get_metric_vals(metric_dict)
|
||||
|
||||
def train(self, args, warmup_epochs=5, warmup_lr=0):
|
||||
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epochs):
|
||||
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epochs, warmup_lr)
|
||||
img_size, val_loss, val_top1, val_top5 = self.validate_all_resolution(epoch, is_test=False)
|
||||
|
||||
is_best = list_mean(val_top1) > self.best_acc
|
||||
self.best_acc = max(self.best_acc, list_mean(val_top1))
|
||||
if self.is_root:
|
||||
val_log = '[{0}/{1}]\tloss {2:.3f}\t{6} acc {3:.3f} ({4:.3f})\t{7} acc {5:.3f}\t' \
|
||||
'Train {6} {top1:.3f}\tloss {train_loss:.3f}\t'. \
|
||||
format(epoch + 1 - warmup_epochs, self.run_config.n_epochs, list_mean(val_loss),
|
||||
list_mean(val_top1), self.best_acc, list_mean(val_top5), *self.get_metric_names(),
|
||||
top1=train_top1, train_loss=train_loss)
|
||||
for i_s, v_a in zip(img_size, val_top1):
|
||||
val_log += '(%d, %.3f), ' % (i_s, v_a)
|
||||
self.write_log(val_log, prefix='valid', should_print=False)
|
||||
|
||||
self.save_model({
|
||||
'epoch': epoch,
|
||||
'best_acc': self.best_acc,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'state_dict': self.net.state_dict(),
|
||||
}, is_best=is_best)
|
||||
|
||||
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
|
||||
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
if net is None:
|
||||
net = self.net
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
|
||||
set_running_statistics(net, data_loader)
|
||||
@@ -0,0 +1,161 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from ofa_local.utils import calc_learning_rate, build_optimizer
|
||||
from ofa_local.imagenet_classification.data_providers import ImagenetDataProvider
|
||||
|
||||
__all__ = ['RunConfig', 'ImagenetRunConfig', 'DistributedImageNetRunConfig']
|
||||
|
||||
|
||||
class RunConfig:
|
||||
|
||||
def __init__(self, n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha, model_init, validation_frequency, print_frequency):
|
||||
self.n_epochs = n_epochs
|
||||
self.init_lr = init_lr
|
||||
self.lr_schedule_type = lr_schedule_type
|
||||
self.lr_schedule_param = lr_schedule_param
|
||||
|
||||
self.dataset = dataset
|
||||
self.train_batch_size = train_batch_size
|
||||
self.test_batch_size = test_batch_size
|
||||
self.valid_size = valid_size
|
||||
|
||||
self.opt_type = opt_type
|
||||
self.opt_param = opt_param
|
||||
self.weight_decay = weight_decay
|
||||
self.label_smoothing = label_smoothing
|
||||
self.no_decay_keys = no_decay_keys
|
||||
|
||||
self.mixup_alpha = mixup_alpha
|
||||
|
||||
self.model_init = model_init
|
||||
self.validation_frequency = validation_frequency
|
||||
self.print_frequency = print_frequency
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
config = {}
|
||||
for key in self.__dict__:
|
||||
if not key.startswith('_'):
|
||||
config[key] = self.__dict__[key]
|
||||
return config
|
||||
|
||||
def copy(self):
|
||||
return RunConfig(**self.config)
|
||||
|
||||
""" learning rate """
|
||||
|
||||
def adjust_learning_rate(self, optimizer, epoch, batch=0, nBatch=None):
|
||||
""" adjust learning of a given optimizer and return the new learning rate """
|
||||
new_lr = calc_learning_rate(epoch, self.init_lr, self.n_epochs, batch, nBatch, self.lr_schedule_type)
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
return new_lr
|
||||
|
||||
def warmup_adjust_learning_rate(self, optimizer, T_total, nBatch, epoch, batch=0, warmup_lr=0):
|
||||
T_cur = epoch * nBatch + batch + 1
|
||||
new_lr = T_cur / T_total * (self.init_lr - warmup_lr) + warmup_lr
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
return new_lr
|
||||
|
||||
""" data provider """
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def train_loader(self):
|
||||
return self.data_provider.train
|
||||
|
||||
@property
|
||||
def valid_loader(self):
|
||||
return self.data_provider.valid
|
||||
|
||||
@property
|
||||
def test_loader(self):
|
||||
return self.data_provider.test
|
||||
|
||||
def random_sub_train_loader(self, n_images, batch_size, num_worker=None, num_replicas=None, rank=None):
|
||||
return self.data_provider.build_sub_train_loader(n_images, batch_size, num_worker, num_replicas, rank)
|
||||
|
||||
""" optimizer """
|
||||
|
||||
def build_optimizer(self, net_params):
|
||||
return build_optimizer(net_params,
|
||||
self.opt_type, self.opt_param, self.init_lr, self.weight_decay, self.no_decay_keys)
|
||||
|
||||
|
||||
class ImagenetRunConfig(RunConfig):
|
||||
|
||||
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=256, test_batch_size=500, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=32, resize_scale=0.08, distort_color='tf', image_size=224, **kwargs):
|
||||
super(ImagenetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha,
|
||||
model_init, validation_frequency, print_frequency
|
||||
)
|
||||
|
||||
self.n_worker = n_worker
|
||||
self.resize_scale = resize_scale
|
||||
self.distort_color = distort_color
|
||||
self.image_size = image_size
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
|
||||
|
||||
class DistributedImageNetRunConfig(ImagenetRunConfig):
|
||||
|
||||
def __init__(self, n_epochs=150, init_lr=0.05, lr_schedule_type='cosine', lr_schedule_param=None,
|
||||
dataset='imagenet', train_batch_size=64, test_batch_size=64, valid_size=None,
|
||||
opt_type='sgd', opt_param=None, weight_decay=4e-5, label_smoothing=0.1, no_decay_keys=None,
|
||||
mixup_alpha=None, model_init='he_fout', validation_frequency=1, print_frequency=10,
|
||||
n_worker=8, resize_scale=0.08, distort_color='tf', image_size=224,
|
||||
**kwargs):
|
||||
super(DistributedImageNetRunConfig, self).__init__(
|
||||
n_epochs, init_lr, lr_schedule_type, lr_schedule_param,
|
||||
dataset, train_batch_size, test_batch_size, valid_size,
|
||||
opt_type, opt_param, weight_decay, label_smoothing, no_decay_keys,
|
||||
mixup_alpha, model_init, validation_frequency, print_frequency, n_worker, resize_scale, distort_color,
|
||||
image_size, **kwargs
|
||||
)
|
||||
|
||||
self._num_replicas = kwargs['num_replicas']
|
||||
self._rank = kwargs['rank']
|
||||
|
||||
@property
|
||||
def data_provider(self):
|
||||
if self.__dict__.get('_data_provider', None) is None:
|
||||
if self.dataset == ImagenetDataProvider.name():
|
||||
DataProviderClass = ImagenetDataProvider
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.__dict__['_data_provider'] = DataProviderClass(
|
||||
train_batch_size=self.train_batch_size, test_batch_size=self.test_batch_size,
|
||||
valid_size=self.valid_size, n_worker=self.n_worker, resize_scale=self.resize_scale,
|
||||
distort_color=self.distort_color, image_size=self.image_size,
|
||||
num_replicas=self._num_replicas, rank=self._rank,
|
||||
)
|
||||
return self.__dict__['_data_provider']
|
||||
@@ -0,0 +1,375 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.optim
|
||||
from tqdm import tqdm
|
||||
|
||||
from ofa_local.utils import get_net_info, cross_entropy_loss_with_soft_target, cross_entropy_with_label_smoothing
|
||||
from ofa_local.utils import AverageMeter, accuracy, write_log, mix_images, mix_labels, init_models
|
||||
from ofa_local.utils import MyRandomResizedCrop
|
||||
|
||||
__all__ = ['RunManager']
|
||||
|
||||
|
||||
class RunManager:
|
||||
|
||||
def __init__(self, path, net, run_config, init=True, measure_latency=None, no_gpu=False):
|
||||
self.path = path
|
||||
self.net = net
|
||||
self.run_config = run_config
|
||||
|
||||
self.best_acc = 0
|
||||
self.start_epoch = 0
|
||||
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
|
||||
# move network to GPU if available
|
||||
if torch.cuda.is_available() and (not no_gpu):
|
||||
self.device = torch.device('cuda:0')
|
||||
self.net = self.net.to(self.device)
|
||||
cudnn.benchmark = True
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
# initialize model (default)
|
||||
if init:
|
||||
init_models(run_config.model_init)
|
||||
|
||||
# net info
|
||||
net_info = get_net_info(self.net, self.run_config.data_provider.data_shape, measure_latency, True)
|
||||
with open('%s/net_info.txt' % self.path, 'w') as fout:
|
||||
fout.write(json.dumps(net_info, indent=4) + '\n')
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
fout.write(self.network.module_str + '\n')
|
||||
except Exception:
|
||||
pass
|
||||
fout.write('%s\n' % self.run_config.data_provider.train.dataset.transform)
|
||||
fout.write('%s\n' % self.run_config.data_provider.test.dataset.transform)
|
||||
fout.write('%s\n' % self.network)
|
||||
|
||||
# criterion
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
self.train_criterion = cross_entropy_loss_with_soft_target
|
||||
elif self.run_config.label_smoothing > 0:
|
||||
self.train_criterion = \
|
||||
lambda pred, target: cross_entropy_with_label_smoothing(pred, target, self.run_config.label_smoothing)
|
||||
else:
|
||||
self.train_criterion = nn.CrossEntropyLoss()
|
||||
self.test_criterion = nn.CrossEntropyLoss()
|
||||
|
||||
# optimizer
|
||||
if self.run_config.no_decay_keys:
|
||||
keys = self.run_config.no_decay_keys.split('#')
|
||||
net_params = [
|
||||
self.network.get_parameters(keys, mode='exclude'), # parameters with weight decay
|
||||
self.network.get_parameters(keys, mode='include'), # parameters without weight decay
|
||||
]
|
||||
else:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
net_params = self.network.weight_parameters()
|
||||
except Exception:
|
||||
net_params = []
|
||||
for param in self.network.parameters():
|
||||
if param.requires_grad:
|
||||
net_params.append(param)
|
||||
self.optimizer = self.run_config.build_optimizer(net_params)
|
||||
|
||||
self.net = torch.nn.DataParallel(self.net)
|
||||
|
||||
""" save path and log path """
|
||||
|
||||
@property
|
||||
def save_path(self):
|
||||
if self.__dict__.get('_save_path', None) is None:
|
||||
save_path = os.path.join(self.path, 'checkpoint')
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
self.__dict__['_save_path'] = save_path
|
||||
return self.__dict__['_save_path']
|
||||
|
||||
@property
|
||||
def logs_path(self):
|
||||
if self.__dict__.get('_logs_path', None) is None:
|
||||
logs_path = os.path.join(self.path, 'logs')
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
self.__dict__['_logs_path'] = logs_path
|
||||
return self.__dict__['_logs_path']
|
||||
|
||||
@property
|
||||
def network(self):
|
||||
return self.net.module if isinstance(self.net, nn.DataParallel) else self.net
|
||||
|
||||
def write_log(self, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
write_log(self.logs_path, log_str, prefix, should_print, mode)
|
||||
|
||||
""" save and load models """
|
||||
|
||||
def save_model(self, checkpoint=None, is_best=False, model_name=None):
|
||||
if checkpoint is None:
|
||||
checkpoint = {'state_dict': self.network.state_dict()}
|
||||
|
||||
if model_name is None:
|
||||
model_name = 'checkpoint.pth.tar'
|
||||
|
||||
checkpoint['dataset'] = self.run_config.dataset # add `dataset` info to the checkpoint
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
model_path = os.path.join(self.save_path, model_name)
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_path + '\n')
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
if is_best:
|
||||
best_path = os.path.join(self.save_path, 'model_best.pth.tar')
|
||||
torch.save({'state_dict': checkpoint['state_dict']}, best_path)
|
||||
|
||||
def load_model(self, model_fname=None):
|
||||
latest_fname = os.path.join(self.save_path, 'latest.txt')
|
||||
if model_fname is None and os.path.exists(latest_fname):
|
||||
with open(latest_fname, 'r') as fin:
|
||||
model_fname = fin.readline()
|
||||
if model_fname[-1] == '\n':
|
||||
model_fname = model_fname[:-1]
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
if model_fname is None or not os.path.exists(model_fname):
|
||||
model_fname = '%s/checkpoint.pth.tar' % self.save_path
|
||||
with open(latest_fname, 'w') as fout:
|
||||
fout.write(model_fname + '\n')
|
||||
print("=> loading checkpoint '{}'".format(model_fname))
|
||||
checkpoint = torch.load(model_fname, map_location='cpu')
|
||||
except Exception:
|
||||
print('fail to load checkpoint from %s' % self.save_path)
|
||||
return {}
|
||||
|
||||
self.network.load_state_dict(checkpoint['state_dict'])
|
||||
if 'epoch' in checkpoint:
|
||||
self.start_epoch = checkpoint['epoch'] + 1
|
||||
if 'best_acc' in checkpoint:
|
||||
self.best_acc = checkpoint['best_acc']
|
||||
if 'optimizer' in checkpoint:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
print("=> loaded checkpoint '{}'".format(model_fname))
|
||||
return checkpoint
|
||||
|
||||
def save_config(self, extra_run_config=None, extra_net_config=None):
|
||||
""" dump run_config and net_config to the model_folder """
|
||||
run_save_path = os.path.join(self.path, 'run.config')
|
||||
if not os.path.isfile(run_save_path):
|
||||
run_config = self.run_config.config
|
||||
if extra_run_config is not None:
|
||||
run_config.update(extra_run_config)
|
||||
json.dump(run_config, open(run_save_path, 'w'), indent=4)
|
||||
print('Run configs dump to %s' % run_save_path)
|
||||
|
||||
try:
|
||||
net_save_path = os.path.join(self.path, 'net.config')
|
||||
net_config = self.network.config
|
||||
if extra_net_config is not None:
|
||||
net_config.update(extra_net_config)
|
||||
json.dump(net_config, open(net_save_path, 'w'), indent=4)
|
||||
print('Network configs dump to %s' % net_save_path)
|
||||
except Exception:
|
||||
print('%s do not support net config' % type(self.network))
|
||||
|
||||
""" metric related """
|
||||
|
||||
def get_metric_dict(self):
|
||||
return {
|
||||
'top1': AverageMeter(),
|
||||
'top5': AverageMeter(),
|
||||
}
|
||||
|
||||
def update_metric(self, metric_dict, output, labels):
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
metric_dict['top1'].update(acc1[0].item(), output.size(0))
|
||||
metric_dict['top5'].update(acc5[0].item(), output.size(0))
|
||||
|
||||
def get_metric_vals(self, metric_dict, return_dict=False):
|
||||
if return_dict:
|
||||
return {
|
||||
key: metric_dict[key].avg for key in metric_dict
|
||||
}
|
||||
else:
|
||||
return [metric_dict[key].avg for key in metric_dict]
|
||||
|
||||
def get_metric_names(self):
|
||||
return 'top1', 'top5'
|
||||
|
||||
""" train and test """
|
||||
|
||||
def validate(self, epoch=0, is_test=False, run_str='', net=None, data_loader=None, no_logs=False, train_mode=False):
|
||||
if net is None:
|
||||
net = self.net
|
||||
if not isinstance(net, nn.DataParallel):
|
||||
net = nn.DataParallel(net)
|
||||
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.test_loader if is_test else self.run_config.valid_loader
|
||||
|
||||
if train_mode:
|
||||
net.train()
|
||||
else:
|
||||
net.eval()
|
||||
|
||||
losses = AverageMeter()
|
||||
metric_dict = self.get_metric_dict()
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(data_loader),
|
||||
desc='Validate Epoch #{} {}'.format(epoch + 1, run_str), disable=no_logs) as t:
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = self.test_criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
self.update_metric(metric_dict, output, labels)
|
||||
|
||||
losses.update(loss.item(), images.size(0))
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
return losses.avg, self.get_metric_vals(metric_dict)
|
||||
|
||||
def validate_all_resolution(self, epoch=0, is_test=False, net=None):
|
||||
if net is None:
|
||||
net = self.network
|
||||
if isinstance(self.run_config.data_provider.image_size, list):
|
||||
img_size_list, loss_list, top1_list, top5_list = [], [], [], []
|
||||
for img_size in self.run_config.data_provider.image_size:
|
||||
img_size_list.append(img_size)
|
||||
self.run_config.data_provider.assign_active_img_size(img_size)
|
||||
self.reset_running_statistics(net=net)
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
loss_list.append(loss)
|
||||
top1_list.append(top1)
|
||||
top5_list.append(top5)
|
||||
return img_size_list, loss_list, top1_list, top5_list
|
||||
else:
|
||||
loss, (top1, top5) = self.validate(epoch, is_test, net=net)
|
||||
return [self.run_config.data_provider.active_img_size], [loss], [top1], [top5]
|
||||
|
||||
def train_one_epoch(self, args, epoch, warmup_epochs=0, warmup_lr=0):
|
||||
# switch to train mode
|
||||
self.net.train()
|
||||
MyRandomResizedCrop.EPOCH = epoch # required by elastic resolution
|
||||
|
||||
nBatch = len(self.run_config.train_loader)
|
||||
|
||||
losses = AverageMeter()
|
||||
metric_dict = self.get_metric_dict()
|
||||
data_time = AverageMeter()
|
||||
|
||||
with tqdm(total=nBatch,
|
||||
desc='{} Train Epoch #{}'.format(self.run_config.dataset, epoch + 1)) as t:
|
||||
end = time.time()
|
||||
for i, (images, labels) in enumerate(self.run_config.train_loader):
|
||||
MyRandomResizedCrop.BATCH = i
|
||||
data_time.update(time.time() - end)
|
||||
if epoch < warmup_epochs:
|
||||
new_lr = self.run_config.warmup_adjust_learning_rate(
|
||||
self.optimizer, warmup_epochs * nBatch, nBatch, epoch, i, warmup_lr,
|
||||
)
|
||||
else:
|
||||
new_lr = self.run_config.adjust_learning_rate(self.optimizer, epoch - warmup_epochs, i, nBatch)
|
||||
|
||||
images, labels = images.to(self.device), labels.to(self.device)
|
||||
target = labels
|
||||
if isinstance(self.run_config.mixup_alpha, float):
|
||||
# transform data
|
||||
lam = random.betavariate(self.run_config.mixup_alpha, self.run_config.mixup_alpha)
|
||||
images = mix_images(images, lam)
|
||||
labels = mix_labels(
|
||||
labels, lam, self.run_config.data_provider.n_classes, self.run_config.label_smoothing
|
||||
)
|
||||
|
||||
# soft target
|
||||
if args.teacher_model is not None:
|
||||
args.teacher_model.train()
|
||||
with torch.no_grad():
|
||||
soft_logits = args.teacher_model(images).detach()
|
||||
soft_label = F.softmax(soft_logits, dim=1)
|
||||
|
||||
# compute output
|
||||
output = self.net(images)
|
||||
loss = self.train_criterion(output, labels)
|
||||
|
||||
if args.teacher_model is None:
|
||||
loss_type = 'ce'
|
||||
else:
|
||||
if args.kd_type == 'ce':
|
||||
kd_loss = cross_entropy_loss_with_soft_target(output, soft_label)
|
||||
else:
|
||||
kd_loss = F.mse_loss(output, soft_logits)
|
||||
loss = args.kd_ratio * kd_loss + loss
|
||||
loss_type = '%.1fkd+ce' % args.kd_ratio
|
||||
|
||||
# compute gradient and do SGD step
|
||||
self.net.zero_grad() # or self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# measure accuracy and record loss
|
||||
losses.update(loss.item(), images.size(0))
|
||||
self.update_metric(metric_dict, output, target)
|
||||
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
**self.get_metric_vals(metric_dict, return_dict=True),
|
||||
'img_size': images.size(2),
|
||||
'lr': new_lr,
|
||||
'loss_type': loss_type,
|
||||
'data_time': data_time.avg,
|
||||
})
|
||||
t.update(1)
|
||||
end = time.time()
|
||||
return losses.avg, self.get_metric_vals(metric_dict)
|
||||
|
||||
def train(self, args, warmup_epoch=0, warmup_lr=0):
|
||||
for epoch in range(self.start_epoch, self.run_config.n_epochs + warmup_epoch):
|
||||
train_loss, (train_top1, train_top5) = self.train_one_epoch(args, epoch, warmup_epoch, warmup_lr)
|
||||
|
||||
if (epoch + 1) % self.run_config.validation_frequency == 0:
|
||||
img_size, val_loss, val_acc, val_acc5 = self.validate_all_resolution(epoch=epoch, is_test=False)
|
||||
|
||||
is_best = np.mean(val_acc) > self.best_acc
|
||||
self.best_acc = max(self.best_acc, np.mean(val_acc))
|
||||
val_log = 'Valid [{0}/{1}]\tloss {2:.3f}\t{5} {3:.3f} ({4:.3f})'. \
|
||||
format(epoch + 1 - warmup_epoch, self.run_config.n_epochs,
|
||||
np.mean(val_loss), np.mean(val_acc), self.best_acc, self.get_metric_names()[0])
|
||||
val_log += '\t{2} {0:.3f}\tTrain {1} {top1:.3f}\tloss {train_loss:.3f}\t'. \
|
||||
format(np.mean(val_acc5), *self.get_metric_names(), top1=train_top1, train_loss=train_loss)
|
||||
for i_s, v_a in zip(img_size, val_acc):
|
||||
val_log += '(%d, %.3f), ' % (i_s, v_a)
|
||||
self.write_log(val_log, prefix='valid', should_print=False)
|
||||
else:
|
||||
is_best = False
|
||||
|
||||
self.save_model({
|
||||
'epoch': epoch,
|
||||
'best_acc': self.best_acc,
|
||||
'optimizer': self.optimizer.state_dict(),
|
||||
'state_dict': self.network.state_dict(),
|
||||
}, is_best=is_best)
|
||||
|
||||
def reset_running_statistics(self, net=None, subset_size=2000, subset_batch_size=200, data_loader=None):
|
||||
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
if net is None:
|
||||
net = self.network
|
||||
if data_loader is None:
|
||||
data_loader = self.run_config.random_sub_train_loader(subset_size, subset_batch_size)
|
||||
set_running_statistics(net, data_loader)
|
||||
@@ -0,0 +1,87 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import json
|
||||
import torch
|
||||
|
||||
from ofa_local.utils import download_url
|
||||
from ofa_local.imagenet_classification.networks import get_net_by_name, proxyless_base
|
||||
from ofa_local.imagenet_classification.elastic_nn.networks import OFAMobileNetV3, OFAProxylessNASNets, OFAResNets
|
||||
|
||||
__all__ = [
|
||||
'ofa_specialized', 'ofa_net',
|
||||
'proxylessnas_net', 'proxylessnas_mobile', 'proxylessnas_cpu', 'proxylessnas_gpu',
|
||||
]
|
||||
|
||||
|
||||
def ofa_specialized(net_id, pretrained=True):
|
||||
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_specialized/'
|
||||
net_config = json.load(open(
|
||||
download_url(url_base + net_id + '/net.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
|
||||
))
|
||||
net = get_net_by_name(net_config['name']).build_from_config(net_config)
|
||||
|
||||
image_size = json.load(open(
|
||||
download_url(url_base + net_id + '/run.config', model_dir='.torch/ofa_specialized/%s/' % net_id)
|
||||
))['image_size']
|
||||
|
||||
if pretrained:
|
||||
init = torch.load(
|
||||
download_url(url_base + net_id + '/init', model_dir='.torch/ofa_specialized/%s/' % net_id),
|
||||
map_location='cpu'
|
||||
)['state_dict']
|
||||
net.load_state_dict(init)
|
||||
return net, image_size
|
||||
|
||||
|
||||
def ofa_net(net_id, pretrained=True):
|
||||
if net_id == 'ofa_proxyless_d234_e346_k357_w1.3':
|
||||
net = OFAProxylessNASNets(
|
||||
dropout_rate=0, width_mult=1.3, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
|
||||
)
|
||||
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.0':
|
||||
net = OFAMobileNetV3(
|
||||
dropout_rate=0, width_mult=1.0, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
|
||||
)
|
||||
elif net_id == 'ofa_mbv3_d234_e346_k357_w1.2':
|
||||
net = OFAMobileNetV3(
|
||||
dropout_rate=0, width_mult=1.2, ks_list=[3, 5, 7], expand_ratio_list=[3, 4, 6], depth_list=[2, 3, 4],
|
||||
)
|
||||
elif net_id == 'ofa_resnet50':
|
||||
net = OFAResNets(
|
||||
dropout_rate=0, depth_list=[0, 1, 2], expand_ratio_list=[0.2, 0.25, 0.35], width_mult_list=[0.65, 0.8, 1.0]
|
||||
)
|
||||
net_id = 'ofa_resnet50_d=0+1+2_e=0.2+0.25+0.35_w=0.65+0.8+1.0'
|
||||
else:
|
||||
raise ValueError('Not supported: %s' % net_id)
|
||||
|
||||
if pretrained:
|
||||
url_base = 'https://hanlab.mit.edu/files/OnceForAll/ofa_nets/'
|
||||
init = torch.load(
|
||||
download_url(url_base + net_id, model_dir='.torch/ofa_nets'),
|
||||
map_location='cpu')['state_dict']
|
||||
net.load_state_dict(init)
|
||||
return net
|
||||
|
||||
|
||||
def proxylessnas_net(net_id, pretrained=True):
|
||||
net = proxyless_base(
|
||||
net_config='https://hanlab.mit.edu/files/proxylessNAS/%s.config' % net_id,
|
||||
)
|
||||
if pretrained:
|
||||
net.load_state_dict(torch.load(
|
||||
download_url('https://hanlab.mit.edu/files/proxylessNAS/%s.pth' % net_id), map_location='cpu'
|
||||
)['state_dict'])
|
||||
|
||||
|
||||
def proxylessnas_mobile(pretrained=True):
|
||||
return proxylessnas_net('proxyless_mobile', pretrained)
|
||||
|
||||
|
||||
def proxylessnas_cpu(pretrained=True):
|
||||
return proxylessnas_net('proxyless_cpu', pretrained)
|
||||
|
||||
|
||||
def proxylessnas_gpu(pretrained=True):
|
||||
return proxylessnas_net('proxyless_gpu', pretrained)
|
||||
@@ -0,0 +1,7 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .acc_dataset import *
|
||||
from .acc_predictor import *
|
||||
from .arch_encoder import *
|
||||
@@ -0,0 +1,181 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
||||
from ofa.utils import list_mean
|
||||
|
||||
__all__ = ['net_setting2id', 'net_id2setting', 'AccuracyDataset']
|
||||
|
||||
|
||||
def net_setting2id(net_setting):
|
||||
return json.dumps(net_setting)
|
||||
|
||||
|
||||
def net_id2setting(net_id):
|
||||
return json.loads(net_id)
|
||||
|
||||
|
||||
class RegDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, inputs, targets):
|
||||
super(RegDataset, self).__init__()
|
||||
self.inputs = inputs
|
||||
self.targets = targets
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.inputs[index], self.targets[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.inputs.size(0)
|
||||
|
||||
|
||||
class AccuracyDataset:
|
||||
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
os.makedirs(self.path, exist_ok=True)
|
||||
|
||||
@property
|
||||
def net_id_path(self):
|
||||
return os.path.join(self.path, 'net_id.dict')
|
||||
|
||||
@property
|
||||
def acc_src_folder(self):
|
||||
return os.path.join(self.path, 'src')
|
||||
|
||||
@property
|
||||
def acc_dict_path(self):
|
||||
return os.path.join(self.path, 'acc.dict')
|
||||
|
||||
# TODO: support parallel building
|
||||
def build_acc_dataset(self, run_manager, ofa_network, n_arch=1000, image_size_list=None):
|
||||
# load net_id_list, random sample if not exist
|
||||
if os.path.isfile(self.net_id_path):
|
||||
net_id_list = json.load(open(self.net_id_path))
|
||||
else:
|
||||
net_id_list = set()
|
||||
while len(net_id_list) < n_arch:
|
||||
net_setting = ofa_network.sample_active_subnet()
|
||||
net_id = net_setting2id(net_setting)
|
||||
net_id_list.add(net_id)
|
||||
net_id_list = list(net_id_list)
|
||||
net_id_list.sort()
|
||||
json.dump(net_id_list, open(self.net_id_path, 'w'), indent=4)
|
||||
|
||||
image_size_list = [128, 160, 192, 224] if image_size_list is None else image_size_list
|
||||
|
||||
with tqdm(total=len(net_id_list) * len(image_size_list), desc='Building Acc Dataset') as t:
|
||||
for image_size in image_size_list:
|
||||
# load val dataset into memory
|
||||
val_dataset = []
|
||||
run_manager.run_config.data_provider.assign_active_img_size(image_size)
|
||||
for images, labels in run_manager.run_config.valid_loader:
|
||||
val_dataset.append((images, labels))
|
||||
# save path
|
||||
os.makedirs(self.acc_src_folder, exist_ok=True)
|
||||
acc_save_path = os.path.join(self.acc_src_folder, '%d.dict' % image_size)
|
||||
acc_dict = {}
|
||||
# load existing acc dict
|
||||
if os.path.isfile(acc_save_path):
|
||||
existing_acc_dict = json.load(open(acc_save_path, 'r'))
|
||||
else:
|
||||
existing_acc_dict = {}
|
||||
for net_id in net_id_list:
|
||||
net_setting = net_id2setting(net_id)
|
||||
key = net_setting2id({**net_setting, 'image_size': image_size})
|
||||
if key in existing_acc_dict:
|
||||
acc_dict[key] = existing_acc_dict[key]
|
||||
t.set_postfix({
|
||||
'net_id': net_id,
|
||||
'image_size': image_size,
|
||||
'info_val': acc_dict[key],
|
||||
'status': 'loading',
|
||||
})
|
||||
t.update()
|
||||
continue
|
||||
ofa_network.set_active_subnet(**net_setting)
|
||||
run_manager.reset_running_statistics(ofa_network)
|
||||
net_setting_str = ','.join(['%s_%s' % (
|
||||
key, '%.1f' % list_mean(val) if isinstance(val, list) else val
|
||||
) for key, val in net_setting.items()])
|
||||
loss, (top1, top5) = run_manager.validate(
|
||||
run_str=net_setting_str, net=ofa_network, data_loader=val_dataset, no_logs=True,
|
||||
)
|
||||
info_val = top1
|
||||
|
||||
t.set_postfix({
|
||||
'net_id': net_id,
|
||||
'image_size': image_size,
|
||||
'info_val': info_val,
|
||||
})
|
||||
t.update()
|
||||
|
||||
acc_dict.update({
|
||||
key: info_val
|
||||
})
|
||||
json.dump(acc_dict, open(acc_save_path, 'w'), indent=4)
|
||||
|
||||
def merge_acc_dataset(self, image_size_list=None):
|
||||
# load existing data
|
||||
merged_acc_dict = {}
|
||||
for fname in os.listdir(self.acc_src_folder):
|
||||
if '.dict' not in fname:
|
||||
continue
|
||||
image_size = int(fname.split('.dict')[0])
|
||||
if image_size_list is not None and image_size not in image_size_list:
|
||||
print('Skip ', fname)
|
||||
continue
|
||||
full_path = os.path.join(self.acc_src_folder, fname)
|
||||
partial_acc_dict = json.load(open(full_path))
|
||||
merged_acc_dict.update(partial_acc_dict)
|
||||
print('loaded %s' % full_path)
|
||||
json.dump(merged_acc_dict, open(self.acc_dict_path, 'w'), indent=4)
|
||||
return merged_acc_dict
|
||||
|
||||
def build_acc_data_loader(self, arch_encoder, n_training_sample=None, batch_size=256, n_workers=16):
|
||||
# load data
|
||||
acc_dict = json.load(open(self.acc_dict_path))
|
||||
X_all = []
|
||||
Y_all = []
|
||||
with tqdm(total=len(acc_dict), desc='Loading data') as t:
|
||||
for k, v in acc_dict.items():
|
||||
dic = json.loads(k)
|
||||
X_all.append(arch_encoder.arch2feature(dic))
|
||||
Y_all.append(v / 100.) # range: 0 - 1
|
||||
t.update()
|
||||
base_acc = np.mean(Y_all)
|
||||
# convert to torch tensor
|
||||
X_all = torch.tensor(X_all, dtype=torch.float)
|
||||
Y_all = torch.tensor(Y_all)
|
||||
|
||||
# random shuffle
|
||||
shuffle_idx = torch.randperm(len(X_all))
|
||||
X_all = X_all[shuffle_idx]
|
||||
Y_all = Y_all[shuffle_idx]
|
||||
|
||||
# split data
|
||||
idx = X_all.size(0) // 5 * 4 if n_training_sample is None else n_training_sample
|
||||
val_idx = X_all.size(0) // 5 * 4
|
||||
X_train, Y_train = X_all[:idx], Y_all[:idx]
|
||||
X_test, Y_test = X_all[val_idx:], Y_all[val_idx:]
|
||||
print('Train Size: %d,' % len(X_train), 'Valid Size: %d' % len(X_test))
|
||||
|
||||
# build data loader
|
||||
train_dataset = RegDataset(X_train, Y_train)
|
||||
val_dataset = RegDataset(X_test, Y_test)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=batch_size, shuffle=True, pin_memory=False, num_workers=n_workers
|
||||
)
|
||||
valid_loader = torch.utils.data.DataLoader(
|
||||
val_dataset, batch_size=batch_size, shuffle=False, pin_memory=False, num_workers=n_workers
|
||||
)
|
||||
|
||||
return train_loader, valid_loader, base_acc
|
||||
@@ -0,0 +1,50 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = ['AccuracyPredictor']
|
||||
|
||||
|
||||
class AccuracyPredictor(nn.Module):
|
||||
|
||||
def __init__(self, arch_encoder, hidden_size=400, n_layers=3,
|
||||
checkpoint_path=None, device='cuda:0'):
|
||||
super(AccuracyPredictor, self).__init__()
|
||||
self.arch_encoder = arch_encoder
|
||||
self.hidden_size = hidden_size
|
||||
self.n_layers = n_layers
|
||||
self.device = device
|
||||
|
||||
# build layers
|
||||
layers = []
|
||||
for i in range(self.n_layers):
|
||||
layers.append(nn.Sequential(
|
||||
nn.Linear(self.arch_encoder.n_dim if i == 0 else self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(inplace=True),
|
||||
))
|
||||
layers.append(nn.Linear(self.hidden_size, 1, bias=False))
|
||||
self.layers = nn.Sequential(*layers)
|
||||
self.base_acc = nn.Parameter(torch.zeros(1, device=self.device), requires_grad=False)
|
||||
|
||||
if checkpoint_path is not None and os.path.exists(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
if 'state_dict' in checkpoint:
|
||||
checkpoint = checkpoint['state_dict']
|
||||
self.load_state_dict(checkpoint)
|
||||
print('Loaded checkpoint from %s' % checkpoint_path)
|
||||
|
||||
self.layers = self.layers.to(self.device)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.layers(x).squeeze()
|
||||
return y + self.base_acc
|
||||
|
||||
def predict_acc(self, arch_dict_list):
|
||||
X = [self.arch_encoder.arch2feature(arch_dict) for arch_dict in arch_dict_list]
|
||||
X = torch.tensor(np.array(X)).float().to(self.device)
|
||||
return self.forward(X)
|
||||
@@ -0,0 +1,315 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
|
||||
import random
|
||||
import numpy as np
|
||||
from ofa.imagenet_classification.networks import ResNets
|
||||
|
||||
__all__ = ['MobileNetArchEncoder', 'ResNetArchEncoder']
|
||||
|
||||
|
||||
class MobileNetArchEncoder:
|
||||
SPACE_TYPE = 'mbv3'
|
||||
|
||||
def __init__(self, image_size_list=None, ks_list=None, expand_list=None, depth_list=None, n_stage=None):
|
||||
self.image_size_list = [224] if image_size_list is None else image_size_list
|
||||
self.ks_list = [3, 5, 7] if ks_list is None else ks_list
|
||||
self.expand_list = [3, 4, 6] if expand_list is None else [int(expand) for expand in expand_list]
|
||||
self.depth_list = [2, 3, 4] if depth_list is None else depth_list
|
||||
if n_stage is not None:
|
||||
self.n_stage = n_stage
|
||||
elif self.SPACE_TYPE == 'mbv2':
|
||||
self.n_stage = 6
|
||||
elif self.SPACE_TYPE == 'mbv3':
|
||||
self.n_stage = 5
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
# build info dict
|
||||
self.n_dim = 0
|
||||
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
|
||||
self._build_info_dict(target='r')
|
||||
|
||||
self.k_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self._build_info_dict(target='k')
|
||||
self._build_info_dict(target='e')
|
||||
|
||||
@property
|
||||
def max_n_blocks(self):
|
||||
if self.SPACE_TYPE == 'mbv3':
|
||||
return self.n_stage * max(self.depth_list)
|
||||
elif self.SPACE_TYPE == 'mbv2':
|
||||
return (self.n_stage - 1) * max(self.depth_list) + 1
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def _build_info_dict(self, target):
|
||||
if target == 'r':
|
||||
target_dict = self.r_info
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for img_size in self.image_size_list:
|
||||
target_dict['val2id'][img_size] = self.n_dim
|
||||
target_dict['id2val'][self.n_dim] = img_size
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
else:
|
||||
if target == 'k':
|
||||
target_dict = self.k_info
|
||||
choices = self.ks_list
|
||||
elif target == 'e':
|
||||
target_dict = self.e_info
|
||||
choices = self.expand_list
|
||||
else:
|
||||
raise NotImplementedError
|
||||
for i in range(self.max_n_blocks):
|
||||
target_dict['val2id'].append({})
|
||||
target_dict['id2val'].append({})
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for k in choices:
|
||||
target_dict['val2id'][i][k] = self.n_dim
|
||||
target_dict['id2val'][i][self.n_dim] = k
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
|
||||
def arch2feature(self, arch_dict):
|
||||
ks, e, d, r = arch_dict['ks'], arch_dict['e'], arch_dict['d'], arch_dict['image_size']
|
||||
|
||||
feature = np.zeros(self.n_dim)
|
||||
for i in range(self.max_n_blocks):
|
||||
nowd = i % max(self.depth_list)
|
||||
stg = i // max(self.depth_list)
|
||||
if nowd < d[stg]:
|
||||
feature[self.k_info['val2id'][i][ks[i]]] = 1
|
||||
feature[self.e_info['val2id'][i][e[i]]] = 1
|
||||
feature[self.r_info['val2id'][r]] = 1
|
||||
return feature
|
||||
|
||||
def feature2arch(self, feature):
|
||||
img_sz = self.r_info['id2val'][
|
||||
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
|
||||
]
|
||||
assert img_sz in self.image_size_list
|
||||
arch_dict = {'ks': [], 'e': [], 'd': [], 'image_size': img_sz}
|
||||
|
||||
d = 0
|
||||
for i in range(self.max_n_blocks):
|
||||
skip = True
|
||||
for j in range(self.k_info['L'][i], self.k_info['R'][i]):
|
||||
if feature[j] == 1:
|
||||
arch_dict['ks'].append(self.k_info['id2val'][i][j])
|
||||
skip = False
|
||||
break
|
||||
|
||||
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
|
||||
if feature[j] == 1:
|
||||
arch_dict['e'].append(self.e_info['id2val'][i][j])
|
||||
assert not skip
|
||||
skip = False
|
||||
break
|
||||
|
||||
if skip:
|
||||
arch_dict['e'].append(0)
|
||||
arch_dict['ks'].append(0)
|
||||
else:
|
||||
d += 1
|
||||
|
||||
if (i + 1) % max(self.depth_list) == 0 or (i + 1) == self.max_n_blocks:
|
||||
arch_dict['d'].append(d)
|
||||
d = 0
|
||||
return arch_dict
|
||||
|
||||
def random_sample_arch(self):
|
||||
return {
|
||||
'ks': random.choices(self.ks_list, k=self.max_n_blocks),
|
||||
'e': random.choices(self.expand_list, k=self.max_n_blocks),
|
||||
'd': random.choices(self.depth_list, k=self.n_stage),
|
||||
'image_size': random.choice(self.image_size_list)
|
||||
}
|
||||
|
||||
def mutate_resolution(self, arch_dict, mutate_prob):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['image_size'] = random.choice(self.image_size_list)
|
||||
return arch_dict
|
||||
|
||||
def mutate_arch(self, arch_dict, mutate_prob):
|
||||
for i in range(self.max_n_blocks):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['ks'][i] = random.choice(self.ks_list)
|
||||
arch_dict['e'][i] = random.choice(self.expand_list)
|
||||
|
||||
for i in range(self.n_stage):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['d'][i] = random.choice(self.depth_list)
|
||||
return arch_dict
|
||||
|
||||
|
||||
class ResNetArchEncoder:
|
||||
|
||||
def __init__(self, image_size_list=None, depth_list=None, expand_list=None, width_mult_list=None,
|
||||
base_depth_list=None):
|
||||
self.image_size_list = [224] if image_size_list is None else image_size_list
|
||||
self.expand_list = [0.2, 0.25, 0.35] if expand_list is None else expand_list
|
||||
self.depth_list = [0, 1, 2] if depth_list is None else depth_list
|
||||
self.width_mult_list = [0.65, 0.8, 1.0] if width_mult_list is None else width_mult_list
|
||||
|
||||
self.base_depth_list = ResNets.BASE_DEPTH_LIST if base_depth_list is None else base_depth_list
|
||||
|
||||
"""" build info dict """
|
||||
self.n_dim = 0
|
||||
# resolution
|
||||
self.r_info = dict(id2val={}, val2id={}, L=[], R=[])
|
||||
self._build_info_dict(target='r')
|
||||
# input stem skip
|
||||
self.input_stem_d_info = dict(id2val={}, val2id={}, L=[], R=[])
|
||||
self._build_info_dict(target='input_stem_d')
|
||||
# width_mult
|
||||
self.width_mult_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self._build_info_dict(target='width_mult')
|
||||
# expand ratio
|
||||
self.e_info = dict(id2val=[], val2id=[], L=[], R=[])
|
||||
self._build_info_dict(target='e')
|
||||
|
||||
@property
|
||||
def n_stage(self):
|
||||
return len(self.base_depth_list)
|
||||
|
||||
@property
|
||||
def max_n_blocks(self):
|
||||
return sum(self.base_depth_list) + self.n_stage * max(self.depth_list)
|
||||
|
||||
def _build_info_dict(self, target):
|
||||
if target == 'r':
|
||||
target_dict = self.r_info
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for img_size in self.image_size_list:
|
||||
target_dict['val2id'][img_size] = self.n_dim
|
||||
target_dict['id2val'][self.n_dim] = img_size
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
elif target == 'input_stem_d':
|
||||
target_dict = self.input_stem_d_info
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for skip in [0, 1]:
|
||||
target_dict['val2id'][skip] = self.n_dim
|
||||
target_dict['id2val'][self.n_dim] = skip
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
elif target == 'e':
|
||||
target_dict = self.e_info
|
||||
choices = self.expand_list
|
||||
for i in range(self.max_n_blocks):
|
||||
target_dict['val2id'].append({})
|
||||
target_dict['id2val'].append({})
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for e in choices:
|
||||
target_dict['val2id'][i][e] = self.n_dim
|
||||
target_dict['id2val'][i][self.n_dim] = e
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
elif target == 'width_mult':
|
||||
target_dict = self.width_mult_info
|
||||
choices = list(range(len(self.width_mult_list)))
|
||||
for i in range(self.n_stage + 2):
|
||||
target_dict['val2id'].append({})
|
||||
target_dict['id2val'].append({})
|
||||
target_dict['L'].append(self.n_dim)
|
||||
for w in choices:
|
||||
target_dict['val2id'][i][w] = self.n_dim
|
||||
target_dict['id2val'][i][self.n_dim] = w
|
||||
self.n_dim += 1
|
||||
target_dict['R'].append(self.n_dim)
|
||||
|
||||
def arch2feature(self, arch_dict):
|
||||
d, e, w, r = arch_dict['d'], arch_dict['e'], arch_dict['w'], arch_dict['image_size']
|
||||
input_stem_skip = 1 if d[0] > 0 else 0
|
||||
d = d[1:]
|
||||
|
||||
feature = np.zeros(self.n_dim)
|
||||
feature[self.r_info['val2id'][r]] = 1
|
||||
feature[self.input_stem_d_info['val2id'][input_stem_skip]] = 1
|
||||
for i in range(self.n_stage + 2):
|
||||
feature[self.width_mult_info['val2id'][i][w[i]]] = 1
|
||||
|
||||
start_pt = 0
|
||||
for i, base_depth in enumerate(self.base_depth_list):
|
||||
depth = base_depth + d[i]
|
||||
for j in range(start_pt, start_pt + depth):
|
||||
feature[self.e_info['val2id'][j][e[j]]] = 1
|
||||
start_pt += max(self.depth_list) + base_depth
|
||||
|
||||
return feature
|
||||
|
||||
def feature2arch(self, feature):
|
||||
img_sz = self.r_info['id2val'][
|
||||
int(np.argmax(feature[self.r_info['L'][0]:self.r_info['R'][0]])) + self.r_info['L'][0]
|
||||
]
|
||||
input_stem_skip = self.input_stem_d_info['id2val'][
|
||||
int(np.argmax(feature[self.input_stem_d_info['L'][0]:self.input_stem_d_info['R'][0]])) +
|
||||
self.input_stem_d_info['L'][0]
|
||||
] * 2
|
||||
assert img_sz in self.image_size_list
|
||||
arch_dict = {'d': [input_stem_skip], 'e': [], 'w': [], 'image_size': img_sz}
|
||||
|
||||
for i in range(self.n_stage + 2):
|
||||
arch_dict['w'].append(
|
||||
self.width_mult_info['id2val'][i][
|
||||
int(np.argmax(feature[self.width_mult_info['L'][i]:self.width_mult_info['R'][i]])) +
|
||||
self.width_mult_info['L'][i]
|
||||
]
|
||||
)
|
||||
|
||||
d = 0
|
||||
skipped = 0
|
||||
stage_id = 0
|
||||
for i in range(self.max_n_blocks):
|
||||
skip = True
|
||||
for j in range(self.e_info['L'][i], self.e_info['R'][i]):
|
||||
if feature[j] == 1:
|
||||
arch_dict['e'].append(self.e_info['id2val'][i][j])
|
||||
skip = False
|
||||
break
|
||||
if skip:
|
||||
arch_dict['e'].append(0)
|
||||
skipped += 1
|
||||
else:
|
||||
d += 1
|
||||
|
||||
if i + 1 == self.max_n_blocks or (skipped + d) % \
|
||||
(max(self.depth_list) + self.base_depth_list[stage_id]) == 0:
|
||||
arch_dict['d'].append(d - self.base_depth_list[stage_id])
|
||||
d, skipped = 0, 0
|
||||
stage_id += 1
|
||||
return arch_dict
|
||||
|
||||
def random_sample_arch(self):
|
||||
return {
|
||||
'd': [random.choice([0, 2])] + random.choices(self.depth_list, k=self.n_stage),
|
||||
'e': random.choices(self.expand_list, k=self.max_n_blocks),
|
||||
'w': random.choices(list(range(len(self.width_mult_list))), k=self.n_stage + 2),
|
||||
'image_size': random.choice(self.image_size_list)
|
||||
}
|
||||
|
||||
def mutate_resolution(self, arch_dict, mutate_prob):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['image_size'] = random.choice(self.image_size_list)
|
||||
return arch_dict
|
||||
|
||||
def mutate_arch(self, arch_dict, mutate_prob):
|
||||
# input stem skip
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['d'][0] = random.choice([0, 2])
|
||||
# depth
|
||||
for i in range(1, len(arch_dict['d'])):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['d'][i] = random.choice(self.depth_list)
|
||||
# width_mult
|
||||
for i in range(len(arch_dict['w'])):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['w'][i] = random.choice(list(range(len(self.width_mult_list))))
|
||||
# expand ratio
|
||||
for i in range(len(arch_dict['e'])):
|
||||
if random.random() < mutate_prob:
|
||||
arch_dict['e'][i] = random.choice(self.expand_list)
|
||||
@@ -0,0 +1,71 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import os
|
||||
import copy
|
||||
from .latency_lookup_table import *
|
||||
|
||||
|
||||
class BaseEfficiencyModel:
|
||||
|
||||
def __init__(self, ofa_net):
|
||||
self.ofa_net = ofa_net
|
||||
|
||||
def get_active_subnet_config(self, arch_dict):
|
||||
arch_dict = copy.deepcopy(arch_dict)
|
||||
image_size = arch_dict.pop('image_size')
|
||||
self.ofa_net.set_active_subnet(**arch_dict)
|
||||
active_net_config = self.ofa_net.get_active_net_config()
|
||||
return active_net_config, image_size
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProxylessNASFLOPsModel(BaseEfficiencyModel):
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return ProxylessNASLatencyTable.count_flops_given_config(active_net_config, image_size)
|
||||
|
||||
|
||||
class Mbv3FLOPsModel(BaseEfficiencyModel):
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return MBv3LatencyTable.count_flops_given_config(active_net_config, image_size)
|
||||
|
||||
|
||||
class ResNet50FLOPsModel(BaseEfficiencyModel):
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return ResNet50LatencyTable.count_flops_given_config(active_net_config, image_size)
|
||||
|
||||
class ProxylessNASLatencyModel(BaseEfficiencyModel):
|
||||
|
||||
def __init__(self, ofa_net, lookup_table_path_dict):
|
||||
super(ProxylessNASLatencyModel, self).__init__(ofa_net)
|
||||
self.latency_tables = {}
|
||||
for image_size, path in lookup_table_path_dict.items():
|
||||
self.latency_tables[image_size] = ProxylessNASLatencyTable(
|
||||
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)
|
||||
|
||||
|
||||
class Mbv3LatencyModel(BaseEfficiencyModel):
|
||||
|
||||
def __init__(self, ofa_net, lookup_table_path_dict):
|
||||
super(Mbv3LatencyModel, self).__init__(ofa_net)
|
||||
self.latency_tables = {}
|
||||
for image_size, path in lookup_table_path_dict.items():
|
||||
self.latency_tables[image_size] = MBv3LatencyTable(
|
||||
local_dir='/tmp/.ofa_latency_tools/', url=os.path.join(path, '%d_lookup_table.yaml' % image_size))
|
||||
|
||||
def get_efficiency(self, arch_dict):
|
||||
active_net_config, image_size = self.get_active_subnet_config(arch_dict)
|
||||
return self.latency_tables[image_size].predict_network_latency_given_config(active_net_config, image_size)
|
||||
@@ -0,0 +1,387 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import yaml
|
||||
from ofa.utils import download_url, make_divisible, MyNetwork
|
||||
|
||||
__all__ = ['count_conv_flop', 'ProxylessNASLatencyTable', 'MBv3LatencyTable', 'ResNet50LatencyTable']
|
||||
|
||||
|
||||
def count_conv_flop(out_size, in_channels, out_channels, kernel_size, groups):
|
||||
out_h = out_w = out_size
|
||||
delta_ops = in_channels * out_channels * kernel_size * kernel_size * out_h * out_w / groups
|
||||
return delta_ops
|
||||
|
||||
|
||||
class LatencyTable(object):
|
||||
|
||||
def __init__(self, local_dir='~/.ofa/latency_tools/',
|
||||
url='https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml'):
|
||||
if url.startswith('http'):
|
||||
fname = download_url(url, local_dir, overwrite=True)
|
||||
else:
|
||||
fname = url
|
||||
with open(fname, 'r') as fp:
|
||||
self.lut = yaml.load(fp)
|
||||
|
||||
@staticmethod
|
||||
def repr_shape(shape):
|
||||
if isinstance(shape, (list, tuple)):
|
||||
return 'x'.join(str(_) for _ in shape)
|
||||
elif isinstance(shape, str):
|
||||
return shape
|
||||
else:
|
||||
return TypeError
|
||||
|
||||
def query(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency(self, net, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProxylessNASLatencyTable(LatencyTable):
|
||||
|
||||
def query(self, l_type: str, input_shape, output_shape, expand=None, ks=None, stride=None, id_skip=None):
|
||||
"""
|
||||
:param l_type:
|
||||
Layer type must be one of the followings
|
||||
1. `Conv`: The initial 3x3 conv with stride 2.
|
||||
2. `Conv_1`: feature_mix_layer
|
||||
3. `Logits`: All operations after `Conv_1`.
|
||||
4. `expanded_conv`: MobileInvertedResidual
|
||||
:param input_shape: input shape (h, w, #channels)
|
||||
:param output_shape: output shape (h, w, #channels)
|
||||
:param expand: expansion ratio
|
||||
:param ks: kernel size
|
||||
:param stride:
|
||||
:param id_skip: indicate whether has the residual connection
|
||||
"""
|
||||
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
|
||||
|
||||
if l_type in ('expanded_conv',):
|
||||
assert None not in (expand, ks, stride, id_skip)
|
||||
infos += ['expand:%d' % expand, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip]
|
||||
key = '-'.join(infos)
|
||||
return self.lut[key]['mean']
|
||||
|
||||
def predict_network_latency(self, net, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net.blocks:
|
||||
mb_conv = block.conv
|
||||
shortcut = block.shortcut
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv.stride + 1) # fsize // mb_conv.stride
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
|
||||
expand=mb_conv.expand_ratio, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net.feature_mix_layer.in_channels],
|
||||
[fsize, fsize, net.feature_mix_layer.out_channels]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [fsize, fsize, net.classifier.in_features], [net.classifier.out_features] # 1000
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
shortcut = block['shortcut']
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
|
||||
expand=mb_conv['expand_ratio'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net_config['feature_mix_layer']['in_channels']],
|
||||
[fsize, fsize, net_config['feature_mix_layer']['out_channels']]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [fsize, fsize, net_config['classifier']['in_features']],
|
||||
[net_config['classifier']['out_features']] # 1000
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
flops = 0
|
||||
# first conv
|
||||
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
if mb_conv is None:
|
||||
continue
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
if mb_conv['mid_channels'] is None:
|
||||
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
|
||||
if mb_conv['expand_ratio'] != 1:
|
||||
# inverted bottleneck
|
||||
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
|
||||
# depth conv
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
|
||||
mb_conv['kernel_size'], mb_conv['mid_channels'])
|
||||
# point linear
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
|
||||
fsize = out_fz
|
||||
# feature mix layer
|
||||
flops += count_conv_flop(fsize, net_config['feature_mix_layer']['in_channels'],
|
||||
net_config['feature_mix_layer']['out_channels'], 1, 1)
|
||||
# classifier
|
||||
flops += count_conv_flop(1, net_config['classifier']['in_features'],
|
||||
net_config['classifier']['out_features'], 1, 1)
|
||||
return flops / 1e6 # MFLOPs
|
||||
|
||||
|
||||
class MBv3LatencyTable(LatencyTable):
|
||||
|
||||
def query(self, l_type: str, input_shape, output_shape, mid=None, ks=None, stride=None, id_skip=None,
|
||||
se=None, h_swish=None):
|
||||
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
|
||||
|
||||
if l_type in ('expanded_conv',):
|
||||
assert None not in (mid, ks, stride, id_skip, se, h_swish)
|
||||
infos += ['expand:%d' % mid, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip,
|
||||
'se:%d' % se, 'hs:%d' % h_swish]
|
||||
key = '-'.join(infos)
|
||||
return self.lut[key]['mean']
|
||||
|
||||
def predict_network_latency(self, net, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net.blocks:
|
||||
mb_conv = block.conv
|
||||
shortcut = block.shortcut
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv.stride + 1)
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
|
||||
mid=mb_conv.depth_conv.conv.in_channels, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip,
|
||||
se=1 if mb_conv.use_se else 0, h_swish=1 if mb_conv.act_func == 'h_swish' else 0,
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# final expand layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net.final_expand_layer.in_channels],
|
||||
[fsize, fsize, net.final_expand_layer.out_channels],
|
||||
)
|
||||
# global average pooling
|
||||
predicted_latency += self.query(
|
||||
'AvgPool2D', [fsize, fsize, net.final_expand_layer.out_channels],
|
||||
[1, 1, net.final_expand_layer.out_channels],
|
||||
)
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_2', [1, 1, net.feature_mix_layer.in_channels],
|
||||
[1, 1, net.feature_mix_layer.out_channels]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [1, 1, net.classifier.in_features], [net.classifier.out_features]
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net_config['first_conv']['out_channels']]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
shortcut = block['shortcut']
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
if mb_conv['mid_channels'] is None:
|
||||
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv['in_channels']], [out_fz, out_fz, mb_conv['out_channels']],
|
||||
mid=mb_conv['mid_channels'], ks=mb_conv['kernel_size'], stride=mb_conv['stride'], id_skip=idskip,
|
||||
se=1 if mb_conv['use_se'] else 0, h_swish=1 if mb_conv['act_func'] == 'h_swish' else 0,
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# final expand layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net_config['final_expand_layer']['in_channels']],
|
||||
[fsize, fsize, net_config['final_expand_layer']['out_channels']],
|
||||
)
|
||||
# global average pooling
|
||||
predicted_latency += self.query(
|
||||
'AvgPool2D', [fsize, fsize, net_config['final_expand_layer']['out_channels']],
|
||||
[1, 1, net_config['final_expand_layer']['out_channels']],
|
||||
)
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_2', [1, 1, net_config['feature_mix_layer']['in_channels']],
|
||||
[1, 1, net_config['feature_mix_layer']['out_channels']]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [1, 1, net_config['classifier']['in_features']], [net_config['classifier']['out_features']]
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
flops = 0
|
||||
# first conv
|
||||
flops += count_conv_flop((image_size + 1) // 2, 3, net_config['first_conv']['out_channels'], 3, 1)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net_config['blocks']:
|
||||
mb_conv = block['mobile_inverted_conv'] if 'mobile_inverted_conv' in block else block['conv']
|
||||
if mb_conv is None:
|
||||
continue
|
||||
out_fz = int((fsize - 1) / mb_conv['stride'] + 1)
|
||||
if mb_conv['mid_channels'] is None:
|
||||
mb_conv['mid_channels'] = round(mb_conv['in_channels'] * mb_conv['expand_ratio'])
|
||||
if mb_conv['expand_ratio'] != 1:
|
||||
# inverted bottleneck
|
||||
flops += count_conv_flop(fsize, mb_conv['in_channels'], mb_conv['mid_channels'], 1, 1)
|
||||
# depth conv
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['mid_channels'],
|
||||
mb_conv['kernel_size'], mb_conv['mid_channels'])
|
||||
if mb_conv['use_se']:
|
||||
# SE layer
|
||||
se_mid = make_divisible(mb_conv['mid_channels'] // 4, divisor=MyNetwork.CHANNEL_DIVISIBLE)
|
||||
flops += count_conv_flop(1, mb_conv['mid_channels'], se_mid, 1, 1)
|
||||
flops += count_conv_flop(1, se_mid, mb_conv['mid_channels'], 1, 1)
|
||||
# point linear
|
||||
flops += count_conv_flop(out_fz, mb_conv['mid_channels'], mb_conv['out_channels'], 1, 1)
|
||||
fsize = out_fz
|
||||
# final expand layer
|
||||
flops += count_conv_flop(fsize, net_config['final_expand_layer']['in_channels'],
|
||||
net_config['final_expand_layer']['out_channels'], 1, 1)
|
||||
# feature mix layer
|
||||
flops += count_conv_flop(1, net_config['feature_mix_layer']['in_channels'],
|
||||
net_config['feature_mix_layer']['out_channels'], 1, 1)
|
||||
# classifier
|
||||
flops += count_conv_flop(1, net_config['classifier']['in_features'],
|
||||
net_config['classifier']['out_features'], 1, 1)
|
||||
return flops / 1e6 # MFLOPs
|
||||
|
||||
|
||||
class ResNet50LatencyTable(LatencyTable):
|
||||
|
||||
def query(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency(self, net, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_network_latency_given_config(self, net_config, image_size):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def count_flops_given_config(net_config, image_size=224):
|
||||
flops = 0
|
||||
# input stem
|
||||
for layer_config in net_config['input_stem']:
|
||||
if layer_config['name'] != 'ConvLayer':
|
||||
layer_config = layer_config['conv']
|
||||
in_channel = layer_config['in_channels']
|
||||
out_channel = layer_config['out_channels']
|
||||
out_image_size = int((image_size - 1) / layer_config['stride'] + 1)
|
||||
|
||||
flops += count_conv_flop(out_image_size, in_channel, out_channel,
|
||||
layer_config['kernel_size'], layer_config.get('groups', 1))
|
||||
image_size = out_image_size
|
||||
# max pooling
|
||||
image_size = int((image_size - 1) / 2 + 1)
|
||||
# ResNetBottleneckBlocks
|
||||
for block_config in net_config['blocks']:
|
||||
in_channel = block_config['in_channels']
|
||||
out_channel = block_config['out_channels']
|
||||
|
||||
out_image_size = int((image_size - 1) / block_config['stride'] + 1)
|
||||
mid_channel = block_config['mid_channels'] if block_config['mid_channels'] is not None \
|
||||
else round(out_channel * block_config['expand_ratio'])
|
||||
mid_channel = make_divisible(mid_channel, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
# conv1
|
||||
flops += count_conv_flop(image_size, in_channel, mid_channel, 1, 1)
|
||||
# conv2
|
||||
flops += count_conv_flop(out_image_size, mid_channel, mid_channel,
|
||||
block_config['kernel_size'], block_config['groups'])
|
||||
# conv3
|
||||
flops += count_conv_flop(out_image_size, mid_channel, out_channel, 1, 1)
|
||||
# downsample
|
||||
if block_config['stride'] == 1 and in_channel == out_channel:
|
||||
pass
|
||||
else:
|
||||
flops += count_conv_flop(out_image_size, in_channel, out_channel, 1, 1)
|
||||
image_size = out_image_size
|
||||
# final classifier
|
||||
flops += count_conv_flop(1, net_config['classifier']['in_features'],
|
||||
net_config['classifier']['out_features'], 1, 1)
|
||||
return flops / 1e6 # MFLOPs
|
||||
@@ -0,0 +1,5 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .evolution import *
|
||||
@@ -0,0 +1,134 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import copy
|
||||
import random
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
__all__ = ['EvolutionFinder']
|
||||
|
||||
|
||||
class EvolutionFinder:
|
||||
|
||||
def __init__(self, efficiency_predictor, accuracy_predictor, **kwargs):
|
||||
self.efficiency_predictor = efficiency_predictor
|
||||
self.accuracy_predictor = accuracy_predictor
|
||||
|
||||
# evolution hyper-parameters
|
||||
self.arch_mutate_prob = kwargs.get('arch_mutate_prob', 0.1)
|
||||
self.resolution_mutate_prob = kwargs.get('resolution_mutate_prob', 0.5)
|
||||
self.population_size = kwargs.get('population_size', 100)
|
||||
self.max_time_budget = kwargs.get('max_time_budget', 500)
|
||||
self.parent_ratio = kwargs.get('parent_ratio', 0.25)
|
||||
self.mutation_ratio = kwargs.get('mutation_ratio', 0.5)
|
||||
|
||||
@property
|
||||
def arch_manager(self):
|
||||
return self.accuracy_predictor.arch_encoder
|
||||
|
||||
def update_hyper_params(self, new_param_dict):
|
||||
self.__dict__.update(new_param_dict)
|
||||
|
||||
def random_valid_sample(self, constraint):
|
||||
while True:
|
||||
sample = self.arch_manager.random_sample_arch()
|
||||
efficiency = self.efficiency_predictor.get_efficiency(sample)
|
||||
if efficiency <= constraint:
|
||||
return sample, efficiency
|
||||
|
||||
def mutate_sample(self, sample, constraint):
|
||||
while True:
|
||||
new_sample = copy.deepcopy(sample)
|
||||
|
||||
self.arch_manager.mutate_resolution(new_sample, self.resolution_mutate_prob)
|
||||
self.arch_manager.mutate_arch(new_sample, self.arch_mutate_prob)
|
||||
|
||||
efficiency = self.efficiency_predictor.get_efficiency(new_sample)
|
||||
if efficiency <= constraint:
|
||||
return new_sample, efficiency
|
||||
|
||||
def crossover_sample(self, sample1, sample2, constraint):
|
||||
while True:
|
||||
new_sample = copy.deepcopy(sample1)
|
||||
for key in new_sample.keys():
|
||||
if not isinstance(new_sample[key], list):
|
||||
new_sample[key] = random.choice([sample1[key], sample2[key]])
|
||||
else:
|
||||
for i in range(len(new_sample[key])):
|
||||
new_sample[key][i] = random.choice([sample1[key][i], sample2[key][i]])
|
||||
|
||||
efficiency = self.efficiency_predictor.get_efficiency(new_sample)
|
||||
if efficiency <= constraint:
|
||||
return new_sample, efficiency
|
||||
|
||||
def run_evolution_search(self, constraint, verbose=False, **kwargs):
|
||||
"""Run a single roll-out of regularized evolution to a fixed time budget."""
|
||||
self.update_hyper_params(kwargs)
|
||||
|
||||
mutation_numbers = int(round(self.mutation_ratio * self.population_size))
|
||||
parents_size = int(round(self.parent_ratio * self.population_size))
|
||||
|
||||
best_valids = [-100]
|
||||
population = [] # (validation, sample, latency) tuples
|
||||
child_pool = []
|
||||
efficiency_pool = []
|
||||
best_info = None
|
||||
if verbose:
|
||||
print('Generate random population...')
|
||||
for _ in range(self.population_size):
|
||||
sample, efficiency = self.random_valid_sample(constraint)
|
||||
child_pool.append(sample)
|
||||
efficiency_pool.append(efficiency)
|
||||
|
||||
accs = self.accuracy_predictor.predict_acc(child_pool)
|
||||
for i in range(mutation_numbers):
|
||||
population.append((accs[i].item(), child_pool[i], efficiency_pool[i]))
|
||||
|
||||
if verbose:
|
||||
print('Start Evolution...')
|
||||
# After the population is seeded, proceed with evolving the population.
|
||||
with tqdm(total=self.max_time_budget, desc='Searching with constraint (%s)' % constraint,
|
||||
disable=(not verbose)) as t:
|
||||
for i in range(self.max_time_budget):
|
||||
parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size]
|
||||
acc = parents[0][0]
|
||||
t.set_postfix({
|
||||
'acc': parents[0][0]
|
||||
})
|
||||
if not verbose and (i + 1) % 100 == 0:
|
||||
print('Iter: {} Acc: {}'.format(i + 1, parents[0][0]))
|
||||
|
||||
if acc > best_valids[-1]:
|
||||
best_valids.append(acc)
|
||||
best_info = parents[0]
|
||||
else:
|
||||
best_valids.append(best_valids[-1])
|
||||
|
||||
population = parents
|
||||
child_pool = []
|
||||
efficiency_pool = []
|
||||
|
||||
for j in range(mutation_numbers):
|
||||
par_sample = population[np.random.randint(parents_size)][1]
|
||||
# Mutate
|
||||
new_sample, efficiency = self.mutate_sample(par_sample, constraint)
|
||||
child_pool.append(new_sample)
|
||||
efficiency_pool.append(efficiency)
|
||||
|
||||
for j in range(self.population_size - mutation_numbers):
|
||||
par_sample1 = population[np.random.randint(parents_size)][1]
|
||||
par_sample2 = population[np.random.randint(parents_size)][1]
|
||||
# Crossover
|
||||
new_sample, efficiency = self.crossover_sample(par_sample1, par_sample2, constraint)
|
||||
child_pool.append(new_sample)
|
||||
efficiency_pool.append(efficiency)
|
||||
|
||||
accs = self.accuracy_predictor.predict_acc(child_pool)
|
||||
for j in range(self.population_size):
|
||||
population.append((accs[j].item(), child_pool[j], efficiency_pool[j]))
|
||||
|
||||
t.update(1)
|
||||
|
||||
return best_valids, best_info
|
||||
@@ -0,0 +1,5 @@
|
||||
from .accuracy_predictor import AccuracyPredictor
|
||||
from .flops_table import FLOPsTable
|
||||
from .latency_table import LatencyTable
|
||||
from .evolution_finder import EvolutionFinder, ArchManager
|
||||
from .imagenet_eval_helper import evaluate_ofa_subnet, evaluate_ofa_specialized
|
||||
@@ -0,0 +1,85 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
import copy
|
||||
|
||||
from ofa.utils import download_url
|
||||
|
||||
|
||||
# Helper for constructing the one-hot vectors.
|
||||
def construct_maps(keys):
|
||||
d = dict()
|
||||
keys = list(set(keys))
|
||||
for k in keys:
|
||||
if k not in d:
|
||||
d[k] = len(list(d.keys()))
|
||||
return d
|
||||
|
||||
|
||||
ks_map = construct_maps(keys=(3, 5, 7))
|
||||
ex_map = construct_maps(keys=(3, 4, 6))
|
||||
dp_map = construct_maps(keys=(2, 3, 4))
|
||||
|
||||
|
||||
class AccuracyPredictor:
|
||||
def __init__(self, pretrained=True, device='cuda:0'):
|
||||
self.device = device
|
||||
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(128, 400),
|
||||
nn.ReLU(),
|
||||
nn.Linear(400, 400),
|
||||
nn.ReLU(),
|
||||
nn.Linear(400, 400),
|
||||
nn.ReLU(),
|
||||
nn.Linear(400, 1),
|
||||
)
|
||||
if pretrained:
|
||||
# load pretrained model
|
||||
fname = download_url("https://hanlab.mit.edu/files/OnceForAll/tutorial/acc_predictor.pth")
|
||||
self.model.load_state_dict(
|
||||
torch.load(fname, map_location=torch.device('cpu'))
|
||||
)
|
||||
self.model = self.model.to(self.device)
|
||||
|
||||
# TODO: merge it with serialization utils.
|
||||
@torch.no_grad()
|
||||
def predict_accuracy(self, population):
|
||||
all_feats = []
|
||||
for sample in population:
|
||||
ks_list = copy.deepcopy(sample['ks'])
|
||||
ex_list = copy.deepcopy(sample['e'])
|
||||
d_list = copy.deepcopy(sample['d'])
|
||||
r = copy.deepcopy(sample['r'])[0]
|
||||
feats = AccuracyPredictor.spec2feats(ks_list, ex_list, d_list, r).reshape(1, -1).to(self.device)
|
||||
all_feats.append(feats)
|
||||
all_feats = torch.cat(all_feats, 0)
|
||||
pred = self.model(all_feats).cpu()
|
||||
return pred
|
||||
|
||||
@staticmethod
|
||||
def spec2feats(ks_list, ex_list, d_list, r):
|
||||
# This function converts a network config to a feature vector (128-D).
|
||||
start = 0
|
||||
end = 4
|
||||
for d in d_list:
|
||||
for j in range(start+d, end):
|
||||
ks_list[j] = 0
|
||||
ex_list[j] = 0
|
||||
start += 4
|
||||
end += 4
|
||||
|
||||
# convert to onehot
|
||||
ks_onehot = [0 for _ in range(60)]
|
||||
ex_onehot = [0 for _ in range(60)]
|
||||
r_onehot = [0 for _ in range(8)]
|
||||
|
||||
for i in range(20):
|
||||
start = i * 3
|
||||
if ks_list[i] != 0:
|
||||
ks_onehot[start + ks_map[ks_list[i]]] = 1
|
||||
if ex_list[i] != 0:
|
||||
ex_onehot[start + ex_map[ex_list[i]]] = 1
|
||||
|
||||
r_onehot[(r - 112) // 16] = 1
|
||||
return torch.Tensor(ks_onehot + ex_onehot + r_onehot)
|
||||
@@ -0,0 +1,213 @@
|
||||
import copy
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['EvolutionFinder']
|
||||
|
||||
|
||||
class ArchManager:
|
||||
def __init__(self):
|
||||
self.num_blocks = 20
|
||||
self.num_stages = 5
|
||||
self.kernel_sizes = [3, 5, 7]
|
||||
self.expand_ratios = [3, 4, 6]
|
||||
self.depths = [2, 3, 4]
|
||||
self.resolutions = [160, 176, 192, 208, 224]
|
||||
|
||||
def random_sample(self):
|
||||
sample = {}
|
||||
d = []
|
||||
e = []
|
||||
ks = []
|
||||
for i in range(self.num_stages):
|
||||
d.append(random.choice(self.depths))
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
e.append(random.choice(self.expand_ratios))
|
||||
ks.append(random.choice(self.kernel_sizes))
|
||||
|
||||
sample = {
|
||||
'wid': None,
|
||||
'ks': ks,
|
||||
'e': e,
|
||||
'd': d,
|
||||
'r': [random.choice(self.resolutions)]
|
||||
}
|
||||
|
||||
return sample
|
||||
|
||||
def random_resample(self, sample, i):
|
||||
assert i >= 0 and i < self.num_blocks
|
||||
sample['ks'][i] = random.choice(self.kernel_sizes)
|
||||
sample['e'][i] = random.choice(self.expand_ratios)
|
||||
|
||||
def random_resample_depth(self, sample, i):
|
||||
assert i >= 0 and i < self.num_stages
|
||||
sample['d'][i] = random.choice(self.depths)
|
||||
|
||||
def random_resample_resolution(self, sample):
|
||||
sample['r'][0] = random.choice(self.resolutions)
|
||||
|
||||
|
||||
class EvolutionFinder:
|
||||
valid_constraint_range = {
|
||||
'flops': [150, 600],
|
||||
'note10': [15, 60],
|
||||
}
|
||||
|
||||
def __init__(self, constraint_type, efficiency_constraint,
|
||||
efficiency_predictor, accuracy_predictor, **kwargs):
|
||||
self.constraint_type = constraint_type
|
||||
if not constraint_type in self.valid_constraint_range.keys():
|
||||
self.invite_reset_constraint_type()
|
||||
self.efficiency_constraint = efficiency_constraint
|
||||
if not (efficiency_constraint <= self.valid_constraint_range[constraint_type][1] and
|
||||
efficiency_constraint >= self.valid_constraint_range[constraint_type][0]):
|
||||
self.invite_reset_constraint()
|
||||
|
||||
self.efficiency_predictor = efficiency_predictor
|
||||
self.accuracy_predictor = accuracy_predictor
|
||||
self.arch_manager = ArchManager()
|
||||
self.num_blocks = self.arch_manager.num_blocks
|
||||
self.num_stages = self.arch_manager.num_stages
|
||||
|
||||
self.mutate_prob = kwargs.get('mutate_prob', 0.1)
|
||||
self.population_size = kwargs.get('population_size', 100)
|
||||
self.max_time_budget = kwargs.get('max_time_budget', 500)
|
||||
self.parent_ratio = kwargs.get('parent_ratio', 0.25)
|
||||
self.mutation_ratio = kwargs.get('mutation_ratio', 0.5)
|
||||
|
||||
def invite_reset_constraint_type(self):
|
||||
print('Invalid constraint type! Please input one of:', list(self.valid_constraint_range.keys()))
|
||||
new_type = input()
|
||||
while new_type not in self.valid_constraint_range.keys():
|
||||
print('Invalid constraint type! Please input one of:', list(self.valid_constraint_range.keys()))
|
||||
new_type = input()
|
||||
self.constraint_type = new_type
|
||||
|
||||
def invite_reset_constraint(self):
|
||||
print('Invalid constraint_value! Please input an integer in interval: [%d, %d]!' % (
|
||||
self.valid_constraint_range[self.constraint_type][0],
|
||||
self.valid_constraint_range[self.constraint_type][1])
|
||||
)
|
||||
|
||||
new_cons = input()
|
||||
while (not new_cons.isdigit()) or (int(new_cons) > self.valid_constraint_range[self.constraint_type][1]) or \
|
||||
(int(new_cons) < self.valid_constraint_range[self.constraint_type][0]):
|
||||
print('Invalid constraint_value! Please input an integer in interval: [%d, %d]!' % (
|
||||
self.valid_constraint_range[self.constraint_type][0],
|
||||
self.valid_constraint_range[self.constraint_type][1])
|
||||
)
|
||||
new_cons = input()
|
||||
new_cons = int(new_cons)
|
||||
self.efficiency_constraint = new_cons
|
||||
|
||||
def set_efficiency_constraint(self, new_constraint):
|
||||
self.efficiency_constraint = new_constraint
|
||||
|
||||
def random_sample(self):
|
||||
constraint = self.efficiency_constraint
|
||||
while True:
|
||||
sample = self.arch_manager.random_sample()
|
||||
efficiency = self.efficiency_predictor.predict_efficiency(sample)
|
||||
if efficiency <= constraint:
|
||||
return sample, efficiency
|
||||
|
||||
def mutate_sample(self, sample):
|
||||
constraint = self.efficiency_constraint
|
||||
while True:
|
||||
new_sample = copy.deepcopy(sample)
|
||||
|
||||
if random.random() < self.mutate_prob:
|
||||
self.arch_manager.random_resample_resolution(new_sample)
|
||||
|
||||
for i in range(self.num_blocks):
|
||||
if random.random() < self.mutate_prob:
|
||||
self.arch_manager.random_resample(new_sample, i)
|
||||
|
||||
for i in range(self.num_stages):
|
||||
if random.random() < self.mutate_prob:
|
||||
self.arch_manager.random_resample_depth(new_sample, i)
|
||||
|
||||
efficiency = self.efficiency_predictor.predict_efficiency(new_sample)
|
||||
if efficiency <= constraint:
|
||||
return new_sample, efficiency
|
||||
|
||||
def crossover_sample(self, sample1, sample2):
|
||||
constraint = self.efficiency_constraint
|
||||
while True:
|
||||
new_sample = copy.deepcopy(sample1)
|
||||
for key in new_sample.keys():
|
||||
if not isinstance(new_sample[key], list):
|
||||
continue
|
||||
for i in range(len(new_sample[key])):
|
||||
new_sample[key][i] = random.choice([sample1[key][i], sample2[key][i]])
|
||||
|
||||
efficiency = self.efficiency_predictor.predict_efficiency(new_sample)
|
||||
if efficiency <= constraint:
|
||||
return new_sample, efficiency
|
||||
|
||||
def run_evolution_search(self, verbose=False):
|
||||
"""Run a single roll-out of regularized evolution to a fixed time budget."""
|
||||
max_time_budget = self.max_time_budget
|
||||
population_size = self.population_size
|
||||
mutation_numbers = int(round(self.mutation_ratio * population_size))
|
||||
parents_size = int(round(self.parent_ratio * population_size))
|
||||
constraint = self.efficiency_constraint
|
||||
|
||||
best_valids = [-100]
|
||||
population = [] # (validation, sample, latency) tuples
|
||||
child_pool = []
|
||||
efficiency_pool = []
|
||||
best_info = None
|
||||
if verbose:
|
||||
print('Generate random population...')
|
||||
for _ in range(population_size):
|
||||
sample, efficiency = self.random_sample()
|
||||
child_pool.append(sample)
|
||||
efficiency_pool.append(efficiency)
|
||||
|
||||
accs = self.accuracy_predictor.predict_accuracy(child_pool)
|
||||
for i in range(mutation_numbers):
|
||||
population.append((accs[i].item(), child_pool[i], efficiency_pool[i]))
|
||||
|
||||
if verbose:
|
||||
print('Start Evolution...')
|
||||
# After the population is seeded, proceed with evolving the population.
|
||||
for iter in tqdm(range(max_time_budget), desc='Searching with %s constraint (%s)' % (self.constraint_type, self.efficiency_constraint)):
|
||||
parents = sorted(population, key=lambda x: x[0])[::-1][:parents_size]
|
||||
acc = parents[0][0]
|
||||
if verbose:
|
||||
print('Iter: {} Acc: {}'.format(iter - 1, parents[0][0]))
|
||||
|
||||
if acc > best_valids[-1]:
|
||||
best_valids.append(acc)
|
||||
best_info = parents[0]
|
||||
else:
|
||||
best_valids.append(best_valids[-1])
|
||||
|
||||
population = parents
|
||||
child_pool = []
|
||||
efficiency_pool = []
|
||||
|
||||
for i in range(mutation_numbers):
|
||||
par_sample = population[np.random.randint(parents_size)][1]
|
||||
# Mutate
|
||||
new_sample, efficiency = self.mutate_sample(par_sample)
|
||||
child_pool.append(new_sample)
|
||||
efficiency_pool.append(efficiency)
|
||||
|
||||
for i in range(population_size - mutation_numbers):
|
||||
par_sample1 = population[np.random.randint(parents_size)][1]
|
||||
par_sample2 = population[np.random.randint(parents_size)][1]
|
||||
# Crossover
|
||||
new_sample, efficiency = self.crossover_sample(par_sample1, par_sample2)
|
||||
child_pool.append(new_sample)
|
||||
efficiency_pool.append(efficiency)
|
||||
|
||||
accs = self.accuracy_predictor.predict_accuracy(child_pool)
|
||||
for i in range(population_size):
|
||||
population.append((accs[i].item(), child_pool[i], efficiency_pool[i]))
|
||||
|
||||
return best_valids, best_info
|
||||
@@ -0,0 +1,224 @@
|
||||
import time
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from ofa.utils.layers import *
|
||||
|
||||
__all__ = ['FLOPsTable']
|
||||
|
||||
|
||||
def rm_bn_from_net(net):
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
m.forward = lambda x: x
|
||||
|
||||
|
||||
class FLOPsTable:
|
||||
def __init__(self, pred_type='flops', device='cuda:0', multiplier=1.2, batch_size=64, load_efficiency_table=None):
|
||||
assert pred_type in ['flops', 'latency']
|
||||
self.multiplier = multiplier
|
||||
self.pred_type = pred_type
|
||||
self.device = device
|
||||
self.batch_size = batch_size
|
||||
self.efficiency_dict = {}
|
||||
if load_efficiency_table is not None:
|
||||
self.efficiency_dict = np.load(load_efficiency_table, allow_pickle=True).item()
|
||||
else:
|
||||
self.build_lut(batch_size)
|
||||
|
||||
@torch.no_grad()
|
||||
def measure_single_layer_latency(self, layer: nn.Module, input_size: tuple, warmup_steps=10, measure_steps=50):
|
||||
total_time = 0
|
||||
inputs = torch.randn(*input_size, device=self.device)
|
||||
layer.eval()
|
||||
rm_bn_from_net(layer)
|
||||
network = layer.to(self.device)
|
||||
torch.cuda.synchronize()
|
||||
for i in range(warmup_steps):
|
||||
network(inputs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
st = time.time()
|
||||
for i in range(measure_steps):
|
||||
network(inputs)
|
||||
torch.cuda.synchronize()
|
||||
ed = time.time()
|
||||
total_time += ed - st
|
||||
|
||||
latency = total_time / measure_steps * 1000
|
||||
|
||||
return latency
|
||||
|
||||
@torch.no_grad()
|
||||
def measure_single_layer_flops(self, layer: nn.Module, input_size: tuple):
|
||||
import thop
|
||||
inputs = torch.randn(*input_size, device=self.device)
|
||||
network = layer.to(self.device)
|
||||
layer.eval()
|
||||
rm_bn_from_net(layer)
|
||||
flops, params = thop.profile(network, (inputs,), verbose=False)
|
||||
return flops / 1e6
|
||||
|
||||
def build_lut(self, batch_size=1, resolutions=[160, 176, 192, 208, 224]):
|
||||
for resolution in resolutions:
|
||||
self.build_single_lut(batch_size, resolution)
|
||||
|
||||
np.save('local_lut.npy', self.efficiency_dict)
|
||||
|
||||
def build_single_lut(self, batch_size=1, base_resolution=224):
|
||||
print('Building the %s lookup table (resolution=%d)...' % (self.pred_type, base_resolution))
|
||||
# block, input_size, in_channels, out_channels, expand_ratio, kernel_size, stride, act, se
|
||||
configurations = [
|
||||
(ConvLayer, base_resolution, 3, 16, 3, 2, 'relu'),
|
||||
(ResidualBlock, base_resolution // 2, 16, 16, [1], [3, 5, 7], 1, 'relu', False),
|
||||
(ResidualBlock, base_resolution // 2, 16, 24, [3, 4, 6], [3, 5, 7], 2, 'relu', False),
|
||||
(ResidualBlock, base_resolution // 4, 24, 24, [3, 4, 6], [3, 5, 7], 1, 'relu', False),
|
||||
(ResidualBlock, base_resolution // 4, 24, 24, [3, 4, 6], [3, 5, 7], 1, 'relu', False),
|
||||
(ResidualBlock, base_resolution // 4, 24, 24, [3, 4, 6], [3, 5, 7], 1, 'relu', False),
|
||||
(ResidualBlock, base_resolution // 4, 24, 40, [3, 4, 6], [3, 5, 7], 2, 'relu', True),
|
||||
(ResidualBlock, base_resolution // 8, 40, 40, [3, 4, 6], [3, 5, 7], 1, 'relu', True),
|
||||
(ResidualBlock, base_resolution // 8, 40, 40, [3, 4, 6], [3, 5, 7], 1, 'relu', True),
|
||||
(ResidualBlock, base_resolution // 8, 40, 40, [3, 4, 6], [3, 5, 7], 1, 'relu', True),
|
||||
(ResidualBlock, base_resolution // 8, 40, 80, [3, 4, 6], [3, 5, 7], 2, 'h_swish', False),
|
||||
(ResidualBlock, base_resolution // 16, 80, 80, [3, 4, 6], [3, 5, 7], 1, 'h_swish', False),
|
||||
(ResidualBlock, base_resolution // 16, 80, 80, [3, 4, 6], [3, 5, 7], 1, 'h_swish', False),
|
||||
(ResidualBlock, base_resolution // 16, 80, 80, [3, 4, 6], [3, 5, 7], 1, 'h_swish', False),
|
||||
(ResidualBlock, base_resolution // 16, 80, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
|
||||
(ResidualBlock, base_resolution // 16, 112, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
|
||||
(ResidualBlock, base_resolution // 16, 112, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
|
||||
(ResidualBlock, base_resolution // 16, 112, 112, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
|
||||
(ResidualBlock, base_resolution // 16, 112, 160, [3, 4, 6], [3, 5, 7], 2, 'h_swish', True),
|
||||
(ResidualBlock, base_resolution // 32, 160, 160, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
|
||||
(ResidualBlock, base_resolution // 32, 160, 160, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
|
||||
(ResidualBlock, base_resolution // 32, 160, 160, [3, 4, 6], [3, 5, 7], 1, 'h_swish', True),
|
||||
(ConvLayer, base_resolution // 32, 160, 960, 1, 1, 'h_swish'),
|
||||
(ConvLayer, 1, 960, 1280, 1, 1, 'h_swish'),
|
||||
(LinearLayer, 1, 1280, 1000, 1, 1),
|
||||
]
|
||||
|
||||
efficiency_dict = {
|
||||
'mobile_inverted_blocks': [],
|
||||
'other_blocks': {}
|
||||
}
|
||||
|
||||
for layer_idx in range(len(configurations)):
|
||||
config = configurations[layer_idx]
|
||||
op_type = config[0]
|
||||
if op_type == ResidualBlock:
|
||||
_, input_size, in_channels, out_channels, expand_list, ks_list, stride, act, se = config
|
||||
in_channels = int(round(in_channels * self.multiplier))
|
||||
out_channels = int(round(out_channels * self.multiplier))
|
||||
template_config = {
|
||||
'name': ResidualBlock.__name__,
|
||||
'mobile_inverted_conv': {
|
||||
'name': MBConvLayer.__name__,
|
||||
'in_channels': in_channels,
|
||||
'out_channels': out_channels,
|
||||
'kernel_size': kernel_size,
|
||||
'stride': stride,
|
||||
'expand_ratio': 0,
|
||||
# 'mid_channels': None,
|
||||
'act_func': act,
|
||||
'use_se': se,
|
||||
},
|
||||
'shortcut': {
|
||||
'name': IdentityLayer.__name__,
|
||||
'in_channels': in_channels,
|
||||
'out_channels': out_channels,
|
||||
} if (in_channels == out_channels and stride == 1) else None
|
||||
}
|
||||
sub_dict = {}
|
||||
for ks in ks_list:
|
||||
for e in expand_list:
|
||||
build_config = copy.deepcopy(template_config)
|
||||
build_config['mobile_inverted_conv']['expand_ratio'] = e
|
||||
build_config['mobile_inverted_conv']['kernel_size'] = ks
|
||||
|
||||
layer = ResidualBlock.build_from_config(build_config)
|
||||
input_shape = (batch_size, in_channels, input_size, input_size)
|
||||
|
||||
if self.pred_type == 'flops':
|
||||
measure_result = self.measure_single_layer_flops(layer, input_shape) / batch_size
|
||||
elif self.pred_type == 'latency':
|
||||
measure_result = self.measure_single_layer_latency(layer, input_shape)
|
||||
|
||||
sub_dict[(ks, e)] = measure_result
|
||||
|
||||
efficiency_dict['mobile_inverted_blocks'].append(sub_dict)
|
||||
|
||||
elif op_type == ConvLayer:
|
||||
_, input_size, in_channels, out_channels, kernel_size, stride, activation = config
|
||||
in_channels = int(round(in_channels * self.multiplier))
|
||||
out_channels = int(round(out_channels * self.multiplier))
|
||||
build_config = {
|
||||
# 'name': ConvLayer.__name__,
|
||||
'in_channels': in_channels,
|
||||
'out_channels': out_channels,
|
||||
'kernel_size': kernel_size,
|
||||
'stride': stride,
|
||||
'dilation': 1,
|
||||
'groups': 1,
|
||||
'bias': False,
|
||||
'use_bn': True,
|
||||
'has_shuffle': False,
|
||||
'act_func': activation,
|
||||
}
|
||||
layer = ConvLayer.build_from_config(build_config)
|
||||
input_shape = (batch_size, in_channels, input_size, input_size)
|
||||
|
||||
if self.pred_type == 'flops':
|
||||
measure_result = self.measure_single_layer_flops(layer, input_shape) / batch_size
|
||||
elif self.pred_type == 'latency':
|
||||
measure_result = self.measure_single_layer_latency(layer, input_shape)
|
||||
|
||||
efficiency_dict['other_blocks'][layer_idx] = measure_result
|
||||
|
||||
elif op_type == LinearLayer:
|
||||
_, input_size, in_channels, out_channels, kernel_size, stride = config
|
||||
in_channels = int(round(in_channels * self.multiplier))
|
||||
out_channels = int(round(out_channels * self.multiplier))
|
||||
build_config = {
|
||||
# 'name': LinearLayer.__name__,
|
||||
'in_features': in_channels,
|
||||
'out_features': out_channels
|
||||
}
|
||||
layer = LinearLayer.build_from_config(build_config)
|
||||
input_shape = (batch_size, in_channels)
|
||||
|
||||
if self.pred_type == 'flops':
|
||||
measure_result = self.measure_single_layer_flops(layer, input_shape) / batch_size
|
||||
elif self.pred_type == 'latency':
|
||||
measure_result = self.measure_single_layer_latency(layer, input_shape)
|
||||
|
||||
efficiency_dict['other_blocks'][layer_idx] = measure_result
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.efficiency_dict[base_resolution] = efficiency_dict
|
||||
print('Built the %s lookup table (resolution=%d)!' % (self.pred_type, base_resolution))
|
||||
return efficiency_dict
|
||||
|
||||
def predict_efficiency(self, sample):
|
||||
input_size = sample.get('r', [224])
|
||||
input_size = input_size[0]
|
||||
assert 'ks' in sample and 'e' in sample and 'd' in sample
|
||||
assert len(sample['ks']) == len(sample['e']) and len(sample['ks']) == 20
|
||||
assert len(sample['d']) == 5
|
||||
total_stats = 0.
|
||||
for i in range(20):
|
||||
stage = i // 4
|
||||
depth_max = sample['d'][stage]
|
||||
depth = i % 4 + 1
|
||||
if depth > depth_max:
|
||||
continue
|
||||
ks, e = sample['ks'][i], sample['e'][i]
|
||||
total_stats += self.efficiency_dict[input_size]['mobile_inverted_blocks'][i + 1][(ks, e)]
|
||||
|
||||
for key in self.efficiency_dict[input_size]['other_blocks']:
|
||||
total_stats += self.efficiency_dict[input_size]['other_blocks'][key]
|
||||
|
||||
total_stats += self.efficiency_dict[input_size]['mobile_inverted_blocks'][0][(3, 1)]
|
||||
return total_stats
|
||||
@@ -0,0 +1,241 @@
|
||||
import os.path as osp
|
||||
import numpy as np
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.utils.data
|
||||
from torchvision import transforms, datasets
|
||||
|
||||
from ofa.utils import AverageMeter, accuracy
|
||||
from ofa.model_zoo import ofa_specialized
|
||||
from ofa.imagenet_classification.elastic_nn.utils import set_running_statistics
|
||||
|
||||
|
||||
def evaluate_ofa_subnet(ofa_net, path, net_config, data_loader, batch_size, device='cuda:0'):
|
||||
assert 'ks' in net_config and 'd' in net_config and 'e' in net_config
|
||||
assert len(net_config['ks']) == 20 and len(net_config['e']) == 20 and len(net_config['d']) == 5
|
||||
ofa_net.set_active_subnet(ks=net_config['ks'], d=net_config['d'], e=net_config['e'])
|
||||
subnet = ofa_net.get_active_subnet().to(device)
|
||||
calib_bn(subnet, path, net_config['r'][0], batch_size)
|
||||
top1 = validate(subnet, path, net_config['r'][0], data_loader, batch_size, device)
|
||||
return top1
|
||||
|
||||
|
||||
def calib_bn(net, path, image_size, batch_size, num_images=2000):
|
||||
# print('Creating dataloader for resetting BN running statistics...')
|
||||
dataset = datasets.ImageFolder(
|
||||
osp.join(
|
||||
path,
|
||||
'train'),
|
||||
transforms.Compose([
|
||||
transforms.RandomResizedCrop(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(brightness=32. / 255., saturation=0.5),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[
|
||||
0.485,
|
||||
0.456,
|
||||
0.406],
|
||||
std=[
|
||||
0.229,
|
||||
0.224,
|
||||
0.225]
|
||||
),
|
||||
])
|
||||
)
|
||||
chosen_indexes = np.random.choice(list(range(len(dataset))), num_images)
|
||||
sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
|
||||
data_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
sampler=sub_sampler,
|
||||
batch_size=batch_size,
|
||||
num_workers=16,
|
||||
pin_memory=True,
|
||||
drop_last=False,
|
||||
)
|
||||
# print('Resetting BN running statistics (this may take 10-20 seconds)...')
|
||||
set_running_statistics(net, data_loader)
|
||||
|
||||
|
||||
|
||||
def validate(net, path, image_size, data_loader, batch_size=100, device='cuda:0'):
|
||||
if 'cuda' in device:
|
||||
net = torch.nn.DataParallel(net).to(device)
|
||||
else:
|
||||
net = net.to(device)
|
||||
|
||||
data_loader.dataset.transform = transforms.Compose([
|
||||
transforms.Resize(int(math.ceil(image_size / 0.875))),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225]
|
||||
),
|
||||
])
|
||||
|
||||
cudnn.benchmark = True
|
||||
criterion = nn.CrossEntropyLoss().to(device)
|
||||
|
||||
net.eval()
|
||||
net = net.to(device)
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
with torch.no_grad():
|
||||
with tqdm(total=len(data_loader), desc='Validate') as t:
|
||||
for i, (images, labels) in enumerate(data_loader):
|
||||
images, labels = images.to(device), labels.to(device)
|
||||
# compute output
|
||||
output = net(images)
|
||||
loss = criterion(output, labels)
|
||||
# measure accuracy and record loss
|
||||
acc1, acc5 = accuracy(output, labels, topk=(1, 5))
|
||||
|
||||
losses.update(loss.item(), images.size(0))
|
||||
top1.update(acc1[0].item(), images.size(0))
|
||||
top5.update(acc5[0].item(), images.size(0))
|
||||
t.set_postfix({
|
||||
'loss': losses.avg,
|
||||
'top1': top1.avg,
|
||||
'top5': top5.avg,
|
||||
'img_size': images.size(2),
|
||||
})
|
||||
t.update(1)
|
||||
|
||||
|
||||
print('Results: loss=%.5f,\t top1=%.1f,\t top5=%.1f' % (losses.avg, top1.avg, top5.avg))
|
||||
return top1.avg
|
||||
|
||||
|
||||
def evaluate_ofa_specialized(path, data_loader, batch_size=100, device='cuda:0'):
|
||||
def select_platform_name():
|
||||
valid_platform_name = [
|
||||
'pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops'
|
||||
]
|
||||
|
||||
print("Please select a hardware platform from ('pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops')!\n")
|
||||
|
||||
while True:
|
||||
platform_name = input()
|
||||
platform_name = platform_name.lower()
|
||||
if platform_name in valid_platform_name:
|
||||
return platform_name
|
||||
print("Platform name is invalid! Please select in ('pixel1', 'pixel2', 'note10', 'note8', 's7edge', 'lg-g8', '1080ti', 'v100', 'tx2', 'cpu', 'flops')!\n")
|
||||
|
||||
def select_netid(platform_name):
|
||||
platform_efficiency_map = {
|
||||
'pixel1': {
|
||||
143: 'pixel1_lat@143ms_top1@80.1_finetune@75',
|
||||
132: 'pixel1_lat@132ms_top1@79.8_finetune@75',
|
||||
79: 'pixel1_lat@79ms_top1@78.7_finetune@75',
|
||||
58: 'pixel1_lat@58ms_top1@76.9_finetune@75',
|
||||
40: 'pixel1_lat@40ms_top1@74.9_finetune@25',
|
||||
28: 'pixel1_lat@28ms_top1@73.3_finetune@25',
|
||||
20: 'pixel1_lat@20ms_top1@71.4_finetune@25',
|
||||
},
|
||||
|
||||
'pixel2': {
|
||||
62: 'pixel2_lat@62ms_top1@75.8_finetune@25',
|
||||
50: 'pixel2_lat@50ms_top1@74.7_finetune@25',
|
||||
35: 'pixel2_lat@35ms_top1@73.4_finetune@25',
|
||||
25: 'pixel2_lat@25ms_top1@71.5_finetune@25',
|
||||
},
|
||||
|
||||
'note10': {
|
||||
64: 'note10_lat@64ms_top1@80.2_finetune@75',
|
||||
50: 'note10_lat@50ms_top1@79.7_finetune@75',
|
||||
41: 'note10_lat@41ms_top1@79.3_finetune@75',
|
||||
30: 'note10_lat@30ms_top1@78.4_finetune@75',
|
||||
22: 'note10_lat@22ms_top1@76.6_finetune@25',
|
||||
16: 'note10_lat@16ms_top1@75.5_finetune@25',
|
||||
11: 'note10_lat@11ms_top1@73.6_finetune@25',
|
||||
8: 'note10_lat@8ms_top1@71.4_finetune@25',
|
||||
},
|
||||
|
||||
'note8': {
|
||||
65: 'note8_lat@65ms_top1@76.1_finetune@25',
|
||||
49: 'note8_lat@49ms_top1@74.9_finetune@25',
|
||||
31: 'note8_lat@31ms_top1@72.8_finetune@25',
|
||||
22: 'note8_lat@22ms_top1@70.4_finetune@25',
|
||||
},
|
||||
|
||||
's7edge': {
|
||||
88: 's7edge_lat@88ms_top1@76.3_finetune@25',
|
||||
58: 's7edge_lat@58ms_top1@74.7_finetune@25',
|
||||
41: 's7edge_lat@41ms_top1@73.1_finetune@25',
|
||||
29: 's7edge_lat@29ms_top1@70.5_finetune@25',
|
||||
},
|
||||
|
||||
'lg-g8': {
|
||||
24: 'LG-G8_lat@24ms_top1@76.4_finetune@25',
|
||||
16: 'LG-G8_lat@16ms_top1@74.7_finetune@25',
|
||||
11: 'LG-G8_lat@11ms_top1@73.0_finetune@25',
|
||||
8: 'LG-G8_lat@8ms_top1@71.1_finetune@25',
|
||||
},
|
||||
|
||||
'1080ti': {
|
||||
27: '1080ti_gpu64@27ms_top1@76.4_finetune@25',
|
||||
22: '1080ti_gpu64@22ms_top1@75.3_finetune@25',
|
||||
15: '1080ti_gpu64@15ms_top1@73.8_finetune@25',
|
||||
12: '1080ti_gpu64@12ms_top1@72.6_finetune@25',
|
||||
},
|
||||
|
||||
'v100': {
|
||||
11: 'v100_gpu64@11ms_top1@76.1_finetune@25',
|
||||
9: 'v100_gpu64@9ms_top1@75.3_finetune@25',
|
||||
6: 'v100_gpu64@6ms_top1@73.0_finetune@25',
|
||||
5: 'v100_gpu64@5ms_top1@71.6_finetune@25',
|
||||
},
|
||||
|
||||
'tx2': {
|
||||
96: 'tx2_gpu16@96ms_top1@75.8_finetune@25',
|
||||
80: 'tx2_gpu16@80ms_top1@75.4_finetune@25',
|
||||
47: 'tx2_gpu16@47ms_top1@72.9_finetune@25',
|
||||
35: 'tx2_gpu16@35ms_top1@70.3_finetune@25',
|
||||
},
|
||||
|
||||
'cpu': {
|
||||
17: 'cpu_lat@17ms_top1@75.7_finetune@25',
|
||||
15: 'cpu_lat@15ms_top1@74.6_finetune@25',
|
||||
11: 'cpu_lat@11ms_top1@72.0_finetune@25',
|
||||
10: 'cpu_lat@10ms_top1@71.1_finetune@25',
|
||||
},
|
||||
|
||||
'flops': {
|
||||
595: 'flops@595M_top1@80.0_finetune@75',
|
||||
482: 'flops@482M_top1@79.6_finetune@75',
|
||||
389: 'flops@389M_top1@79.1_finetune@75',
|
||||
}
|
||||
}
|
||||
|
||||
sub_efficiency_map = platform_efficiency_map[platform_name]
|
||||
if not platform_name == 'flops':
|
||||
print("Now, please specify a latency constraint for model specialization among", sorted(list(sub_efficiency_map.keys())), 'ms. (Please just input the number.) \n')
|
||||
else:
|
||||
print("Now, please specify a FLOPs constraint for model specialization among", sorted(list(sub_efficiency_map.keys())), 'MFLOPs. (Please just input the number.) \n')
|
||||
|
||||
while True:
|
||||
efficiency_constraint = input()
|
||||
if not efficiency_constraint.isdigit():
|
||||
print('Sorry, please input an integer! \n')
|
||||
continue
|
||||
efficiency_constraint = int(efficiency_constraint)
|
||||
if not efficiency_constraint in sub_efficiency_map.keys():
|
||||
print('Sorry, please choose a value from: ', sorted(list(sub_efficiency_map.keys())), '.\n')
|
||||
continue
|
||||
return sub_efficiency_map[efficiency_constraint]
|
||||
|
||||
platform_name = select_platform_name()
|
||||
net_id = select_netid(platform_name)
|
||||
|
||||
net, image_size = ofa_specialized(net_id=net_id, pretrained=True)
|
||||
|
||||
validate(net, path, image_size, data_loader, batch_size, device)
|
||||
|
||||
return net_id
|
||||
|
||||
@@ -0,0 +1,164 @@
|
||||
import yaml
|
||||
from ofa.utils import download_url
|
||||
|
||||
|
||||
class LatencyEstimator(object):
|
||||
|
||||
def __init__(self, local_dir='~/.hancai/latency_tools/',
|
||||
url='https://hanlab.mit.edu/files/proxylessNAS/LatencyTools/mobile_trim.yaml'):
|
||||
if url.startswith('http'):
|
||||
fname = download_url(url, local_dir, overwrite=True)
|
||||
else:
|
||||
fname = url
|
||||
|
||||
with open(fname, 'r') as fp:
|
||||
self.lut = yaml.load(fp)
|
||||
|
||||
@staticmethod
|
||||
def repr_shape(shape):
|
||||
if isinstance(shape, (list, tuple)):
|
||||
return 'x'.join(str(_) for _ in shape)
|
||||
elif isinstance(shape, str):
|
||||
return shape
|
||||
else:
|
||||
return TypeError
|
||||
|
||||
def query(self, l_type: str, input_shape, output_shape, mid=None, ks=None, stride=None, id_skip=None,
|
||||
se=None, h_swish=None):
|
||||
infos = [l_type, 'input:%s' % self.repr_shape(input_shape), 'output:%s' % self.repr_shape(output_shape), ]
|
||||
|
||||
if l_type in ('expanded_conv',):
|
||||
assert None not in (mid, ks, stride, id_skip, se, h_swish)
|
||||
infos += ['expand:%d' % mid, 'kernel:%d' % ks, 'stride:%d' % stride, 'idskip:%d' % id_skip,
|
||||
'se:%d' % se, 'hs:%d' % h_swish]
|
||||
key = '-'.join(infos)
|
||||
return self.lut[key]['mean']
|
||||
|
||||
def predict_network_latency(self, net, image_size=224):
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, net.first_conv.out_channels]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
for block in net.blocks:
|
||||
mb_conv = block.mobile_inverted_conv
|
||||
shortcut = block.shortcut
|
||||
|
||||
if mb_conv is None:
|
||||
continue
|
||||
if shortcut is None:
|
||||
idskip = 0
|
||||
else:
|
||||
idskip = 1
|
||||
out_fz = int((fsize - 1) / mb_conv.stride + 1)
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, mb_conv.in_channels], [out_fz, out_fz, mb_conv.out_channels],
|
||||
mid=mb_conv.depth_conv.conv.in_channels, ks=mb_conv.kernel_size, stride=mb_conv.stride, id_skip=idskip,
|
||||
se=1 if mb_conv.use_se else 0, h_swish=1 if mb_conv.act_func == 'h_swish' else 0,
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
# final expand layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, net.final_expand_layer.in_channels],
|
||||
[fsize, fsize, net.final_expand_layer.out_channels],
|
||||
)
|
||||
# global average pooling
|
||||
predicted_latency += self.query(
|
||||
'AvgPool2D', [fsize, fsize, net.final_expand_layer.out_channels],
|
||||
[1, 1, net.final_expand_layer.out_channels],
|
||||
)
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_2', [1, 1, net.feature_mix_layer.in_channels],
|
||||
[1, 1, net.feature_mix_layer.out_channels]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [1, 1, net.classifier.in_features], [net.classifier.out_features]
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
def predict_network_latency_given_spec(self, spec):
|
||||
image_size = spec['r'][0]
|
||||
predicted_latency = 0
|
||||
# first conv
|
||||
predicted_latency += self.query(
|
||||
'Conv', [image_size, image_size, 3],
|
||||
[(image_size + 1) // 2, (image_size + 1) // 2, 24]
|
||||
)
|
||||
# blocks
|
||||
fsize = (image_size + 1) // 2
|
||||
# first block
|
||||
predicted_latency += self.query(
|
||||
'expanded_conv', [fsize, fsize, 24], [fsize, fsize, 24],
|
||||
mid=24, ks=3, stride=1, id_skip=1, se=0, h_swish=0,
|
||||
)
|
||||
in_channel = 24
|
||||
stride_stages = [2, 2, 2, 1, 2]
|
||||
width_stages = [32, 48, 96, 136, 192]
|
||||
act_stages = ['relu', 'relu', 'h_swish', 'h_swish', 'h_swish']
|
||||
se_stages = [False, True, False, True, True]
|
||||
for i in range(20):
|
||||
stage = i // 4
|
||||
depth_max = spec['d'][stage]
|
||||
depth = i % 4 + 1
|
||||
if depth > depth_max:
|
||||
continue
|
||||
ks, e = spec['ks'][i], spec['e'][i]
|
||||
if i % 4 == 0:
|
||||
stride = stride_stages[stage]
|
||||
idskip = 0
|
||||
else:
|
||||
stride = 1
|
||||
idskip = 1
|
||||
out_channel = width_stages[stage]
|
||||
out_fz = int((fsize - 1) / stride + 1)
|
||||
|
||||
mid_channel = round(in_channel * e)
|
||||
block_latency = self.query(
|
||||
'expanded_conv', [fsize, fsize, in_channel], [out_fz, out_fz, out_channel],
|
||||
mid=mid_channel, ks=ks, stride=stride, id_skip=idskip,
|
||||
se=1 if se_stages[stage] else 0, h_swish=1 if act_stages[stage] == 'h_swish' else 0,
|
||||
)
|
||||
predicted_latency += block_latency
|
||||
fsize = out_fz
|
||||
in_channel = out_channel
|
||||
# final expand layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_1', [fsize, fsize, 192],
|
||||
[fsize, fsize, 1152],
|
||||
)
|
||||
# global average pooling
|
||||
predicted_latency += self.query(
|
||||
'AvgPool2D', [fsize, fsize, 1152],
|
||||
[1, 1, 1152],
|
||||
)
|
||||
# feature mix layer
|
||||
predicted_latency += self.query(
|
||||
'Conv_2', [1, 1, 1152],
|
||||
[1, 1, 1536]
|
||||
)
|
||||
# classifier
|
||||
predicted_latency += self.query(
|
||||
'Logits', [1, 1, 1536], [1000]
|
||||
)
|
||||
return predicted_latency
|
||||
|
||||
|
||||
class LatencyTable:
|
||||
def __init__(self, device='note10', resolutions=(160, 176, 192, 208, 224)):
|
||||
self.latency_tables = {}
|
||||
|
||||
for image_size in resolutions:
|
||||
self.latency_tables[image_size] = LatencyEstimator(
|
||||
url='https://hanlab.mit.edu/files/OnceForAll/tutorial/latency_table@%s/%d_lookup_table.yaml' % (
|
||||
device, image_size)
|
||||
)
|
||||
print('Built latency table for image size: %d.' % image_size)
|
||||
|
||||
def predict_efficiency(self, spec: dict):
|
||||
return self.latency_tables[spec['r'][0]].predict_network_latency_given_spec(spec)
|
||||
@@ -0,0 +1,10 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
from .pytorch_modules import *
|
||||
from .pytorch_utils import *
|
||||
from .my_modules import *
|
||||
from .flops_counter import *
|
||||
from .common_tools import *
|
||||
from .my_dataloader import *
|
||||
@@ -0,0 +1,284 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
|
||||
try:
|
||||
from urllib import urlretrieve
|
||||
except ImportError:
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
__all__ = [
|
||||
'sort_dict', 'get_same_padding',
|
||||
'get_split_list', 'list_sum', 'list_mean', 'list_join',
|
||||
'subset_mean', 'sub_filter_start_end', 'min_divisible_value', 'val2list',
|
||||
'download_url',
|
||||
'write_log', 'pairwise_accuracy', 'accuracy',
|
||||
'AverageMeter', 'MultiClassAverageMeter',
|
||||
'DistributedMetric', 'DistributedTensor',
|
||||
]
|
||||
|
||||
|
||||
def sort_dict(src_dict, reverse=False, return_dict=True):
|
||||
output = sorted(src_dict.items(), key=lambda x: x[1], reverse=reverse)
|
||||
if return_dict:
|
||||
return dict(output)
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def get_same_padding(kernel_size):
|
||||
if isinstance(kernel_size, tuple):
|
||||
assert len(kernel_size) == 2, 'invalid kernel size: %s' % kernel_size
|
||||
p1 = get_same_padding(kernel_size[0])
|
||||
p2 = get_same_padding(kernel_size[1])
|
||||
return p1, p2
|
||||
assert isinstance(kernel_size, int), 'kernel size should be either `int` or `tuple`'
|
||||
assert kernel_size % 2 > 0, 'kernel size should be odd number'
|
||||
return kernel_size // 2
|
||||
|
||||
|
||||
def get_split_list(in_dim, child_num, accumulate=False):
|
||||
in_dim_list = [in_dim // child_num] * child_num
|
||||
for _i in range(in_dim % child_num):
|
||||
in_dim_list[_i] += 1
|
||||
if accumulate:
|
||||
for i in range(1, child_num):
|
||||
in_dim_list[i] += in_dim_list[i - 1]
|
||||
return in_dim_list
|
||||
|
||||
|
||||
def list_sum(x):
|
||||
return x[0] if len(x) == 1 else x[0] + list_sum(x[1:])
|
||||
|
||||
|
||||
def list_mean(x):
|
||||
return list_sum(x) / len(x)
|
||||
|
||||
|
||||
def list_join(val_list, sep='\t'):
|
||||
return sep.join([str(val) for val in val_list])
|
||||
|
||||
|
||||
def subset_mean(val_list, sub_indexes):
|
||||
sub_indexes = val2list(sub_indexes, 1)
|
||||
return list_mean([val_list[idx] for idx in sub_indexes])
|
||||
|
||||
|
||||
def sub_filter_start_end(kernel_size, sub_kernel_size):
|
||||
center = kernel_size // 2
|
||||
dev = sub_kernel_size // 2
|
||||
start, end = center - dev, center + dev + 1
|
||||
assert end - start == sub_kernel_size
|
||||
return start, end
|
||||
|
||||
|
||||
def min_divisible_value(n1, v1):
|
||||
""" make sure v1 is divisible by n1, otherwise decrease v1 """
|
||||
if v1 >= n1:
|
||||
return n1
|
||||
while n1 % v1 != 0:
|
||||
v1 -= 1
|
||||
return v1
|
||||
|
||||
|
||||
def val2list(val, repeat_time=1):
|
||||
if isinstance(val, list) or isinstance(val, np.ndarray):
|
||||
return val
|
||||
elif isinstance(val, tuple):
|
||||
return list(val)
|
||||
else:
|
||||
return [val for _ in range(repeat_time)]
|
||||
|
||||
|
||||
def download_url(url, model_dir='~/.torch/', overwrite=False):
|
||||
target_dir = url.split('/')[-1]
|
||||
model_dir = os.path.expanduser(model_dir)
|
||||
try:
|
||||
if not os.path.exists(model_dir):
|
||||
os.makedirs(model_dir)
|
||||
model_dir = os.path.join(model_dir, target_dir)
|
||||
cached_file = model_dir
|
||||
if not os.path.exists(cached_file) or overwrite:
|
||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
urlretrieve(url, cached_file)
|
||||
return cached_file
|
||||
except Exception as e:
|
||||
# remove lock file so download can be executed next time.
|
||||
os.remove(os.path.join(model_dir, 'download.lock'))
|
||||
sys.stderr.write('Failed to download from url %s' % url + '\n' + str(e) + '\n')
|
||||
return None
|
||||
|
||||
|
||||
def write_log(logs_path, log_str, prefix='valid', should_print=True, mode='a'):
|
||||
if not os.path.exists(logs_path):
|
||||
os.makedirs(logs_path, exist_ok=True)
|
||||
""" prefix: valid, train, test """
|
||||
if prefix in ['valid', 'test']:
|
||||
with open(os.path.join(logs_path, 'valid_console.txt'), mode) as fout:
|
||||
fout.write(log_str + '\n')
|
||||
fout.flush()
|
||||
if prefix in ['valid', 'test', 'train']:
|
||||
with open(os.path.join(logs_path, 'train_console.txt'), mode) as fout:
|
||||
if prefix in ['valid', 'test']:
|
||||
fout.write('=' * 10)
|
||||
fout.write(log_str + '\n')
|
||||
fout.flush()
|
||||
else:
|
||||
with open(os.path.join(logs_path, '%s.txt' % prefix), mode) as fout:
|
||||
fout.write(log_str + '\n')
|
||||
fout.flush()
|
||||
if should_print:
|
||||
print(log_str)
|
||||
|
||||
|
||||
def pairwise_accuracy(la, lb, n_samples=200000):
|
||||
n = len(la)
|
||||
assert n == len(lb)
|
||||
total = 0
|
||||
count = 0
|
||||
for _ in range(n_samples):
|
||||
i = np.random.randint(n)
|
||||
j = np.random.randint(n)
|
||||
while i == j:
|
||||
j = np.random.randint(n)
|
||||
if la[i] >= la[j] and lb[i] >= lb[j]:
|
||||
count += 1
|
||||
if la[i] < la[j] and lb[i] < lb[j]:
|
||||
count += 1
|
||||
total += 1
|
||||
return float(count) / total
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
""" Computes the precision@k for the specified values of k """
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""
|
||||
Computes and stores the average and current value
|
||||
Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class MultiClassAverageMeter:
|
||||
|
||||
""" Multi Binary Classification Tasks """
|
||||
def __init__(self, num_classes, balanced=False, **kwargs):
|
||||
|
||||
super(MultiClassAverageMeter, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.balanced = balanced
|
||||
|
||||
self.counts = []
|
||||
for k in range(self.num_classes):
|
||||
self.counts.append(np.ndarray((2, 2), dtype=np.float32))
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
for k in range(self.num_classes):
|
||||
self.counts[k].fill(0)
|
||||
|
||||
def add(self, outputs, targets):
|
||||
outputs = outputs.data.cpu().numpy()
|
||||
targets = targets.data.cpu().numpy()
|
||||
|
||||
for k in range(self.num_classes):
|
||||
output = np.argmax(outputs[:, k, :], axis=1)
|
||||
target = targets[:, k]
|
||||
|
||||
x = output + 2 * target
|
||||
bincount = np.bincount(x.astype(np.int32), minlength=2 ** 2)
|
||||
|
||||
self.counts[k] += bincount.reshape((2, 2))
|
||||
|
||||
def value(self):
|
||||
mean = 0
|
||||
for k in range(self.num_classes):
|
||||
if self.balanced:
|
||||
value = np.mean((self.counts[k] / np.maximum(np.sum(self.counts[k], axis=1), 1)[:, None]).diagonal())
|
||||
else:
|
||||
value = np.sum(self.counts[k].diagonal()) / np.maximum(np.sum(self.counts[k]), 1)
|
||||
|
||||
mean += value / self.num_classes * 100.
|
||||
return mean
|
||||
|
||||
|
||||
class DistributedMetric(object):
|
||||
"""
|
||||
Horovod: average metrics from distributed training.
|
||||
"""
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.sum = torch.zeros(1)[0]
|
||||
self.count = torch.zeros(1)[0]
|
||||
|
||||
def update(self, val, delta_n=1):
|
||||
import horovod.torch as hvd
|
||||
val *= delta_n
|
||||
self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
|
||||
self.count += delta_n
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
return self.sum / self.count
|
||||
|
||||
|
||||
class DistributedTensor(object):
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.sum = None
|
||||
self.count = torch.zeros(1)[0]
|
||||
self.synced = False
|
||||
|
||||
def update(self, val, delta_n=1):
|
||||
val *= delta_n
|
||||
if self.sum is None:
|
||||
self.sum = val.detach()
|
||||
else:
|
||||
self.sum += val.detach()
|
||||
self.count += delta_n
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
import horovod.torch as hvd
|
||||
if not self.synced:
|
||||
self.sum = hvd.allreduce(self.sum, name=self.name)
|
||||
self.synced = True
|
||||
return self.sum / self.count
|
||||
@@ -0,0 +1,97 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .my_modules import MyConv2d
|
||||
|
||||
__all__ = ['profile']
|
||||
|
||||
|
||||
def count_convNd(m, _, y):
|
||||
cin = m.in_channels
|
||||
|
||||
kernel_ops = m.weight.size()[2] * m.weight.size()[3]
|
||||
ops_per_element = kernel_ops
|
||||
output_elements = y.nelement()
|
||||
|
||||
# cout x oW x oH
|
||||
total_ops = cin * output_elements * ops_per_element // m.groups
|
||||
m.total_ops = torch.zeros(1).fill_(total_ops)
|
||||
|
||||
|
||||
def count_linear(m, _, __):
|
||||
total_ops = m.in_features * m.out_features
|
||||
|
||||
m.total_ops = torch.zeros(1).fill_(total_ops)
|
||||
|
||||
|
||||
register_hooks = {
|
||||
nn.Conv1d: count_convNd,
|
||||
nn.Conv2d: count_convNd,
|
||||
nn.Conv3d: count_convNd,
|
||||
MyConv2d: count_convNd,
|
||||
######################################
|
||||
nn.Linear: count_linear,
|
||||
######################################
|
||||
nn.Dropout: None,
|
||||
nn.Dropout2d: None,
|
||||
nn.Dropout3d: None,
|
||||
nn.BatchNorm2d: None,
|
||||
}
|
||||
|
||||
|
||||
def profile(model, input_size, custom_ops=None):
|
||||
handler_collection = []
|
||||
custom_ops = {} if custom_ops is None else custom_ops
|
||||
|
||||
def add_hooks(m_):
|
||||
if len(list(m_.children())) > 0:
|
||||
return
|
||||
|
||||
m_.register_buffer('total_ops', torch.zeros(1))
|
||||
m_.register_buffer('total_params', torch.zeros(1))
|
||||
|
||||
for p in m_.parameters():
|
||||
m_.total_params += torch.zeros(1).fill_(p.numel())
|
||||
|
||||
m_type = type(m_)
|
||||
fn = None
|
||||
|
||||
if m_type in custom_ops:
|
||||
fn = custom_ops[m_type]
|
||||
elif m_type in register_hooks:
|
||||
fn = register_hooks[m_type]
|
||||
|
||||
if fn is not None:
|
||||
_handler = m_.register_forward_hook(fn)
|
||||
handler_collection.append(_handler)
|
||||
|
||||
original_device = model.parameters().__next__().device
|
||||
training = model.training
|
||||
|
||||
model.eval()
|
||||
model.apply(add_hooks)
|
||||
|
||||
x = torch.zeros(input_size).to(original_device)
|
||||
with torch.no_grad():
|
||||
model(x)
|
||||
|
||||
total_ops = 0
|
||||
total_params = 0
|
||||
for m in model.modules():
|
||||
if len(list(m.children())) > 0: # skip for non-leaf module
|
||||
continue
|
||||
total_ops += m.total_ops
|
||||
total_params += m.total_params
|
||||
|
||||
total_ops = total_ops.item()
|
||||
total_params = total_params.item()
|
||||
|
||||
model.train(training).to(original_device)
|
||||
for handler in handler_collection:
|
||||
handler.remove()
|
||||
|
||||
return total_ops, total_params
|
||||
@@ -0,0 +1,727 @@
|
||||
######################################################################################
|
||||
# Copyright (c) Han Cai, Once for All, ICLR 2020 [GitHub OFA]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributions import Bernoulli
|
||||
from collections import OrderedDict
|
||||
from ofa_local.utils import get_same_padding, min_divisible_value, SEModule, ShuffleLayer
|
||||
from ofa_local.utils import MyNetwork, MyModule
|
||||
from ofa_local.utils import build_activation, make_divisible
|
||||
|
||||
__all__ = [
|
||||
'set_layer_from_config',
|
||||
'ConvLayer', 'IdentityLayer', 'LinearLayer', 'MultiHeadLinearLayer', 'ZeroLayer', 'MBConvLayer',
|
||||
'ResidualBlock', 'ResNetBottleneckBlock',
|
||||
]
|
||||
|
||||
|
||||
class DropBlock(nn.Module):
|
||||
def __init__(self, block_size):
|
||||
super(DropBlock, self).__init__()
|
||||
|
||||
self.block_size = block_size
|
||||
|
||||
def forward(self, x, gamma):
|
||||
# shape: (bsize, channels, height, width)
|
||||
|
||||
if self.training:
|
||||
batch_size, channels, height, width = x.shape
|
||||
|
||||
bernoulli = Bernoulli(gamma)
|
||||
mask = bernoulli.sample(
|
||||
(batch_size, channels, height - (self.block_size - 1), width - (self.block_size - 1))).cuda()
|
||||
# print((x.sample[-2], x.sample[-1]))
|
||||
block_mask = self._compute_block_mask(mask)
|
||||
# print (block_mask.size())
|
||||
# print (x.size())
|
||||
countM = block_mask.size()[0] * block_mask.size()[1] * block_mask.size()[2] * block_mask.size()[3]
|
||||
count_ones = block_mask.sum()
|
||||
|
||||
return block_mask * x * (countM / count_ones)
|
||||
else:
|
||||
return x
|
||||
|
||||
def _compute_block_mask(self, mask):
|
||||
left_padding = int((self.block_size - 1) / 2)
|
||||
right_padding = int(self.block_size / 2)
|
||||
|
||||
batch_size, channels, height, width = mask.shape
|
||||
# print ("mask", mask[0][0])
|
||||
non_zero_idxs = mask.nonzero()
|
||||
nr_blocks = non_zero_idxs.shape[0]
|
||||
|
||||
offsets = torch.stack(
|
||||
[
|
||||
torch.arange(self.block_size).view(-1, 1).expand(self.block_size, self.block_size).reshape(-1),
|
||||
# - left_padding,
|
||||
torch.arange(self.block_size).repeat(self.block_size), # - left_padding
|
||||
]
|
||||
).t().cuda()
|
||||
offsets = torch.cat((torch.zeros(self.block_size ** 2, 2).cuda().long(), offsets.long()), 1)
|
||||
|
||||
if nr_blocks > 0:
|
||||
non_zero_idxs = non_zero_idxs.repeat(self.block_size ** 2, 1)
|
||||
offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
|
||||
offsets = offsets.long()
|
||||
|
||||
block_idxs = non_zero_idxs + offsets
|
||||
# block_idxs += left_padding
|
||||
padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
|
||||
padded_mask[block_idxs[:, 0], block_idxs[:, 1], block_idxs[:, 2], block_idxs[:, 3]] = 1.
|
||||
else:
|
||||
padded_mask = F.pad(mask, (left_padding, right_padding, left_padding, right_padding))
|
||||
|
||||
block_mask = 1 - padded_mask # [:height, :width]
|
||||
return block_mask
|
||||
|
||||
|
||||
def set_layer_from_config(layer_config):
|
||||
if layer_config is None:
|
||||
return None
|
||||
|
||||
name2layer = {
|
||||
ConvLayer.__name__: ConvLayer,
|
||||
IdentityLayer.__name__: IdentityLayer,
|
||||
LinearLayer.__name__: LinearLayer,
|
||||
MultiHeadLinearLayer.__name__: MultiHeadLinearLayer,
|
||||
ZeroLayer.__name__: ZeroLayer,
|
||||
MBConvLayer.__name__: MBConvLayer,
|
||||
'MBInvertedConvLayer': MBConvLayer,
|
||||
##########################################################
|
||||
ResidualBlock.__name__: ResidualBlock,
|
||||
ResNetBottleneckBlock.__name__: ResNetBottleneckBlock,
|
||||
}
|
||||
|
||||
layer_name = layer_config.pop('name')
|
||||
layer = name2layer[layer_name]
|
||||
return layer.build_from_config(layer_config)
|
||||
|
||||
|
||||
class My2DLayer(MyModule):
|
||||
|
||||
def __init__(self, in_channels, out_channels,
|
||||
use_bn=True, act_func='relu', dropout_rate=0, ops_order='weight_bn_act'):
|
||||
super(My2DLayer, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.use_bn = use_bn
|
||||
self.act_func = act_func
|
||||
self.dropout_rate = dropout_rate
|
||||
self.ops_order = ops_order
|
||||
|
||||
""" modules """
|
||||
modules = {}
|
||||
# batch norm
|
||||
if self.use_bn:
|
||||
if self.bn_before_weight:
|
||||
modules['bn'] = nn.BatchNorm2d(in_channels)
|
||||
else:
|
||||
modules['bn'] = nn.BatchNorm2d(out_channels)
|
||||
else:
|
||||
modules['bn'] = None
|
||||
# activation
|
||||
modules['act'] = build_activation(self.act_func, self.ops_list[0] != 'act' and self.use_bn)
|
||||
# dropout
|
||||
if self.dropout_rate > 0:
|
||||
modules['dropout'] = nn.Dropout2d(self.dropout_rate, inplace=True)
|
||||
else:
|
||||
modules['dropout'] = None
|
||||
# weight
|
||||
modules['weight'] = self.weight_op()
|
||||
|
||||
# add modules
|
||||
for op in self.ops_list:
|
||||
if modules[op] is None:
|
||||
continue
|
||||
elif op == 'weight':
|
||||
# dropout before weight operation
|
||||
if modules['dropout'] is not None:
|
||||
self.add_module('dropout', modules['dropout'])
|
||||
for key in modules['weight']:
|
||||
self.add_module(key, modules['weight'][key])
|
||||
else:
|
||||
self.add_module(op, modules[op])
|
||||
|
||||
@property
|
||||
def ops_list(self):
|
||||
return self.ops_order.split('_')
|
||||
|
||||
@property
|
||||
def bn_before_weight(self):
|
||||
for op in self.ops_list:
|
||||
if op == 'bn':
|
||||
return True
|
||||
elif op == 'weight':
|
||||
return False
|
||||
raise ValueError('Invalid ops_order: %s' % self.ops_order)
|
||||
|
||||
def weight_op(self):
|
||||
raise NotImplementedError
|
||||
|
||||
""" Methods defined in MyModule """
|
||||
|
||||
def forward(self, x):
|
||||
# similar to nn.Sequential
|
||||
for module in self._modules.values():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'in_channels': self.in_channels,
|
||||
'out_channels': self.out_channels,
|
||||
'use_bn': self.use_bn,
|
||||
'act_func': self.act_func,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
'ops_order': self.ops_order,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConvLayer(My2DLayer):
|
||||
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size=3, stride=1, dilation=1, groups=1, bias=False, has_shuffle=False, use_se=False,
|
||||
use_bn=True, act_func='relu', dropout_rate=0, ops_order='weight_bn_act'):
|
||||
# default normal 3x3_Conv with bn and relu
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.bias = bias
|
||||
self.has_shuffle = has_shuffle
|
||||
self.use_se = use_se
|
||||
|
||||
super(ConvLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order)
|
||||
if self.use_se:
|
||||
self.add_module('se', SEModule(self.out_channels))
|
||||
|
||||
def weight_op(self):
|
||||
padding = get_same_padding(self.kernel_size)
|
||||
if isinstance(padding, int):
|
||||
padding *= self.dilation
|
||||
else:
|
||||
padding[0] *= self.dilation
|
||||
padding[1] *= self.dilation
|
||||
|
||||
weight_dict = OrderedDict({
|
||||
'conv': nn.Conv2d(
|
||||
self.in_channels, self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=padding,
|
||||
dilation=self.dilation, groups=min_divisible_value(self.in_channels, self.groups), bias=self.bias
|
||||
)
|
||||
})
|
||||
if self.has_shuffle and self.groups > 1:
|
||||
weight_dict['shuffle'] = ShuffleLayer(self.groups)
|
||||
|
||||
return weight_dict
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
if isinstance(self.kernel_size, int):
|
||||
kernel_size = (self.kernel_size, self.kernel_size)
|
||||
else:
|
||||
kernel_size = self.kernel_size
|
||||
if self.groups == 1:
|
||||
if self.dilation > 1:
|
||||
conv_str = '%dx%d_DilatedConv' % (kernel_size[0], kernel_size[1])
|
||||
else:
|
||||
conv_str = '%dx%d_Conv' % (kernel_size[0], kernel_size[1])
|
||||
else:
|
||||
if self.dilation > 1:
|
||||
conv_str = '%dx%d_DilatedGroupConv' % (kernel_size[0], kernel_size[1])
|
||||
else:
|
||||
conv_str = '%dx%d_GroupConv' % (kernel_size[0], kernel_size[1])
|
||||
conv_str += '_O%d' % self.out_channels
|
||||
if self.use_se:
|
||||
conv_str = 'SE_' + conv_str
|
||||
conv_str += '_' + self.act_func.upper()
|
||||
if self.use_bn:
|
||||
if isinstance(self.bn, nn.GroupNorm):
|
||||
conv_str += '_GN%d' % self.bn.num_groups
|
||||
elif isinstance(self.bn, nn.BatchNorm2d):
|
||||
conv_str += '_BN'
|
||||
return conv_str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ConvLayer.__name__,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'dilation': self.dilation,
|
||||
'groups': self.groups,
|
||||
'bias': self.bias,
|
||||
'has_shuffle': self.has_shuffle,
|
||||
'use_se': self.use_se,
|
||||
**super(ConvLayer, self).config
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return ConvLayer(**config)
|
||||
|
||||
|
||||
class IdentityLayer(My2DLayer):
|
||||
|
||||
def __init__(self, in_channels, out_channels,
|
||||
use_bn=False, act_func=None, dropout_rate=0, ops_order='weight_bn_act'):
|
||||
super(IdentityLayer, self).__init__(in_channels, out_channels, use_bn, act_func, dropout_rate, ops_order)
|
||||
|
||||
def weight_op(self):
|
||||
return None
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return 'Identity'
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': IdentityLayer.__name__,
|
||||
**super(IdentityLayer, self).config,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return IdentityLayer(**config)
|
||||
|
||||
|
||||
class LinearLayer(MyModule):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True,
|
||||
use_bn=False, act_func=None, dropout_rate=0, ops_order='weight_bn_act'):
|
||||
super(LinearLayer, self).__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.bias = bias
|
||||
|
||||
self.use_bn = use_bn
|
||||
self.act_func = act_func
|
||||
self.dropout_rate = dropout_rate
|
||||
self.ops_order = ops_order
|
||||
|
||||
""" modules """
|
||||
modules = {}
|
||||
# batch norm
|
||||
if self.use_bn:
|
||||
if self.bn_before_weight:
|
||||
modules['bn'] = nn.BatchNorm1d(in_features)
|
||||
else:
|
||||
modules['bn'] = nn.BatchNorm1d(out_features)
|
||||
else:
|
||||
modules['bn'] = None
|
||||
# activation
|
||||
modules['act'] = build_activation(self.act_func, self.ops_list[0] != 'act')
|
||||
# dropout
|
||||
if self.dropout_rate > 0:
|
||||
modules['dropout'] = nn.Dropout(self.dropout_rate, inplace=True)
|
||||
else:
|
||||
modules['dropout'] = None
|
||||
# linear
|
||||
modules['weight'] = {'linear': nn.Linear(self.in_features, self.out_features, self.bias)}
|
||||
|
||||
# add modules
|
||||
for op in self.ops_list:
|
||||
if modules[op] is None:
|
||||
continue
|
||||
elif op == 'weight':
|
||||
if modules['dropout'] is not None:
|
||||
self.add_module('dropout', modules['dropout'])
|
||||
for key in modules['weight']:
|
||||
self.add_module(key, modules['weight'][key])
|
||||
else:
|
||||
self.add_module(op, modules[op])
|
||||
|
||||
@property
|
||||
def ops_list(self):
|
||||
return self.ops_order.split('_')
|
||||
|
||||
@property
|
||||
def bn_before_weight(self):
|
||||
for op in self.ops_list:
|
||||
if op == 'bn':
|
||||
return True
|
||||
elif op == 'weight':
|
||||
return False
|
||||
raise ValueError('Invalid ops_order: %s' % self.ops_order)
|
||||
|
||||
def forward(self, x):
|
||||
for module in self._modules.values():
|
||||
x = module(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '%dx%d_Linear' % (self.in_features, self.out_features)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': LinearLayer.__name__,
|
||||
'in_features': self.in_features,
|
||||
'out_features': self.out_features,
|
||||
'bias': self.bias,
|
||||
'use_bn': self.use_bn,
|
||||
'act_func': self.act_func,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
'ops_order': self.ops_order,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return LinearLayer(**config)
|
||||
|
||||
|
||||
class MultiHeadLinearLayer(MyModule):
|
||||
|
||||
def __init__(self, in_features, out_features, num_heads=1, bias=True, dropout_rate=0):
|
||||
super(MultiHeadLinearLayer, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.bias = bias
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
if self.dropout_rate > 0:
|
||||
self.dropout = nn.Dropout(self.dropout_rate, inplace=True)
|
||||
else:
|
||||
self.dropout = None
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for k in range(num_heads):
|
||||
layer = nn.Linear(in_features, out_features, self.bias)
|
||||
self.layers.append(layer)
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.dropout is not None:
|
||||
inputs = self.dropout(inputs)
|
||||
|
||||
outputs = []
|
||||
for layer in self.layers:
|
||||
output = layer.forward(inputs)
|
||||
outputs.append(output)
|
||||
|
||||
outputs = torch.stack(outputs, dim=1)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return self.__repr__()
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': MultiHeadLinearLayer.__name__,
|
||||
'in_features': self.in_features,
|
||||
'out_features': self.out_features,
|
||||
'num_heads': self.num_heads,
|
||||
'bias': self.bias,
|
||||
'dropout_rate': self.dropout_rate,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return MultiHeadLinearLayer(**config)
|
||||
|
||||
def __repr__(self):
|
||||
return 'MultiHeadLinear(in_features=%d, out_features=%d, num_heads=%d, bias=%s, dropout_rate=%s)' % (
|
||||
self.in_features, self.out_features, self.num_heads, self.bias, self.dropout_rate
|
||||
)
|
||||
|
||||
|
||||
class ZeroLayer(MyModule):
|
||||
|
||||
def __init__(self):
|
||||
super(ZeroLayer, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
raise ValueError
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return 'Zero'
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ZeroLayer.__name__,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return ZeroLayer()
|
||||
|
||||
|
||||
class MBConvLayer(MyModule):
|
||||
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size=3, stride=1, expand_ratio=6, mid_channels=None, act_func='relu6', use_se=False,
|
||||
groups=None):
|
||||
super(MBConvLayer, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.expand_ratio = expand_ratio
|
||||
self.mid_channels = mid_channels
|
||||
self.act_func = act_func
|
||||
self.use_se = use_se
|
||||
self.groups = groups
|
||||
|
||||
if self.mid_channels is None:
|
||||
feature_dim = round(self.in_channels * self.expand_ratio)
|
||||
else:
|
||||
feature_dim = self.mid_channels
|
||||
|
||||
if self.expand_ratio == 1:
|
||||
self.inverted_bottleneck = None
|
||||
else:
|
||||
self.inverted_bottleneck = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
|
||||
('bn', nn.BatchNorm2d(feature_dim)),
|
||||
('act', build_activation(self.act_func, inplace=True)),
|
||||
]))
|
||||
|
||||
pad = get_same_padding(self.kernel_size)
|
||||
groups = feature_dim if self.groups is None else min_divisible_value(feature_dim, self.groups)
|
||||
depth_conv_modules = [
|
||||
('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=groups, bias=False)),
|
||||
('bn', nn.BatchNorm2d(feature_dim)),
|
||||
('act', build_activation(self.act_func, inplace=True))
|
||||
]
|
||||
if self.use_se:
|
||||
depth_conv_modules.append(('se', SEModule(feature_dim)))
|
||||
self.depth_conv = nn.Sequential(OrderedDict(depth_conv_modules))
|
||||
|
||||
self.point_linear = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(feature_dim, out_channels, 1, 1, 0, bias=False)),
|
||||
('bn', nn.BatchNorm2d(out_channels)),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
if self.inverted_bottleneck:
|
||||
x = self.inverted_bottleneck(x)
|
||||
x = self.depth_conv(x)
|
||||
x = self.point_linear(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
if self.mid_channels is None:
|
||||
expand_ratio = self.expand_ratio
|
||||
else:
|
||||
expand_ratio = self.mid_channels // self.in_channels
|
||||
layer_str = '%dx%d_MBConv%d_%s' % (self.kernel_size, self.kernel_size, expand_ratio, self.act_func.upper())
|
||||
if self.use_se:
|
||||
layer_str = 'SE_' + layer_str
|
||||
layer_str += '_O%d' % self.out_channels
|
||||
if self.groups is not None:
|
||||
layer_str += '_G%d' % self.groups
|
||||
if isinstance(self.point_linear.bn, nn.GroupNorm):
|
||||
layer_str += '_GN%d' % self.point_linear.bn.num_groups
|
||||
elif isinstance(self.point_linear.bn, nn.BatchNorm2d):
|
||||
layer_str += '_BN'
|
||||
|
||||
return layer_str
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': MBConvLayer.__name__,
|
||||
'in_channels': self.in_channels,
|
||||
'out_channels': self.out_channels,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'expand_ratio': self.expand_ratio,
|
||||
'mid_channels': self.mid_channels,
|
||||
'act_func': self.act_func,
|
||||
'use_se': self.use_se,
|
||||
'groups': self.groups,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return MBConvLayer(**config)
|
||||
|
||||
|
||||
class ResidualBlock(MyModule):
|
||||
|
||||
def __init__(self, conv, shortcut, dropout_rate, dropblock, block_size):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv = conv
|
||||
self.shortcut = shortcut
|
||||
# hayeon
|
||||
self.num_batches_tracked = 0
|
||||
self.dropout_rate = dropout_rate
|
||||
self.dropblock = dropblock
|
||||
self.block_size = block_size
|
||||
self.DropBlock = DropBlock(block_size=self.block_size)
|
||||
|
||||
def forward(self, x):
|
||||
# hayeon
|
||||
self.num_batches_tracked += 1
|
||||
|
||||
if self.conv is None or isinstance(self.conv, ZeroLayer):
|
||||
res = x
|
||||
elif self.shortcut is None or isinstance(self.shortcut, ZeroLayer):
|
||||
res = self.conv(x)
|
||||
else:
|
||||
res = self.conv(x) + self.shortcut(x)
|
||||
|
||||
# hayeon
|
||||
if self.dropout_rate > 0:
|
||||
if self.dropblock:
|
||||
feat_size = res.size()[2]
|
||||
keep_rate = max(1.0 - self.dropout_rate / (20*2000) * (self.num_batches_tracked), 1.0 - self.drop_rate)
|
||||
gamma = (1 - keep_rate) / self.block_size**2 * feat_size**2 / (feat_size - self.block_size + 1)**2
|
||||
res = self.DropBlock(res, gamma=gamma)
|
||||
else:
|
||||
res = F.dropout(res, p=self.dropout_rate, training=self.training, inplace=True)
|
||||
return res
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '(%s, %s)' % (
|
||||
self.conv.module_str if self.conv is not None else None,
|
||||
self.shortcut.module_str if self.shortcut is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ResidualBlock.__name__,
|
||||
'conv': self.conv.config if self.conv is not None else None,
|
||||
'shortcut': self.shortcut.config if self.shortcut is not None else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
conv_config = config['conv'] if 'conv' in config else config['mobile_inverted_conv']
|
||||
conv = set_layer_from_config(conv_config)
|
||||
shortcut = set_layer_from_config(config['shortcut'])
|
||||
return ResidualBlock(conv, shortcut)
|
||||
|
||||
@property
|
||||
def mobile_inverted_conv(self):
|
||||
return self.conv
|
||||
|
||||
|
||||
class ResNetBottleneckBlock(MyModule):
|
||||
|
||||
def __init__(self, in_channels, out_channels,
|
||||
kernel_size=3, stride=1, expand_ratio=0.25, mid_channels=None, act_func='relu', groups=1,
|
||||
downsample_mode='avgpool_conv'):
|
||||
super(ResNetBottleneckBlock, self).__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.expand_ratio = expand_ratio
|
||||
self.mid_channels = mid_channels
|
||||
self.act_func = act_func
|
||||
self.groups = groups
|
||||
|
||||
self.downsample_mode = downsample_mode
|
||||
|
||||
if self.mid_channels is None:
|
||||
feature_dim = round(self.out_channels * self.expand_ratio)
|
||||
else:
|
||||
feature_dim = self.mid_channels
|
||||
|
||||
feature_dim = make_divisible(feature_dim, MyNetwork.CHANNEL_DIVISIBLE)
|
||||
self.mid_channels = feature_dim
|
||||
|
||||
# build modules
|
||||
self.conv1 = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(self.in_channels, feature_dim, 1, 1, 0, bias=False)),
|
||||
('bn', nn.BatchNorm2d(feature_dim)),
|
||||
('act', build_activation(self.act_func, inplace=True)),
|
||||
]))
|
||||
|
||||
pad = get_same_padding(self.kernel_size)
|
||||
self.conv2 = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(feature_dim, feature_dim, kernel_size, stride, pad, groups=groups, bias=False)),
|
||||
('bn', nn.BatchNorm2d(feature_dim)),
|
||||
('act', build_activation(self.act_func, inplace=True))
|
||||
]))
|
||||
|
||||
self.conv3 = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(feature_dim, self.out_channels, 1, 1, 0, bias=False)),
|
||||
('bn', nn.BatchNorm2d(self.out_channels)),
|
||||
]))
|
||||
|
||||
if stride == 1 and in_channels == out_channels:
|
||||
self.downsample = IdentityLayer(in_channels, out_channels)
|
||||
elif self.downsample_mode == 'conv':
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(in_channels, out_channels, 1, stride, 0, bias=False)),
|
||||
('bn', nn.BatchNorm2d(out_channels)),
|
||||
]))
|
||||
elif self.downsample_mode == 'avgpool_conv':
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
('avg_pool', nn.AvgPool2d(kernel_size=stride, stride=stride, padding=0, ceil_mode=True)),
|
||||
('conv', nn.Conv2d(in_channels, out_channels, 1, 1, 0, bias=False)),
|
||||
('bn', nn.BatchNorm2d(out_channels)),
|
||||
]))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.final_act = build_activation(self.act_func, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.downsample(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
|
||||
x = x + residual
|
||||
x = self.final_act(x)
|
||||
return x
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
return '(%s, %s)' % (
|
||||
'%dx%d_BottleneckConv_%d->%d->%d_S%d_G%d' % (
|
||||
self.kernel_size, self.kernel_size, self.in_channels, self.mid_channels, self.out_channels,
|
||||
self.stride, self.groups
|
||||
),
|
||||
'Identity' if isinstance(self.downsample, IdentityLayer) else self.downsample_mode,
|
||||
)
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return {
|
||||
'name': ResNetBottleneckBlock.__name__,
|
||||
'in_channels': self.in_channels,
|
||||
'out_channels': self.out_channels,
|
||||
'kernel_size': self.kernel_size,
|
||||
'stride': self.stride,
|
||||
'expand_ratio': self.expand_ratio,
|
||||
'mid_channels': self.mid_channels,
|
||||
'act_func': self.act_func,
|
||||
'groups': self.groups,
|
||||
'downsample_mode': self.downsample_mode,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
return ResNetBottleneckBlock(**config)
|
||||
@@ -0,0 +1,4 @@
|
||||
from .my_data_loader import *
|
||||
from .my_data_worker import *
|
||||
from .my_distributed_sampler import *
|
||||
from .my_random_resize_crop import *
|
||||
@@ -0,0 +1,962 @@
|
||||
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
|
||||
|
||||
To support these two classes, in `./_utils` we define many utility methods and
|
||||
functions to be run in multiprocessing. E.g., the data loading worker loop is
|
||||
in `./_utils/worker.py`.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import itertools
|
||||
import warnings
|
||||
import multiprocessing as python_multiprocessing
|
||||
import torch
|
||||
import torch.multiprocessing as multiprocessing
|
||||
from torch._utils import ExceptionWrapper
|
||||
from torch._six import queue, string_classes
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
from torch.utils.data import Sampler, SequentialSampler, RandomSampler, BatchSampler
|
||||
from torch.utils.data import _utils
|
||||
|
||||
from .my_data_worker import worker_loop
|
||||
|
||||
__all__ = ['MyDataLoader']
|
||||
|
||||
get_worker_info = _utils.worker.get_worker_info
|
||||
|
||||
# This function used to be defined in this file. However, it was moved to
|
||||
# _utils/collate.py. Although it is rather hard to access this from user land
|
||||
# (one has to explicitly directly `import torch.utils.data.dataloader`), there
|
||||
# probably is user code out there using it. This aliasing maintains BC in this
|
||||
# aspect.
|
||||
default_collate = _utils.collate.default_collate
|
||||
|
||||
|
||||
class _DatasetKind(object):
|
||||
Map = 0
|
||||
Iterable = 1
|
||||
|
||||
@staticmethod
|
||||
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
|
||||
if kind == _DatasetKind.Map:
|
||||
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
|
||||
else:
|
||||
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
|
||||
|
||||
|
||||
class _InfiniteConstantSampler(Sampler):
|
||||
r"""Analogous to ``itertools.repeat(None, None)``.
|
||||
Used as sampler for :class:`~torch.utils.data.IterableDataset`.
|
||||
|
||||
Arguments:
|
||||
data_source (Dataset): dataset to sample from
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(_InfiniteConstantSampler, self).__init__(None)
|
||||
|
||||
def __iter__(self):
|
||||
while True:
|
||||
yield None
|
||||
|
||||
|
||||
class MyDataLoader(object):
|
||||
r"""
|
||||
Data loader. Combines a dataset and a sampler, and provides an iterable over
|
||||
the given dataset.
|
||||
|
||||
The :class:`~torch.utils.data.DataLoader` supports both map-style and
|
||||
iterable-style datasets with single- or multi-process loading, customizing
|
||||
loading order and optional automatic batching (collation) and memory pinning.
|
||||
|
||||
See :py:mod:`torch.utils.data` documentation page for more details.
|
||||
|
||||
Arguments:
|
||||
dataset (Dataset): dataset from which to load the data.
|
||||
batch_size (int, optional): how many samples per batch to load
|
||||
(default: ``1``).
|
||||
shuffle (bool, optional): set to ``True`` to have the data reshuffled
|
||||
at every epoch (default: ``False``).
|
||||
sampler (Sampler, optional): defines the strategy to draw samples from
|
||||
the dataset. If specified, :attr:`shuffle` must be ``False``.
|
||||
batch_sampler (Sampler, optional): like :attr:`sampler`, but returns a batch of
|
||||
indices at a time. Mutually exclusive with :attr:`batch_size`,
|
||||
:attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
|
||||
num_workers (int, optional): how many subprocesses to use for data
|
||||
loading. ``0`` means that the data will be loaded in the main process.
|
||||
(default: ``0``)
|
||||
collate_fn (callable, optional): merges a list of samples to form a
|
||||
mini-batch of Tensor(s). Used when using batched loading from a
|
||||
map-style dataset.
|
||||
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
||||
into CUDA pinned memory before returning them. If your data elements
|
||||
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
||||
see the example below.
|
||||
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
||||
if the dataset size is not divisible by the batch size. If ``False`` and
|
||||
the size of dataset is not divisible by the batch size, then the last batch
|
||||
will be smaller. (default: ``False``)
|
||||
timeout (numeric, optional): if positive, the timeout value for collecting a batch
|
||||
from workers. Should always be non-negative. (default: ``0``)
|
||||
worker_init_fn (callable, optional): If not ``None``, this will be called on each
|
||||
worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
|
||||
input, after seeding and before data loading. (default: ``None``)
|
||||
|
||||
|
||||
.. warning:: If the ``spawn`` start method is used, :attr:`worker_init_fn`
|
||||
cannot be an unpicklable object, e.g., a lambda function. See
|
||||
:ref:`multiprocessing-best-practices` on more details related
|
||||
to multiprocessing in PyTorch.
|
||||
|
||||
.. note:: ``len(dataloader)`` heuristic is based on the length of the sampler used.
|
||||
When :attr:`dataset` is an :class:`~torch.utils.data.IterableDataset`,
|
||||
``len(dataset)`` (if implemented) is returned instead, regardless
|
||||
of multi-process loading configurations, because PyTorch trust
|
||||
user :attr:`dataset` code in correctly handling multi-process
|
||||
loading to avoid duplicate data. See `Dataset Types`_ for more
|
||||
details on these two types of datasets and how
|
||||
:class:`~torch.utils.data.IterableDataset` interacts with `Multi-process data loading`_.
|
||||
"""
|
||||
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
|
||||
batch_sampler=None, num_workers=0, collate_fn=None,
|
||||
pin_memory=False, drop_last=False, timeout=0,
|
||||
worker_init_fn=None, multiprocessing_context=None):
|
||||
torch._C._log_api_usage_once("python.data_loader")
|
||||
|
||||
if num_workers < 0:
|
||||
raise ValueError('num_workers option should be non-negative; '
|
||||
'use num_workers=0 to disable multiprocessing.')
|
||||
|
||||
if timeout < 0:
|
||||
raise ValueError('timeout option should be non-negative')
|
||||
|
||||
self.dataset = dataset
|
||||
self.num_workers = num_workers
|
||||
self.pin_memory = pin_memory
|
||||
self.timeout = timeout
|
||||
self.worker_init_fn = worker_init_fn
|
||||
self.multiprocessing_context = multiprocessing_context
|
||||
|
||||
# Arg-check dataset related before checking samplers because we want to
|
||||
# tell users that iterable-style datasets are incompatible with custom
|
||||
# samplers first, so that they don't learn that this combo doesn't work
|
||||
# after spending time fixing the custom sampler errors.
|
||||
if isinstance(dataset, IterableDataset):
|
||||
self._dataset_kind = _DatasetKind.Iterable
|
||||
# NOTE [ Custom Samplers and `IterableDataset` ]
|
||||
#
|
||||
# `IterableDataset` does not support custom `batch_sampler` or
|
||||
# `sampler` since the key is irrelevant (unless we support
|
||||
# generator-style dataset one day...).
|
||||
#
|
||||
# For `sampler`, we always create a dummy sampler. This is an
|
||||
# infinite sampler even when the dataset may have an implemented
|
||||
# finite `__len__` because in multi-process data loading, naive
|
||||
# settings will return duplicated data (which may be desired), and
|
||||
# thus using a sampler with length matching that of dataset will
|
||||
# cause data lost (you may have duplicates of the first couple
|
||||
# batches, but never see anything afterwards). Therefore,
|
||||
# `Iterabledataset` always uses an infinite sampler, an instance of
|
||||
# `_InfiniteConstantSampler` defined above.
|
||||
#
|
||||
# A custom `batch_sampler` essentially only controls the batch size.
|
||||
# However, it is unclear how useful it would be since an iterable-style
|
||||
# dataset can handle that within itself. Moreover, it is pointless
|
||||
# in multi-process data loading as the assignment order of batches
|
||||
# to workers is an implementation detail so users can not control
|
||||
# how to batchify each worker's iterable. Thus, we disable this
|
||||
# option. If this turns out to be useful in future, we can re-enable
|
||||
# this, and support custom samplers that specify the assignments to
|
||||
# specific workers.
|
||||
if shuffle is not False:
|
||||
raise ValueError(
|
||||
"DataLoader with IterableDataset: expected unspecified "
|
||||
"shuffle option, but got shuffle={}".format(shuffle))
|
||||
elif sampler is not None:
|
||||
# See NOTE [ Custom Samplers and IterableDataset ]
|
||||
raise ValueError(
|
||||
"DataLoader with IterableDataset: expected unspecified "
|
||||
"sampler option, but got sampler={}".format(sampler))
|
||||
elif batch_sampler is not None:
|
||||
# See NOTE [ Custom Samplers and IterableDataset ]
|
||||
raise ValueError(
|
||||
"DataLoader with IterableDataset: expected unspecified "
|
||||
"batch_sampler option, but got batch_sampler={}".format(batch_sampler))
|
||||
else:
|
||||
self._dataset_kind = _DatasetKind.Map
|
||||
|
||||
if sampler is not None and shuffle:
|
||||
raise ValueError('sampler option is mutually exclusive with '
|
||||
'shuffle')
|
||||
|
||||
if batch_sampler is not None:
|
||||
# auto_collation with custom batch_sampler
|
||||
if batch_size != 1 or shuffle or sampler is not None or drop_last:
|
||||
raise ValueError('batch_sampler option is mutually exclusive '
|
||||
'with batch_size, shuffle, sampler, and '
|
||||
'drop_last')
|
||||
batch_size = None
|
||||
drop_last = False
|
||||
elif batch_size is None:
|
||||
# no auto_collation
|
||||
if shuffle or drop_last:
|
||||
raise ValueError('batch_size=None option disables auto-batching '
|
||||
'and is mutually exclusive with '
|
||||
'shuffle, and drop_last')
|
||||
|
||||
if sampler is None: # give default samplers
|
||||
if self._dataset_kind == _DatasetKind.Iterable:
|
||||
# See NOTE [ Custom Samplers and IterableDataset ]
|
||||
sampler = _InfiniteConstantSampler()
|
||||
else: # map-style
|
||||
if shuffle:
|
||||
sampler = RandomSampler(dataset)
|
||||
else:
|
||||
sampler = SequentialSampler(dataset)
|
||||
|
||||
if batch_size is not None and batch_sampler is None:
|
||||
# auto_collation without custom batch_sampler
|
||||
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
self.sampler = sampler
|
||||
self.batch_sampler = batch_sampler
|
||||
|
||||
if collate_fn is None:
|
||||
if self._auto_collation:
|
||||
collate_fn = _utils.collate.default_collate
|
||||
else:
|
||||
collate_fn = _utils.collate.default_convert
|
||||
|
||||
self.collate_fn = collate_fn
|
||||
self.__initialized = True
|
||||
self._IterableDataset_len_called = None # See NOTE [ IterableDataset and __len__ ]
|
||||
|
||||
@property
|
||||
def multiprocessing_context(self):
|
||||
return self.__multiprocessing_context
|
||||
|
||||
@multiprocessing_context.setter
|
||||
def multiprocessing_context(self, multiprocessing_context):
|
||||
if multiprocessing_context is not None:
|
||||
if self.num_workers > 0:
|
||||
if not multiprocessing._supports_context:
|
||||
raise ValueError('multiprocessing_context relies on Python >= 3.4, with '
|
||||
'support for different start methods')
|
||||
|
||||
if isinstance(multiprocessing_context, string_classes):
|
||||
valid_start_methods = multiprocessing.get_all_start_methods()
|
||||
if multiprocessing_context not in valid_start_methods:
|
||||
raise ValueError(
|
||||
('multiprocessing_context option '
|
||||
'should specify a valid start method in {}, but got '
|
||||
'multiprocessing_context={}').format(valid_start_methods, multiprocessing_context))
|
||||
multiprocessing_context = multiprocessing.get_context(multiprocessing_context)
|
||||
|
||||
if not isinstance(multiprocessing_context, python_multiprocessing.context.BaseContext):
|
||||
raise ValueError(('multiprocessing_context option should be a valid context '
|
||||
'object or a string specifying the start method, but got '
|
||||
'multiprocessing_context={}').format(multiprocessing_context))
|
||||
else:
|
||||
raise ValueError(('multiprocessing_context can only be used with '
|
||||
'multi-process loading (num_workers > 0), but got '
|
||||
'num_workers={}').format(self.num_workers))
|
||||
|
||||
self.__multiprocessing_context = multiprocessing_context
|
||||
|
||||
def __setattr__(self, attr, val):
|
||||
if self.__initialized and attr in ('batch_size', 'batch_sampler', 'sampler', 'drop_last', 'dataset'):
|
||||
raise ValueError('{} attribute should not be set after {} is '
|
||||
'initialized'.format(attr, self.__class__.__name__))
|
||||
|
||||
super(MyDataLoader, self).__setattr__(attr, val)
|
||||
|
||||
def __iter__(self):
|
||||
if self.num_workers == 0:
|
||||
return _SingleProcessDataLoaderIter(self)
|
||||
else:
|
||||
return _MultiProcessingDataLoaderIter(self)
|
||||
|
||||
@property
|
||||
def _auto_collation(self):
|
||||
return self.batch_sampler is not None
|
||||
|
||||
@property
|
||||
def _index_sampler(self):
|
||||
# The actual sampler used for generating indices for `_DatasetFetcher`
|
||||
# (see _utils/fetch.py) to read data at each time. This would be
|
||||
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
|
||||
# We can't change `.sampler` and `.batch_sampler` attributes for BC
|
||||
# reasons.
|
||||
if self._auto_collation:
|
||||
return self.batch_sampler
|
||||
else:
|
||||
return self.sampler
|
||||
|
||||
def __len__(self):
|
||||
if self._dataset_kind == _DatasetKind.Iterable:
|
||||
# NOTE [ IterableDataset and __len__ ]
|
||||
#
|
||||
# For `IterableDataset`, `__len__` could be inaccurate when one naively
|
||||
# does multi-processing data loading, since the samples will be duplicated.
|
||||
# However, no real use case should be actually using that behavior, so
|
||||
# it should count as a user error. We should generally trust user
|
||||
# code to do the proper thing (e.g., configure each replica differently
|
||||
# in `__iter__`), and give us the correct `__len__` if they choose to
|
||||
# implement it (this will still throw if the dataset does not implement
|
||||
# a `__len__`).
|
||||
#
|
||||
# To provide a further warning, we track if `__len__` was called on the
|
||||
# `DataLoader`, save the returned value in `self._len_called`, and warn
|
||||
# if the iterator ends up yielding more than this number of samples.
|
||||
length = self._IterableDataset_len_called = len(self.dataset)
|
||||
return length
|
||||
else:
|
||||
return len(self._index_sampler)
|
||||
|
||||
|
||||
class _BaseDataLoaderIter(object):
|
||||
def __init__(self, loader):
|
||||
self._dataset = loader.dataset
|
||||
self._dataset_kind = loader._dataset_kind
|
||||
self._IterableDataset_len_called = loader._IterableDataset_len_called
|
||||
self._auto_collation = loader._auto_collation
|
||||
self._drop_last = loader.drop_last
|
||||
self._index_sampler = loader._index_sampler
|
||||
self._num_workers = loader.num_workers
|
||||
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
|
||||
self._timeout = loader.timeout
|
||||
self._collate_fn = loader.collate_fn
|
||||
self._sampler_iter = iter(self._index_sampler)
|
||||
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
|
||||
self._num_yielded = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def _next_index(self):
|
||||
return next(self._sampler_iter) # may raise StopIteration
|
||||
|
||||
def _next_data(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __next__(self):
|
||||
data = self._next_data()
|
||||
self._num_yielded += 1
|
||||
if self._dataset_kind == _DatasetKind.Iterable and \
|
||||
self._IterableDataset_len_called is not None and \
|
||||
self._num_yielded > self._IterableDataset_len_called:
|
||||
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
|
||||
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
|
||||
self._num_yielded)
|
||||
if self._num_workers > 0:
|
||||
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
|
||||
"IterableDataset replica at each worker. Please see "
|
||||
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
|
||||
warnings.warn(warn_msg)
|
||||
return data
|
||||
|
||||
next = __next__ # Python 2 compatibility
|
||||
|
||||
def __len__(self):
|
||||
return len(self._index_sampler)
|
||||
|
||||
def __getstate__(self):
|
||||
# across multiple threads for HOGWILD.
|
||||
# Probably the best way to do this is by moving the sample pushing
|
||||
# to a separate thread and then just sharing the data queue
|
||||
# but signalling the end is tricky without a non-blocking API
|
||||
raise NotImplementedError("{} cannot be pickled", self.__class__.__name__)
|
||||
|
||||
|
||||
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
|
||||
def __init__(self, loader):
|
||||
super(_SingleProcessDataLoaderIter, self).__init__(loader)
|
||||
assert self._timeout == 0
|
||||
assert self._num_workers == 0
|
||||
|
||||
self._dataset_fetcher = _DatasetKind.create_fetcher(
|
||||
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
|
||||
|
||||
def _next_data(self):
|
||||
index = self._next_index() # may raise StopIteration
|
||||
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
|
||||
if self._pin_memory:
|
||||
data = _utils.pin_memory.pin_memory(data)
|
||||
return data
|
||||
|
||||
|
||||
class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
||||
r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
|
||||
|
||||
# NOTE [ Data Loader Multiprocessing Shutdown Logic ]
|
||||
#
|
||||
# Preliminary:
|
||||
#
|
||||
# Our data model looks like this (queues are indicated with curly brackets):
|
||||
#
|
||||
# main process ||
|
||||
# | ||
|
||||
# {index_queue} ||
|
||||
# | ||
|
||||
# worker processes || DATA
|
||||
# | ||
|
||||
# {worker_result_queue} || FLOW
|
||||
# | ||
|
||||
# pin_memory_thread of main process || DIRECTION
|
||||
# | ||
|
||||
# {data_queue} ||
|
||||
# | ||
|
||||
# data output \/
|
||||
#
|
||||
# P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
|
||||
# `pin_memory=False`.
|
||||
#
|
||||
#
|
||||
# Terminating multiprocessing logic requires very careful design. In
|
||||
# particular, we need to make sure that
|
||||
#
|
||||
# 1. The iterator gracefully exits the workers when its last reference is
|
||||
# gone or it is depleted.
|
||||
#
|
||||
# In this case, the workers should be gracefully exited because the
|
||||
# main process may still need to continue to run, and we want cleaning
|
||||
# up code in the workers to be executed (e.g., releasing GPU memory).
|
||||
# Naturally, we implement the shutdown logic in `__del__` of
|
||||
# DataLoaderIterator.
|
||||
#
|
||||
# We delay the discussion on the logic in this case until later.
|
||||
#
|
||||
# 2. The iterator exits the workers when the loader process and/or worker
|
||||
# processes exits normally or with error.
|
||||
#
|
||||
# We set all workers and `pin_memory_thread` to have `daemon=True`.
|
||||
#
|
||||
# You may ask, why can't we make the workers non-daemonic, and
|
||||
# gracefully exit using the same logic as we have in `__del__` when the
|
||||
# iterator gets deleted (see 1 above)?
|
||||
#
|
||||
# First of all, `__del__` is **not** guaranteed to be called when
|
||||
# interpreter exits. Even if it is called, by the time it executes,
|
||||
# many Python core library resources may alreay be freed, and even
|
||||
# simple things like acquiring an internal lock of a queue may hang.
|
||||
# Therefore, in this case, we actually need to prevent `__del__` from
|
||||
# being executed, and rely on the automatic termination of daemonic
|
||||
# children. Thus, we register an `atexit` hook that sets a global flag
|
||||
# `_utils.python_exit_status`. Since `atexit` hooks are executed in the
|
||||
# reverse order of registration, we are guaranteed that this flag is
|
||||
# set before library resources we use are freed. (Hooks freeing those
|
||||
# resources are registered at importing the Python core libraries at
|
||||
# the top of this file.) So in `__del__`, we check if
|
||||
# `_utils.python_exit_status` is set or `None` (freed), and perform
|
||||
# no-op if so.
|
||||
#
|
||||
# Another problem with `__del__` is also related to the library cleanup
|
||||
# calls. When a process ends, it shuts the all its daemonic children
|
||||
# down with a SIGTERM (instead of joining them without a timeout).
|
||||
# Simiarly for threads, but by a different mechanism. This fact,
|
||||
# together with a few implementation details of multiprocessing, forces
|
||||
# us to make workers daemonic. All of our problems arise when a
|
||||
# DataLoader is used in a subprocess, and are caused by multiprocessing
|
||||
# code which looks more or less like this:
|
||||
#
|
||||
# try:
|
||||
# your_function_using_a_dataloader()
|
||||
# finally:
|
||||
# multiprocessing.util._exit_function()
|
||||
#
|
||||
# The joining/termination mentioned above happens inside
|
||||
# `_exit_function()`. Now, if `your_function_using_a_dataloader()`
|
||||
# throws, the stack trace stored in the exception will prevent the
|
||||
# frame which uses `DataLoaderIter` to be freed. If the frame has any
|
||||
# reference to the `DataLoaderIter` (e.g., in a method of the iter),
|
||||
# its `__del__`, which starts the shutdown procedure, will not be
|
||||
# called. That, in turn, means that workers aren't notified. Attempting
|
||||
# to join in `_exit_function` will then result in a hang.
|
||||
#
|
||||
# For context, `_exit_function` is also registered as an `atexit` call.
|
||||
# So it is unclear to me (@ssnl) why this is needed in a finally block.
|
||||
# The code dates back to 2008 and there is no comment on the original
|
||||
# PEP 371 or patch https://bugs.python.org/issue3050 (containing both
|
||||
# the finally block and the `atexit` registration) that explains this.
|
||||
#
|
||||
# Another choice is to just shutdown workers with logic in 1 above
|
||||
# whenever we see an error in `next`. This isn't ideal because
|
||||
# a. It prevents users from using try-catch to resume data loading.
|
||||
# b. It doesn't prevent hanging if users have references to the
|
||||
# iterator.
|
||||
#
|
||||
# 3. All processes exit if any of them die unexpectedly by fatal signals.
|
||||
#
|
||||
# As shown above, the workers are set as daemonic children of the main
|
||||
# process. However, automatic cleaning-up of such child processes only
|
||||
# happens if the parent process exits gracefully (e.g., not via fatal
|
||||
# signals like SIGKILL). So we must ensure that each process will exit
|
||||
# even the process that should send/receive data to/from it were
|
||||
# killed, i.e.,
|
||||
#
|
||||
# a. A process won't hang when getting from a queue.
|
||||
#
|
||||
# Even with carefully designed data dependencies (i.e., a `put()`
|
||||
# always corresponding to a `get()`), hanging on `get()` can still
|
||||
# happen when data in queue is corrupted (e.g., due to
|
||||
# `cancel_join_thread` or unexpected exit).
|
||||
#
|
||||
# For child exit, we set a timeout whenever we try to get data
|
||||
# from `data_queue`, and check the workers' status on each timeout
|
||||
# and error.
|
||||
# See `_DataLoaderiter._get_batch()` and
|
||||
# `_DataLoaderiter._try_get_data()` for details.
|
||||
#
|
||||
# Additionally, for child exit on non-Windows platforms, we also
|
||||
# register a SIGCHLD handler (which is supported on Windows) on
|
||||
# the main process, which checks if any of the workers fail in the
|
||||
# (Python) handler. This is more efficient and faster in detecting
|
||||
# worker failures, compared to only using the above mechanism.
|
||||
# See `DataLoader.cpp` and `_utils/signal_handling.py` for details.
|
||||
#
|
||||
# For `.get()` calls where the sender(s) is not the workers, we
|
||||
# guard them with timeouts, and check the status of the sender
|
||||
# when timeout happens:
|
||||
# + in the workers, the `_utils.worker.ManagerWatchdog` class
|
||||
# checks the status of the main process.
|
||||
# + if `pin_memory=True`, when getting from `pin_memory_thread`,
|
||||
# check `pin_memory_thread` status periodically until `.get()`
|
||||
# returns or see that `pin_memory_thread` died.
|
||||
#
|
||||
# b. A process won't hang when putting into a queue;
|
||||
#
|
||||
# We use `mp.Queue` which has a separate background thread to put
|
||||
# objects from an unbounded buffer array. The background thread is
|
||||
# daemonic and usually automatically joined when the process
|
||||
# exits.
|
||||
#
|
||||
# However, in case that the receiver has ended abruptly while
|
||||
# reading from the pipe, the join will hang forever. Therefore,
|
||||
# for both `worker_result_queue` (worker -> main process/pin_memory_thread)
|
||||
# and each `index_queue` (main process -> worker), we use
|
||||
# `q.cancel_join_thread()` in sender process before any `q.put` to
|
||||
# prevent this automatic join.
|
||||
#
|
||||
# Moreover, having all queues called `cancel_join_thread` makes
|
||||
# implementing graceful shutdown logic in `__del__` much easier.
|
||||
# It won't need to get from any queue, which would also need to be
|
||||
# guarded by periodic status checks.
|
||||
#
|
||||
# Nonetheless, `cancel_join_thread` must only be called when the
|
||||
# queue is **not** going to be read from or write into by another
|
||||
# process, because it may hold onto a lock or leave corrupted data
|
||||
# in the queue, leading other readers/writers to hang.
|
||||
#
|
||||
# `pin_memory_thread`'s `data_queue` is a `queue.Queue` that does
|
||||
# a blocking `put` if the queue is full. So there is no above
|
||||
# problem, but we do need to wrap the `put` in a loop that breaks
|
||||
# not only upon success, but also when the main process stops
|
||||
# reading, i.e., is shutting down.
|
||||
#
|
||||
#
|
||||
# Now let's get back to 1:
|
||||
# how we gracefully exit the workers when the last reference to the
|
||||
# iterator is gone.
|
||||
#
|
||||
# To achieve this, we implement the following logic along with the design
|
||||
# choices mentioned above:
|
||||
#
|
||||
# `workers_done_event`:
|
||||
# A `multiprocessing.Event` shared among the main process and all worker
|
||||
# processes. This is used to signal the workers that the iterator is
|
||||
# shutting down. After it is set, they will not send processed data to
|
||||
# queues anymore, and only wait for the final `None` before exiting.
|
||||
# `done_event` isn't strictly needed. I.e., we can just check for `None`
|
||||
# from the input queue, but it allows us to skip wasting resources
|
||||
# processing data if we are already shutting down.
|
||||
#
|
||||
# `pin_memory_thread_done_event`:
|
||||
# A `threading.Event` for a similar purpose to that of
|
||||
# `workers_done_event`, but is for the `pin_memory_thread`. The reason
|
||||
# that separate events are needed is that `pin_memory_thread` reads from
|
||||
# the output queue of the workers. But the workers, upon seeing that
|
||||
# `workers_done_event` is set, only wants to see the final `None`, and is
|
||||
# not required to flush all data in the output queue (e.g., it may call
|
||||
# `cancel_join_thread` on that queue if its `IterableDataset` iterator
|
||||
# happens to exhaust coincidentally, which is out of the control of the
|
||||
# main process). Thus, since we will exit `pin_memory_thread` before the
|
||||
# workers (see below), two separete events are used.
|
||||
#
|
||||
# NOTE: In short, the protocol is that the main process will set these
|
||||
# `done_event`s and then the corresponding processes/threads a `None`,
|
||||
# and that they may exit at any time after receiving the `None`.
|
||||
#
|
||||
# NOTE: Using `None` as the final signal is valid, since normal data will
|
||||
# always be a 2-tuple with the 1st element being the index of the data
|
||||
# transferred (different from dataset index/key), and the 2nd being
|
||||
# either the dataset key or the data sample (depending on which part
|
||||
# of the data model the queue is at).
|
||||
#
|
||||
# [ worker processes ]
|
||||
# While loader process is alive:
|
||||
# Get from `index_queue`.
|
||||
# If get anything else,
|
||||
# Check `workers_done_event`.
|
||||
# If set, continue to next iteration
|
||||
# i.e., keep getting until see the `None`, then exit.
|
||||
# Otherwise, process data:
|
||||
# If is fetching from an `IterableDataset` and the iterator
|
||||
# is exhausted, send an `_IterableDatasetStopIteration`
|
||||
# object to signal iteration end. The main process, upon
|
||||
# receiving such an object, will send `None` to this
|
||||
# worker and not use the corresponding `index_queue`
|
||||
# anymore.
|
||||
# If timed out,
|
||||
# No matter `workers_done_event` is set (still need to see `None`)
|
||||
# or not, must continue to next iteration.
|
||||
# (outside loop)
|
||||
# If `workers_done_event` is set, (this can be False with `IterableDataset`)
|
||||
# `data_queue.cancel_join_thread()`. (Everything is ending here:
|
||||
# main process won't read from it;
|
||||
# other workers will also call
|
||||
# `cancel_join_thread`.)
|
||||
#
|
||||
# [ pin_memory_thread ]
|
||||
# # No need to check main thread. If this thread is alive, the main loader
|
||||
# # thread must be alive, because this thread is set as daemonic.
|
||||
# While `pin_memory_thread_done_event` is not set:
|
||||
# Get from `index_queue`.
|
||||
# If timed out, continue to get in the next iteration.
|
||||
# Otherwise, process data.
|
||||
# While `pin_memory_thread_done_event` is not set:
|
||||
# Put processed data to `data_queue` (a `queue.Queue` with blocking put)
|
||||
# If timed out, continue to put in the next iteration.
|
||||
# Otherwise, break, i.e., continuing to the out loop.
|
||||
#
|
||||
# NOTE: we don't check the status of the main thread because
|
||||
# 1. if the process is killed by fatal signal, `pin_memory_thread`
|
||||
# ends.
|
||||
# 2. in other cases, either the cleaning-up in __del__ or the
|
||||
# automatic exit of daemonic thread will take care of it.
|
||||
# This won't busy-wait either because `.get(timeout)` does not
|
||||
# busy-wait.
|
||||
#
|
||||
# [ main process ]
|
||||
# In the DataLoader Iter's `__del__`
|
||||
# b. Exit `pin_memory_thread`
|
||||
# i. Set `pin_memory_thread_done_event`.
|
||||
# ii Put `None` in `worker_result_queue`.
|
||||
# iii. Join the `pin_memory_thread`.
|
||||
# iv. `worker_result_queue.cancel_join_thread()`.
|
||||
#
|
||||
# c. Exit the workers.
|
||||
# i. Set `workers_done_event`.
|
||||
# ii. Put `None` in each worker's `index_queue`.
|
||||
# iii. Join the workers.
|
||||
# iv. Call `.cancel_join_thread()` on each worker's `index_queue`.
|
||||
#
|
||||
# NOTE: (c) is better placed after (b) because it may leave corrupted
|
||||
# data in `worker_result_queue`, which `pin_memory_thread`
|
||||
# reads from, in which case the `pin_memory_thread` can only
|
||||
# happen at timeing out, which is slow. Nonetheless, same thing
|
||||
# happens if a worker is killed by signal at unfortunate times,
|
||||
# but in other cases, we are better off having a non-corrupted
|
||||
# `worker_result_queue` for `pin_memory_thread`.
|
||||
#
|
||||
# NOTE: If `pin_memory=False`, there is no `pin_memory_thread` and (b)
|
||||
# can be omitted
|
||||
#
|
||||
# NB: `done_event`s isn't strictly needed. E.g., we can just check for
|
||||
# `None` from `index_queue`, but it allows us to skip wasting resources
|
||||
# processing indices already in `index_queue` if we are already shutting
|
||||
# down.
|
||||
|
||||
def __init__(self, loader):
|
||||
super(_MultiProcessingDataLoaderIter, self).__init__(loader)
|
||||
|
||||
assert self._num_workers > 0
|
||||
|
||||
if loader.multiprocessing_context is None:
|
||||
multiprocessing_context = multiprocessing
|
||||
else:
|
||||
multiprocessing_context = loader.multiprocessing_context
|
||||
|
||||
self._worker_init_fn = loader.worker_init_fn
|
||||
self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
|
||||
self._worker_result_queue = multiprocessing_context.Queue()
|
||||
self._worker_pids_set = False
|
||||
self._shutdown = False
|
||||
self._send_idx = 0 # idx of the next task to be sent to workers
|
||||
self._rcvd_idx = 0 # idx of the next task to be returned in __next__
|
||||
# information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
|
||||
# map: task idx => - (worker_id,) if data isn't fetched (outstanding)
|
||||
# \ (worker_id, data) if data is already fetched (out-of-order)
|
||||
self._task_info = {}
|
||||
self._tasks_outstanding = 0 # always equal to count(v for v in task_info.values() if len(v) == 1)
|
||||
self._workers_done_event = multiprocessing_context.Event()
|
||||
|
||||
self._index_queues = []
|
||||
self._workers = []
|
||||
# A list of booleans representing whether each worker still has work to
|
||||
# do, i.e., not having exhausted its iterable dataset object. It always
|
||||
# contains all `True`s if not using an iterable-style dataset
|
||||
# (i.e., if kind != Iterable).
|
||||
self._workers_status = []
|
||||
for i in range(self._num_workers):
|
||||
index_queue = multiprocessing_context.Queue()
|
||||
# index_queue.cancel_join_thread()
|
||||
w = multiprocessing_context.Process(
|
||||
target=worker_loop,
|
||||
args=(self._dataset_kind, self._dataset, index_queue,
|
||||
self._worker_result_queue, self._workers_done_event,
|
||||
self._auto_collation, self._collate_fn, self._drop_last,
|
||||
self._base_seed + i, self._worker_init_fn, i, self._num_workers))
|
||||
w.daemon = True
|
||||
# NB: Process.start() actually take some time as it needs to
|
||||
# start a process and pass the arguments over via a pipe.
|
||||
# Therefore, we only add a worker to self._workers list after
|
||||
# it started, so that we do not call .join() if program dies
|
||||
# before it starts, and __del__ tries to join but will get:
|
||||
# AssertionError: can only join a started process.
|
||||
w.start()
|
||||
self._index_queues.append(index_queue)
|
||||
self._workers.append(w)
|
||||
self._workers_status.append(True)
|
||||
|
||||
if self._pin_memory:
|
||||
self._pin_memory_thread_done_event = threading.Event()
|
||||
self._data_queue = queue.Queue()
|
||||
pin_memory_thread = threading.Thread(
|
||||
target=_utils.pin_memory._pin_memory_loop,
|
||||
args=(self._worker_result_queue, self._data_queue,
|
||||
torch.cuda.current_device(),
|
||||
self._pin_memory_thread_done_event))
|
||||
pin_memory_thread.daemon = True
|
||||
pin_memory_thread.start()
|
||||
# Similar to workers (see comment above), we only register
|
||||
# pin_memory_thread once it is started.
|
||||
self._pin_memory_thread = pin_memory_thread
|
||||
else:
|
||||
self._data_queue = self._worker_result_queue
|
||||
|
||||
_utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))
|
||||
_utils.signal_handling._set_SIGCHLD_handler()
|
||||
self._worker_pids_set = True
|
||||
|
||||
# prime the prefetch loop
|
||||
for _ in range(2 * self._num_workers):
|
||||
self._try_put_index()
|
||||
|
||||
def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
|
||||
# Tries to fetch data from `self._data_queue` once for a given timeout.
|
||||
# This can also be used as inner loop of fetching without timeout, with
|
||||
# the sender status as the loop condition.
|
||||
#
|
||||
# This raises a `RuntimeError` if any worker died expectedly. This error
|
||||
# can come from either the SIGCHLD handler in `_utils/signal_handling.py`
|
||||
# (only for non-Windows platforms), or the manual check below on errors
|
||||
# and timeouts.
|
||||
#
|
||||
# Returns a 2-tuple:
|
||||
# (bool: whether successfully get data, any: data if successful else None)
|
||||
try:
|
||||
data = self._data_queue.get(timeout=timeout)
|
||||
return (True, data)
|
||||
except Exception as e:
|
||||
# At timeout and error, we manually check whether any worker has
|
||||
# failed. Note that this is the only mechanism for Windows to detect
|
||||
# worker failures.
|
||||
failed_workers = []
|
||||
for worker_id, w in enumerate(self._workers):
|
||||
if self._workers_status[worker_id] and not w.is_alive():
|
||||
failed_workers.append(w)
|
||||
self._shutdown_worker(worker_id)
|
||||
if len(failed_workers) > 0:
|
||||
pids_str = ', '.join(str(w.pid) for w in failed_workers)
|
||||
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str))
|
||||
if isinstance(e, queue.Empty):
|
||||
return (False, None)
|
||||
raise
|
||||
|
||||
def _get_data(self):
|
||||
# Fetches data from `self._data_queue`.
|
||||
#
|
||||
# We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
|
||||
# which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
|
||||
# in a loop. This is the only mechanism to detect worker failures for
|
||||
# Windows. For other platforms, a SIGCHLD handler is also used for
|
||||
# worker failure detection.
|
||||
#
|
||||
# If `pin_memory=True`, we also need check if `pin_memory_thread` had
|
||||
# died at timeouts.
|
||||
if self._timeout > 0:
|
||||
success, data = self._try_get_data(self._timeout)
|
||||
if success:
|
||||
return data
|
||||
else:
|
||||
raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
|
||||
elif self._pin_memory:
|
||||
while self._pin_memory_thread.is_alive():
|
||||
success, data = self._try_get_data()
|
||||
if success:
|
||||
return data
|
||||
else:
|
||||
# while condition is false, i.e., pin_memory_thread died.
|
||||
raise RuntimeError('Pin memory thread exited unexpectedly')
|
||||
# In this case, `self._data_queue` is a `queue.Queue`,. But we don't
|
||||
# need to call `.task_done()` because we don't use `.join()`.
|
||||
else:
|
||||
while True:
|
||||
success, data = self._try_get_data()
|
||||
if success:
|
||||
return data
|
||||
|
||||
def _next_data(self):
|
||||
while True:
|
||||
# If the worker responsible for `self._rcvd_idx` has already ended
|
||||
# and was unable to fulfill this task (due to exhausting an `IterableDataset`),
|
||||
# we try to advance `self._rcvd_idx` to find the next valid index.
|
||||
#
|
||||
# This part needs to run in the loop because both the `self._get_data()`
|
||||
# call and `_IterableDatasetStopIteration` check below can mark
|
||||
# extra worker(s) as dead.
|
||||
while self._rcvd_idx < self._send_idx:
|
||||
info = self._task_info[self._rcvd_idx]
|
||||
worker_id = info[0]
|
||||
if len(info) == 2 or self._workers_status[worker_id]: # has data or is still active
|
||||
break
|
||||
del self._task_info[self._rcvd_idx]
|
||||
self._rcvd_idx += 1
|
||||
else:
|
||||
# no valid `self._rcvd_idx` is found (i.e., didn't break)
|
||||
self._shutdown_workers()
|
||||
raise StopIteration
|
||||
|
||||
# Now `self._rcvd_idx` is the batch index we want to fetch
|
||||
|
||||
# Check if the next sample has already been generated
|
||||
if len(self._task_info[self._rcvd_idx]) == 2:
|
||||
data = self._task_info.pop(self._rcvd_idx)[1]
|
||||
return self._process_data(data)
|
||||
|
||||
assert not self._shutdown and self._tasks_outstanding > 0
|
||||
idx, data = self._get_data()
|
||||
self._tasks_outstanding -= 1
|
||||
|
||||
if self._dataset_kind == _DatasetKind.Iterable:
|
||||
# Check for _IterableDatasetStopIteration
|
||||
if isinstance(data, _utils.worker._IterableDatasetStopIteration):
|
||||
self._shutdown_worker(data.worker_id)
|
||||
self._try_put_index()
|
||||
continue
|
||||
|
||||
if idx != self._rcvd_idx:
|
||||
# store out-of-order samples
|
||||
self._task_info[idx] += (data,)
|
||||
else:
|
||||
del self._task_info[idx]
|
||||
return self._process_data(data)
|
||||
|
||||
def _try_put_index(self):
|
||||
assert self._tasks_outstanding < 2 * self._num_workers
|
||||
try:
|
||||
index = self._next_index()
|
||||
except StopIteration:
|
||||
return
|
||||
for _ in range(self._num_workers): # find the next active worker, if any
|
||||
worker_queue_idx = next(self._worker_queue_idx_cycle)
|
||||
if self._workers_status[worker_queue_idx]:
|
||||
break
|
||||
else:
|
||||
# not found (i.e., didn't break)
|
||||
return
|
||||
|
||||
self._index_queues[worker_queue_idx].put((self._send_idx, index))
|
||||
self._task_info[self._send_idx] = (worker_queue_idx,)
|
||||
self._tasks_outstanding += 1
|
||||
self._send_idx += 1
|
||||
|
||||
def _process_data(self, data):
|
||||
self._rcvd_idx += 1
|
||||
self._try_put_index()
|
||||
if isinstance(data, ExceptionWrapper):
|
||||
data.reraise()
|
||||
return data
|
||||
|
||||
def _shutdown_worker(self, worker_id):
|
||||
# Mark a worker as having finished its work and dead, e.g., due to
|
||||
# exhausting an `IterableDataset`. This should be used only when this
|
||||
# `_MultiProcessingDataLoaderIter` is going to continue running.
|
||||
|
||||
assert self._workers_status[worker_id]
|
||||
|
||||
# Signal termination to that specific worker.
|
||||
q = self._index_queues[worker_id]
|
||||
# Indicate that no more data will be put on this queue by the current
|
||||
# process.
|
||||
q.put(None)
|
||||
|
||||
# Note that we don't actually join the worker here, nor do we remove the
|
||||
# worker's pid from C side struct because (1) joining may be slow, and
|
||||
# (2) since we don't join, the worker may still raise error, and we
|
||||
# prefer capturing those, rather than ignoring them, even though they
|
||||
# are raised after the worker has finished its job.
|
||||
# Joinning is deferred to `_shutdown_workers`, which it is called when
|
||||
# all workers finish their jobs (e.g., `IterableDataset` replicas) or
|
||||
# when this iterator is garbage collected.
|
||||
self._workers_status[worker_id] = False
|
||||
|
||||
def _shutdown_workers(self):
|
||||
# Called when shutting down this `_MultiProcessingDataLoaderIter`.
|
||||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on
|
||||
# the logic of this function.
|
||||
python_exit_status = _utils.python_exit_status
|
||||
if python_exit_status is True or python_exit_status is None:
|
||||
# See (2) of the note. If Python is shutting down, do no-op.
|
||||
return
|
||||
# Normal exit when last reference is gone / iterator is depleted.
|
||||
# See (1) and the second half of the note.
|
||||
if not self._shutdown:
|
||||
self._shutdown = True
|
||||
try:
|
||||
# Exit `pin_memory_thread` first because exiting workers may leave
|
||||
# corrupted data in `worker_result_queue` which `pin_memory_thread`
|
||||
# reads from.
|
||||
if hasattr(self, '_pin_memory_thread'):
|
||||
# Use hasattr in case error happens before we set the attribute.
|
||||
self._pin_memory_thread_done_event.set()
|
||||
# Send something to pin_memory_thread in case it is waiting
|
||||
# so that it can wake up and check `pin_memory_thread_done_event`
|
||||
self._worker_result_queue.put((None, None))
|
||||
self._pin_memory_thread.join()
|
||||
self._worker_result_queue.close()
|
||||
|
||||
# Exit workers now.
|
||||
self._workers_done_event.set()
|
||||
for worker_id in range(len(self._workers)):
|
||||
# Get number of workers from `len(self._workers)` instead of
|
||||
# `self._num_workers` in case we error before starting all
|
||||
# workers.
|
||||
if self._workers_status[worker_id]:
|
||||
self._shutdown_worker(worker_id)
|
||||
for w in self._workers:
|
||||
w.join()
|
||||
for q in self._index_queues:
|
||||
q.cancel_join_thread()
|
||||
q.close()
|
||||
finally:
|
||||
# Even though all this function does is putting into queues that
|
||||
# we have called `cancel_join_thread` on, weird things can
|
||||
# happen when a worker is killed by a signal, e.g., hanging in
|
||||
# `Event.set()`. So we need to guard this with SIGCHLD handler,
|
||||
# and remove pids from the C side data structure only at the
|
||||
# end.
|
||||
#
|
||||
# FIXME: Unfortunately, for Windows, we are missing a worker
|
||||
# error detection mechanism here in this function, as it
|
||||
# doesn't provide a SIGCHLD handler.
|
||||
if self._worker_pids_set:
|
||||
_utils.signal_handling._remove_worker_pids(id(self))
|
||||
self._worker_pids_set = False
|
||||
|
||||
def __del__(self):
|
||||
self._shutdown_workers()
|
||||
@@ -0,0 +1,207 @@
|
||||
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
|
||||
|
||||
These **needs** to be in global scope since Py2 doesn't support serializing
|
||||
static methods.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import random
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from torch._six import queue
|
||||
from torch._utils import ExceptionWrapper
|
||||
from torch.utils.data._utils import signal_handling, MP_STATUS_CHECK_INTERVAL, IS_WINDOWS
|
||||
|
||||
from .my_random_resize_crop import MyRandomResizedCrop
|
||||
|
||||
__all__ = ['worker_loop']
|
||||
|
||||
if IS_WINDOWS:
|
||||
import ctypes
|
||||
from ctypes.wintypes import DWORD, BOOL, HANDLE
|
||||
|
||||
|
||||
# On Windows, the parent ID of the worker process remains unchanged when the manager process
|
||||
# is gone, and the only way to check it through OS is to let the worker have a process handle
|
||||
# of the manager and ask if the process status has changed.
|
||||
class ManagerWatchdog(object):
|
||||
def __init__(self):
|
||||
self.manager_pid = os.getppid()
|
||||
|
||||
self.kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
|
||||
self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD)
|
||||
self.kernel32.OpenProcess.restype = HANDLE
|
||||
self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD)
|
||||
self.kernel32.WaitForSingleObject.restype = DWORD
|
||||
|
||||
# Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx
|
||||
SYNCHRONIZE = 0x00100000
|
||||
self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid)
|
||||
|
||||
if not self.manager_handle:
|
||||
raise ctypes.WinError(ctypes.get_last_error())
|
||||
|
||||
self.manager_dead = False
|
||||
|
||||
def is_alive(self):
|
||||
if not self.manager_dead:
|
||||
# Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx
|
||||
self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0
|
||||
return not self.manager_dead
|
||||
else:
|
||||
class ManagerWatchdog(object):
|
||||
def __init__(self):
|
||||
self.manager_pid = os.getppid()
|
||||
self.manager_dead = False
|
||||
|
||||
def is_alive(self):
|
||||
if not self.manager_dead:
|
||||
self.manager_dead = os.getppid() != self.manager_pid
|
||||
return not self.manager_dead
|
||||
|
||||
_worker_info = None
|
||||
|
||||
|
||||
class WorkerInfo(object):
|
||||
__initialized = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
self.__initialized = True
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
if self.__initialized:
|
||||
raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
|
||||
return super(WorkerInfo, self).__setattr__(key, val)
|
||||
|
||||
|
||||
def get_worker_info():
|
||||
r"""Returns the information about the current
|
||||
:class:`~torch.utils.data.DataLoader` iterator worker process.
|
||||
|
||||
When called in a worker, this returns an object guaranteed to have the
|
||||
following attributes:
|
||||
|
||||
* :attr:`id`: the current worker id.
|
||||
* :attr:`num_workers`: the total number of workers.
|
||||
* :attr:`seed`: the random seed set for the current worker. This value is
|
||||
determined by main process RNG and the worker id. See
|
||||
:class:`~torch.utils.data.DataLoader`'s documentation for more details.
|
||||
* :attr:`dataset`: the copy of the dataset object in **this** process. Note
|
||||
that this will be a different object in a different process than the one
|
||||
in the main process.
|
||||
|
||||
When called in the main process, this returns ``None``.
|
||||
|
||||
.. note::
|
||||
When used in a :attr:`worker_init_fn` passed over to
|
||||
:class:`~torch.utils.data.DataLoader`, this method can be useful to
|
||||
set up each worker process differently, for instance, using ``worker_id``
|
||||
to configure the ``dataset`` object to only read a specific fraction of a
|
||||
sharded dataset, or use ``seed`` to seed other libraries used in dataset
|
||||
code (e.g., NumPy).
|
||||
"""
|
||||
return _worker_info
|
||||
|
||||
|
||||
r"""Dummy class used to signal the end of an IterableDataset"""
|
||||
_IterableDatasetStopIteration = namedtuple('_IterableDatasetStopIteration', ['worker_id'])
|
||||
|
||||
|
||||
def worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
|
||||
auto_collation, collate_fn, drop_last, seed, init_fn, worker_id,
|
||||
num_workers):
|
||||
# See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
|
||||
# logic of this function.
|
||||
|
||||
try:
|
||||
# Intialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
|
||||
# module's handlers are executed after Python returns from C low-level
|
||||
# handlers, likely when the same fatal signal had already happened
|
||||
# again.
|
||||
# https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
|
||||
signal_handling._set_worker_signal_handlers()
|
||||
|
||||
torch.set_num_threads(1)
|
||||
random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
global _worker_info
|
||||
_worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
|
||||
seed=seed, dataset=dataset)
|
||||
|
||||
from torch.utils.data import _DatasetKind
|
||||
|
||||
init_exception = None
|
||||
|
||||
try:
|
||||
if init_fn is not None:
|
||||
init_fn(worker_id)
|
||||
|
||||
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
|
||||
except Exception:
|
||||
init_exception = ExceptionWrapper(
|
||||
where="in DataLoader worker process {}".format(worker_id))
|
||||
|
||||
# When using Iterable mode, some worker can exit earlier than others due
|
||||
# to the IterableDataset behaving differently for different workers.
|
||||
# When such things happen, an `_IterableDatasetStopIteration` object is
|
||||
# sent over to the main process with the ID of this worker, so that the
|
||||
# main process won't send more tasks to this worker, and will send
|
||||
# `None` to this worker to properly exit it.
|
||||
#
|
||||
# Note that we cannot set `done_event` from a worker as it is shared
|
||||
# among all processes. Instead, we set the `iteration_end` flag to
|
||||
# signify that the iterator is exhausted. When either `done_event` or
|
||||
# `iteration_end` is set, we skip all processing step and just wait for
|
||||
# `None`.
|
||||
iteration_end = False
|
||||
|
||||
watchdog = ManagerWatchdog()
|
||||
|
||||
while watchdog.is_alive():
|
||||
try:
|
||||
r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if r is None:
|
||||
# Received the final signal
|
||||
assert done_event.is_set() or iteration_end
|
||||
break
|
||||
elif done_event.is_set() or iteration_end:
|
||||
# `done_event` is set. But I haven't received the final signal
|
||||
# (None) yet. I will keep continuing until get it, and skip the
|
||||
# processing steps.
|
||||
continue
|
||||
idx, index = r
|
||||
""" Added """
|
||||
MyRandomResizedCrop.sample_image_size(idx)
|
||||
""" Added """
|
||||
if init_exception is not None:
|
||||
data = init_exception
|
||||
init_exception = None
|
||||
else:
|
||||
try:
|
||||
data = fetcher.fetch(index)
|
||||
except Exception as e:
|
||||
if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable:
|
||||
data = _IterableDatasetStopIteration(worker_id)
|
||||
# Set `iteration_end`
|
||||
# (1) to save future `next(...)` calls, and
|
||||
# (2) to avoid sending multiple `_IterableDatasetStopIteration`s.
|
||||
iteration_end = True
|
||||
else:
|
||||
# It is important that we don't store exc_info in a variable.
|
||||
# `ExceptionWrapper` does the correct thing.
|
||||
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
||||
data = ExceptionWrapper(
|
||||
where="in DataLoader worker process {}".format(worker_id))
|
||||
data_queue.put((idx, data))
|
||||
del data, idx, index, r # save memory
|
||||
except KeyboardInterrupt:
|
||||
# Main process will raise KeyboardInterrupt anyways.
|
||||
pass
|
||||
if done_event.is_set():
|
||||
data_queue.cancel_join_thread()
|
||||
data_queue.close()
|
||||
@@ -0,0 +1,69 @@
|
||||
import math
|
||||
import torch
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
__all__ = ['MyDistributedSampler', 'WeightedDistributedSampler']
|
||||
|
||||
|
||||
class MyDistributedSampler(DistributedSampler):
|
||||
""" Allow Subset Sampler in Distributed Training """
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True,
|
||||
sub_index_list=None):
|
||||
super(MyDistributedSampler, self).__init__(dataset, num_replicas, rank, shuffle)
|
||||
self.sub_index_list = sub_index_list # numpy
|
||||
|
||||
self.num_samples = int(math.ceil(len(self.sub_index_list) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
print('Use MyDistributedSampler: %d, %d' % (self.num_samples, self.total_size))
|
||||
|
||||
def __iter__(self):
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
indices = torch.randperm(len(self.sub_index_list), generator=g).tolist()
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices = self.sub_index_list[indices].tolist()
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
|
||||
class WeightedDistributedSampler(DistributedSampler):
|
||||
""" Allow Weighted Random Sampling in Distributed Training """
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True,
|
||||
weights=None, replacement=True):
|
||||
super(WeightedDistributedSampler, self).__init__(dataset, num_replicas, rank, shuffle)
|
||||
|
||||
self.weights = torch.as_tensor(weights, dtype=torch.double) if weights is not None else None
|
||||
self.replacement = replacement
|
||||
print('Use WeightedDistributedSampler')
|
||||
|
||||
def __iter__(self):
|
||||
if self.weights is None:
|
||||
return super(WeightedDistributedSampler, self).__iter__()
|
||||
else:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
if self.shuffle:
|
||||
# original: indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
indices = torch.multinomial(self.weights, len(self.dataset), self.replacement, generator=g).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
@@ -0,0 +1,136 @@
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
__all__ = ['MyRandomResizedCrop', 'MyResizeRandomCrop', 'MyResize']
|
||||
|
||||
_pil_interpolation_to_str = {
|
||||
Image.NEAREST: 'PIL.Image.NEAREST',
|
||||
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
||||
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
||||
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
||||
Image.HAMMING: 'PIL.Image.HAMMING',
|
||||
Image.BOX: 'PIL.Image.BOX',
|
||||
}
|
||||
|
||||
|
||||
class MyRandomResizedCrop(transforms.RandomResizedCrop):
|
||||
ACTIVE_SIZE = 224
|
||||
IMAGE_SIZE_LIST = [224]
|
||||
IMAGE_SIZE_SEG = 4
|
||||
|
||||
CONTINUOUS = False
|
||||
SYNC_DISTRIBUTED = True
|
||||
|
||||
EPOCH = 0
|
||||
BATCH = 0
|
||||
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR):
|
||||
if not isinstance(size, int):
|
||||
size = size[0]
|
||||
super(MyRandomResizedCrop, self).__init__(size, scale, ratio, interpolation)
|
||||
|
||||
def __call__(self, img):
|
||||
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
||||
return F.resized_crop(
|
||||
img, i, j, h, w, (MyRandomResizedCrop.ACTIVE_SIZE, MyRandomResizedCrop.ACTIVE_SIZE), self.interpolation
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_candidate_image_size():
|
||||
if MyRandomResizedCrop.CONTINUOUS:
|
||||
min_size = min(MyRandomResizedCrop.IMAGE_SIZE_LIST)
|
||||
max_size = max(MyRandomResizedCrop.IMAGE_SIZE_LIST)
|
||||
candidate_sizes = []
|
||||
for i in range(min_size, max_size + 1):
|
||||
if i % MyRandomResizedCrop.IMAGE_SIZE_SEG == 0:
|
||||
candidate_sizes.append(i)
|
||||
else:
|
||||
candidate_sizes = MyRandomResizedCrop.IMAGE_SIZE_LIST
|
||||
|
||||
relative_probs = None
|
||||
return candidate_sizes, relative_probs
|
||||
|
||||
@staticmethod
|
||||
def sample_image_size(batch_id=None):
|
||||
if batch_id is None:
|
||||
batch_id = MyRandomResizedCrop.BATCH
|
||||
if MyRandomResizedCrop.SYNC_DISTRIBUTED:
|
||||
_seed = int('%d%.3d' % (batch_id, MyRandomResizedCrop.EPOCH))
|
||||
else:
|
||||
_seed = os.getpid() + time.time()
|
||||
random.seed(_seed)
|
||||
candidate_sizes, relative_probs = MyRandomResizedCrop.get_candidate_image_size()
|
||||
MyRandomResizedCrop.ACTIVE_SIZE = random.choices(candidate_sizes, weights=relative_probs)[0]
|
||||
|
||||
def __repr__(self):
|
||||
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
||||
format_string = self.__class__.__name__ + '(size={0}'.format(MyRandomResizedCrop.IMAGE_SIZE_LIST)
|
||||
if MyRandomResizedCrop.CONTINUOUS:
|
||||
format_string += '@continuous'
|
||||
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
||||
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
|
||||
format_string += ', interpolation={0})'.format(interpolate_str)
|
||||
return format_string
|
||||
|
||||
|
||||
class MyResizeRandomCrop(object):
|
||||
|
||||
def __init__(self, interpolation=Image.BILINEAR,
|
||||
use_padding=False, pad_if_needed=False, fill=0, padding_mode='constant'):
|
||||
# resize
|
||||
self.interpolation = interpolation
|
||||
# random crop
|
||||
self.use_padding = use_padding
|
||||
self.pad_if_needed = pad_if_needed
|
||||
self.fill = fill
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
def __call__(self, img):
|
||||
crop_size = MyRandomResizedCrop.ACTIVE_SIZE
|
||||
|
||||
if not self.use_padding:
|
||||
resize_size = int(math.ceil(crop_size / 0.875))
|
||||
img = F.resize(img, resize_size, self.interpolation)
|
||||
else:
|
||||
img = F.resize(img, crop_size, self.interpolation)
|
||||
padding_size = crop_size // 8
|
||||
img = F.pad(img, padding_size, self.fill, self.padding_mode)
|
||||
|
||||
# pad the width if needed
|
||||
if self.pad_if_needed and img.size[0] < crop_size:
|
||||
img = F.pad(img, (crop_size - img.size[0], 0), self.fill, self.padding_mode)
|
||||
# pad the height if needed
|
||||
if self.pad_if_needed and img.size[1] < crop_size:
|
||||
img = F.pad(img, (0, crop_size - img.size[1]), self.fill, self.padding_mode)
|
||||
|
||||
i, j, h, w = transforms.RandomCrop.get_params(img, (crop_size, crop_size))
|
||||
return F.crop(img, i, j, h, w)
|
||||
|
||||
def __repr__(self):
|
||||
return 'MyResizeRandomCrop(size=%s%s, interpolation=%s, use_padding=%s, fill=%s)' % (
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST, '@continuous' if MyRandomResizedCrop.CONTINUOUS else '',
|
||||
_pil_interpolation_to_str[self.interpolation], self.use_padding, self.fill,
|
||||
)
|
||||
|
||||
|
||||
class MyResize(object):
|
||||
|
||||
def __init__(self, interpolation=Image.BILINEAR):
|
||||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img):
|
||||
target_size = MyRandomResizedCrop.ACTIVE_SIZE
|
||||
img = F.resize(img, target_size, self.interpolation)
|
||||
return img
|
||||
|
||||
def __repr__(self):
|
||||
return 'MyResize(size=%s%s, interpolation=%s)' % (
|
||||
MyRandomResizedCrop.IMAGE_SIZE_LIST, '@continuous' if MyRandomResizedCrop.CONTINUOUS else '',
|
||||
_pil_interpolation_to_str[self.interpolation]
|
||||
)
|
||||
@@ -0,0 +1,238 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .common_tools import min_divisible_value
|
||||
|
||||
__all__ = ['MyModule', 'MyNetwork', 'init_models', 'set_bn_param', 'get_bn_param', 'replace_bn_with_gn',
|
||||
'MyConv2d', 'replace_conv2d_with_my_conv2d']
|
||||
|
||||
|
||||
def set_bn_param(net, momentum, eps, gn_channel_per_group=None, ws_eps=None, **kwargs):
|
||||
replace_bn_with_gn(net, gn_channel_per_group)
|
||||
|
||||
for m in net.modules():
|
||||
if type(m) in [nn.BatchNorm1d, nn.BatchNorm2d]:
|
||||
m.momentum = momentum
|
||||
m.eps = eps
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
m.eps = eps
|
||||
|
||||
replace_conv2d_with_my_conv2d(net, ws_eps)
|
||||
return
|
||||
|
||||
|
||||
def get_bn_param(net):
|
||||
ws_eps = None
|
||||
for m in net.modules():
|
||||
if isinstance(m, MyConv2d):
|
||||
ws_eps = m.WS_EPS
|
||||
break
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
return {
|
||||
'momentum': m.momentum,
|
||||
'eps': m.eps,
|
||||
'ws_eps': ws_eps,
|
||||
}
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
return {
|
||||
'momentum': None,
|
||||
'eps': m.eps,
|
||||
'gn_channel_per_group': m.num_channels // m.num_groups,
|
||||
'ws_eps': ws_eps,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def replace_bn_with_gn(model, gn_channel_per_group):
|
||||
if gn_channel_per_group is None:
|
||||
return
|
||||
|
||||
for m in model.modules():
|
||||
to_replace_dict = {}
|
||||
for name, sub_m in m.named_children():
|
||||
if isinstance(sub_m, nn.BatchNorm2d):
|
||||
num_groups = sub_m.num_features // min_divisible_value(sub_m.num_features, gn_channel_per_group)
|
||||
gn_m = nn.GroupNorm(num_groups=num_groups, num_channels=sub_m.num_features, eps=sub_m.eps, affine=True)
|
||||
|
||||
# load weight
|
||||
gn_m.weight.data.copy_(sub_m.weight.data)
|
||||
gn_m.bias.data.copy_(sub_m.bias.data)
|
||||
# load requires_grad
|
||||
gn_m.weight.requires_grad = sub_m.weight.requires_grad
|
||||
gn_m.bias.requires_grad = sub_m.bias.requires_grad
|
||||
|
||||
to_replace_dict[name] = gn_m
|
||||
m._modules.update(to_replace_dict)
|
||||
|
||||
|
||||
def replace_conv2d_with_my_conv2d(net, ws_eps=None):
|
||||
if ws_eps is None:
|
||||
return
|
||||
|
||||
for m in net.modules():
|
||||
to_update_dict = {}
|
||||
for name, sub_module in m.named_children():
|
||||
if isinstance(sub_module, nn.Conv2d) and not sub_module.bias:
|
||||
# only replace conv2d layers that are followed by normalization layers (i.e., no bias)
|
||||
to_update_dict[name] = sub_module
|
||||
for name, sub_module in to_update_dict.items():
|
||||
m._modules[name] = MyConv2d(
|
||||
sub_module.in_channels, sub_module.out_channels, sub_module.kernel_size, sub_module.stride,
|
||||
sub_module.padding, sub_module.dilation, sub_module.groups, sub_module.bias,
|
||||
)
|
||||
# load weight
|
||||
m._modules[name].load_state_dict(sub_module.state_dict())
|
||||
# load requires_grad
|
||||
m._modules[name].weight.requires_grad = sub_module.weight.requires_grad
|
||||
if sub_module.bias is not None:
|
||||
m._modules[name].bias.requires_grad = sub_module.bias.requires_grad
|
||||
# set ws_eps
|
||||
for m in net.modules():
|
||||
if isinstance(m, MyConv2d):
|
||||
m.WS_EPS = ws_eps
|
||||
|
||||
|
||||
def init_models(net, model_init='he_fout'):
|
||||
"""
|
||||
Conv2d,
|
||||
BatchNorm2d, BatchNorm1d, GroupNorm
|
||||
Linear,
|
||||
"""
|
||||
if isinstance(net, list):
|
||||
for sub_net in net:
|
||||
init_models(sub_net, model_init)
|
||||
return
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
if model_init == 'he_fout':
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif model_init == 'he_fin':
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif type(m) in [nn.BatchNorm2d, nn.BatchNorm1d, nn.GroupNorm]:
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
stdv = 1. / math.sqrt(m.weight.size(1))
|
||||
m.weight.data.uniform_(-stdv, stdv)
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
class MyConv2d(nn.Conv2d):
|
||||
"""
|
||||
Conv2d with Weight Standardization
|
||||
https://github.com/joe-siyuan-qiao/WeightStandardization
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1, bias=True):
|
||||
super(MyConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
||||
self.WS_EPS = None
|
||||
|
||||
def weight_standardization(self, weight):
|
||||
if self.WS_EPS is not None:
|
||||
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
|
||||
weight = weight - weight_mean
|
||||
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + self.WS_EPS
|
||||
weight = weight / std.expand_as(weight)
|
||||
return weight
|
||||
|
||||
def forward(self, x):
|
||||
if self.WS_EPS is None:
|
||||
return super(MyConv2d, self).forward(x)
|
||||
else:
|
||||
return F.conv2d(x, self.weight_standardization(self.weight), self.bias,
|
||||
self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
def __repr__(self):
|
||||
return super(MyConv2d, self).__repr__()[:-1] + ', ws_eps=%s)' % self.WS_EPS
|
||||
|
||||
|
||||
class MyModule(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MyNetwork(MyModule):
|
||||
CHANNEL_DIVISIBLE = 8
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def module_str(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def build_from_config(config):
|
||||
raise NotImplementedError
|
||||
|
||||
def zero_last_gamma(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def grouped_block_index(self):
|
||||
raise NotImplementedError
|
||||
|
||||
""" implemented methods """
|
||||
|
||||
def set_bn_param(self, momentum, eps, gn_channel_per_group=None, **kwargs):
|
||||
set_bn_param(self, momentum, eps, gn_channel_per_group, **kwargs)
|
||||
|
||||
def get_bn_param(self):
|
||||
return get_bn_param(self)
|
||||
|
||||
def get_parameters(self, keys=None, mode='include'):
|
||||
if keys is None:
|
||||
for name, param in self.named_parameters():
|
||||
if param.requires_grad: yield param
|
||||
elif mode == 'include':
|
||||
for name, param in self.named_parameters():
|
||||
flag = False
|
||||
for key in keys:
|
||||
if key in name:
|
||||
flag = True
|
||||
break
|
||||
if flag and param.requires_grad: yield param
|
||||
elif mode == 'exclude':
|
||||
for name, param in self.named_parameters():
|
||||
flag = True
|
||||
for key in keys:
|
||||
if key in name:
|
||||
flag = False
|
||||
break
|
||||
if flag and param.requires_grad: yield param
|
||||
else:
|
||||
raise ValueError('do not support: %s' % mode)
|
||||
|
||||
def weight_parameters(self):
|
||||
return self.get_parameters()
|
||||
@@ -0,0 +1,154 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
from .my_modules import MyNetwork
|
||||
|
||||
__all__ = [
|
||||
'make_divisible', 'build_activation', 'ShuffleLayer', 'MyGlobalAvgPool2d', 'Hswish', 'Hsigmoid', 'SEModule',
|
||||
'MultiHeadCrossEntropyLoss'
|
||||
]
|
||||
|
||||
|
||||
def make_divisible(v, divisor, min_val=None):
|
||||
"""
|
||||
This function is taken from the original tf repo.
|
||||
It ensures that all layers have a channel number that is divisible by 8
|
||||
It can be seen here:
|
||||
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
|
||||
:param v:
|
||||
:param divisor:
|
||||
:param min_val:
|
||||
:return:
|
||||
"""
|
||||
if min_val is None:
|
||||
min_val = divisor
|
||||
new_v = max(min_val, int(v + divisor / 2) // divisor * divisor)
|
||||
# Make sure that round down does not go down by more than 10%.
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
def build_activation(act_func, inplace=True):
|
||||
if act_func == 'relu':
|
||||
return nn.ReLU(inplace=inplace)
|
||||
elif act_func == 'relu6':
|
||||
return nn.ReLU6(inplace=inplace)
|
||||
elif act_func == 'tanh':
|
||||
return nn.Tanh()
|
||||
elif act_func == 'sigmoid':
|
||||
return nn.Sigmoid()
|
||||
elif act_func == 'h_swish':
|
||||
return Hswish(inplace=inplace)
|
||||
elif act_func == 'h_sigmoid':
|
||||
return Hsigmoid(inplace=inplace)
|
||||
elif act_func is None or act_func == 'none':
|
||||
return None
|
||||
else:
|
||||
raise ValueError('do not support: %s' % act_func)
|
||||
|
||||
|
||||
class ShuffleLayer(nn.Module):
|
||||
|
||||
def __init__(self, groups):
|
||||
super(ShuffleLayer, self).__init__()
|
||||
self.groups = groups
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, num_channels, height, width = x.size()
|
||||
channels_per_group = num_channels // self.groups
|
||||
# reshape
|
||||
x = x.view(batch_size, self.groups, channels_per_group, height, width)
|
||||
x = torch.transpose(x, 1, 2).contiguous()
|
||||
# flatten
|
||||
x = x.view(batch_size, -1, height, width)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
return 'ShuffleLayer(groups=%d)' % self.groups
|
||||
|
||||
|
||||
class MyGlobalAvgPool2d(nn.Module):
|
||||
|
||||
def __init__(self, keep_dim=True):
|
||||
super(MyGlobalAvgPool2d, self).__init__()
|
||||
self.keep_dim = keep_dim
|
||||
|
||||
def forward(self, x):
|
||||
return x.mean(3, keepdim=self.keep_dim).mean(2, keepdim=self.keep_dim)
|
||||
|
||||
def __repr__(self):
|
||||
return 'MyGlobalAvgPool2d(keep_dim=%s)' % self.keep_dim
|
||||
|
||||
|
||||
class Hswish(nn.Module):
|
||||
|
||||
def __init__(self, inplace=True):
|
||||
super(Hswish, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return x * F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||
|
||||
def __repr__(self):
|
||||
return 'Hswish()'
|
||||
|
||||
|
||||
class Hsigmoid(nn.Module):
|
||||
|
||||
def __init__(self, inplace=True):
|
||||
super(Hsigmoid, self).__init__()
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu6(x + 3., inplace=self.inplace) / 6.
|
||||
|
||||
def __repr__(self):
|
||||
return 'Hsigmoid()'
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
REDUCTION = 4
|
||||
|
||||
def __init__(self, channel, reduction=None):
|
||||
super(SEModule, self).__init__()
|
||||
|
||||
self.channel = channel
|
||||
self.reduction = SEModule.REDUCTION if reduction is None else reduction
|
||||
|
||||
num_mid = make_divisible(self.channel // self.reduction, divisor=MyNetwork.CHANNEL_DIVISIBLE)
|
||||
|
||||
self.fc = nn.Sequential(OrderedDict([
|
||||
('reduce', nn.Conv2d(self.channel, num_mid, 1, 1, 0, bias=True)),
|
||||
('relu', nn.ReLU(inplace=True)),
|
||||
('expand', nn.Conv2d(num_mid, self.channel, 1, 1, 0, bias=True)),
|
||||
('h_sigmoid', Hsigmoid(inplace=True)),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
y = x.mean(3, keepdim=True).mean(2, keepdim=True)
|
||||
y = self.fc(y)
|
||||
return x * y
|
||||
|
||||
def __repr__(self):
|
||||
return 'SE(channel=%d, reduction=%d)' % (self.channel, self.reduction)
|
||||
|
||||
|
||||
class MultiHeadCrossEntropyLoss(nn.Module):
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
assert outputs.dim() == 3, outputs
|
||||
assert targets.dim() == 2, targets
|
||||
|
||||
assert outputs.size(1) == targets.size(1), (outputs, targets)
|
||||
num_heads = targets.size(1)
|
||||
|
||||
loss = 0
|
||||
for k in range(num_heads):
|
||||
loss += F.cross_entropy(outputs[:, k, :], targets[:, k]) / num_heads
|
||||
return loss
|
||||
@@ -0,0 +1,218 @@
|
||||
# Once for All: Train One Network and Specialize it for Efficient Deployment
|
||||
# Han Cai, Chuang Gan, Tianzhe Wang, Zhekai Zhang, Song Han
|
||||
# International Conference on Learning Representations (ICLR), 2020.
|
||||
|
||||
import math
|
||||
import copy
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
'mix_images', 'mix_labels',
|
||||
'label_smooth', 'cross_entropy_loss_with_soft_target', 'cross_entropy_with_label_smoothing',
|
||||
'clean_num_batch_tracked', 'rm_bn_from_net',
|
||||
'get_net_device', 'count_parameters', 'count_net_flops', 'measure_net_latency', 'get_net_info',
|
||||
'build_optimizer', 'calc_learning_rate',
|
||||
]
|
||||
|
||||
|
||||
""" Mixup """
|
||||
def mix_images(images, lam):
|
||||
flipped_images = torch.flip(images, dims=[0]) # flip along the batch dimension
|
||||
return lam * images + (1 - lam) * flipped_images
|
||||
|
||||
|
||||
def mix_labels(target, lam, n_classes, label_smoothing=0.1):
|
||||
onehot_target = label_smooth(target, n_classes, label_smoothing)
|
||||
flipped_target = torch.flip(onehot_target, dims=[0])
|
||||
return lam * onehot_target + (1 - lam) * flipped_target
|
||||
|
||||
|
||||
""" Label smooth """
|
||||
def label_smooth(target, n_classes: int, label_smoothing=0.1):
|
||||
# convert to one-hot
|
||||
batch_size = target.size(0)
|
||||
target = torch.unsqueeze(target, 1)
|
||||
soft_target = torch.zeros((batch_size, n_classes), device=target.device)
|
||||
soft_target.scatter_(1, target, 1)
|
||||
# label smoothing
|
||||
soft_target = soft_target * (1 - label_smoothing) + label_smoothing / n_classes
|
||||
return soft_target
|
||||
|
||||
|
||||
def cross_entropy_loss_with_soft_target(pred, soft_target):
|
||||
logsoftmax = nn.LogSoftmax()
|
||||
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
|
||||
|
||||
|
||||
def cross_entropy_with_label_smoothing(pred, target, label_smoothing=0.1):
|
||||
soft_target = label_smooth(target, pred.size(1), label_smoothing)
|
||||
return cross_entropy_loss_with_soft_target(pred, soft_target)
|
||||
|
||||
|
||||
""" BN related """
|
||||
def clean_num_batch_tracked(net):
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
if m.num_batches_tracked is not None:
|
||||
m.num_batches_tracked.zero_()
|
||||
|
||||
|
||||
def rm_bn_from_net(net):
|
||||
for m in net.modules():
|
||||
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
|
||||
m.forward = lambda x: x
|
||||
|
||||
|
||||
""" Network profiling """
|
||||
def get_net_device(net):
|
||||
return net.parameters().__next__().device
|
||||
|
||||
|
||||
def count_parameters(net):
|
||||
total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||
return total_params
|
||||
|
||||
|
||||
def count_net_flops(net, data_shape=(1, 3, 224, 224)):
|
||||
from .flops_counter import profile
|
||||
if isinstance(net, nn.DataParallel):
|
||||
net = net.module
|
||||
|
||||
flop, _ = profile(copy.deepcopy(net), data_shape)
|
||||
return flop
|
||||
|
||||
|
||||
def measure_net_latency(net, l_type='gpu8', fast=True, input_shape=(3, 224, 224), clean=False):
|
||||
if isinstance(net, nn.DataParallel):
|
||||
net = net.module
|
||||
|
||||
# remove bn from graph
|
||||
rm_bn_from_net(net)
|
||||
|
||||
# return `ms`
|
||||
if 'gpu' in l_type:
|
||||
l_type, batch_size = l_type[:3], int(l_type[3:])
|
||||
else:
|
||||
batch_size = 1
|
||||
|
||||
data_shape = [batch_size] + list(input_shape)
|
||||
if l_type == 'cpu':
|
||||
if fast:
|
||||
n_warmup = 5
|
||||
n_sample = 10
|
||||
else:
|
||||
n_warmup = 50
|
||||
n_sample = 50
|
||||
if get_net_device(net) != torch.device('cpu'):
|
||||
if not clean:
|
||||
print('move net to cpu for measuring cpu latency')
|
||||
net = copy.deepcopy(net).cpu()
|
||||
elif l_type == 'gpu':
|
||||
if fast:
|
||||
n_warmup = 5
|
||||
n_sample = 10
|
||||
else:
|
||||
n_warmup = 50
|
||||
n_sample = 50
|
||||
else:
|
||||
raise NotImplementedError
|
||||
images = torch.zeros(data_shape, device=get_net_device(net))
|
||||
|
||||
measured_latency = {'warmup': [], 'sample': []}
|
||||
net.eval()
|
||||
with torch.no_grad():
|
||||
for i in range(n_warmup):
|
||||
inner_start_time = time.time()
|
||||
net(images)
|
||||
used_time = (time.time() - inner_start_time) * 1e3 # ms
|
||||
measured_latency['warmup'].append(used_time)
|
||||
if not clean:
|
||||
print('Warmup %d: %.3f' % (i, used_time))
|
||||
outer_start_time = time.time()
|
||||
for i in range(n_sample):
|
||||
net(images)
|
||||
total_time = (time.time() - outer_start_time) * 1e3 # ms
|
||||
measured_latency['sample'].append((total_time, n_sample))
|
||||
return total_time / n_sample, measured_latency
|
||||
|
||||
|
||||
def get_net_info(net, input_shape=(3, 224, 224), measure_latency=None, print_info=True):
|
||||
net_info = {}
|
||||
if isinstance(net, nn.DataParallel):
|
||||
net = net.module
|
||||
|
||||
# parameters
|
||||
net_info['params'] = count_parameters(net) / 1e6
|
||||
|
||||
# flops
|
||||
net_info['flops'] = count_net_flops(net, [1] + list(input_shape)) / 1e6
|
||||
|
||||
# latencies
|
||||
latency_types = [] if measure_latency is None else measure_latency.split('#')
|
||||
for l_type in latency_types:
|
||||
latency, measured_latency = measure_net_latency(net, l_type, fast=False, input_shape=input_shape)
|
||||
net_info['%s latency' % l_type] = {
|
||||
'val': latency,
|
||||
'hist': measured_latency
|
||||
}
|
||||
|
||||
if print_info:
|
||||
print(net)
|
||||
print('Total training params: %.2fM' % (net_info['params']))
|
||||
print('Total FLOPs: %.2fM' % (net_info['flops']))
|
||||
for l_type in latency_types:
|
||||
print('Estimated %s latency: %.3fms' % (l_type, net_info['%s latency' % l_type]['val']))
|
||||
|
||||
return net_info
|
||||
|
||||
|
||||
""" optimizer """
|
||||
def build_optimizer(net_params, opt_type, opt_param, init_lr, weight_decay, no_decay_keys, seperate=1.0):
|
||||
# enc_list, dec_list = [], []
|
||||
# for name, param in model.named_parameters():
|
||||
# if ('setenc' in name) or ('fc1' in name) or ('fc2' in name):
|
||||
# enc_list.append(param)
|
||||
# else:
|
||||
# dec_list.append(param)
|
||||
#optimizer = optim.Adam([{'params': dec_list, 'lr': args.dec_lr},
|
||||
# {'params': enc_list, 'lr': args.enc_lr}], lr=1e-4)
|
||||
if no_decay_keys is not None:
|
||||
assert isinstance(net_params, list) and len(net_params) == 2
|
||||
net_params = [
|
||||
{'params': net_params[0], 'weight_decay': weight_decay},
|
||||
{'params': net_params[1], 'weight_decay': 0},
|
||||
]
|
||||
elif seperate != 1.0:
|
||||
net_params = [{'params': net_params[0], 'weight_decay': weight_decay, 'lr': init_lr * seperate},
|
||||
{'params': net_params[1], 'weight_decay': weight_decay, 'lr': init_lr}]
|
||||
else:
|
||||
net_params = [{'params': net_params, 'weight_decay': weight_decay}]
|
||||
|
||||
if opt_type == 'sgd':
|
||||
opt_param = {} if opt_param is None else opt_param
|
||||
momentum, nesterov = opt_param.get('momentum', 0.9), opt_param.get('nesterov', True)
|
||||
optimizer = torch.optim.SGD(net_params, init_lr, momentum=momentum, nesterov=nesterov)
|
||||
elif opt_type == 'adam':
|
||||
optimizer = torch.optim.Adam(net_params, init_lr)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return optimizer
|
||||
|
||||
|
||||
""" learning rate schedule """
|
||||
def calc_learning_rate(epoch, init_lr, n_epochs, batch=0,
|
||||
nBatch=None, lr_schedule_type='cosine', optimizer=None):
|
||||
if lr_schedule_type == 'cosine':
|
||||
t_total = n_epochs * nBatch
|
||||
t_cur = epoch * nBatch + batch
|
||||
lr = 0.5 * init_lr * (1 + math.cos(math.pi * t_cur / t_total))
|
||||
elif lr_schedule_type == 'reduce':
|
||||
for param_group in optimizer.param_groups:
|
||||
lr = param_group['lr']
|
||||
elif lr_schedule_type is None:
|
||||
lr = init_lr
|
||||
else:
|
||||
raise ValueError('do not support: %s' % lr_schedule_type)
|
||||
return lr
|
||||
@@ -0,0 +1,43 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
import argparse
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ['t', 'true', True]
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
# general settings
|
||||
parser.add_argument('--seed', type=int, default=333)
|
||||
parser.add_argument('--gpu', type=str, default='0', help='set visible gpus')
|
||||
parser.add_argument('--model_name', type=str, default=None, choices=['generator', 'predictor', 'train_arch'])
|
||||
parser.add_argument('--save-path', type=str, default='results', help='the path of save directory')
|
||||
parser.add_argument('--data-path', type=str, default='data', help='the path of save directory')
|
||||
parser.add_argument('--save-epoch', type=int, default=20, help='how many epochs to wait each time to save model states')
|
||||
parser.add_argument('--max-epoch', type=int, default=400, help='number of epochs to train')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='batch size for generator')
|
||||
parser.add_argument('--graph-data-name', default='ofa_mbv3', help='graph dataset name')
|
||||
parser.add_argument('--nvt', type=int, default=27, help='number of different node types, 21 for ofa_mbv3 without in/out node')
|
||||
# set encoder
|
||||
parser.add_argument('--num-sample', type=int, default=20, help='the number of images as input for set encoder')
|
||||
# graph encoder
|
||||
parser.add_argument('--hs', type=int, default=56, help='hidden size of GRUs')
|
||||
parser.add_argument('--nz', type=int, default=56, help='the number of dimensions of latent vectors z')
|
||||
# test
|
||||
parser.add_argument('--test', action='store_true', default=False, help='turn on test mode')
|
||||
parser.add_argument('--load-epoch', type=int, default=20, help='checkpoint epoch loaded for meta-test')
|
||||
parser.add_argument('--data-name', type=str, default=None, help='meta-test dataset name')
|
||||
parser.add_argument('--num-class', type=int, default=None, help='the number of class of dataset')
|
||||
parser.add_argument('--num-gen-arch', type=int, default=200, help='the number of candidate architectures generated by the generator')
|
||||
parser.add_argument('--train-arch', type=str2bool, default=True, help='whether to train the searched architecture')
|
||||
# database
|
||||
parser.add_argument('--index', type=int, default=None, help='the process number when creating DB')
|
||||
parser.add_argument('--imgnet', type=str, default=None, help='The path of imagenet')
|
||||
parser.add_argument('--collect', action='store_true', default=False, help='whether to train the searched architecture')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
@@ -0,0 +1,6 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from .predictor import Predictor
|
||||
from .predictor_model import PredictorModel
|
||||
@@ -0,0 +1,172 @@
|
||||
###########################################################################################
|
||||
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||||
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||||
###########################################################################################
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import time
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from torch import nn, optim
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from scipy.stats import pearsonr
|
||||
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import load_graph_config, decode_ofa_mbv3_to_igraph
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import Log, get_log
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.metad2a_utils import load_model, save_model
|
||||
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.loader import get_meta_train_loader
|
||||
from .predictor_model import PredictorModel
|
||||
from all_path import *
|
||||
|
||||
|
||||
class Predictor:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self.batch_size = args.batch_size
|
||||
self.data_path = args.data_path
|
||||
self.num_sample = args.num_sample
|
||||
self.max_epoch = args.max_epoch
|
||||
self.save_epoch = args.save_epoch
|
||||
self.model_path = UNNOISE_META_PREDICTOR_CKPT_PATH #MODEL_METAD2A_PATH_OFA
|
||||
self.save_path = args.save_path
|
||||
self.model_name = 'predictor'
|
||||
self.test = args.test
|
||||
self.device = torch.device("cuda:0")
|
||||
self.max_corr_dict = {'corr': -1, 'epoch': -1}
|
||||
self.train_arch = args.train_arch
|
||||
|
||||
graph_config = load_graph_config(
|
||||
args.graph_data_name, args.nvt, args.data_path)
|
||||
|
||||
self.model = PredictorModel(args, graph_config)
|
||||
self.model.to(self.device)
|
||||
|
||||
if self.test:
|
||||
self.data_name = args.data_name
|
||||
self.num_class = args.num_class
|
||||
self.load_epoch = args.load_epoch
|
||||
load_model(self.model, self.model_path, load_max_pt='ckpt_max_corr.pt')
|
||||
self.model.to(self.device)
|
||||
else:
|
||||
self.optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
|
||||
self.scheduler = ReduceLROnPlateau(self.optimizer, 'min',
|
||||
factor=0.1, patience=10, verbose=True)
|
||||
self.mtrloader = get_meta_train_loader(
|
||||
self.batch_size, self.data_path, self.num_sample, is_pred=True)
|
||||
|
||||
self.acc_mean = self.mtrloader.dataset.mean
|
||||
self.acc_std = self.mtrloader.dataset.std
|
||||
|
||||
self.mtrlog = Log(self.args, open(os.path.join(
|
||||
self.save_path, self.model_name, 'meta_train_predictor.log'), 'w'))
|
||||
self.mtrlog.print_args()
|
||||
|
||||
def forward(self, x, arch):
|
||||
D_mu = self.model.set_encode(x.unsqueeze(0).to(self.device)).unsqueeze(0)
|
||||
G_mu = self.model.graph_encode(arch[0])
|
||||
y_pred = self.model.predict(D_mu, G_mu)
|
||||
return y_pred
|
||||
|
||||
def meta_train(self):
|
||||
sttime = time.time()
|
||||
for epoch in range(1, self.max_epoch + 1):
|
||||
self.mtrlog.ep_sttime = time.time()
|
||||
loss, corr = self.meta_train_epoch(epoch)
|
||||
self.scheduler.step(loss)
|
||||
self.mtrlog.print_pred_log(loss, corr, 'train', epoch)
|
||||
valoss, vacorr = self.meta_validation(epoch)
|
||||
if self.max_corr_dict['corr'] < vacorr:
|
||||
self.max_corr_dict['corr'] = vacorr
|
||||
self.max_corr_dict['epoch'] = epoch
|
||||
self.max_corr_dict['loss'] = valoss
|
||||
save_model(epoch, self.model, self.model_path, max_corr=True)
|
||||
|
||||
self.mtrlog.print_pred_log(
|
||||
valoss, vacorr, 'valid', max_corr_dict=self.max_corr_dict)
|
||||
|
||||
if epoch % self.save_epoch == 0:
|
||||
save_model(epoch, self.model, self.model_path)
|
||||
|
||||
self.mtrlog.save_time_log()
|
||||
self.mtrlog.max_corr_log(self.max_corr_dict)
|
||||
|
||||
def meta_train_epoch(self, epoch):
|
||||
self.model.to(self.device)
|
||||
self.model.train()
|
||||
|
||||
self.mtrloader.dataset.set_mode('train')
|
||||
|
||||
dlen = len(self.mtrloader.dataset)
|
||||
trloss = 0
|
||||
y_all, y_pred_all = [], []
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
for batch in pbar:
|
||||
batch_loss = 0
|
||||
y_batch, y_pred_batch = [], []
|
||||
self.optimizer.zero_grad()
|
||||
for x, g, acc in batch:
|
||||
y_pred = self.forward(x, decode_ofa_mbv3_to_igraph(g))
|
||||
|
||||
y = acc.to(self.device)
|
||||
batch_loss += self.model.mseloss(y_pred, y)
|
||||
|
||||
y = y.squeeze().tolist()
|
||||
y_pred = y_pred.squeeze().tolist()
|
||||
|
||||
y_batch.append(y)
|
||||
y_pred_batch.append(y_pred)
|
||||
y_all.append(y)
|
||||
y_pred_all.append(y_pred)
|
||||
|
||||
batch_loss.backward()
|
||||
trloss += float(batch_loss)
|
||||
self.optimizer.step()
|
||||
pbar.set_description(get_log(
|
||||
epoch, batch_loss, y_pred_batch, y_batch, self.acc_std, self.acc_mean))
|
||||
|
||||
return trloss / dlen, pearsonr(np.array(y_all),
|
||||
np.array(y_pred_all))[0]
|
||||
|
||||
|
||||
def meta_validation(self, epoch):
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
valoss = 0
|
||||
self.mtrloader.dataset.set_mode('valid')
|
||||
dlen = len(self.mtrloader.dataset)
|
||||
y_all, y_pred_all = [], []
|
||||
pbar = tqdm(self.mtrloader)
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in pbar:
|
||||
batch_loss = 0
|
||||
y_batch, y_pred_batch = [], []
|
||||
|
||||
for x, g, acc in batch:
|
||||
y_pred = self.forward(x, decode_ofa_mbv3_to_igraph(g))
|
||||
|
||||
y = acc.to(self.device)
|
||||
batch_loss += self.model.mseloss(y_pred, y)
|
||||
|
||||
y = y.squeeze().tolist()
|
||||
y_pred = y_pred.squeeze().tolist()
|
||||
|
||||
y_batch.append(y)
|
||||
y_pred_batch.append(y_pred)
|
||||
y_all.append(y)
|
||||
y_pred_all.append(y_pred)
|
||||
|
||||
valoss += float(batch_loss)
|
||||
pbar.set_description(get_log(
|
||||
epoch, batch_loss, y_pred_batch, y_batch, self.acc_std, self.acc_mean, tag='val'))
|
||||
return valoss / dlen, pearsonr(np.array(y_all),
|
||||
np.array(y_pred_all))[0]
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
######################################################################################
|
||||
# Copyright (c) muhanzhang, D-VAE, NeurIPS 2019 [GitHub D-VAE]
|
||||
# Modified by Hayeon Lee, Eunyoung Hyung, MetaD2A, ICLR2021, 2021. 03 [GitHub MetaD2A]
|
||||
######################################################################################
|
||||
import torch
|
||||
from torch import nn
|
||||
from transfer_nag_lib.MetaD2A_mobilenetV3.set_encoder.setenc_models import SetPool
|
||||
|
||||
|
||||
class PredictorModel(nn.Module):
|
||||
def __init__(self, args, graph_config):
|
||||
super(PredictorModel, self).__init__()
|
||||
self.max_n = graph_config['max_n'] # maximum number of vertices
|
||||
self.nvt = graph_config['num_vertex_type'] # number of vertex types
|
||||
self.START_TYPE = graph_config['START_TYPE']
|
||||
self.END_TYPE = graph_config['END_TYPE']
|
||||
# import pdb; pdb.set_trace()
|
||||
self.hs = args.hs # hidden state size of each vertex
|
||||
self.nz = args.nz # size of latent representation z
|
||||
self.gs = args.hs # size of graph state
|
||||
self.bidir = True # whether to use bidirectional encoding
|
||||
self.vid = True
|
||||
self.device = None
|
||||
self.input_type = 'DG'
|
||||
self.num_sample = args.num_sample
|
||||
|
||||
if self.vid:
|
||||
self.vs = self.hs + self.max_n # vertex state size = hidden state + vid
|
||||
else:
|
||||
self.vs = self.hs
|
||||
|
||||
# 0. encoding-related
|
||||
self.grue_forward = nn.GRUCell(self.nvt, self.hs) # encoder GRU
|
||||
self.grue_backward = nn.GRUCell(self.nvt, self.hs) # backward encoder GRU
|
||||
self.fc1 = nn.Linear(self.gs, self.nz) # latent mean
|
||||
self.fc2 = nn.Linear(self.gs, self.nz) # latent logvar
|
||||
|
||||
# 2. gate-related
|
||||
self.gate_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.gate_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
self.mapper_forward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
) # disable bias to ensure padded zeros also mapped to zeros
|
||||
self.mapper_backward = nn.Sequential(
|
||||
nn.Linear(self.vs, self.hs, bias=False),
|
||||
)
|
||||
|
||||
# 3. bidir-related, to unify sizes
|
||||
if self.bidir:
|
||||
self.hv_unify = nn.Sequential(
|
||||
nn.Linear(self.hs * 2, self.hs),
|
||||
)
|
||||
self.hg_unify = nn.Sequential(
|
||||
nn.Linear(self.gs * 2, self.gs),
|
||||
)
|
||||
|
||||
# 4. other
|
||||
self.relu = nn.ReLU()
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.tanh = nn.Tanh()
|
||||
self.logsoftmax1 = nn.LogSoftmax(1)
|
||||
|
||||
# 6. predictor
|
||||
np = self.gs
|
||||
self.intra_setpool = SetPool(dim_input=512,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.inter_setpool = SetPool(dim_input=self.nz,
|
||||
num_outputs=1,
|
||||
dim_output=self.nz,
|
||||
dim_hidden=self.nz,
|
||||
mode='sabPF')
|
||||
self.set_fc = nn.Sequential(
|
||||
nn.Linear(512, self.nz),
|
||||
nn.ReLU())
|
||||
|
||||
input_dim = 0
|
||||
if 'D' in self.input_type:
|
||||
input_dim += self.nz
|
||||
if 'G' in self.input_type:
|
||||
input_dim += self.nz
|
||||
|
||||
self.pred_fc = nn.Sequential(
|
||||
nn.Linear(input_dim, self.hs),
|
||||
nn.Tanh(),
|
||||
nn.Linear(self.hs, 1)
|
||||
)
|
||||
self.mseloss = nn.MSELoss(reduction='sum')
|
||||
|
||||
|
||||
def predict(self, D_mu, G_mu):
|
||||
input_vec = []
|
||||
if 'D' in self.input_type:
|
||||
input_vec.append(D_mu)
|
||||
if 'G' in self.input_type:
|
||||
input_vec.append(G_mu)
|
||||
input_vec = torch.cat(input_vec, dim=1)
|
||||
return self.pred_fc(input_vec)
|
||||
|
||||
def get_device(self):
|
||||
if self.device is None:
|
||||
self.device = next(self.parameters()).device
|
||||
return self.device
|
||||
|
||||
def _get_zeros(self, n, length):
|
||||
return torch.zeros(n, length).to(self.get_device()) # get a zero hidden state
|
||||
|
||||
def _get_zero_hidden(self, n=1):
|
||||
return self._get_zeros(n, self.hs) # get a zero hidden state
|
||||
|
||||
def _one_hot(self, idx, length):
|
||||
if type(idx) in [list, range]:
|
||||
if idx == []:
|
||||
return None
|
||||
idx = torch.LongTensor(idx).unsqueeze(0).t()
|
||||
x = torch.zeros((len(idx), length)).scatter_(1, idx, 1).to(self.get_device())
|
||||
else:
|
||||
idx = torch.LongTensor([idx]).unsqueeze(0)
|
||||
x = torch.zeros((1, length)).scatter_(1, idx, 1).to(self.get_device())
|
||||
return x
|
||||
|
||||
def _gated(self, h, gate, mapper):
|
||||
return gate(h) * mapper(h)
|
||||
|
||||
def _collate_fn(self, G):
|
||||
return [g.copy() for g in G]
|
||||
|
||||
def _propagate_to(self, G, v, propagator, H=None, reverse=False, gate=None, mapper=None):
|
||||
# propagate messages to vertex index v for all graphs in G
|
||||
# return the new messages (states) at v
|
||||
G = [g for g in G if g.vcount() > v]
|
||||
if len(G) == 0:
|
||||
return
|
||||
if H is not None:
|
||||
idx = [i for i, g in enumerate(G) if g.vcount() > v]
|
||||
H = H[idx]
|
||||
v_types = [g.vs[v]['type'] for g in G]
|
||||
X = self._one_hot(v_types, self.nvt)
|
||||
if reverse:
|
||||
H_name = 'H_backward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.successors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.successors(v), self.max_n) for g in G]
|
||||
gate, mapper = self.gate_backward, self.mapper_backward
|
||||
else:
|
||||
H_name = 'H_forward' # name of the hidden states attribute
|
||||
H_pred = [[g.vs[x][H_name] for x in g.predecessors(v)] for g in G]
|
||||
if self.vid:
|
||||
vids = [self._one_hot(g.predecessors(v), self.max_n) for g in G]
|
||||
if gate is None:
|
||||
gate, mapper = self.gate_forward, self.mapper_forward
|
||||
if self.vid:
|
||||
H_pred = [[torch.cat([x[i], y[i:i + 1]], 1) for i in range(len(x))] for x, y in zip(H_pred, vids)]
|
||||
# if h is not provided, use gated sum of v's predecessors' states as the input hidden state
|
||||
if H is None:
|
||||
max_n_pred = max([len(x) for x in H_pred]) # maximum number of predecessors
|
||||
if max_n_pred == 0:
|
||||
H = self._get_zero_hidden(len(G))
|
||||
else:
|
||||
H_pred = [torch.cat(h_pred +
|
||||
[self._get_zeros(max_n_pred - len(h_pred), self.vs)], 0).unsqueeze(0)
|
||||
for h_pred in H_pred] # pad all to same length
|
||||
H_pred = torch.cat(H_pred, 0) # batch * max_n_pred * vs
|
||||
H = self._gated(H_pred, gate, mapper).sum(1) # batch * hs
|
||||
Hv = propagator(X, H)
|
||||
for i, g in enumerate(G):
|
||||
g.vs[v][H_name] = Hv[i:i + 1]
|
||||
return Hv
|
||||
|
||||
def _propagate_from(self, G, v, propagator, H0=None, reverse=False):
|
||||
# perform a series of propagation_to steps starting from v following a topo order
|
||||
# assume the original vertex indices are in a topological order
|
||||
if reverse:
|
||||
prop_order = range(v, -1, -1)
|
||||
else:
|
||||
prop_order = range(v, self.max_n)
|
||||
Hv = self._propagate_to(G, v, propagator, H0, reverse=reverse) # the initial vertex
|
||||
for v_ in prop_order[1:]:
|
||||
self._propagate_to(G, v_, propagator, reverse=reverse)
|
||||
return Hv
|
||||
|
||||
def _get_graph_state(self, G, decode=False):
|
||||
# get the graph states
|
||||
# when decoding, use the last generated vertex's state as the graph state
|
||||
# when encoding, use the ending vertex state or unify the starting and ending vertex states
|
||||
Hg = []
|
||||
for g in G:
|
||||
hg = g.vs[g.vcount() - 1]['H_forward']
|
||||
if self.bidir and not decode: # decoding never uses backward propagation
|
||||
hg_b = g.vs[0]['H_backward']
|
||||
hg = torch.cat([hg, hg_b], 1)
|
||||
Hg.append(hg)
|
||||
Hg = torch.cat(Hg, 0)
|
||||
if self.bidir and not decode:
|
||||
Hg = self.hg_unify(Hg)
|
||||
return Hg
|
||||
|
||||
|
||||
def set_encode(self, X):
|
||||
proto_batch = []
|
||||
for x in X:
|
||||
cls_protos = self.intra_setpool(
|
||||
x.view(-1, self.num_sample, 512)).squeeze(1)
|
||||
proto_batch.append(
|
||||
self.inter_setpool(cls_protos.unsqueeze(0)))
|
||||
v = torch.stack(proto_batch).squeeze()
|
||||
return v
|
||||
|
||||
|
||||
def graph_encode(self, G):
|
||||
# encode graphs G into latent vectors
|
||||
if type(G) != list:
|
||||
G = [G]
|
||||
self._propagate_from(G, 0, self.grue_forward, H0=self._get_zero_hidden(len(G)),
|
||||
reverse=False)
|
||||
if self.bidir:
|
||||
self._propagate_from(G, self.max_n - 1, self.grue_backward,
|
||||
H0=self._get_zero_hidden(len(G)), reverse=True)
|
||||
Hg = self._get_graph_state(G)
|
||||
mu = self.fc1(Hg)
|
||||
#logvar = self.fc2(Hg)
|
||||
return mu #, logvar
|
||||
|
||||
|
||||
def reparameterize(self, mu, logvar, eps_scale=0.01):
|
||||
# return z ~ N(mu, std)
|
||||
if self.training:
|
||||
std = logvar.mul(0.5).exp_()
|
||||
eps = torch.randn_like(std) * eps_scale
|
||||
return eps.mul(std).add_(mu)
|
||||
else:
|
||||
return mu
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user