# These imports are tricky because they use c++, do not move them from tqdm import tqdm import os, shutil import warnings import torch import hydra from omegaconf import DictConfig from pytorch_lightning import Trainer import utils from datasets import dataset from diffusion_model import Graph_DiT from metrics.molecular_metrics_train import TrainMolecularMetricsDiscrete from metrics.molecular_metrics_train import TrainGraphMetricsDiscrete from metrics.molecular_metrics_sampling import SamplingMolecularMetrics from metrics.molecular_metrics_sampling import SamplingGraphMetrics from analysis.visualization import MolecularVisualization from analysis.visualization import GraphVisualization warnings.filterwarnings("ignore", category=UserWarning) torch.set_float32_matmul_precision("medium") def remove_folder(folder): for filename in os.listdir(folder): file_path = os.path.join(folder, filename) try: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as e: print("Failed to delete %s. Reason: %s" % (file_path, e)) def get_resume(cfg, model_kwargs): """Resumes a run. It loads previous config without allowing to update keys (used for testing).""" saved_cfg = cfg.copy() name = cfg.general.name + "_resume" resume = cfg.general.test_only batch_size = cfg.train.batch_size model = Graph_DiT.load_from_checkpoint(resume, **model_kwargs) cfg = model.cfg cfg.general.test_only = resume cfg.general.name = name cfg.train.batch_size = batch_size cfg = utils.update_config_with_new_keys(cfg, saved_cfg) return cfg, model def get_resume_adaptive(cfg, model_kwargs): """Resumes a run. It loads previous config but allows to make some changes (used for resuming training).""" saved_cfg = cfg.copy() # Fetch path to this file to get base path current_path = os.path.dirname(os.path.realpath(__file__)) root_dir = current_path.split("outputs")[0] resume_path = os.path.join(root_dir, cfg.general.resume) if cfg.model.type == "discrete": model = Graph_DiT.load_from_checkpoint( resume_path, **model_kwargs ) else: raise NotImplementedError("Unknown model") new_cfg = model.cfg for category in cfg: for arg in cfg[category]: new_cfg[category][arg] = cfg[category][arg] new_cfg.general.resume = resume_path new_cfg.general.name = new_cfg.general.name + "_resume" new_cfg = utils.update_config_with_new_keys(new_cfg, saved_cfg) return new_cfg, model @hydra.main( version_base="1.1", config_path="../configs", config_name="config" ) def main(cfg: DictConfig): datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) train_smiles, reference_smiles = datamodule.get_train_smiles() # train_graphs, reference_graphs = datamodule.get_train_graphs() # get input output dimensions dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainMolecularMetricsDiscrete(dataset_infos) # train_metrics = TrainGraphMetricsDiscrete(dataset_infos) sampling_metrics = SamplingMolecularMetrics( dataset_infos, train_smiles, reference_smiles ) # sampling_metrics = SamplingGraphMetrics( # dataset_infos, train_graphs, reference_graphs # ) visualization_tools = MolecularVisualization(dataset_infos) model_kwargs = { "dataset_infos": dataset_infos, # "train_metrics": train_metrics, # "sampling_metrics": sampling_metrics, "visualization_tools": visualization_tools, } if cfg.general.test_only: # When testing, previous configuration is fully loaded cfg, _ = get_resume(cfg, model_kwargs) os.chdir(cfg.general.test_only.split("checkpoints")[0]) elif cfg.general.resume is not None: # When resuming, we can override some parts of previous configuration cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) model = Graph_DiT(cfg=cfg, **model_kwargs) trainer = Trainer( gradient_clip_val=cfg.train.clip_grad, # accelerator="gpu" # if torch.cuda.is_available() and cfg.general.gpus > 0 # else "cpu", accelerator="cpu", devices=cfg.general.gpus if torch.cuda.is_available() and cfg.general.gpus > 0 else None, max_epochs=cfg.train.n_epochs, enable_checkpointing=False, check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, val_check_interval=cfg.train.val_check_interval, strategy="ddp" if cfg.general.gpus > 1 else "auto", enable_progress_bar=cfg.general.enable_progress_bar, callbacks=[], reload_dataloaders_every_n_epochs=0, logger=[], ) if not cfg.general.test_only: trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) if cfg.general.save_model: trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") trainer.test(model, datamodule=datamodule) else: trainer.test(model, datamodule=datamodule, ckpt_path=cfg.general.test_only) from accelerate import Accelerator from accelerate.utils import set_seed, ProjectConfiguration @hydra.main( version_base="1.1", config_path="../configs", config_name="config" ) def test(cfg: DictConfig): os.environ["CUDA_VISIBLE_DEVICES"] = cfg.general.gpu_number accelerator_config = ProjectConfiguration( project_dir=os.path.join(cfg.general.log_dir, cfg.general.name), automatic_checkpoint_naming=True, total_limit=cfg.general.number_checkpoint_limit, ) accelerator = Accelerator( mixed_precision='no', project_config=accelerator_config, gradient_accumulation_steps=cfg.train.gradient_accumulation_steps * cfg.train.n_epochs, ) # Debug: 确认可用设备 print(f"Available GPUs: {torch.cuda.device_count()}") print(f"Using device: {accelerator.device}") set_seed(cfg.train.seed, device_specific=True) datamodule = dataset.DataModule(cfg) datamodule.prepare_data() dataset_infos = dataset.DataInfos(datamodule=datamodule, cfg=cfg, dataset=datamodule.dataset) train_graphs, reference_graphs = datamodule.get_train_graphs() dataset_infos.compute_input_output_dims(datamodule=datamodule) train_metrics = TrainGraphMetricsDiscrete(dataset_infos) sampling_metrics = SamplingGraphMetrics( dataset_infos, train_graphs, reference_graphs ) visulization_tools = GraphVisualization(dataset_infos) model_kwargs = { "dataset_infos": dataset_infos, "train_metrics": train_metrics, "sampling_metrics": sampling_metrics, "visualization_tools": visulization_tools, } # Debug: 确认可用设备 print(f"Available GPUs: {torch.cuda.device_count()}") print(f"Using device: {accelerator.device}") if cfg.general.test_only: cfg, _ = get_resume(cfg, model_kwargs) os.chdir(cfg.general.test_only.split("checkpoints")[0]) elif cfg.general.resume is not None: cfg, _ = get_resume_adaptive(cfg, model_kwargs) os.chdir(cfg.general.resume.split("checkpoints")[0]) model = Graph_DiT(cfg=cfg, **model_kwargs) graph_dit_model = model inference_dtype = torch.float32 graph_dit_model.to(accelerator.device, dtype=inference_dtype) # optional: freeze the model # graph_dit_model.model.requires_grad_(True) import torch.nn.functional as F optimizer = graph_dit_model.configure_optimizers() train_dataloader = accelerator.prepare(datamodule.train_dataloader()) optimizer, graph_dit_model = accelerator.prepare(optimizer, graph_dit_model) # start training for epoch in range(cfg.train.n_epochs): graph_dit_model.train() # 设置模型为训练模式 print(f"Epoch {epoch}", end="\n") for data in train_dataloader: # 从数据加载器中获取一个批次的数据 data.to(accelerator.device) data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) dense_data = dense_data.mask(node_mask) X, E = dense_data.X, dense_data.E noisy_data = graph_dit_model.apply_noise(X, E, data.y, node_mask) pred = graph_dit_model.forward(noisy_data) loss = graph_dit_model.train_loss(masked_pred_X=pred.X, masked_pred_E=pred.E, pred_y=pred.y, true_X=X, true_E=E, true_y=data.y, node_mask=node_mask, log=epoch % graph_dit_model.log_every_steps == 0) # print(f'training loss: {loss}, epoch: {self.current_epoch}, batch: {i}\n, pred type: {type(pred)}, pred.X shape: {type(pred.X)}, {pred.X.shape}, pred.E shape: {type(pred.E)}, {pred.E.shape}') graph_dit_model.train_metrics(masked_pred_X=pred.X, masked_pred_E=pred.E, true_X=X, true_E=E, log=epoch % graph_dit_model.log_every_steps == 0) graph_dit_model.log(f'loss', loss, batch_size=X.size(0), sync_dist=True) print(f"training loss: {loss}") with open("training-loss.csv", "a") as f: f.write(f"{loss}, {epoch}\n") accelerator.backward(loss) optimizer.step() optimizer.zero_grad() # return {'loss': loss} if epoch % cfg.train.check_val_every_n_epoch == 0: print(f'print validation loss') graph_dit_model.eval() graph_dit_model.on_validation_epoch_start() graph_dit_model.validation_step(data, epoch) graph_dit_model.on_validation_epoch_end() # start testing print("start testing") graph_dit_model.eval() test_dataloader = accelerator.prepare(datamodule.test_dataloader()) for data in test_dataloader: data_x = F.one_hot(data.x, num_classes=12).float()[:, graph_dit_model.active_index] data_edge_attr = F.one_hot(data.edge_attr, num_classes=2).float() dense_data, node_mask = utils.to_dense(data_x, data.edge_index, data_edge_attr, data.batch, graph_dit_model.max_n_nodes) dense_data = dense_data.mask(node_mask) noisy_data = graph_dit_model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask) pred = graph_dit_model.forward(noisy_data) nll = graph_dit_model.compute_val_loss(pred, noisy_data, dense_data.X, dense_data.E, data.y, node_mask, test=True) graph_dit_model.test_y_collection.append(data.y) print(f'test loss: {nll}') # start sampling samples_left_to_generate = cfg.general.final_model_samples_to_generate samples_left_to_save = cfg.general.final_model_samples_to_save chains_left_to_save = cfg.general.final_model_chains_to_save samples, all_ys, batch_id = [], [], 0 samples_with_log_probs = [] test_y_collection = torch.cat(graph_dit_model.test_y_collection, dim=0) num_examples = test_y_collection.size(0) if cfg.general.final_model_samples_to_generate > num_examples: ratio = cfg.general.final_model_samples_to_generate // num_examples test_y_collection = test_y_collection.repeat(ratio+1, 1) num_examples = test_y_collection.size(0) # Normal reward function from nas_201_api import NASBench201API as API api = API('/nfs/data3/hanzhang/nasbench201/graph_dit/NAS-Bench-201-v1_1-096897.pth') def graph_reward_fn(graphs, true_graphs=None, device=None, reward_model='swap'): rewards = [] if reward_model == 'swap': import csv with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f: reader = csv.reader(f) header = next(reader) data = [row for row in reader] swap_scores = [float(row[0]) for row in data] for graph in graphs: node_tensor = graph[0] node = node_tensor.cpu().numpy().tolist() def nodes_to_arch_str(nodes): num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output'] nodes_str = [num_to_op[node] for node in nodes] arch_str = '|' + nodes_str[1] + '~0|+' + \ '|' + nodes_str[2] + '~0|' + nodes_str[3] + '~1|+' +\ '|' + nodes_str[4] + '~0|' + nodes_str[5] + '~1|' + nodes_str[6] + '~2|' return arch_str arch_str = nodes_to_arch_str(node) reward = swap_scores[api.query_index_by_arch(arch_str)] rewards.append(reward) for graph in graphs: reward = 1.0 rewards.append(reward) return torch.tensor(rewards, dtype=torch.float32, requires_grad=True).unsqueeze(0).to(device) while samples_left_to_generate > 0: print(f'samples left to generate: {samples_left_to_generate}/' f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) bs = 1 * cfg.train.batch_size to_generate = min(samples_left_to_generate, bs) to_save = min(samples_left_to_save, bs) chains_save = min(chains_left_to_save, bs) # batch_y = test_y_collection[batch_id : batch_id + to_generate] batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) cur_sample, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save, keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) samples = samples + cur_sample reward = graph_reward_fn(cur_sample, device=graph_dit_model.device) samples_with_log_probs.append((cur_sample, log_probs, reward)) all_ys.append(batch_y) batch_id += to_generate samples_left_to_save -= to_save samples_left_to_generate -= to_generate chains_left_to_save -= chains_save print(f"final Computing sampling metrics...") graph_dit_model.sampling_metrics.reset() graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) graph_dit_model.sampling_metrics.reset() print(f"Done.") # save samples print("Samples:") print(samples) # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) # samples, log_probs, rewards = samples_with_log_probs[perm] # samples = list(samples) # log_probs = list(log_probs) # for i in range(len(log_probs)): # log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) # print(f'log_probs: {log_probs[:5]}') # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) # rewards = list(rewards) # log_probs = torch.cat(log_probs, dim=0) # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) # old_log_probs = log_probs.clone() # === # old_log_probs = None # while samples_left_to_generate > 0: # print(f'samples left to generate: {samples_left_to_generate}/' # f'{cfg.general.final_model_samples_to_generate}', end='', flush=True) # bs = 1 * cfg.train.batch_size # to_generate = min(samples_left_to_generate, bs) # to_save = min(samples_left_to_save, bs) # chains_save = min(chains_left_to_save, bs) # with accelerator.accumulate(graph_dit_model): # batch_y = torch.ones(to_generate, graph_dit_model.ydim_output, device=graph_dit_model.device) # new_samples, log_probs = graph_dit_model.sample_batch(batch_id, to_generate, batch_y, save_final=to_save,keep_chain=chains_save, number_chain_steps=graph_dit_model.number_chain_steps) # samples = samples + new_samples # reward = graph_reward_fn(new_samples, device=graph_dit_model.device) # advantages = (reward - torch.mean(reward)) / (torch.std(reward) + 1e-6) # if old_log_probs is None: # old_log_probs = log_probs.clone() # ratio = torch.exp(log_probs - old_log_probs) # unclipped_loss = -advantages * ratio # clipped_loss = -advantages * torch.clamp(ratio, # 1.0 - cfg.ppo.clip_param, # 1.0 + cfg.ppo.clip_param) # loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) # accelerator.backward(loss) # optimizer.step() # optimizer.zero_grad() # samples_with_log_probs.append((new_samples, log_probs, reward)) # all_ys.append(batch_y) # batch_id += to_generate # samples_left_to_save -= to_save # samples_left_to_generate -= to_generate # chains_left_to_save -= chains_save # # break # print(f"final Computing sampling metrics...") # graph_dit_model.sampling_metrics.reset() # graph_dit_model.sampling_metrics(samples, all_ys, graph_dit_model.name, graph_dit_model.current_epoch, graph_dit_model.val_counter, test=True) # graph_dit_model.sampling_metrics.reset() # print(f"Done.") # # save samples # print("Samples:") # print(samples) # perm = torch.randperm(len(samples_with_log_probs), device=accelerator.device) # samples, log_probs, rewards = samples_with_log_probs[perm] # samples = list(samples) # log_probs = list(log_probs) # for i in range(len(log_probs)): # log_probs[i] = torch.sum(log_probs[i], dim=-1).unsqueeze(0) # print(f'log_probs: {log_probs[:5]}') # print(f'log_probs: {log_probs[0].shape}') # torch.Size([1]) # rewards = list(rewards) # log_probs = torch.cat(log_probs, dim=0) # print(f'log_probs: {log_probs.shape}') # torch.Size([1000, 1]) # old_log_probs = log_probs.clone() # # multi metrics range # # reward hacking hiking # for inner_epoch in range(cfg.train.n_epochs): # # print(f'rewards: {rewards.shape}') # torch.Size([1000]) # print(f'rewards: {rewards[:5]}') # print(f'len rewards: {len(rewards)}') # print(f'type rewards: {type(rewards)}') # if len(rewards) > 1 and isinstance(rewards, list): # rewards = torch.cat(rewards, dim=0) # elif len(rewards) == 1 and isinstance(rewards, list): # rewards = rewards[0] # # print(f'rewards: {rewards.shape}') # advantages = (rewards - torch.mean(rewards)) / (torch.std(rewards) + 1e-6) # print(f'advantages: {advantages.shape}') # with accelerator.accumulate(graph_dit_model): # ratio = torch.exp(log_probs - old_log_probs) # unclipped_loss = -advantages * ratio # # z-score normalization # clipped_loss = -advantages * torch.clamp(ratio, # 1.0 - cfg.ppo.clip_param, # 1.0 + cfg.ppo.clip_param) # loss = torch.mean(torch.max(unclipped_loss, clipped_loss)) # accelerator.backward(loss) # optimizer.step() # optimizer.zero_grad() # accelerator.log({"loss": loss.item(), "epoch": inner_epoch}) # print(f"loss: {loss.item()}, epoch: {inner_epoch}") # trainer = Trainer( # gradient_clip_val=cfg.train.clip_grad, # # accelerator="cpu", # accelerator="gpu" # if torch.cuda.is_available() and cfg.general.gpus > 0 # else "cpu", # devices=[cfg.general.gpu_number] # if torch.cuda.is_available() and cfg.general.gpus > 0 # else None, # max_epochs=cfg.train.n_epochs, # enable_checkpointing=False, # check_val_every_n_epoch=cfg.train.check_val_every_n_epoch, # val_check_interval=cfg.train.val_check_interval, # strategy="ddp" if cfg.general.gpus > 1 else "auto", # enable_progress_bar=cfg.general.enable_progress_bar, # callbacks=[], # reload_dataloaders_every_n_epochs=0, # logger=[], # ) # if not cfg.general.test_only: # print("start testing fit method") # trainer.fit(model, datamodule=datamodule, ckpt_path=cfg.general.resume) # if cfg.general.save_model: # trainer.save_checkpoint(f"checkpoints/{cfg.general.name}/last.ckpt") # trainer.test(model, datamodule=datamodule) if __name__ == "__main__": test()