Add unit tests for super-linear
This commit is contained in:
		
							
								
								
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -32,7 +32,7 @@ jobs: | |||||||
|           echo $PWD ; ls |           echo $PWD ; ls | ||||||
|           python -m black ./exps -l 88 --check --diff --verbose |           python -m black ./exps -l 88 --check --diff --verbose | ||||||
|           python -m black ./tests -l 88 --check --diff --verbose |           python -m black ./tests -l 88 --check --diff --verbose | ||||||
|           python -m black ./lib/layers -l 88 --check --diff --verbose |           python -m black ./lib/xlayers -l 88 --check --diff --verbose | ||||||
|           python -m black ./lib/spaces -l 88 --check --diff --verbose |           python -m black ./lib/spaces -l 88 --check --diff --verbose | ||||||
|           python -m black ./lib/trade_models -l 88 --check --diff --verbose |           python -m black ./lib/trade_models -l 88 --check --diff --verbose | ||||||
|  |  | ||||||
|   | |||||||
| @@ -11,5 +11,6 @@ from .basic_space import Space | |||||||
| from .basic_space import VirtualNode | from .basic_space import VirtualNode | ||||||
| from .basic_op import has_categorical | from .basic_op import has_categorical | ||||||
| from .basic_op import has_continuous | from .basic_op import has_continuous | ||||||
|  | from .basic_op import is_determined | ||||||
| from .basic_op import get_min | from .basic_op import get_min | ||||||
| from .basic_op import get_max | from .basic_op import get_max | ||||||
|   | |||||||
| @@ -19,6 +19,13 @@ def has_continuous(space_or_value, x): | |||||||
|         return abs(space_or_value - x) <= _EPS |         return abs(space_or_value - x) <= _EPS | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def is_determined(space_or_value): | ||||||
|  |     if isinstance(space_or_value, Space): | ||||||
|  |         return space_or_value.determined | ||||||
|  |     else: | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_max(space_or_value): | def get_max(space_or_value): | ||||||
|     if isinstance(space_or_value, Integer): |     if isinstance(space_or_value, Integer): | ||||||
|         return max(space_or_value.candidates) |         return max(space_or_value.candidates) | ||||||
|   | |||||||
| @@ -30,10 +30,13 @@ class Space(metaclass=abc.ABCMeta): | |||||||
|     def determined(self): |     def determined(self): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractproperty | ||||||
|     def __repr__(self): |     def xrepr(self, indent=0): | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return self.xrepr() | ||||||
|  |  | ||||||
|     @abc.abstractmethod |     @abc.abstractmethod | ||||||
|     def has(self, x): |     def has(self, x): | ||||||
|         """Check whether x is in this search space.""" |         """Check whether x is in this search space.""" | ||||||
| @@ -58,15 +61,28 @@ class VirtualNode(Space): | |||||||
|  |  | ||||||
|     def has(self, x): |     def has(self, x): | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             if isinstance(value, Space) and value.has(x): |             if value.has(x): | ||||||
|                 return True |                 return True | ||||||
|         return False |         return False | ||||||
|  |  | ||||||
|     def __repr__(self): |     def append(self, key, value): | ||||||
|         strs = [self.__class__.__name__ + "("] |         if not isinstance(value, Space): | ||||||
|         indent = " " * 4 |             raise ValueError("Invalid type of value: {:}".format(type(value))) | ||||||
|  |         self._attributes[key] = value | ||||||
|  |  | ||||||
|  |     def determined(self): | ||||||
|         for key, value in self._attributes.items(): |         for key, value in self._attributes.items(): | ||||||
|             strs.append(indent + strs(value)) |             if not value.determined(x): | ||||||
|  |                 return False | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def random(self, recursion=True): | ||||||
|  |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |     def xrepr(self, indent=0): | ||||||
|  |         strs = [self.__class__.__name__ + "("] | ||||||
|  |         for key, value in self._attributes.items(): | ||||||
|  |             strs.append(value.xrepr(indent + 2) + ",") | ||||||
|         strs.append(")") |         strs.append(")") | ||||||
|         return "\n".join(strs) |         return "\n".join(strs) | ||||||
|  |  | ||||||
| @@ -104,10 +120,11 @@ class Categorical(Space): | |||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._candidates) |         return len(self._candidates) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def xrepr(self, indent=0): | ||||||
|         return "{name:}(candidates={cs:}, default_index={default:})".format( |         xrepr = "{name:}(candidates={cs:}, default_index={default:})".format( | ||||||
|             name=self.__class__.__name__, cs=self._candidates, default=self._default |             name=self.__class__.__name__, cs=self._candidates, default=self._default | ||||||
|         ) |         ) | ||||||
|  |         return " " * indent + xrepr | ||||||
|  |  | ||||||
|     def has(self, x): |     def has(self, x): | ||||||
|         super().has(x) |         super().has(x) | ||||||
| @@ -143,13 +160,14 @@ class Integer(Categorical): | |||||||
|             default = data.index(default) |             default = data.index(default) | ||||||
|         super(Integer, self).__init__(*data, default=default) |         super(Integer, self).__init__(*data, default=default) | ||||||
|  |  | ||||||
|     def __repr__(self): |     def xrepr(self, indent=0): | ||||||
|         return "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( |         xrepr = "{name:}(lower={lower:}, upper={upper:}, default={default:})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             lower=self._raw_lower, |             lower=self._raw_lower, | ||||||
|             upper=self._raw_upper, |             upper=self._raw_upper, | ||||||
|             default=self._raw_default, |             default=self._raw_default, | ||||||
|         ) |         ) | ||||||
|  |         return " " * indent + xrepr | ||||||
|  |  | ||||||
|  |  | ||||||
| np_float_types = (np.float16, np.float32, np.float64) | np_float_types = (np.float16, np.float32, np.float64) | ||||||
| @@ -198,14 +216,15 @@ class Continuous(Space): | |||||||
|     def determined(self): |     def determined(self): | ||||||
|         return abs(self.lower - self.upper) <= self._eps |         return abs(self.lower - self.upper) <= self._eps | ||||||
|  |  | ||||||
|     def __repr__(self): |     def xrepr(self, indent=0): | ||||||
|         return "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( |         xrepr = "{name:}(lower={lower:}, upper={upper:}, default_value={default:}, log_scale={log:})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             lower=self._lower, |             lower=self._lower, | ||||||
|             upper=self._upper, |             upper=self._upper, | ||||||
|             default=self._default, |             default=self._default, | ||||||
|             log=self._log_scale, |             log=self._log_scale, | ||||||
|         ) |         ) | ||||||
|  |         return " " * indent + xrepr | ||||||
|  |  | ||||||
|     def convert(self, x): |     def convert(self, x): | ||||||
|         if isinstance(x, np_float_types) and x.size == 1: |         if isinstance(x, np_float_types) and x.size == 1: | ||||||
|   | |||||||
| @@ -30,7 +30,7 @@ class SuperLinear(SuperModule): | |||||||
|         self._in_features = in_features |         self._in_features = in_features | ||||||
|         self._out_features = out_features |         self._out_features = out_features | ||||||
|         self._bias = bias |         self._bias = bias | ||||||
|  |         # weights to be optimized | ||||||
|         self._super_weight = torch.nn.Parameter( |         self._super_weight = torch.nn.Parameter( | ||||||
|             torch.Tensor(self.out_features, self.in_features) |             torch.Tensor(self.out_features, self.in_features) | ||||||
|         ) |         ) | ||||||
| @@ -53,7 +53,14 @@ class SuperLinear(SuperModule): | |||||||
|         return spaces.has_categorical(self._bias, True) |         return spaces.has_categorical(self._bias, True) | ||||||
|  |  | ||||||
|     def abstract_search_space(self): |     def abstract_search_space(self): | ||||||
|         print('-') |         root_node = spaces.VirtualNode(id(self)) | ||||||
|  |         if not spaces.is_determined(self._in_features): | ||||||
|  |           root_node.append("_in_features", self._in_features) | ||||||
|  |         if not spaces.is_determined(self._out_features): | ||||||
|  |           root_node.append("_out_features", self._out_features) | ||||||
|  |         if not spaces.is_determined(self._bias): | ||||||
|  |           root_node.append("_bias", self._bias) | ||||||
|  |         return root_node | ||||||
|  |  | ||||||
|     def reset_parameters(self) -> None: |     def reset_parameters(self) -> None: | ||||||
|         nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) |         nn.init.kaiming_uniform_(self._super_weight, a=math.sqrt(5)) | ||||||
|   | |||||||
| @@ -14,7 +14,7 @@ if str(lib_dir) not in sys.path: | |||||||
| from spaces import Categorical | from spaces import Categorical | ||||||
| from spaces import Continuous | from spaces import Continuous | ||||||
| from spaces import Integer | from spaces import Integer | ||||||
| from spaces import Integer | from spaces import is_determined | ||||||
| from spaces import get_min | from spaces import get_min | ||||||
| from spaces import get_max | from spaces import get_max | ||||||
|  |  | ||||||
| @@ -92,3 +92,7 @@ class TestBasicSpace(unittest.TestCase): | |||||||
|         print(nested_space) |         print(nested_space) | ||||||
|         for i in range(1, 13): |         for i in range(1, 13): | ||||||
|             self.assertTrue(nested_space.has(i)) |             self.assertTrue(nested_space.has(i)) | ||||||
|  |  | ||||||
|  |         # Test Simple Op | ||||||
|  |         self.assertTrue(is_determined(1)) | ||||||
|  |         self.assertFalse(is_determined(nested_space)) | ||||||
|   | |||||||
| @@ -26,3 +26,5 @@ class TestSuperLinear(unittest.TestCase): | |||||||
|         bias = spaces.Categorical(True, False) |         bias = spaces.Categorical(True, False) | ||||||
|         model = super_core.SuperLinear(10, out_features, bias=bias) |         model = super_core.SuperLinear(10, out_features, bias=bias) | ||||||
|         print(model) |         print(model) | ||||||
|  |         print(model.super_run_type) | ||||||
|  |         print(model.abstract_search_space()) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user