import sys
sys.path.append('../') 

from nas_201_api import NASBench201API as API

import os
import os.path as osp
import pathlib
import json
import random

import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import rdchem
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes
from naswot.score_networks import get_nasbench201_idx_score
from naswot import nasspace
from naswot import datasets as dt

import networkx as nx


bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}

op_to_atom = {
    'input': 'Si',          # Hydrogen for input
    'nor_conv_1x1': 'C',   # Carbon for 1x1 convolution
    'nor_conv_3x3': 'N',   # Nitrogen for 3x3 convolution
    'avg_pool_3x3': 'O',   # Oxygen for 3x3 average pooling
    'skip_connect': 'P',   # Phosphorus for skip connection
    'none': 'S',           # Sulfur for no operation
    'output': 'He'         # Helium for output
}

op_type = {
    'input': 0,
    'nor_conv_1x1': 1,
    'nor_conv_3x3': 2,
    'avg_pool_3x3': 3,
    'skip_connect': 4,
    'none': 5,
    'output': 6,
}

num_to_op = ['input', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'none', 'output']

class DataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.datadir = cfg.dataset.datadir
        self.task = cfg.dataset.task_name
        print("DataModule")
        print("task", self.task)
        print("datadir", self.datadir)
        super().__init__(cfg)

    def prepare_data(self) -> None:
        target = getattr(self.cfg.dataset, 'guidance_target', None)
        print("target", target) # nasbench-201
        # try:
        #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        # except NameError:
        # base_path = pathlib.Path(os.getcwd()).parent[2]
        base_path = '/nfs/data3/hanzhang/nasbenchDiT'
        root_path = os.path.join(base_path, self.datadir)
        self.root_path = root_path

        batch_size = self.cfg.train.batch_size
        
        num_workers = self.cfg.train.num_workers
        pin_memory = self.cfg.dataset.pin_memory

        # Load the dataset to the memory
        # Dataset has target property, root path, and transform
        source = './NAS-Bench-201-v1_1-096897.pth'
        dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)
        self.dataset = dataset
        # self.api = dataset.api

        # if len(self.task.split('-')) == 2:
        #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
        # else:
        train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)

        self.train_index, self.val_index, self.test_index, self.unlabeled_index = (
            train_index, val_index, test_index, unlabeled_index)
        train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
        if len(unlabeled_index) > 0:
            train_index = torch.cat([train_index, unlabeled_index], dim=0)
        
        train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
        self.train_dataset = train_dataset  
        self.test_dataset = test_dataset
        print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
        print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
        print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)

        self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)

        training_iterations = len(train_dataset) // batch_size
        self.training_iterations = training_iterations
    
    def random_data_split(self, dataset):
        # nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
        # labeled_len = len(dataset) - nan_count
        labeled_len = len(dataset) 
        full_idx = list(range(labeled_len))
        train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
        train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
        train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
        unlabeled_index = list(range(labeled_len, len(dataset)))
        print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
        return train_index, val_index, test_index, unlabeled_index
    
    def fixed_split(self, dataset):
        if self.task == 'O2-N2':
            test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
        else:
            raise ValueError('Invalid task name: {}'.format(self.task))
        full_idx = list(range(len(dataset)))
        full_idx = list(set(full_idx) - set(test_index))
        train_ratio = 0.8
        train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
        print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
        return train_index, val_index, test_index, []

    # def parse_architecture_string(self, arch_str):
    #         stages = arch_str.split('+')
    #         nodes = ['input']
    #         edges = []
            
    #         for stage in stages:
    #             operations = stage.strip('|').split('|')
    #             for op in operations:
    #                 operation, idx = op.split('~')
    #                 idx = int(idx)
    #                 edges.append((idx, len(nodes)))  # Add edge from idx to the new node
    #                 nodes.append(operation)
    #         nodes.append('output')  # Add the output node
    #         return nodes, edges
    def parse_architecture_string(arch_str):
        # print(arch_str)
        steps = arch_str.split('+')
        nodes = ['input']  # Start with input node
        adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
                            [0, 0, 0, 1, 0, 1 ,0 ,0],
                            [0, 0, 0, 0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 0, 0, 1, 0],
                            [0, 0, 0, 0, 0, 0, 0, 1],
                            [0, 0, 0, 0, 0, 0, 0, 1],
                            [0, 0, 0, 0, 0, 0, 0, 1],
                            [0, 0, 0, 0, 0, 0, 0, 0]]) 
        steps = arch_str.split('+')
        steps_coding = ['0', '0', '1', '0', '1', '2']
        cont = 0
        for step in steps:
            step =  step.strip('|').split('|')
            for node in step:
                n, idx = node.split('~')
                assert idx == steps_coding[cont]
                cont += 1
                nodes.append(n)
        nodes.append('output')  # Add output node
        return nodes, adj_mat

    # def create_molecule_from_graph(nodes, edges):
    def create_molecule_from_graph(self, graph):
        nodes = graph.x
        edges = graph.edge_index
        mol = Chem.RWMol()  # RWMol allows for building the molecule step by step
        atom_indices = {}
        num_to_op = {
            1 :'nor_conv_1x1',
            2 :'nor_conv_3x3',
            3 :'avg_pool_3x3',
            4 :'skip_connect',
            5 :'output',
            6 :'none',
            7 :'input'
        } 

        # Extract node operations from the data object

        # Add atoms to the molecule
        for i, op_tensor in enumerate(nodes):
            op = op_tensor.item()
            if op == 0: continue
            op = num_to_op[op]
            atom_symbol = op_to_atom[op]
            atom = Chem.Atom(atom_symbol)
            atom_idx = mol.AddAtom(atom)
            atom_indices[i] = atom_idx
        
        # Add bonds to the molecule
        edge_number = edges.shape[1]
        for i in range(edge_number):
            start = edges[0, i].item()
            end = edges[1, i].item()
            mol.AddBond(atom_indices[start], atom_indices[end], rdchem.BondType.SINGLE)
        
        return mol

    # def arch_str_to_smiles(self, arch_str):
    #     nodes, edges = self.parse_architecture_string(arch_str)
    #     mol = self.create_molecule_from_graph(nodes, edges)
    #     smiles = Chem.MolToSmiles(mol)
    #     return smiles

    def get_train_graphs(self):
        train_graphs = []
        test_graphs = []
        for graph in self.train_dataset:
            train_graphs.append(graph)
        for graph in self.test_dataset:
            test_graphs.append(graph)
        return train_graphs, test_graphs


    # def get_train_smiles(self):
    #     filename = f'{self.task}.csv.gz'
    #     df = pd.read_csv(f'{self.root_path}/raw/{filename}')
    #     df_test = df.iloc[self.test_index]
    #     df = df.iloc[self.train_index]
    #     smiles_list = df['smiles'].tolist()
    #     smiles_list_test = df_test['smiles'].tolist()
    #     smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
    #     smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
    #     return smiles_list, smiles_list_test

    def get_train_smiles(self):
        train_smiles = []   
        test_smiles = []

        for graph in self.train_dataset:
            # print(f'idx={idx}')
            # graph = self.train_dataset[idx]
            print(graph.x)
            print(graph.edge_index)
            print(f'class of graph.x: {graph.x.__class__}, class of graph.edge_index: {graph.edge_index.__class__}')
            mol = self.create_molecule_from_graph(graph)
            train_smiles.append(Chem.MolToSmiles(mol))
        
        # for idx in self.test_index:
        for graph in self.test_dataset:
            # graph = self.dataset[idx]
            # mol = self.create_molecule_from_graph(graph.x, graph.edge_index)
            mol = self.create_molecule_from_graph(graph)
            test_smiles.append(Chem.MolToSmiles(mol))
        
        # train_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in train_arch_strs]
        # test_smiles = [self.arch_str_to_smiles(arch_str) for arch_str in test_arch_strs]
        return train_smiles, test_smiles

    def get_data_split(self):
        raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")

    def example_batch(self):
        return next(iter(self.val_loader))
    
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader
    
    def test_dataloader(self):
        return self.test_loader




