xautodl/lib/xlayers/super_container.py

114 lines
3.9 KiB
Python

#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
#####################################################
import torch
from itertools import islice
import operator
from collections import OrderedDict
from typing import Optional, Union, Callable, TypeVar, Iterator
import spaces
from .super_module import SuperModule
T = TypeVar("T", bound=SuperModule)
class SuperSequential(SuperModule):
"""A sequential container wrapped with 'Super' ability.
Modules will be added to it in the order they are passed in the constructor.
Alternatively, an ordered dict of modules can also be passed in.
To make it easier to understand, here is a small example::
# Example of using Sequential
model = SuperSequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Example of using Sequential with OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
"""
def __init__(self, *args):
super(SuperSequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
if not isinstance(args, (list, tuple)):
raise ValueError("Invalid input type: {:}".format(type(args)))
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T:
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError("index {} is out of range".format(idx))
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx) -> Union["SuperSequential", T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx: int, module: SuperModule) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
def __len__(self) -> int:
return len(self._modules)
def __dir__(self):
keys = super(SuperSequential, self).__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def __iter__(self) -> Iterator[SuperModule]:
return iter(self._modules.values())
@property
def abstract_search_space(self):
root_node = spaces.VirtualNode(id(self))
for index, module in enumerate(self):
if not isinstance(module, SuperModule):
continue
space = module.abstract_search_space
if not spaces.is_determined(space):
root_node.append(str(index), space)
return root_node
def apply_candidate(self, abstract_child: spaces.VirtualNode):
super(SuperSequential, self).apply_candidate(abstract_child)
for index, module in enumerate(self):
if str(index) in abstract_child:
module.apply_candidate(abstract_child[str(index)])
def forward_candidate(self, input):
return self.forward_raw(input)
def forward_raw(self, input):
for module in self:
input = module(input)
return input