In [1]:
from nas_201_api import NASBench201API as API

In [2]:
api = API('./NAS-Bench-201-v1_1-096897.pth', verbose=False)

In [3]:
num = len(api)
for i, arch_str in enumerate(api):
    print('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str))

    0/15625 : |avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|
    1/15625 : |nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
    2/15625 : |avg_pool_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|avg_pool_3x3~0|avg_pool_3x3~1|avg_pool_3x3~2|
    3/15625 : |avg_pool_3x3~0|+|skip_connect~0|none~1|+|none~0|none~1|skip_connect~2|
    4/15625 : |skip_connect~0|+|skip_connect~0|nor_conv_1x1~1|+|skip_connect~0|skip_connect~1|nor_conv_1x1~2|
    5/15625 : |nor_conv_1x1~0|+|skip_connect~0|nor_conv_1x1~1|+|nor_conv_3x3~0|none~1|avg_pool_3x3~2|
    6/15625 : |nor_conv_3x3~0|+|none~0|none~1|+|avg_pool_3x3~0|nor_conv_3x3~1|nor_conv_3x3~2|
    7/15625 : |none~0|+|skip_connect~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|
    8/15625 : |nor_conv_1x1~0|+|avg_pool_3x3~0|skip_connect~1|+|skip_connect~0|none~1|nor_conv_3x3~2|
    9/15625 : |avg_pool_3x3~0|+|none~0|none~1|+|skip_connect~0|none~1|nor_

In [4]:
api.show(1)
api.show(2)

info = api.query_meta_info_by_index(1)
res_metrics = info.get_metrics('cifar10', 'train')
cost_metrics = info.get_compute_costs('cifar100') 

>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> 012 epochs >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
datasets : ['cifar10-valid', 'cifar10', 'cifar100', 'ImageNet16-120'], extra-info : None
cifar10-valid  FLOP=113.95 M, Params=0.802 MB, latency=16.85 ms.
cifar10-valid  train : [loss = 0.382, top1 = 86.97%], valid : [loss = 0.514, top1 = 82.83%]
cifar10        FLOP=113.95 M, Params=0.802 MB, latency=16.85 ms.
cifar10        train : [loss = 0.243, top1 = 91.69%], test  : [loss = 0.362, top1 = 88.22%]
cifar100       FLOP=113.96 M, Params=0.808 MB, latency=15.36 ms.
cifar100       train : [loss = 1.271, top1 = 63.76%], valid : [loss = 1.495, top1 = 57.80%], test : [loss = 1.478, top1 = 58.26%]
ImageNet16-120 FLOP= 28.50 M, Params=0.810 MB, latency=13.77 ms.
ImageNet16-120 train : [loss = 2.548, top1 = 35.41%], valid : [loss = 2.580, top1 = 35.43%], test : [loss = 2.611, top1 = 33.80%]
>>>>>>>>>>>>>>>>>>

In [5]:
results = api.query_by_index(1, 'cifar100')
print('There are {:} trials for this architecture [{:}] on CIFAR-100'.format(len(results), api[1]))

for seed, result in results.items():
    print('Latency : {:}'.format(result.get_latency()))
    print('Train Info : {:}'.format(result.get_train()))
    print('Valid Info : {:}'.format(result.get_eval('x-valid')))
    print('Test  Info : {:}'.format(result.get_eval('x-test')))
    print('')
    print('Train Info [10-th epoch]: {:}'.format(result.get_train(10)))

There are 1 trials for this architecture [|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|] on CIFAR-100
Latency : 0.015363384176183631
Train Info : {'iepoch': 11, 'loss': 1.2711473697662354, 'accuracy': 63.756, 'cur_time': 21.69641375541687, 'all_time': 260.35696506500244}
Valid Info : {'iepoch': 11, 'loss': 1.495258326148987, 'accuracy': 57.79999996948242, 'cur_time': 0.7508397953850883, 'all_time': 9.01007754462106}
Test  Info : {'iepoch': 11, 'loss': 1.477725588607788, 'accuracy': 58.25999995727539, 'cur_time': 0.7508397953850883, 'all_time': 9.01007754462106}

Train Info [10-th epoch]: {'iepoch': 10, 'loss': 1.3365458668136596, 'accuracy': 61.568, 'cur_time': 21.69641375541687, 'all_time': 238.66055130958557}


In [6]:
import sys
sys.path.append('../') 

import os
import os.path as osp
import pathlib
import json

import torch
import torch.nn.functional as F
from rdkit import Chem, RDLogger
from rdkit.Chem.rdchem import BondType as BT
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split


In [7]:
def compute_meta_graph(root, source_name, train_index, test_index):
    # initialize the periodic table
    # 118 elements + 1 for *
    # Initializes arrays to count the number of atoms per molecule, bond types, valencies, and transition probabilities between atom types.
    pt = Chem.GetPeriodicTable()
    atom_name_list = []
    atom_count_list = []
    for i in range(2, 119):
        atom_name_list.append(pt.GetElementSymbol(i))
        atom_count_list.append(0)
    atom_name_list.append('*')
    atom_count_list.append(0)
    n_atoms_per_mol = [0] * 500
    bond_count_list = [0, 0, 0, 0, 0]
    bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
    valencies = [0] * 500
    tansition_E = np.zeros((118, 118, 5))
    
    # Load the data from the source file
    filename = f'{source_name}.csv.gz'
    df = pd.read_csv(f'{root}/{filename}')
    all_index = list(range(len(df)))
    print(all_index)
    print(test_index)
    non_test_index = list(set(all_index) - set(test_index))
    df = df.iloc[non_test_index]
    print(df.head())
    print(df['smiles'].tolist()[:5])
    # extract the smiles from the dataframe
    tot_smiles = df['smiles'].tolist()

    n_atom_list = []
    n_bond_list = []
    for i, sms in enumerate(tot_smiles):
        try:
            mol = Chem.MolFromSmiles(sms)
        except:
            continue

        n_atom = mol.GetNumHeavyAtoms()
        n_bond = mol.GetNumBonds()
        n_atom_list.append(n_atom)
        n_bond_list.append(n_bond)

        n_atoms_per_mol[n_atom] += 1
        cur_atom_count_arr = np.zeros(118)
        for atom in mol.GetAtoms():
            symbol = atom.GetSymbol()
            if symbol == 'H':
                continue
            elif symbol == '*':
                atom_count_list[-1] += 1
                cur_atom_count_arr[-1] += 1
            else:
                atom_count_list[atom.GetAtomicNum()-2] += 1
                cur_atom_count_arr[atom.GetAtomicNum()-2] += 1
                try:
                    valencies[int(atom.GetExplicitValence())] += 1
                except:
                    print('src', source_name,'int(atom.GetExplicitValence())', int(atom.GetExplicitValence()))
        
        tansition_E_temp = np.zeros((118, 118, 5))
        for bond in mol.GetBonds():
            start_atom, end_atom = bond.GetBeginAtom(), bond.GetEndAtom()
            if start_atom.GetSymbol() == 'H' or end_atom.GetSymbol() == 'H':
                continue
            
            if start_atom.GetSymbol() == '*':
                start_index = 117
            else:
                start_index = start_atom.GetAtomicNum() - 2
            if end_atom.GetSymbol() == '*':
                end_index = 117
            else:
                end_index = end_atom.GetAtomicNum() - 2

            bond_type = bond.GetBondType()
            bond_index = bond_type_to_index[bond_type]
            bond_count_list[bond_index] += 2

            # Update the transition matrix
            # The transition matrix is symmetric, so we update both directions
            # We also update the temporary transition matrix to check for errors
            # in the atom count
            
            tansition_E[start_index, end_index, bond_index] += 2
            tansition_E[end_index, start_index, bond_index] += 2
            tansition_E_temp[start_index, end_index, bond_index] += 2
            tansition_E_temp[end_index, start_index, bond_index] += 2

        bond_count_list[0] += n_atom * (n_atom - 1) - n_bond * 2
        cur_tot_bond = cur_atom_count_arr.reshape(-1,1) * cur_atom_count_arr.reshape(1,-1) * 2 # 118 * 118
        cur_tot_bond = cur_tot_bond - np.diag(cur_atom_count_arr) * 2 # 118 * 118
        tansition_E[:, :, 0] += cur_tot_bond - tansition_E_temp.sum(axis=-1)
        assert (cur_tot_bond > tansition_E_temp.sum(axis=-1)).sum() >= 0, f'i:{i}, sms:{sms}'
    
    n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)
    n_atoms_per_mol = n_atoms_per_mol.tolist()[:51]

    atom_count_list = np.array(atom_count_list) / np.sum(atom_count_list)
    print('processed meta info: ------', filename, '------')
    print('len atom_count_list', len(atom_count_list))
    print('len atom_name_list', len(atom_name_list))
    active_atoms = np.array(atom_name_list)[atom_count_list > 0]
    active_atoms = active_atoms.tolist()
    atom_count_list = atom_count_list.tolist()

    bond_count_list = np.array(bond_count_list) / np.sum(bond_count_list)
    bond_count_list = bond_count_list.tolist()
    valencies = np.array(valencies) / np.sum(valencies)
    valencies = valencies.tolist()

    no_edge = np.sum(tansition_E, axis=-1) == 0
    first_elt = tansition_E[:, :, 0]
    first_elt[no_edge] = 1
    tansition_E[:, :, 0] = first_elt

    tansition_E = tansition_E / np.sum(tansition_E, axis=-1, keepdims=True)
    
    meta_dict = {
        'source': source_name, 
        'num_graph': len(n_atom_list), 
        'n_atoms_per_mol_dist': n_atoms_per_mol,
        'max_node': max(n_atom_list), 
        'max_bond': max(n_bond_list), 
        'atom_type_dist': atom_count_list,
        'bond_type_dist': bond_count_list,
        'valencies': valencies,
        'active_atoms': active_atoms,
        'num_atom_type': len(active_atoms),
        'transition_E': tansition_E.tolist(),
        }

    with open(f'{root}/{source_name}.meta.json', "w") as f:
        json.dump(meta_dict, f)
    
    return meta_dict