class DataModule_original(AbstractDataModule):
    def __init__(self, cfg):
        self.datadir = cfg.dataset.datadir
        self.task = cfg.dataset.task_name
        print("DataModule")
        print("task", self.task)
        print("datadir`",self.datadir)
        super().__init__(cfg)

    def prepare_data(self) -> None:
        target = getattr(self.cfg.dataset, 'guidance_target', None)
        print("target", target)
        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        root_path = os.path.join(base_path, self.datadir)
        self.root_path = root_path

        batch_size = self.cfg.train.batch_size
        
        num_workers = self.cfg.train.num_workers
        pin_memory = self.cfg.dataset.pin_memory

        # Load the dataset to the memory
        # Dataset has target property, root path, and transform
        dataset = Dataset(source=self.task, root=root_path, target_prop=target, transform=None)

        if len(self.task.split('-')) == 2:
            train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
        else:
            train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)

        self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
        train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
        if len(unlabeled_index) > 0:
            train_index = torch.cat([train_index, unlabeled_index], dim=0)
        
        train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
        self.train_dataset = train_dataset  
        print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
        print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
        print('dataset len', len(dataset) , 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)

        self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)

        training_iterations = len(train_dataset) // batch_size
        self.training_iterations = training_iterations
    
    def random_data_split(self, dataset):
        nan_count = torch.isnan(dataset.y[:, 0]).sum().item()
        labeled_len = len(dataset) - nan_count
        full_idx = list(range(labeled_len))
        train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
        train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
        train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
        unlabeled_index = list(range(labeled_len, len(dataset)))
        print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
        return train_index, val_index, test_index, unlabeled_index
    
    def fixed_split(self, dataset):
        if self.task == 'O2-N2':
            test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
        else:
            raise ValueError('Invalid task name: {}'.format(self.task))
        full_idx = list(range(len(dataset)))
        full_idx = list(set(full_idx) - set(test_index))
        train_ratio = 0.8
        train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
        print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
        return train_index, val_index, test_index, []

    def get_train_smiles(self):
        filename = f'{self.task}.csv.gz'
        df = pd.read_csv(f'{self.root_path}/raw/{filename}')
        df_test = df.iloc[self.test_index]
        df = df.iloc[self.train_index]
        smiles_list = df['smiles'].tolist()
        smiles_list_test = df_test['smiles'].tolist()
        smiles_list = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list]
        smiles_list_test = [Chem.MolToSmiles(Chem.MolFromSmiles(smi)) for smi in smiles_list_test]
        return smiles_list, smiles_list_test
    
    def get_data_split(self):
        filename = f'{self.task}.csv.gz'
        df = pd.read_csv(f'{self.root_path}/raw/{filename}')
        df_val = df.iloc[self.val_index]
        df_test = df.iloc[self.test_index]
        df_train = df.iloc[self.train_index]
        return df_train, df_val, df_test

    def example_batch(self):
        return next(iter(self.val_loader))
    
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader
    
    def test_dataloader(self):
        return self.test_loader

