diffusionNAG/MobileNetV3/configs/tr_scorenet_ofa.py
2024-03-15 14:38:51 +00:00

142 lines
3.8 KiB
Python

"""Training PGSN on Community Small Dataset with GraphGDP"""
import ml_collections
import torch
def get_config():
config = ml_collections.ConfigDict()
# general
config.resume = False
config.resume_ckpt_path = './exp'
config.folder_name = 'tr_scorenet'
config.task = 'tr_scorenet'
config.exp_name = None
config.model_type = 'sde'
# training
config.training = training = ml_collections.ConfigDict()
training.sde = 'vesde'
training.continuous = True
training.reduce_mean = True
training.batch_size = 256
training.eval_batch_size = 1000
training.n_iters = 1000000
training.snapshot_freq = 10000
training.log_freq = 200
training.eval_freq = 10000
## store additional checkpoints for preemption
training.snapshot_freq_for_preemption = 5000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'euler_maruyama'
sampling.corrector = 'none'
sampling.rtol = 1e-5
sampling.atol = 1e-5
sampling.ode_method = 'dopri5' # 'rk4'
sampling.ode_step = 0.01
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.16
sampling.vis_row = 4
sampling.vis_col = 4
sampling.alpha = 0.5
sampling.qtype = 'threshold'
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 5
evaluate.end_ckpt = 20
evaluate.batch_size = 1024
evaluate.enable_sampling = True
evaluate.num_samples = 1024
evaluate.mmd_distance = 'RBF'
evaluate.max_subgraph = False
evaluate.save_graph = False
# data
config.data = data = ml_collections.ConfigDict()
data.centered = True
data.dequantization = False
data.root = './data/ofa/data_score_model/ofa_database_500000.pt'
data.name = 'ofa'
data.split_ratio = 0.9
data.dataset_idx = 'random'
data.max_node = 20
data.n_vocab = 9 # 10 #
data.START_TYPE = 0
data.END_TYPE = 1
data.num_graphs = 100000
data.num_channels = 1
data.except_inout = False
data.triu_adj = True
data.connect_prev = False
data.label_list = None
data.tg_dataset = None
data.node_rule_type = 2
# aug_mask
data.aug_mask_algo = 'none'
# model
config.model = model = ml_collections.ConfigDict()
model.name = 'CATE'
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.num_gnn_layers = 4
model.size_cond = False
model.embedding_type = 'positional'
model.rw_depth = 16
model.graph_layer = 'PosTransLayer'
model.edge_th = -1.
model.heads = 8
model.attn_clamp = False
model.num_scales = 1000
model.sigma_min = 0.1
model.sigma_max = 1.0
model.dropout = 0.1
model.pos_enc_type = 2
# graph encoder
config.model.graph_encoder = graph_encoder = ml_collections.ConfigDict()
graph_encoder.n_layers = 12
graph_encoder.d_model = 64
graph_encoder.n_head = 8
graph_encoder.d_ff = 128
graph_encoder.dropout = 0.1
graph_encoder.n_vocab = 9 #10 # 30
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-5
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 1000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# log
config.log = log = ml_collections.ConfigDict()
log.use_wandb = True
log.wandb_project_name = 'DiffusionNAG'
log.log_valid_sample_prop = False
log.num_graphs_to_visualize = 20
return config