39 lines
1.8 KiB
Python
39 lines
1.8 KiB
Python
|
###########################################################################################
|
||
|
# Copyright (c) Hayeon Lee, Eunyoung Hyung [GitHub MetaD2A], 2021
|
||
|
# Rapid Neural Architecture Search by Learning to Generate Graphs from Datasets, ICLR 2021
|
||
|
###########################################################################################
|
||
|
from .setenc_modules import *
|
||
|
|
||
|
|
||
|
class SetPool(nn.Module):
|
||
|
def __init__(self, dim_input, num_outputs, dim_output,
|
||
|
num_inds=32, dim_hidden=128, num_heads=4, ln=False, mode=None):
|
||
|
super(SetPool, self).__init__()
|
||
|
if 'sab' in mode: # [32, 400, 128]
|
||
|
self.enc = nn.Sequential(
|
||
|
SAB(dim_input, dim_hidden, num_heads, ln=ln), # SAB?
|
||
|
SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
|
||
|
else: # [32, 400, 128]
|
||
|
self.enc = nn.Sequential(
|
||
|
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln), # SAB?
|
||
|
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
|
||
|
if 'PF' in mode: # [32, 1, 501]
|
||
|
self.dec = nn.Sequential(
|
||
|
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
|
||
|
nn.Linear(dim_hidden, dim_output))
|
||
|
elif 'P' in mode:
|
||
|
self.dec = nn.Sequential(
|
||
|
PMA(dim_hidden, num_heads, num_outputs, ln=ln))
|
||
|
else: # torch.Size([32, 1, 501])
|
||
|
self.dec = nn.Sequential(
|
||
|
PMA(dim_hidden, num_heads, num_outputs, ln=ln), # 32 1 128
|
||
|
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||
|
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
|
||
|
nn.Linear(dim_hidden, dim_output))
|
||
|
# "", sm, sab, sabsm
|
||
|
|
||
|
def forward(self, X):
|
||
|
x1 = self.enc(X)
|
||
|
x2 = self.dec(x1)
|
||
|
return x2
|