def new_graphs_to_json(graphs, filename):
    source_name = "nasbench-201"
    num_graph = len(graphs)

    node_name_list = []
    node_count_list = []
    node_name_list.append('*')
    
    for op_name in op_type:
        node_name_list.append(op_name)
        node_count_list.append(0) 
    
    node_count_list.append(0)
    n_nodes_per_graph = [0] * num_graph
    edge_count_list = [0, 0] 
    valencies = [0] * (len(op_type) + 1)
    transition_E = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))

    n_node_list = []
    n_edge_list = []

    for graph in graphs:
        ops = graph[1]
        adj = graph[0]

        n_node = len(ops)
        print(n_node)
        n_edge = len(ops)
        n_node_list.append(n_node)
        n_edge_list.append(n_edge)

        n_nodes_per_graph[n_node] += 1
        cur_node_count_arr = np.zeros(len(op_type) + 1)

        for op in ops:
            node = op
            # if node == '*':
            #     node_count_list[-1] += 1
            #     cur_node_count_arr[-1] += 1
            # else:
            node_count_list[node] += 1
            cur_node_count_arr[node] += 1
            try:
                valencies[node] += 1
            except:
                print('int(op_type[node])', int(node))
        
        transition_E_temp = np.zeros((len(op_type) + 1, len(op_type) + 1, 2))
        for i in range(n_node):
            for j in range(n_node):
                if i == j or adj[i][j] == 0:
                    continue
                start_node, end_node = i, j
                
                start_index = ops[start_node]
                end_index = ops[end_node]
                bond_index = 1
                edge_count_list[bond_index] += 2
                
                transition_E[start_index, end_index, bond_index] += 2
                transition_E[end_index, start_index, bond_index] += 2
                transition_E_temp[start_index, end_index, bond_index] += 2
                transition_E_temp[end_index, start_index, bond_index] += 2

        edge_count_list[0] += n_node * (n_node - 1) - n_edge * 2
        cur_tot_edge = cur_node_count_arr.reshape(-1,1) * cur_node_count_arr.reshape(1,-1) * 2
        # print(f"cur_tot_edge={cur_tot_edge}, shape: {cur_tot_edge.shape}")
        cur_tot_edge = cur_tot_edge - np.diag(cur_node_count_arr) * 2
        transition_E[:, :, 0] += cur_tot_edge - transition_E_temp.sum(axis=-1)
        assert (cur_tot_edge > transition_E_temp.sum(axis=-1)).sum() >= 0
    
    n_nodes_per_graph = np.array(n_nodes_per_graph) / np.sum(n_nodes_per_graph)
    n_nodes_per_graph = n_nodes_per_graph.tolist()[:51]

    node_count_list = np.array(node_count_list) / np.sum(node_count_list)
    print('processed meta info: ------', filename, '------')
    print('len node_count_list', len(node_count_list))
    print('len node_name_list', len(node_name_list))
    active_nodes = np.array(node_name_list)[node_count_list > 0]
    active_nodes = active_nodes.tolist()
    node_count_list = node_count_list.tolist()

    edge_count_list = np.array(edge_count_list) / np.sum(edge_count_list)
    edge_count_list = edge_count_list.tolist()
    valencies = np.array(valencies) / np.sum(valencies)
    valencies = valencies.tolist()

    no_edge = np.sum(transition_E, axis=-1) == 0
    first_elt = transition_E[:, :, 0]
    first_elt[no_edge] = 1
    transition_E[:, :, 0] = first_elt

    transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)

    meta_dict = {
        'source': source_name,
        'num_graph': num_graph,
        'n_nodes_per_graph': n_nodes_per_graph,
        'max_n_nodes': max(n_node_list),
        'max_n_edges': max(n_edge_list),
        'node_type_list': node_count_list,
        'edge_type_list': edge_count_list,
        'valencies': valencies,
        'active_nodes': active_nodes,
        'num_active_nodes': len(active_nodes),
        'transition_E': transition_E.tolist(),
    }

    with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-meta.json', 'w') as f:
        json.dump(meta_dict, f)
    
    return meta_dict




def graphs_to_json(graphs, filename):
    bonds = {
        'nor_conv_1x1': 1,
        'nor_conv_3x3': 2,
        'avg_pool_3x3': 3,
        'skip_connect': 4,
        'input': 7,
        'output': 5,
        'none': 6
    }

    source_name = "nas-bench-201"
    num_graph = len(graphs)
    pt = Chem.GetPeriodicTable()
    atom_name_list = []
    atom_count_list = []
    for i in range(2, 119):
        atom_name_list.append(pt.GetElementSymbol(i))
        atom_count_list.append(0)
    atom_name_list.append('*')
    atom_count_list.append(0)
    n_atoms_per_mol = [0] * 500
    bond_count_list = [0, 0, 0, 0, 0, 0, 0, 0]
    bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
    valencies = [0] * 500
    transition_E = np.zeros((118, 118, 8))

    n_atom_list = []
    n_bond_list = []
    # graphs = [(adj_matrix, ops), ...]
    for graph in graphs:
        ops = graph[1]
        adj = graph[0]
        n_atom = len(ops)
        n_bond = len(ops)
        n_atom_list.append(n_atom)
        n_bond_list.append(n_bond)

        n_atoms_per_mol[n_atom] += 1
        cur_atom_count_arr = np.zeros(118)

        for op in ops:
            symbol = op_to_atom[op]
            if symbol == 'H':
                continue
            elif symbol == '*':
                atom_count_list[-1] += 1
                cur_atom_count_arr[-1] += 1
            else:
                atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
                cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
                # print('symbol', symbol)
                # print('pt.GetDefaultValence(symbol)', pt.GetDefaultValence(symbol))
                # print(f'cur_atom_count_arr[{pt.GetAtomicNumber(symbol)-2}], {cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2]}')
                try:
                    valencies[int(pt.GetDefaultValence(symbol))] += 1
                except:
                    print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
        transition_E_temp = np.zeros((118, 118, 8))
        # print(n_atom)
        for i in range(n_atom):
            for j in range(n_atom):
                if i == j or adj[i][j] == 0:
                    continue
                start_atom, end_atom = i, j
                if ops[start_atom] == 'input' or ops[end_atom] == 'input':
                    continue
                if ops[start_atom] == 'output' or ops[end_atom] == 'output':
                    continue
                if ops[start_atom] == 'none' or ops[end_atom] == 'none':
                    continue
                
                start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
                end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
                bond_index = bonds[ops[end_atom]]
                bond_count_list[bond_index] += 2

                # print(start_index, end_index, bond_index)
                
                transition_E[start_index, end_index, bond_index] += 2
                transition_E[end_index, start_index, bond_index] += 2
                transition_E_temp[start_index, end_index, bond_index] += 2
                transition_E_temp[end_index, start_index, bond_index] += 2

        bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
        print(bond_count_list)
        cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2
        # print(f'cur_tot_bond={cur_tot_bond}')   
        # find non-zero element in cur_tot_bond
        # for i in range(118):
        #     for j in range(118):
        #         if cur_tot_bond[i][j] != 0:
        #             print(f'i={i}, j={j}, cur_tot_bond[i][j]={cur_tot_bond[i][j]}')
        # n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
        cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2
        # print(f"transition_E[:,:,0]={cur_tot_bond - transition_E_temp.sum(axis=-1)}")
        transition_E[:, :, 0] += cur_tot_bond - transition_E_temp.sum(axis=-1)
        # find non-zero element in transition_E
        # for i in range(118):
        #     for j in range(118):
        #         if transition_E[i][j][0] != 0:
        #             print(f'i={i}, j={j}, transition_E[i][j][0]={transition_E[i][j][0]}')
        assert (cur_tot_bond > transition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
    
    n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
    n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]

    atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
    print('processed meta info: ------', filename, '------')
    print('len atom_count_list', len(atom_count_list))
    print('len atom_name_list', len(atom_name_list))
    active_atoms = np.array(atom_name_list)[atom_count_list > 0]
    active_atoms = active_atoms.tolist()
    atom_count_list = atom_count_list.tolist()

    bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
    bond_count_list = bond_count_list.tolist()
    valencies = np.array(valencies) / np.sum(valencies)
    valencies = valencies.tolist()

    no_edge = np.sum(transition_E, axis=-1) == 0
    for i in range(118):
        for j in range(118):
            if no_edge[i][j] == False:
                print(f'have an edge at i={i} , j={j}, transition_E[i][j]={transition_E[i][j]}')
    # print(f'no_edge: {no_edge}')
    first_elt = transition_E[:, :, 0]
    first_elt[no_edge] = 1
    transition_E[:, :, 0] = first_elt

    transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)

    # find non-zero element in transition_E again
    for i in range(118):
        for j in range(118):
            if transition_E[i][j][0] != 0 and transition_E[i][j][0] != 1 and transition_E[i][j][0] != -1:
                print(f'i={i}, j={j}, 2_transition_E[i][j][0]={transition_E[i][j][0]}')

    meta_dict = {
        'source': 'nasbench-201',
        'num_graph': num_graph,
        'n_atoms_per_mol_dist': n_atoms_per_mol[:51],
        'max_node': max(n_atom_list),
        'max_bond': max(n_bond_list),
        'atom_type_dist': atom_count_list,
        'bond_type_dist': bond_count_list,
        'valencies': valencies,
        'active_nodes': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
        'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
        'transition_E': transition_E.tolist(),
    }

    with open(f'{filename}.meta.json', 'w') as f:
        json.dump(meta_dict, f)
    return meta_dict
