xautodl/tests/test_super_transformer.py
2021-05-19 13:13:11 +08:00

37 lines
1.3 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
# pytest ./tests/test_super_model.py -s #
#####################################################
import unittest
import torch
from xautodl.xlayers.super_core import SuperRunMode
from xautodl.trade_models import get_transformer
class TestSuperTransformer(unittest.TestCase):
"""Test the super transformer."""
def test_super_transformer(self):
model = get_transformer(None)
model.apply_verbose(False)
print(model)
inputs = torch.rand(10, 360)
print("Input shape: {:}".format(inputs.shape))
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), (10,))
abstract_space = model.abstract_search_space
abstract_space.clean_last()
abstract_child = abstract_space.random(reuse_last=True)
print("The abstract searc space:\n{:}".format(abstract_space))
print("The abstract child program:\n{:}".format(abstract_child))
model.set_super_run_type(SuperRunMode.Candidate)
model.apply_candidate(abstract_child)
outputs = model(inputs)
self.assertEqual(tuple(outputs.shape), (10,))