Fix black errors
This commit is contained in:
		| @@ -13,6 +13,8 @@ | |||||||
| # python exps/trading/baselines.py --alg LightGBM   # | # python exps/trading/baselines.py --alg LightGBM   # | ||||||
| # python exps/trading/baselines.py --alg DoubleE    # | # python exps/trading/baselines.py --alg DoubleE    # | ||||||
| # python exps/trading/baselines.py --alg TabNet     # | # python exps/trading/baselines.py --alg TabNet     # | ||||||
|  | #                                                   # | ||||||
|  | # python exps/trading/baselines.py --alg Transformer# | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys | import sys | ||||||
| import argparse | import argparse | ||||||
|   | |||||||
| @@ -27,8 +27,10 @@ import torch.utils.data as th_data | |||||||
|  |  | ||||||
| from log_utils import AverageMeter | from log_utils import AverageMeter | ||||||
| from utils import count_parameters | from utils import count_parameters | ||||||
| from trade_models.transformers import DEFAULT_NET_CONFIG |  | ||||||
| from trade_models.transformers import get_transformer | from xlayers import super_core | ||||||
|  | from .transformers import DEFAULT_NET_CONFIG | ||||||
|  | from .transformers import get_transformer | ||||||
|  |  | ||||||
|  |  | ||||||
| from qlib.model.base import Model | from qlib.model.base import Model | ||||||
| @@ -90,6 +92,7 @@ class QuantTransformer(Model): | |||||||
|                 torch.cuda.manual_seed_all(self.seed) |                 torch.cuda.manual_seed_all(self.seed) | ||||||
|  |  | ||||||
|         self.model = get_transformer(self.net_config) |         self.model = get_transformer(self.net_config) | ||||||
|  |         self.model.set_super_run_type(super_core.SuperRunMode.FullModel) | ||||||
|         self.logger.info("model: {:}".format(self.model)) |         self.logger.info("model: {:}".format(self.model)) | ||||||
|         self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model))) |         self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model))) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -17,7 +17,7 @@ from xlayers import trunc_normal_ | |||||||
| from xlayers import super_core | from xlayers import super_core | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["DefaultSearchSpace"] | __all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_mul_specs(candidates, num): | def _get_mul_specs(candidates, num): | ||||||
| @@ -41,6 +41,7 @@ def _assert_types(x, expected_types): | |||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | DEFAULT_NET_CONFIG = None | ||||||
| _default_max_depth = 5 | _default_max_depth = 5 | ||||||
| DefaultSearchSpace = dict( | DefaultSearchSpace = dict( | ||||||
|     d_feat=6, |     d_feat=6, | ||||||
| @@ -163,7 +164,9 @@ class SuperTransformer(super_core.SuperModule): | |||||||
|         else: |         else: | ||||||
|             stem_dim = spaces.get_determined_value(self._stem_dim) |             stem_dim = spaces.get_determined_value(self._stem_dim) | ||||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) |         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||||
|         cls_tokens = F.interpolate(cls_tokens, size=(stem_dim), mode="linear", align_corners=True) |         cls_tokens = F.interpolate( | ||||||
|  |             cls_tokens, size=(stem_dim), mode="linear", align_corners=True | ||||||
|  |         ) | ||||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) |         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||||
|         feats_w_tp = self.pos_embed(feats_w_ct) |         feats_w_tp = self.pos_embed(feats_w_ct) | ||||||
|         xfeats = self.backbone(feats_w_tp) |         xfeats = self.backbone(feats_w_tp) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user