class Dataset(InMemoryDataset):
    def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
        self.target_prop = target_prop
        source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
        self.source = source
        # self.api = API(source)  # Initialize NAS-Bench-201 API
        # print('API loaded')
        super().__init__(root, transform, pre_transform, pre_filter)
        print(self.processed_paths[0]) #/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth.pt
        self.data, self.slices = torch.load(self.processed_paths[0])
        print('Dataset initialized')
        self.data.edge_attr = self.data.edge_attr.squeeze()
        self.data.idx = torch.arange(len(self.data.y))
        print(f"self.data={self.data}, self.slices={self.slices}")

    @property
    def raw_file_names(self):
        return []  # NAS-Bench-201 data is loaded via the API, no raw files needed
    
    @property
    def processed_file_names(self):
        return [f'{self.source}.pt']

    def process(self):
        source = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
        # self.api = API(source)

        data_list = []
        # len_data = len(self.api)
        len_data = 15625
        def check_valid_graph(nodes, edges):
            if len(nodes) != edges.shape[0] or len(nodes) != edges.shape[1]:
                return False
            if nodes[0] != 'input' or nodes[-1] != 'output':
                return False
            for i in range(0, len(nodes)):
                if edges[i][i] == 1:
                    return False
            for i in range(1, len(nodes) - 1):
                if nodes[i] not in op_type or nodes[i] == 'input' or nodes[i] == 'output':
                    return False
            for i in range(0, len(nodes)):
                for j in range(i, len(nodes)):
                    if edges[i, j] == 1 and nodes[j] == 'input':
                        return False
            for i in range(0, len(nodes)):
                for j in range(i, len(nodes)):
                    if edges[i, j] == 1 and nodes[i] == 'output':
                        return False
            flag = 0
            for i in range(0,len(nodes)):
                if edges[i,-1] == 1:
                    flag = 1
                    break
            if flag == 0: return False
            return True

        def generate_flex_adj_mat(ori_nodes, ori_edges, max_nodes=12, min_nodes=8,random_ratio=0.5):
            # print(ori_nodes)
            # print(ori_edges)
            
            ori_edges = np.array(ori_edges)
            # ori_nodes = np.array(ori_nodes)
            nasbench_201_node_num = 8
            # random.seed(random_seed)
            nodes_num = random.randint(min_nodes, max_nodes)
            # print(f'arch_str: {arch_str}, \nmax_nodes: {max_nodes}, min_nodes: {min_nodes}, nodes_num: {nodes_num},random_seed: {random_seed},random_ratio: {random_ratio}')
            add_num = nodes_num - nasbench_201_node_num
            # ori_nodes, ori_edges = parse_architecture_string(arch_str)
            add_nodes = []
            print(f'add_num: {add_num}')
            for i in range(add_num):
                add_nodes.append(random.choice(num_to_op[1:-1]))
            # print(add_nodes)
            print(f'ori_nodes[:-1]: {ori_nodes[:-1]}, add_nodes: {add_nodes}')
            print(f'len(ori_nodes[:-1]): {len(ori_nodes[:-1])}, len(add_nodes): {len(add_nodes)}')
            nodes = ori_nodes[:-1] + add_nodes + ['output']
            edges = np.zeros((nodes_num , nodes_num))
            edges[:6, :6] = ori_edges[:6, :6]
            edges[0:8, -1] = ori_edges[0:8 , -1]
            for i in range(0, nodes_num):
                for j in range(max(7,i + 1), nodes_num):
                    rand = random.random()
                    if rand < random_ratio:
                        edges[i, j] = 1
            if nodes_num < max_nodes:
                edges = np.pad(edges, ((0, max_nodes - nodes_num), (0, max_nodes - nodes_num)), 'constant',constant_values=0)
                while len(nodes) < max_nodes:
                    nodes.append('none')
            print(f'edges size: {edges.shape}, nodes size: {len(nodes)}')
            return  edges,nodes
        

        def graph_to_graph_data(graph, idx, train_loader, searchspace, args, device):
        # def graph_to_graph_data(graph):
            ops = graph[1]
            adj = graph[0]
            nodes = []
            for op in ops:
                nodes.append(op_type[op])
            x = torch.LongTensor(nodes)

            edges_list = []
            edge_type = []
            for start in range(len(ops)):
                for end in range(len(ops)):
                    if adj[start][end] == 1:
                        edges_list.append((start, end))
                        edge_type.append(1)
                        edges_list.append((end, start))
                        edge_type.append(1)
            
            edge_index = torch.tensor(edges_list, dtype=torch.long).t()
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = edge_type
            # y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
            # y = get_nasbench201_idx_score(idx, train_loader, searchspace, args, device)
            y = self.swap_scores[idx]
            print(y, idx)
            if y > 60000:
                print(f'idx={idx}, y={y}')
                y = torch.tensor([1, 1], dtype=torch.float).view(1, -1)
                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
            else:
                print(f'idx={idx}, y={y}')
                y = torch.tensor([0, 0], dtype=torch.float).view(1, -1)
                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
                # return None
            return data
        graph_list = []
        class Args:
            pass
        args = Args()
        args.trainval = True
        args.augtype = 'none'
        args.repeat = 1
        args.score = 'hook_logdet'
        args.sigma = 0.05
        args.nasspace = 'nasbench201'
        args.batch_size = 128
        args.GPU = '0'
        args.dataset = 'cifar10'
        args.api_loc = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/NAS-Bench-201-v1_1-096897.pth'
        args.data_loc = '../cifardata/'
        args.seed = 777
        args.init = ''
        args.save_loc = 'results'
        args.save_string = 'naswot'
        args.dropout = False
        args.maxofn = 1
        args.n_samples = 100
        args.n_runs = 500
        args.stem_out_channels = 16
        args.num_stacks = 3
        args.num_modules_per_stack = 3
        args.num_labels = 1
        searchspace = nasspace.get_search_space(args)
        train_loader = dt.get_data(args.dataset, args.data_loc, args.trainval, args.batch_size, args.augtype, args.repeat, args)
        self.swap_scores = []
        import csv
        with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results.csv', 'r') as f:
        # with open('/nfs/data3/hanzhang/nasbenchDiT/graph_dit/swap_results_cifar100.csv', 'r') as f:
            reader = csv.reader(f)
            header = next(reader)
            data = [row for row in reader]
            self.swap_scores = [float(row[0]) for row in data]
        device = torch.device('cuda:2')
        with tqdm(total = len_data) as pbar:
            active_nodes = set()
            file_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json'
            with open(file_path, 'r') as f:
                graph_list = json.load(f)
            i = 0
            flex_graph_list = []
            flex_graph_path = '/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json'
            for graph in graph_list:
                print(f'iterate every graph in graph_list, here is {i}')
                arch_info = graph['arch_str']
                ops, adj_matrix, ori_nodes, ori_adj = parse_architecture_string(arch_info, padding=4)
                for op in ops:
                    if op not in active_nodes:
                        active_nodes.add(op)
                data = graph_to_graph_data((adj_matrix, ops),idx=i, train_loader=train_loader, searchspace=searchspace, args=args, device=device) 
                i += 1
                if data is None:
                    pbar.update(1)
                    continue
                flex_graph_list.append({
                    'adj_matrix':adj_matrix,
                    'ops': ops,
                })
                if i < 3:
                    print(f"i={i}, data={data}")
                    with open(f'{i}.json', 'w') as f:
                        f.write(str(data.x))
                        f.write(str(data.edge_index))
                        f.write(str(data.edge_attr))
                data_list.append(data)

                # new_adj, new_ops = generate_flex_adj_mat(ori_nodes=ori_nodes, ori_edges=ori_adj, max_nodes=12, min_nodes=9,  random_ratio=0.5)
                # flex_graph_list.append({
                #     'adj_matrix':new_adj.tolist(),
                #     'ops': new_ops,
                # })
                # data_list.append(graph_to_graph_data((new_adj, new_ops)))
               
                # graph_list.append({
                #     "adj_matrix": adj_matrix,
                #     "ops": ops,
                #     "arch_str": arch_info.arch_str,
                #     "idx": i,
                #     "train": [{
                #         "iepoch": result.get_train()['iepoch'],
                #         "loss": result.get_train()['loss'],
                #         "accuracy": result.get_train()['accuracy'],
                #         "cur_time": result.get_train()['cur_time'],
                #         "all_time": result.get_train()['all_time'],
                #         "seed": seed,
                #     }for seed, result in results.items()],
                #     "valid": [{
                #         "iepoch": result.get_eval('x-valid')['iepoch'],
                #         "loss": result.get_eval('x-valid')['loss'],
                #         "accuracy": result.get_eval('x-valid')['accuracy'],
                #         "cur_time": result.get_eval('x-valid')['cur_time'],
                #         "all_time": result.get_eval('x-valid')['all_time'],
                #         "seed": seed,
                #     }for seed, result in results.items()],
                #     "test": [{
                #         "iepoch": result.get_eval('x-test')['iepoch'],
                #         "loss": result.get_eval('x-test')['loss'],
                #         "accuracy": result.get_eval('x-test')['accuracy'],
                #         "cur_time": result.get_eval('x-test')['cur_time'],
                #         "all_time": result.get_eval('x-test')['all_time'],
                #         "seed": seed,
                #     }for seed, result in results.items()]
                # })
                # i += 1
                pbar.update(1)
        
        for graph in graph_list:
            adj_matrix = graph['adj_matrix']
            if isinstance(adj_matrix, np.ndarray):
                adj_matrix = adj_matrix.tolist()
                graph['adj_matrix'] = adj_matrix
            ops = graph['ops']
            if isinstance(ops, np.ndarray):
                ops = ops.tolist()
                graph['ops'] = ops
        with open(f'nasbench-201-graph.json', 'w') as f:
            json.dump(graph_list, f)
        # with open(flex_graph_path, 'w') as f:
            # json.dump(flex_graph_list, f)
            
        torch.save(self.collate(data_list), self.processed_paths[0])

        # def parse_architecture_string(arch_str):
        #     stages = arch_str.split('+')
        #     nodes = ['input']
        #     edges = []
            
        #     for stage in stages:
        #         operations = stage.strip('|').split('|')
        #         for op in operations:
        #             operation, idx = op.split('~')
        #             idx = int(idx)
        #             edges.append((idx, len(nodes)))  # Add edge from idx to the new node
        #             nodes.append(operation)
        #     nodes.append('output')  # Add the output node
        #     return nodes, edges

        # def create_graph(nodes, edges):
        #     G = nx.DiGraph()
        #     for i, node in enumerate(nodes):
        #         G.add_node(i, label=node)
        #     G.add_edges_from(edges)
        #     return G

        # def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
        #     nodes, edges = parse_architecture_string(arch_str)

        #     node_labels = [bonds[node] for node in nodes]  # Replace with appropriate encoding if necessary
        #     assert 0 not in node_labels, f'Invalid node label: {node_labels}'
        #     x = torch.LongTensor(node_labels)
        #     print(f'in initialize Dataset, arch_to_Graph x={x}')

        #     edges_list = [(start, end) for start, end in edges]
        #     edge_type = [bonds[nodes[end]] for start, end in edges]  # Example: using end node type as edge type
        #     edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
        #     edge_type = torch.tensor(edge_type, dtype=torch.long)
        #     edge_attr = edge_type.view(-1, 1)

        #     if target3 is not None:
        #         y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
        #     elif target2 is not None:
        #         y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
        #     else:
        #         y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)

        #     print(f'in initialize Dataset, Data_init, x={x}, y={y}, edge_index={edge_index}, edge_attr={edge_attr}')
        #     data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
        #     return data, nodes

        # bonds = {
        #     'nor_conv_1x1': 1,
        #     'nor_conv_3x3': 2,
        #     'avg_pool_3x3': 3,
        #     'skip_connect': 4,
        #     'output': 5,
        #     'none': 6,
        #     'input': 7
        # }

        # # Prepare to process NAS-Bench-201 data
        # data_list = []
        # len_data = len(self.api)  # Number of architectures
        # with tqdm(total=len_data) as pbar:
        #     for arch_index in range(len_data):
        #         arch_info = self.api.query_meta_info_by_index(arch_index)
        #         arch_str = arch_info.arch_str
        #         sa = np.random.rand()  # Placeholder for synthetic accessibility
        #         sc = np.random.rand()  # Placeholder for substructure count
        #         target = np.random.rand()  # Placeholder for target value
        #         target2 = np.random.rand()  # Placeholder for second target value
        #         target3 = np.random.rand()  # Placeholder for third target value

        #         data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
        #         data_list.append(data)
        #         pbar.update(1)

        # torch.save(self.collate(data_list), self.processed_paths[0])