In [8]:
bonds = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
class Dataset(InMemoryDataset):
    def __init__(self, source, root, target_prop=None,
                 transform=None, pre_transform=None, pre_filter=None):
        self.target_prop = target_prop
        self.source = source
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [f'{self.source}.csv.gz']
    
    @property
    def processed_file_names(self):
        return [f'{self.source}.pt']

    def process(self):
        RDLogger.DisableLog('rdApp.*')
        data_path = osp.join(self.raw_dir, self.raw_file_names[0])
        data_df = pd.read_csv(data_path)
       
        def mol_to_graph(mol, sa, sc, target, target2=None, target3=None, valid_atoms=None):
            type_idx = []
            heavy_atom_indices, active_atoms = [], []
            for atom in mol.GetAtoms():
                if atom.GetAtomicNum() != 1:
                    type_idx.append(119-2) if atom.GetSymbol() == '*' else type_idx.append(atom.GetAtomicNum()-2)
                    heavy_atom_indices.append(atom.GetIdx())
                    active_atoms.append(atom.GetSymbol())
                    if valid_atoms is not None:
                        if not atom.GetSymbol() in valid_atoms:
                            return None, None
            x = torch.LongTensor(type_idx)

            edges_list = []
            edge_type = []
            for bond in mol.GetBonds():
                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                if start in heavy_atom_indices and end in heavy_atom_indices:
                    start_new, end_new = heavy_atom_indices.index(start), heavy_atom_indices.index(end)
                    edges_list.append((start_new, end_new))
                    edge_type.append(bonds[bond.GetBondType()])
                    edges_list.append((end_new, start_new))
                    edge_type.append(bonds[bond.GetBondType()])
            edge_index = torch.tensor(edges_list, dtype=torch.long).t()
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = edge_type

            if target3 is not None:
                y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1,-1)
            elif target2 is not None:
                y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1,-1)
            else:
                y = torch.tensor([sa, sc, target], dtype=torch.float).view(1,-1)

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, idx=i)
            if self.pre_transform is not None:
                data = self.pre_transform(data)
            return data, active_atoms
        
        # Loop through every row in the DataFrame and apply the function
        data_list = []
        len_data = len(data_df)
        with tqdm(total=len_data) as pbar:
            # --- data processing start ---
            active_atoms = set()
            for i, (sms, df_row) in enumerate(data_df.iterrows()):
                if i == sms:
                    sms = df_row['smiles']
                mol = Chem.MolFromSmiles(sms, sanitize=False)
                if len(self.target_prop.split('-')) == 2:
                    target1, target2 = self.target_prop.split('-')
                    data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2])
                elif len(self.target_prop.split('-')) == 3:
                    target1, target2, target3 = self.target_prop.split('-')
                    data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[target1], target2=df_row[target2], target3=df_row[target3])
                else:
                    data, cur_active_atoms = mol_to_graph(mol, df_row['SA'], df_row['SC'], df_row[self.target_prop])
                active_atoms.update(cur_active_atoms)
                data_list.append(data)
                pbar.update(1)

        torch.save(self.collate(data_list), self.processed_paths[0])

