114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
#####################################################
|
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
|
#####################################################
|
|
import torch
|
|
|
|
from itertools import islice
|
|
import operator
|
|
|
|
from collections import OrderedDict
|
|
from typing import Optional, Union, Callable, TypeVar, Iterator
|
|
|
|
import spaces
|
|
from .super_module import SuperModule
|
|
|
|
|
|
T = TypeVar("T", bound=SuperModule)
|
|
|
|
|
|
class SuperSequential(SuperModule):
|
|
"""A sequential container wrapped with 'Super' ability.
|
|
|
|
Modules will be added to it in the order they are passed in the constructor.
|
|
Alternatively, an ordered dict of modules can also be passed in.
|
|
To make it easier to understand, here is a small example::
|
|
# Example of using Sequential
|
|
model = SuperSequential(
|
|
nn.Conv2d(1,20,5),
|
|
nn.ReLU(),
|
|
nn.Conv2d(20,64,5),
|
|
nn.ReLU()
|
|
)
|
|
# Example of using Sequential with OrderedDict
|
|
model = nn.Sequential(OrderedDict([
|
|
('conv1', nn.Conv2d(1,20,5)),
|
|
('relu1', nn.ReLU()),
|
|
('conv2', nn.Conv2d(20,64,5)),
|
|
('relu2', nn.ReLU())
|
|
]))
|
|
"""
|
|
|
|
def __init__(self, *args):
|
|
super(SuperSequential, self).__init__()
|
|
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
|
for key, module in args[0].items():
|
|
self.add_module(key, module)
|
|
else:
|
|
if not isinstance(args, (list, tuple)):
|
|
raise ValueError("Invalid input type: {:}".format(type(args)))
|
|
for idx, module in enumerate(args):
|
|
self.add_module(str(idx), module)
|
|
|
|
def _get_item_by_idx(self, iterator, idx) -> T:
|
|
"""Get the idx-th item of the iterator"""
|
|
size = len(self)
|
|
idx = operator.index(idx)
|
|
if not -size <= idx < size:
|
|
raise IndexError("index {} is out of range".format(idx))
|
|
idx %= size
|
|
return next(islice(iterator, idx, None))
|
|
|
|
def __getitem__(self, idx) -> Union["SuperSequential", T]:
|
|
if isinstance(idx, slice):
|
|
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
|
else:
|
|
return self._get_item_by_idx(self._modules.values(), idx)
|
|
|
|
def __setitem__(self, idx: int, module: SuperModule) -> None:
|
|
key: str = self._get_item_by_idx(self._modules.keys(), idx)
|
|
return setattr(self, key, module)
|
|
|
|
def __delitem__(self, idx: Union[slice, int]) -> None:
|
|
if isinstance(idx, slice):
|
|
for key in list(self._modules.keys())[idx]:
|
|
delattr(self, key)
|
|
else:
|
|
key = self._get_item_by_idx(self._modules.keys(), idx)
|
|
delattr(self, key)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._modules)
|
|
|
|
def __dir__(self):
|
|
keys = super(SuperSequential, self).__dir__()
|
|
keys = [key for key in keys if not key.isdigit()]
|
|
return keys
|
|
|
|
def __iter__(self) -> Iterator[SuperModule]:
|
|
return iter(self._modules.values())
|
|
|
|
@property
|
|
def abstract_search_space(self):
|
|
root_node = spaces.VirtualNode(id(self))
|
|
for index, module in enumerate(self):
|
|
if not isinstance(module, SuperModule):
|
|
continue
|
|
space = module.abstract_search_space
|
|
if not spaces.is_determined(space):
|
|
root_node.append(str(index), space)
|
|
return root_node
|
|
|
|
def apply_candidate(self, abstract_child: spaces.VirtualNode):
|
|
super(SuperSequential, self).apply_candidate(abstract_child)
|
|
for index, module in enumerate(self):
|
|
if str(index) in abstract_child:
|
|
module.apply_candidate(abstract_child[str(index)])
|
|
|
|
def forward_candidate(self, input):
|
|
return self.forward_raw(input)
|
|
|
|
def forward_raw(self, input):
|
|
for module in self:
|
|
input = module(input)
|
|
return input
|