class Dataset_origin(InMemoryDataset):
    def __init__(self, source, root, target_prop=None,
                 transform=None, pre_transform=None, pre_filter=None):
        self.target_prop = target_prop
        self.source = source
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [f'{self.source}.csv.gz']
    
    @property
    def processed_file_names(self):
        return [f'{self.source}.pt']

    def process(self):
        RDLogger.DisableLog('rdApp.*')
        data_path = osp.join(self.raw_dir, self.raw_file_names[0])
        data_df = pd.read_csv(data_path)
       
        def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None):
            type_idx = []
            heavy_atom_indices, active_atoms = [], []
            for atom in mol.GetAtoms():
                if atom.GetAtomicNum() != 1:
                    type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2)
                    heavy_atom_indices.append(atom.GetIdx())
                    active_atoms.append(atom.GetSymbol())
                    if valid_atoms is not None:
                        if not atom.GetSymbol() in valid_atoms:
                            return None, None
            x = torch.LongTensor(type_idx)

            edges_list = []
            edge_type = []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                if start in heavy_atom_indices and end in heavy_atom_indices:
                    start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end)
                    edges_list.append((start_new, end_new))
                    edge_type.append(bonds[bond.GetBondType()])
                    edges_list.append((end_new, start_new))
                    edge_type.append(bonds[bond.GetBondType()])
            edge_index = torch.tensor(edges_list, dtype=torch.long).t()
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = edge_type

            if target3 is not None:
                y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1)
            elif target2 is not None:
                y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1)
            else:
                y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1)

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            return data, active_atoms
        
        # Loop through every row in the DataFrame and apply the function
        data_list = []
        len_data = len(data_df)
        with tqdm(total=len_data) as pbar:
            # --- data processing start ---
            active_atoms = set()
            for i, (sms, df_row) in enumerate(data_df.iterrows()):
                if i == sms:
                    sms = df_row['smiles']
                mol = Chem.MolFromSmiles(sms, sanitize=False)
                if len(self.target_prop.split('-')) == 2:
                    target1, target2 = self.target_prop.split('-')
                    data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
                elif len(self.target_prop.split('-')) == 3:
                    target1, target2, target3 = self.target_prop.split('-')
                    data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
                else:
                    data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
                active_atoms.update(cur_active_atoms)
                data_list.append(data)
                pbar.update(1)

        torch.save(self.collate(data_list), self.processed_paths[0])

