73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
"""Evaluate trained score network"""
|
|
|
|
import ml_collections
|
|
import torch
|
|
|
|
from all_path import SCORENET_CKPT_PATH
|
|
|
|
def get_config():
|
|
config = ml_collections.ConfigDict()
|
|
|
|
# general
|
|
config.folder_name = 'test'
|
|
config.model_type = 'scorenet'
|
|
config.task = 'eval_scorenet'
|
|
config.exp_name = None
|
|
config.seed = 42
|
|
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
|
config.resume = False
|
|
config.scorenet_ckpt_path = SCORENET_CKPT_PATH
|
|
|
|
# training
|
|
config.training = training = ml_collections.ConfigDict()
|
|
training.sde = 'vesde'
|
|
training.continuous = True
|
|
training.reduce_mean = True
|
|
training.noised = True
|
|
|
|
# sampling
|
|
config.sampling = sampling = ml_collections.ConfigDict()
|
|
sampling.method = 'pc'
|
|
sampling.predictor = 'euler_maruyama'
|
|
sampling.corrector = 'langevin'
|
|
sampling.n_steps_each = 1
|
|
sampling.noise_removal = True
|
|
sampling.probability_flow = False
|
|
sampling.snr = 0.16
|
|
|
|
# evaluation
|
|
config.eval = evaluate = ml_collections.ConfigDict()
|
|
evaluate.batch_size = 256
|
|
evaluate.enable_sampling = True
|
|
evaluate.num_samples = 256
|
|
|
|
# data
|
|
config.data = data = ml_collections.ConfigDict()
|
|
data.centered = True
|
|
data.dequantization = False
|
|
|
|
data.root = '../data/transfer_nag/nasbench201_info.pt'
|
|
data.name = 'NASBench201'
|
|
data.split_ratio = 1.0
|
|
data.dataset_idx = 'random' # 'sorted' | 'random'
|
|
data.max_node = 8
|
|
data.n_vocab = 7 # number of operations
|
|
data.START_TYPE = 0
|
|
data.END_TYPE = 1
|
|
data.num_graphs = 15625
|
|
data.num_channels = 1
|
|
data.label_list = ['test-acc']
|
|
data.tg_dataset = 'cifar10'
|
|
# aug_mask
|
|
data.aug_mask_algo = 'floyd' # 'long_range' | 'floyd'
|
|
|
|
# model
|
|
config.model = model = ml_collections.ConfigDict()
|
|
model.num_scales = 1000
|
|
model.beta_min = 0.1
|
|
model.beta_max = 5.0
|
|
model.sigma_min = 0.1
|
|
model.sigma_max = 5.0
|
|
|
|
return config
|