diff --git a/exps/trading/baselines.py b/exps/trading/baselines.py index 21efdb5..854b96e 100644 --- a/exps/trading/baselines.py +++ b/exps/trading/baselines.py @@ -13,6 +13,8 @@ # python exps/trading/baselines.py --alg LightGBM # # python exps/trading/baselines.py --alg DoubleE # # python exps/trading/baselines.py --alg TabNet # +# # +# python exps/trading/baselines.py --alg Transformer# ##################################################### import sys import argparse diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index 9f8cf8d..6cdda4b 100644 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -27,8 +27,10 @@ import torch.utils.data as th_data from log_utils import AverageMeter from utils import count_parameters -from trade_models.transformers import DEFAULT_NET_CONFIG -from trade_models.transformers import get_transformer + +from xlayers import super_core +from .transformers import DEFAULT_NET_CONFIG +from .transformers import get_transformer from qlib.model.base import Model @@ -90,6 +92,7 @@ class QuantTransformer(Model): torch.cuda.manual_seed_all(self.seed) self.model = get_transformer(self.net_config) + self.model.set_super_run_type(super_core.SuperRunMode.FullModel) self.logger.info("model: {:}".format(self.model)) self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model))) diff --git a/lib/trade_models/transformers.py b/lib/trade_models/transformers.py index 7070182..591a462 100644 --- a/lib/trade_models/transformers.py +++ b/lib/trade_models/transformers.py @@ -17,7 +17,7 @@ from xlayers import trunc_normal_ from xlayers import super_core -__all__ = ["DefaultSearchSpace"] +__all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"] def _get_mul_specs(candidates, num): @@ -41,6 +41,7 @@ def _assert_types(x, expected_types): ) +DEFAULT_NET_CONFIG = None _default_max_depth = 5 DefaultSearchSpace = dict( d_feat=6, @@ -163,7 +164,9 @@ class SuperTransformer(super_core.SuperModule): else: stem_dim = spaces.get_determined_value(self._stem_dim) cls_tokens = self.cls_token.expand(batch, -1, -1) - cls_tokens = F.interpolate(cls_tokens, size=(stem_dim), mode="linear", align_corners=True) + cls_tokens = F.interpolate( + cls_tokens, size=(stem_dim), mode="linear", align_corners=True + ) feats_w_ct = torch.cat((cls_tokens, feats), dim=1) feats_w_tp = self.pos_embed(feats_w_ct) xfeats = self.backbone(feats_w_tp)