diff --git a/graph_dit/metrics/property_metric.py b/graph_dit/metrics/property_metric.py index e0ff0b1..dbab4d6 100644 --- a/graph_dit/metrics/property_metric.py +++ b/graph_dit/metrics/property_metric.py @@ -15,6 +15,17 @@ from rdkit.Chem import AllChem from rdkit import DataStructs from rdkit.Chem import rdMolDescriptors rdBase.DisableLog('rdApp.error') +import json + +op_type = { + 'nor_conv_1x1': 1, + 'nor_conv_3x3': 2, + 'avg_pool_3x3': 3, + 'skip_connect': 4, + 'output': 5, + 'none': 6, + 'input': 7 +} task_to_colname = { 'hiv_b': 'HIV_active', @@ -32,8 +43,10 @@ tasktype_name = { 'O2': 'regression', 'N2': 'regression', 'CO2': 'regression', + 'nasbench201': 'regression', } + class TaskModel(): """Scores based on an ECFP classifier.""" def __init__(self, model_path, task_name): @@ -55,8 +68,47 @@ class TaskModel(): perfermance = self.train() dump(self.model, model_path) print('Oracle peformance: ', perfermance) - def train(self): + def read_adj_ops_from_json(filename): + with open(filename, 'r') as json_file: + data = json.load(json_file) + + adj_ops_pairs = [] + for item in data: + adj_matrix = np.array(item['adj_matrix']) + ops = item['ops'] + acc = item['train'][0]['accuracy'] + adj_ops_pairs.append((adj_matrix, ops, acc)) + + return adj_ops_pairs + def feature_from_adj_and_ops(adj, ops): + return np.concatenate([adj.flatten(), ops]) + filename = '/home/stud/hanzhang/nasbenchDiT/graph_dit/nasbench-201-graph.json' + graphs = read_adj_ops_from_json(filename) + adjs = [] + opss = [] + accs = [] + features = [] + for graph in graphs: + adj, ops, acc=graph + op_code = [op_type[op] for op in ops] + adjs.append(adj) + opss.append(op_code) + accs.append(acc) + features.append(feature_from_adj_and_ops(adj, op_code)) + features = np.array(features) + labels = np.array(accs) + + mask = ~np.isnan(labels) + labels = labels[mask] + features = features[mask] + self.model.fit(features, labels) + y_pred = self.model.predict(features) + perf = self.metric_func(labels, y_pred) + print(f'{self.task_name} performance: {perf}') + return perf + + def train__(self): data_path = os.path.dirname(self.model_path) data_path = os.path.join(os.path.dirname(self.model_path), '..', f'raw/{self.task_name}.csv.gz') df = pd.read_csv(data_path)