In [9]:
def random_data_split(task, dataset):
    nan_count = torch.isnan(dataset.y[:, 0]).sum().item()
    labeled_len = len(dataset) - nan_count
    full_idx = list(range(labeled_len))
    train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
    train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
    train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
    unlabeled_index = list(range(labeled_len, len(dataset)))
    print(task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
    return train_index, val_index, test_index, unlabeled_index

In [10]:
data_root = './data'
task_name = 'O2-N2-CO2'
guidance_target='O2-N2-CO2'

dataset = Dataset(source=task_name, root=data_root, target_prop=guidance_target, transform=None)

In [11]:
random_index = random_data_split(task_name, dataset)

O2-N2-CO2  dataset len 553 train len 331 val len 111 test len 111 unlabeled len 0


In [12]:
data_dic = './data/raw'
meta_dict = compute_meta_graph(data_dic, task_name, None, random_index[2])

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221,

In [13]:
arch_index = 0
arch_info = api.query_meta_info_by_index(arch_index)

In [14]:
def parse_architecture_string(arch_str):
    print(arch_str)
    steps = arch_str.split('+')
    nodes = ['input']  # Start with input node
    edges = []
    for i, step in enumerate(steps):
        step = step.strip('|').split('|')
        for node in step:
            op, idx = node.split('~')
            edges.append((int(idx), i+1))  # i+1 because 0 is input node
            nodes.append(op)
    nodes.append('output')  # Add output node
    return nodes, edges

In [15]:
def create_adj_matrix_and_ops(nodes, edges):
    num_nodes = len(nodes)
    adj_matrix = np.zeros((num_nodes, num_nodes), dtype=int)
    for (src, dst) in edges:
        adj_matrix[src][dst] = 1
    return adj_matrix, nodes

In [16]:
nodes, edges = parse_architecture_string(arch_info.arch_str)
adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)