def parse_architecture_string(arch_str, padding=0):
    # print(arch_str)
    steps = arch_str.split('+')
    nodes = ['input']  # Start with input node
    ori_adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0],
                        [0, 0, 0, 1, 0, 1 ,0 ,0],
                        [0, 0, 0, 0, 0, 0, 1, 0],
                        [0, 0, 0, 0, 0, 0, 1, 0],
                        [0, 0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 0, 0, 0, 1],
                        # [0, 0, 0, 0, 0, 0, 0, 0]]) 
                        [0, 0, 0, 0, 0, 0, 0, 0]] 
   # adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
    adj_mat = [[0, 1, 1, 0, 1, 0, 0, 0],
                        [0, 0, 0, 1, 0, 1 ,0 ,0],
                        [0, 0, 0, 0, 0, 0, 1, 0],
                        [0, 0, 0, 0, 0, 0, 1, 0],
                        [0, 0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 0, 0, 0, 1],
                        [0, 0, 0, 0, 0, 0, 0, 1],
                        # [0, 0, 0, 0, 0, 0, 0, 0]]) 
                        [0, 0, 0, 0, 0, 0, 0, 0]] 
    steps = arch_str.split('+')
    steps_coding = ['0', '0', '1', '0', '1', '2']
    cont = 0
    for step in steps:
        step =  step.strip('|').split('|')
        for node in step:
            n, idx = node.split('~')
            assert idx == steps_coding[cont]
            cont += 1
            nodes.append(n)
    nodes.append('output')  # Add output node
    ori_nodes = nodes.copy()
    if padding > 0:
        for i in range(padding):
            nodes.append('none')
        for adj_row in adj_mat:
            for i in range(padding):
                adj_row.append(0)
        # adj_mat = np.append(adj_mat, np.zeros((padding, len(nodes))))
        for i in range(padding):
            adj_mat.append([0] * len(nodes))
    # print(nodes)
    # print(adj_mat)
    # print(len(adj_mat))
    # print(f'len(ori_nodes): {len(ori_nodes)}, len(nodes): {len(nodes)}')
    return nodes, adj_mat, ori_nodes, ori_adj_mat

