Fix black errors
This commit is contained in:
		| @@ -13,6 +13,8 @@ | ||||
| # python exps/trading/baselines.py --alg LightGBM   # | ||||
| # python exps/trading/baselines.py --alg DoubleE    # | ||||
| # python exps/trading/baselines.py --alg TabNet     # | ||||
| #                                                   # | ||||
| # python exps/trading/baselines.py --alg Transformer# | ||||
| ##################################################### | ||||
| import sys | ||||
| import argparse | ||||
|   | ||||
| @@ -27,8 +27,10 @@ import torch.utils.data as th_data | ||||
|  | ||||
| from log_utils import AverageMeter | ||||
| from utils import count_parameters | ||||
| from trade_models.transformers import DEFAULT_NET_CONFIG | ||||
| from trade_models.transformers import get_transformer | ||||
|  | ||||
| from xlayers import super_core | ||||
| from .transformers import DEFAULT_NET_CONFIG | ||||
| from .transformers import get_transformer | ||||
|  | ||||
|  | ||||
| from qlib.model.base import Model | ||||
| @@ -90,6 +92,7 @@ class QuantTransformer(Model): | ||||
|                 torch.cuda.manual_seed_all(self.seed) | ||||
|  | ||||
|         self.model = get_transformer(self.net_config) | ||||
|         self.model.set_super_run_type(super_core.SuperRunMode.FullModel) | ||||
|         self.logger.info("model: {:}".format(self.model)) | ||||
|         self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model))) | ||||
|  | ||||
|   | ||||
| @@ -17,7 +17,7 @@ from xlayers import trunc_normal_ | ||||
| from xlayers import super_core | ||||
|  | ||||
|  | ||||
| __all__ = ["DefaultSearchSpace"] | ||||
| __all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] | ||||
|  | ||||
|  | ||||
| def _get_mul_specs(candidates, num): | ||||
| @@ -41,6 +41,7 @@ def _assert_types(x, expected_types): | ||||
|         ) | ||||
|  | ||||
|  | ||||
| DEFAULT_NET_CONFIG = None | ||||
| _default_max_depth = 5 | ||||
| DefaultSearchSpace = dict( | ||||
|     d_feat=6, | ||||
| @@ -163,7 +164,9 @@ class SuperTransformer(super_core.SuperModule): | ||||
|         else: | ||||
|             stem_dim = spaces.get_determined_value(self._stem_dim) | ||||
|         cls_tokens = self.cls_token.expand(batch, -1, -1) | ||||
|         cls_tokens = F.interpolate(cls_tokens, size=(stem_dim), mode="linear", align_corners=True) | ||||
|         cls_tokens = F.interpolate( | ||||
|             cls_tokens, size=(stem_dim), mode="linear", align_corners=True | ||||
|         ) | ||||
|         feats_w_ct = torch.cat((cls_tokens, feats), dim=1) | ||||
|         feats_w_tp = self.pos_embed(feats_w_ct) | ||||
|         xfeats = self.backbone(feats_w_tp) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user