39 lines
1.5 KiB
Python
39 lines
1.5 KiB
Python
import os, sys, time, queue, torch
|
|
from pathlib import Path
|
|
|
|
lib_dir = (Path(__file__).parent / '..' / 'lib').resolve()
|
|
if str(lib_dir) not in sys.path: sys.path.insert(0, str(lib_dir))
|
|
|
|
from log_utils import time_string
|
|
from models import CellStructure
|
|
|
|
def get_unique_matrix(archs, consider_zero):
|
|
UniquStrs = [arch.to_unique_str(consider_zero) for arch in archs]
|
|
print ('{:} create unique-string done'.format(time_string()))
|
|
sm_matrix = torch.eye(len(archs)).bool()
|
|
for i, _ in enumerate(UniquStrs):
|
|
for j in range(i):
|
|
sm_matrix[i,j] = sm_matrix[j,i] = UniquStrs[i] == UniquStrs[j]
|
|
unique_ids, unique_num = [-1 for _ in archs], 0
|
|
for i in range(len(unique_ids)):
|
|
if unique_ids[i] > -1: continue
|
|
neighbours = sm_matrix[i].nonzero().view(-1).tolist()
|
|
for nghb in neighbours:
|
|
assert unique_ids[nghb] == -1, 'impossible'
|
|
unique_ids[nghb] = unique_num
|
|
unique_num += 1
|
|
return sm_matrix, unique_ids, unique_num
|
|
|
|
def check_unique_arch():
|
|
print ('{:} start'.format(time_string()))
|
|
meta_info = torch.load('./output/AA-NAS-BENCH-4/meta-node-4.pth')
|
|
arch_strs = meta_info['archs']
|
|
archs = [CellStructure.str2structure(arch_str) for arch_str in arch_strs]
|
|
_, _, unique_num = get_unique_matrix(archs, False)
|
|
print ('{:} There are {:} unique architectures (not considering zero).'.format(time_string(), unique_num))
|
|
_, _, unique_num = get_unique_matrix(archs, True)
|
|
print ('{:} There are {:} unique architectures (considering zero).'.format(time_string(), unique_num))
|
|
|
|
if __name__ == '__main__':
|
|
check_unique_arch()
|