def create_adj_matrix_and_ops(nodes, edges):
    num_nodes = len(nodes)
    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
    for (src, dst) in edges:
        adj_matrix[src][dst] = 1
    return adj_matrix, nodes
class DataInfos(AbstractDatasetInfos):
    def __init__(self, datamodule, cfg, dataset):
        tasktype_dict = {
            'hiv_b': 'classification',
            'bace_b': 'classification',
            'bbbp_b': 'classification',
            'O2': 'regression',
            'N2': 'regression',
            'CO2': 'regression',
        }
        task_name = cfg.dataset.task_name
        self.task = task_name
        self.task_type = tasktype_dict.get(task_name, "regression")
        self.ensure_connected = cfg.model.ensure_connected
        # self.api = dataset.api

        datadir = cfg.dataset.datadir

        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
        data_root = os.path.join(base_path, datadir, 'raw')
        graphs = []
        length = 15625
        ops_type = {}
        len_ops = set()
        # api = API('/home/stud/hanzhang/Graph-DiT/graph_dit/NAS-Bench-201-v1_1-096897.pth')


        def read_adj_ops_from_json(filename):
            with open(filename, 'r') as json_file:
                data = json.load(json_file)

            adj_ops_pairs = []
            for item in data:
                print(item)
                adj_matrix = np.array(item['adj_matrix'])
                ops = item['ops']
                ops = [op_type[op] for op in ops]
                adj_ops_pairs.append((adj_matrix, ops))
            
            return adj_ops_pairs
        # for i in range(length):
        #     arch_info = self.api.query_meta_info_by_index(i)
        #     nodes, edges = parse_architecture_string(arch_info.arch_str)
        #     adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)    
            # if i < 5:
            #     print("Adjacency Matrix:")
            #     print(adj_matrix)
            #     print("Operations List:")
            #     print(ops)
            # for op in ops:
            #     if op not in ops_type:
            #         ops_type[op] = len(ops_type)
            # len_ops.add(len(ops))
            # graphs.append((adj_matrix, ops))
        # graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/flex-nasbench201-graph.json')
        graphs = read_adj_ops_from_json(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json')

        # check first five graphs
        for i in range(5):
            print(f'graph {i} : {graphs[i]}')
        # print(f'ops_type: {ops_type}')

        meta_dict = new_graphs_to_json(graphs, 'nasbench-201')
        self.base_path = base_path
        self.active_nodes = meta_dict['active_nodes']
        self.max_n_nodes = meta_dict['max_n_nodes']
        self.original_max_n_nodes = meta_dict['max_n_nodes']
        self.n_nodes = torch.Tensor(meta_dict['n_nodes_per_graph'])
        self.edge_types = torch.Tensor(meta_dict['edge_type_list'])
        self.transition_E = torch.Tensor(meta_dict['transition_E'])

        self.node_decoder = meta_dict['active_nodes']
        node_types = torch.Tensor(meta_dict['node_type_list'])
        active_index = (node_types > 0).nonzero().squeeze()
        self.node_types = torch.Tensor(meta_dict['node_type_list'])[active_index]
        self.nodes_dist = DistributionNodes(self.n_nodes)
        self.active_index = active_index

        val_len = 3 * self.original_max_n_nodes - 2
        meta_val = torch.Tensor(meta_dict['valencies'])
        self.valency_distribution = torch.zeros(val_len)
        val_len = min(val_len, len(meta_val))
        self.valency_distribution[:val_len] = meta_val[:val_len]
        self.y_prior = None
        self.train_ymin = []
        self.train_ymax = []



class DataInfos_origin(AbstractDatasetInfos):
    def __init__(self, datamodule, cfg):
        tasktype_dict = {
            'hiv_b': 'classification',
            'bace_b': 'classification',
            'bbbp_b': 'classification',
            'O2': 'regression',
            'N2': 'regression',
            'CO2': 'regression',
        }
        task_name = cfg.dataset.task_name
        self.task = task_name
        self.task_type = tasktype_dict.get(task_name, "regression")
        self.ensure_connected = cfg.model.ensure_connected

        datadir = cfg.dataset.datadir

        base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        meta_filename = os.path.join(base_path, datadir, 'raw', f'{task_name}.meta.json')
        data_root = os.path.join(base_path, datadir, 'raw')
        if os.path.exists(meta_filename):
            with open(meta_filename, 'r') as f:
                meta_dict = json.load(f)
        else:
            meta_dict = compute_meta(data_root, task_name, datamodule.train_index, datamodule.test_index)

        self.base_path = base_path
        self.active_atoms = meta_dict['active_atoms']
        self.max_n_nodes = meta_dict['max_node']
        self.original_max_n_nodes = meta_dict['max_node']
        self.n_nodes = torch.Tensor(meta_dict['n_atoms_per_mol_dist'])
        self.edge_types = torch.Tensor(meta_dict['bond_type_dist'])
        self.transition_E = torch.Tensor(meta_dict['transition_E'])

        self.atom_decoder = meta_dict['active_atoms']
        node_types = torch.Tensor(meta_dict['atom_type_dist'])
        active_index = (node_types > 0).nonzero().squeeze()
        self.node_types = torch.Tensor(meta_dict['atom_type_dist'])[active_index]
        self.nodes_dist = DistributionNodes(self.n_nodes)
        self.active_index = active_index

        val_len = 3 * self.original_max_n_nodes - 2
        meta_val = torch.Tensor(meta_dict['valencies'])
        self.valency_distribution = torch.zeros(val_len)
        val_len = min(val_len, len(meta_val))
        self.valency_distribution[:val_len] = meta_val[:val_len]
        self.y_prior = None
        self.train_ymin = []
        self.train_ymax = []


def compute_meta(root, source_name, train_index, test_index):
    # initialize the periodic table
    # 118 elements + 1 for *
    # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types.
    pt = Chem.GetPeriodicTable()
    atom_name_list = []
    atom_count_list = []
    for i in range(2, 119):
        atom_name_list.append(pt.GetElementSymbol(i))
        atom_count_list.append(0)
    atom_name_list.append('*')
    atom_count_list.append(0)
    n_atoms_per_mol = [0] * 500
    bond_count_list = [0, 0, 0, 0, 0]
    bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
    valencies = [0] * 500
    tansition_E = np.zeros((118, 118, 5))
    
    # Load the data from the source file
    filename = f'{source_name}.csv.gz'
    df = pd.read_csv(f'{root}/{filename}')
    all_index = list(range(len(df)))
    non_test_index = list(set(all_index) - set(test_index))
    df = df.iloc[non_test_index]
    # extract the smiles from the dataframe
    tot_smiles = df['smiles'].tolist()

    n_atom_list = []
    n_bond_list = []
    for i, sms in enumerate(tot_smiles):
        try:
            mol = Chem.MolFromSmiles(sms)
        except:
            continue

        n_atom = mol.GetNumHeavyAtoms()
        n_bond = mol.GetNumBonds()
        n_atom_list.append(n_atom)
        n_bond_list.append(n_bond)

        n_atoms_per_mol[n_atom] += 1
        cur_atom_count_arr = np.zeros(118)
        for atom in mol.GetAtoms():
            symbol = atom.GetSymbol()
            if symbol == 'H':
                continue
            elif symbol == '*':
                atom_count_list[-1] += 1
                cur_atom_count_arr[-1] += 1
            else:
                atom_count_list[atom.GetAtomicNum()-2] += 1
                cur_atom_count_arr[atom.GetAtomicNum()-2] += 1
                try:
                    valencies[int(atom.GetExplicitValence())] += 1
                except:
                    print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence()))
        
        tansition_E_temp = np.zeros((118, 118, 5))
        for bond in mol.GetBonds():
            start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
            if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H':
                continue
            
            if start_atom.GetSymbol() == '*':
                start_index = 117
            else:
                start_index = start_atom.GetAtomicNum() - 2
            if end_atom.GetSymbol() == '*':
                end_index = 117
            else:
                end_index = end_atom.GetAtomicNum() - 2

            bond_type = bond.GetBondType()
            bond_index = bond_type_to_index[bond_type]
            bond_count_list[bond_index] += 2

            # Update the transition matrix
            # The transition matrix is symmetric, so we update both directions
            # We also update the temporary transition matrix to check for errors
            # in the atom count
            
            tansition_E[start_index, end_index, bond_index] += 2
            tansition_E[end_index, start_index, bond_index] += 2
            tansition_E_temp[start_index, end_index, bond_index] += 2
            tansition_E_temp[end_index, start_index, bond_index] += 2

        bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
        cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118
        cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118
        tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1)
        assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
    
    n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
    n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]

    atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
    print('processed meta info: ------', filename, '------')
    print('len atom_count_list', len(atom_count_list))
    print('len atom_name_list', len(atom_name_list))
    active_atoms = np.array(atom_name_list)[atom_count_list > 0]
    active_atoms = active_atoms.tolist()
    atom_count_list = atom_count_list.tolist()

    bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
    bond_count_list = bond_count_list.tolist()
    valencies = np.array(valencies) / np.sum(valencies)
    valencies = valencies.tolist()

    no_edge = np.sum(tansition_E, axis=-1) == 0
    first_elt = tansition_E[:, :, 0]
    first_elt[no_edge] = 1
    tansition_E[:, :, 0] = first_elt

    tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True)
    
    meta_dict = {
        'source': source_name, 
        'num_graph': len(n_atom_list), 
        'n_atoms_per_mol_dist': n_atoms_per_mol,
        'max_node': max(n_atom_list), 
        'max_bond': max(n_bond_list), 
        'atom_type_dist': atom_count_list,
        'bond_type_dist': bond_count_list,
        'valencies': valencies,
        'active_atoms': active_atoms,
        'num_atom_type': len(active_atoms),
        'transition_E': tansition_E.tolist(),
        }

    with open(f'/nfs/data3/hanzhang/nasbenchDiT/graph_dit/nasbench201.meta.json', "w") as f:
        json.dump(meta_dict, f)
    
    return meta_dict


if __name__ == "__main__":
    dataset = Dataset(source='nasbench', root='/nfs/data3/hanzhang/nasbenchDiT/graph-dit', target_prop='Class', transform=None)