|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|


In [17]:

for i in range(1, 4):
    arch_index = i
    arch_info = api.query_meta_info_by_index(arch_index)
    nodes, edges = parse_architecture_string(arch_info.arch_str)
    adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)
    print("Adjacency Matrix:")
    print(adj_matrix)
    print("Operations List:")
    print(ops)


|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
Adjacency Matrix:
[[0 1 1 1 0 0 0 0]
 [0 0 1 1 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]]
Operations List:
['input', 'nor_conv_3x3', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'nor_conv_3x3', 'skip_connect', 'output']
|avg_pool_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|avg_pool_3x3~0|avg_pool_3x3~1|avg_pool_3x3~2|
Adjacency Matrix:
[[0 1 1 1 0 0 0 0]
 [0 0 1 1 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]]
Operations List:
['input', 'avg_pool_3x3', 'nor_conv_3x3', 'nor_conv_3x3', 'avg_pool_3x3', 'avg_pool_3x3', 'avg_pool_3x3', 'output']
|avg_pool_3x3~0|+|skip_connect~0|none~1|+|none~0|none~1|skip_connect~2|
Adjacency Matrix:
[[0 1 1 1 0 0 0 0]
 [0 0 1 1 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0

In [18]:
graphs = []
length = 15625
ops_type = {}
len_ops = set()
for i in range(length):
    arch_info = api.query_meta_info_by_index(i)
    nodes, edges = parse_architecture_string(arch_info.arch_str)
    adj_matrix, ops = create_adj_matrix_and_ops(nodes, edges)    
    if i < 5:
        print("Adjacency Matrix:")
        print(adj_matrix)
        print("Operations List:")
        print(ops)
    for op in ops:
        if op not in ops_type:
            ops_type[op] = len(ops_type)
    len_ops.add(len(ops))
    graphs.append((adj_matrix, ops))


|avg_pool_3x3~0|+|nor_conv_1x1~0|skip_connect~1|+|nor_conv_1x1~0|skip_connect~1|skip_connect~2|
Adjacency Matrix:
[[0 1 1 1 0 0 0 0]
 [0 0 1 1 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]]
Operations List:
['input', 'avg_pool_3x3', 'nor_conv_1x1', 'skip_connect', 'nor_conv_1x1', 'skip_connect', 'skip_connect', 'output']
|nor_conv_3x3~0|+|nor_conv_3x3~0|avg_pool_3x3~1|+|skip_connect~0|nor_conv_3x3~1|skip_connect~2|
Adjacency Matrix:
[[0 1 1 1 0 0 0 0]
 [0 0 1 1 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0]]
Operations List:
['input', 'nor_conv_3x3', 'nor_conv_3x3', 'avg_pool_3x3', 'skip_connect', 'nor_conv_3x3', 'skip_connect', 'output']
|avg_pool_3x3~0|+|nor_conv_3x3~0|nor_conv_3x3~1|+|avg_pool_3x3~0|avg_pool_3x3~1|avg_pool_3x3~2|
Adjacency Matrix:
[[0 1 1 1 0 0 0 0]
 [0 0 1 1 0 0 0 0]
 [0 0 0 1 0 0 0 0]
 [0 0 0 0 0 0 0 0]
 [0 0 0 0 

|none~0|+|nor_conv_3x3~0|none~1|+|skip_connect~0|none~1|nor_conv_3x3~2|
|none~0|+|skip_connect~0|nor_conv_3x3~1|+|nor_conv_3x3~0|none~1|none~2|
|skip_connect~0|+|avg_pool_3x3~0|none~1|+|nor_conv_1x1~0|nor_conv_3x3~1|avg_pool_3x3~2|
|none~0|+|none~0|skip_connect~1|+|skip_connect~0|nor_conv_1x1~1|skip_connect~2|
|nor_conv_3x3~0|+|avg_pool_3x3~0|nor_conv_3x3~1|+|none~0|avg_pool_3x3~1|avg_pool_3x3~2|
|nor_conv_3x3~0|+|none~0|none~1|+|none~0|none~1|nor_conv_1x1~2|
|nor_conv_3x3~0|+|skip_connect~0|none~1|+|avg_pool_3x3~0|none~1|avg_pool_3x3~2|
|skip_connect~0|+|none~0|skip_connect~1|+|skip_connect~0|nor_conv_1x1~1|avg_pool_3x3~2|
|nor_conv_3x3~0|+|avg_pool_3x3~0|skip_connect~1|+|nor_conv_1x1~0|nor_conv_3x3~1|none~2|
|nor_conv_1x1~0|+|nor_conv_1x1~0|nor_conv_1x1~1|+|nor_conv_3x3~0|none~1|none~2|
|none~0|+|skip_connect~0|skip_connect~1|+|skip_connect~0|nor_conv_1x1~1|nor_conv_3x3~2|
|none~0|+|nor_conv_1x1~0|none~1|+|nor_conv_1x1~0|skip_connect~1|nor_conv_3x3~2|
|nor_conv_1x1~0|+|nor_conv_3x3~0

In [19]:
print(len(ops_type))
print(len(len_ops))
print(ops_type)
print(len_ops)

7
1
{'input': 0, 'avg_pool_3x3': 1, 'nor_conv_1x1': 2, 'skip_connect': 3, 'output': 4, 'nor_conv_3x3': 5, 'none': 6}
{8}


In [20]:
op_to_atom = {
    'input': 'Si',          # Hydrogen for input
    'nor_conv_1x1': 'C',   # Carbon for 1x1 convolution
    'nor_conv_3x3': 'N',   # Nitrogen for 3x3 convolution
    'avg_pool_3x3': 'O',   # Oxygen for 3x3 average pooling
    'skip_connect': 'P',   # Phosphorus for skip connection
    'none': 'S',           # Sulfur for no operation
    'output': 'He'         # Helium for output
}



In [21]:
def graphs_to_json(graphs, filename):
    bonds = {
        'nor_conv_1x1': 1,
        'nor_conv_3x3': 2,
        'avg_pool_3x3': 3,
        'skip_connect': 4,
        'input': 0,
        'output': 5,
        'none': 6
    }

    source_name = "nas-bench-201"
    num_graph = len(graphs)
    pt = Chem.GetPeriodicTable()
    atom_name_list = []
    atom_count_list = []
    for i in range(2, 119):
        atom_name_list.append(pt.GetElementSymbol(i))
        atom_count_list.append(0)
    atom_name_list.append('*')
    atom_count_list.append(0)
    n_atoms_per_mol = [0] * 500
    bond_count_list = [0, 0, 0, 0, 0]
    bond_type_to_index =  {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
    valencies = [0] * 500
    transition_E = np.zeros((118, 118, 5))

    n_atom_list = []
    n_bond_list = []
    # graphs = [(adj_matrix, ops), ...]
    for graph in graphs:
        ops = graph[1]
        n_atom = len(ops)
        n_bond = len(ops)
        n_atom_list.append(n_atom)
        n_bond_list.append(n_bond)

        n_atoms_per_mol[n_atom] += 1
        cur_atom_count_arr = np.zeros(118)

        for op in ops:
            symbol = op_to_atom[op]
            if symbol == 'H':
                continue
            elif symbol == '*':
                atom_count_list[-1] += 1
                cur_atom_count_arr[-1] += 1
            else:
                atom_count_list[pt.GetAtomicNumber(symbol)-2] += 1
                cur_atom_count_arr[pt.GetAtomicNumber(symbol)-2] += 1
                try:
                    valencies[int(pt.GetDefaultValence(symbol))] += 1
                except:
                    print('int(pt.GetDefaultValence(symbol))', int(pt.GetDefaultValence(symbol)))
        transition_E_temp = np.zeros((118, 118, 5))
        print(n_atom)
        for i in range(n_atom):
            for j in range(n_atom):
                if i == j:
                    continue
                start_atom, end_atom = i, j
                if ops[start_atom] == 'input' or ops[end_atom] == 'input':
                    continue
                if ops[start_atom] == 'output' or ops[end_atom] == 'output':
                    continue
                if ops[start_atom] == 'none' or ops[end_atom] == 'none':
                    continue
                
                start_index = pt.GetAtomicNumber(op_to_atom[ops[start_atom]]) - 2
                end_index = pt.GetAtomicNumber(op_to_atom[ops[end_atom]]) - 2
                print(start_index, end_index)
                bond_index = 1
                bond_count_list[bond_index] += 2
                
                transition_E[start_index, end_index, bond_index] += 2
                transition_E[end_index, start_index, bond_index] += 2
                transition_E_temp[start_index, end_index, bond_index] += 2
                transition_E_temp[end_index, start_index, bond_index] += 2

        n_atoms_per_mol = np.array(n_atoms_per_mol) / np.sum(n_atoms_per_mol)

        # transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
        no_edge = np.sum(transition_E, axis=-1) == 0
        first_elt = transition_E[:, :, 0]
        first_elt[no_edge] = 1
        transition_E[:, :, 0] = first_elt
        transition_E = transition_E / np.sum(transition_E, axis=-1, keepdims=True)
        meta_dict = {
            'source': 'nasbench-201',
            'num_graph': num_graph,
            'n_atoms_per_mol_dist': n_atoms_per_mol.tolist()[:51],
            'max_node': max(n_atom_list),
            'max_bond': max(n_bond_list),
            'atom_type_dist': atom_count_list,
            'bond_type_dist': bond_count_list,
            'valencies': valencies,
            'active_atoms': [atom_name_list[i] for i in range(118) if atom_count_list[i] > 0],
            'num_atom_type': len([atom_name_list[i] for i in range(118) if atom_count_list[i] > 0]),
            'transition_E': transition_E.tolist(),
        }

        with open(f'{filename}.meta.json', 'w') as f:
            json.dump(meta_dict, f)
        return meta_dict

        

In [22]:
graphs_to_json(graphs, 'nasbench-201')

8
6 4
6 13
6 4
6 13
6 13
4 6
4 13
4 4
4 13
4 13
13 6
13 4
13 4
13 13
13 13
4 6
4 4
4 13
4 13
4 13
13 6
13 4
13 13
13 4
13 13
13 6
13 4
13 13
13 4
13 13


{'source': 'nasbench-201',
 'num_graph': 15625,
 'n_atoms_per_mol_dist': [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 'max_node': 8,
 'max_bond': 1,
 'atom_type_dist': [1,
  0,
  0,
  0,
  2,
  0,
  1,
  0,
  0,
  0,
  0,
  0,
  1,
  3,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,


In [23]:
def gen_adj_matrix_and_ops(nasbench):
    i = 0
    epoch = 108

    for unique_hash in nasbench.hash_iterator():
        fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash(unique_hash)

In [38]:
import torch
from torch_geometric.data import InMemoryDataset, Data
import os.path as osp
import pandas as pd
from tqdm import tqdm
import networkx as nx
import numpy as np

class Dataset(InMemoryDataset):
    def __init__(self, source, root, target_prop=None, transform=None, pre_transform=None, pre_filter=None):
        self.target_prop = target_prop
        source = './NAS-Bench-201-v1_1-096897.pth'
        self.source = source
        self.api = API(source)  # Initialize NAS-Bench-201 API
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []  # NAS-Bench-201 data is loaded via the API, no raw files needed
    
    @property
    def processed_file_names(self):
        return [f'{self.source}.pt']

    def process(self):
        def parse_architecture_string(arch_str):
            stages = arch_str.split('+')
            nodes = ['input']
            edges = []
            
            for stage in stages:
                operations = stage.strip('|').split('|')
                for op in operations:
                    operation, idx = op.split('~')
                    idx = int(idx)
                    edges.append((idx, len(nodes)))  # Add edge from idx to the new node
                    nodes.append(operation)
            nodes.append('output')  # Add the output node
            return nodes, edges

        def create_graph(nodes, edges):
            G = nx.DiGraph()
            for i, node in enumerate(nodes):
                G.add_node(i, label=node)
            G.add_edges_from(edges)
            return G

        def arch_to_graph(arch_str, sa, sc, target, target2=None, target3=None):
            nodes, edges = parse_architecture_string(arch_str)
            node_labels = [bonds[node] for node in nodes]  # Replace with appropriate encoding if necessary
            x = torch.LongTensor(node_labels)

            edges_list = [(start, end) for start, end in edges]
            edge_type = [bonds[nodes[end]] for start, end in edges]  # Example: using end node type as edge type
            edge_index = torch.tensor(edges_list, dtype=torch.long).t().contiguous()
            edge_type = torch.tensor(edge_type, dtype=torch.long)
            edge_attr = edge_type.view(-1, 1)

            if target3 is not None:
                y = torch.tensor([sa, sc, target, target2, target3], dtype=torch.float).view(1, -1)
            elif target2 is not None:
                y = torch.tensor([sa, sc, target, target2], dtype=torch.float).view(1, -1)
            else:
                y = torch.tensor([sa, sc, target], dtype=torch.float).view(1, -1)

            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)
            return data, nodes

        bonds = {
            'nor_conv_1x1': 1,
            'nor_conv_3x3': 2,
            'avg_pool_3x3': 3,
            'skip_connect': 4,
            'input': 0,
            'output': 5,
            'none': 6
        }

        # Prepare to process NAS-Bench-201 data
        data_list = []
        len_data = len(self.api)  # Number of architectures
        with tqdm(total=len_data) as pbar:
            for arch_index in range(len_data):
                arch_info = self.api.query_meta_info_by_index(arch_index)
                arch_str = arch_info.arch_str
                sa = np.random.rand()  # Placeholder for synthetic accessibility
                sc = np.random.rand()  # Placeholder for substructure count
                target = np.random.rand()  # Placeholder for target value
                target2 = np.random.rand()  # Placeholder for second target value
                target3 = np.random.rand()  # Placeholder for third target value

                data, active_nodes = arch_to_graph(arch_str, sa, sc, target, target2, target3)
                data_list.append(data)
                pbar.update(1)

        torch.save(self.collate(data_list), self.processed_paths[0])


In [39]:
# dataset = Dataset(source='./NAS-Bench-201-v1_1-096897.pth', root='./data')

In [40]:
import os
import pathlib
import torch
from torch_geometric.data import DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd
from tqdm import tqdm
# import nas_bench_201 as nb201

import utils as utils
from datasets.abstract_dataset import AbstractDatasetInfos, AbstractDataModule
from diffusion.distributions import DistributionNodes

class DataModule(AbstractDataModule):
    def __init__(self, cfg):
        self.datadir = cfg.dataset.datadir
        self.task = cfg.dataset.task_name
        print("DataModule")
        print("task", self.task)
        print("datadir", self.datadir)
        super().__init__(cfg)

    def prepare_data(self) -> None:
        target = getattr(self.cfg.dataset, 'guidance_target', None)
        print("target", target)
        # try:
        #     base_path = pathlib.Path(os.path.realpath(__file__)).parents[2]
        # except NameError:
        # base_path = pathlib.Path(os.getcwd()).parent[2]
        base_path = '/home/stud/hanzhang/Graph-Dit'
        root_path = os.path.join(base_path, self.datadir)
        self.root_path = root_path

        batch_size = self.cfg.train.batch_size
        
        num_workers = self.cfg.train.num_workers
        pin_memory = self.cfg.dataset.pin_memory

        # Load the dataset to the memory
        # Dataset has target property, root path, and transform
        source = './NAS-Bench-201-v1_1-096897.pth'
        dataset = Dataset(source=source, root=root_path, target_prop=target, transform=None)

        # if len(self.task.split('-')) == 2:
        #     train_index, val_index, test_index, unlabeled_index = self.fixed_split(dataset)
        # else:
        train_index, val_index, test_index, unlabeled_index = self.random_data_split(dataset)

        self.train_index, self.val_index, self.test_index, self.unlabeled_index = train_index, val_index, test_index, unlabeled_index
        train_index, val_index, test_index, unlabeled_index = torch.LongTensor(train_index), torch.LongTensor(val_index), torch.LongTensor(test_index), torch.LongTensor(unlabeled_index)
        if len(unlabeled_index) > 0:
            train_index = torch.cat([train_index, unlabeled_index], dim=0)
        
        train_dataset, val_dataset, test_dataset = dataset[train_index], dataset[val_index], dataset[test_index]
        self.train_dataset = train_dataset  
        print('train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
        print('train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
        print('dataset len', len(dataset), 'train len', len(train_dataset), 'val len', len(val_dataset), 'test len', len(test_dataset))
        self.train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=True, pin_memory=pin_memory)

        self.val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)
        self.test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, pin_memory=False)

        training_iterations = len(train_dataset) // batch_size
        self.training_iterations = training_iterations
    
    def random_data_split(self, dataset):
        nan_count = torch.isnan(dataset.data.y[:, 0]).sum().item()
        labeled_len = len(dataset) - nan_count
        full_idx = list(range(labeled_len))
        train_ratio, valid_ratio, test_ratio = 0.6, 0.2, 0.2
        train_index, test_index, _, _ = train_test_split(full_idx, full_idx, test_size=test_ratio, random_state=42)
        train_index, val_index, _, _ = train_test_split(train_index, train_index, test_size=valid_ratio/(valid_ratio+train_ratio), random_state=42)
        unlabeled_index = list(range(labeled_len, len(dataset)))
        print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index), 'unlabeled len', len(unlabeled_index))
        return train_index, val_index, test_index, unlabeled_index
    
    def fixed_split(self, dataset):
        if self.task == 'O2-N2':
            test_index = [42,43,92,122,197,198,251,254,257,355,511,512,549,602,603,604]
        else:
            raise ValueError('Invalid task name: {}'.format(self.task))
        full_idx = list(range(len(dataset)))
        full_idx = list(set(full_idx) - set(test_index))
        train_ratio = 0.8
        train_index, val_index, _, _ = train_test_split(full_idx, full_idx, test_size=1-train_ratio, random_state=42)
        print(self.task, ' dataset len', len(dataset), 'train len', len(train_index), 'val len', len(val_index), 'test len', len(test_index))
        return train_index, val_index, test_index, []

    def get_train_smiles(self):
        raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")

    def get_data_split(self):
        raise NotImplementedError("This method is not applicable for NAS-Bench-201 data.")

    def example_batch(self):
        return next(iter(self.val_loader))
    
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader
    
    def test_dataloader(self):
        return self.test_loader


In [41]:
from omegaconf import DictConfig, OmegaConf
import argparse
import hydra

def parse_arg():
    parser = argparse.ArgumentParser(description='Diffusion')
    parser.add_argument('--config', type=str, default='config.yaml', help='config file')
    return parser.parse_args()

def task1(cfg: DictConfig):
    datamodule = DataModule(cfg=cfg)
    datamodule.prepare_data()

cfg = {
    'general':{
        'name': 'graph_dit',
        'wandb': 'disabled' ,
        'gpus': 1,
        'resume': 'null',
        'test_only': 'null',
        'sample_every_val': 2500,
        'samples_to_generate': 512,
        'samples_to_save': 3,
        'chains_to_save': 1,
        'log_every_steps': 50,
        'number_chain_steps': 8,
        'final_model_samples_to_generate': 10000,
        'final_model_samples_to_save': 20,
        'final_model_chains_to_save': 1,
        'enable_progress_bar': False,
        'save_model': True,
    },
    'model':{
            'type': 'discrete',
            'transition': 'marginal',
            'model': 'graph_dit',
            'diffusion_steps': 500,
            'diffusion_noise_schedule': 'cosine',
            'guide_scale': 2,
            'hidden_size': 1152,
            'depth': 6,
            'num_heads': 16,
            'mlp_ratio': 4,
            'drop_condition': 0.01,
            'lambda_train': [1, 10],  # node and edge training weight 
            'ensure_connected': True,
    },
    'train':{
            'n_epochs': 10000,
            'batch_size': 1200,
            'lr': 0.0002,
            'clip_grad': 'null',
            'num_workers': 0,
            'weight_decay': 0,
            'seed': 0,
            'val_check_interval': 'null',
            'check_val_every_n_epoch': 1,
    },
    'dataset':{
            'datadir': 'data',
            'task_name': 'nasbench-201',
            'guidance_target': 'nasbench-201',
            'pin_memory': False,
    },
}


In [42]:
cfg = OmegaConf.create(cfg)
task1(cfg)

DataModule
task nasbench-201
datadir data
target nasbench-201
try to create the NAS-Bench-201 api from ./NAS-Bench-201-v1_1-096897.pth




nasbench-201  dataset len 15625 train len 9375 val len 3125 test len 3125 unlabeled len 0
train len 9375 val len 3125 test len 3125
train len 9375 val len 3125 test len 3125
dataset len 15625 train len 9375 val len 3125 test len 3125
