Fix black errors
This commit is contained in:
parent
b8c173eb76
commit
1acd1e9f9b
@ -13,6 +13,8 @@
|
|||||||
# python exps/trading/baselines.py --alg LightGBM #
|
# python exps/trading/baselines.py --alg LightGBM #
|
||||||
# python exps/trading/baselines.py --alg DoubleE #
|
# python exps/trading/baselines.py --alg DoubleE #
|
||||||
# python exps/trading/baselines.py --alg TabNet #
|
# python exps/trading/baselines.py --alg TabNet #
|
||||||
|
# #
|
||||||
|
# python exps/trading/baselines.py --alg Transformer#
|
||||||
#####################################################
|
#####################################################
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -27,8 +27,10 @@ import torch.utils.data as th_data
|
|||||||
|
|
||||||
from log_utils import AverageMeter
|
from log_utils import AverageMeter
|
||||||
from utils import count_parameters
|
from utils import count_parameters
|
||||||
from trade_models.transformers import DEFAULT_NET_CONFIG
|
|
||||||
from trade_models.transformers import get_transformer
|
from xlayers import super_core
|
||||||
|
from .transformers import DEFAULT_NET_CONFIG
|
||||||
|
from .transformers import get_transformer
|
||||||
|
|
||||||
|
|
||||||
from qlib.model.base import Model
|
from qlib.model.base import Model
|
||||||
@ -90,6 +92,7 @@ class QuantTransformer(Model):
|
|||||||
torch.cuda.manual_seed_all(self.seed)
|
torch.cuda.manual_seed_all(self.seed)
|
||||||
|
|
||||||
self.model = get_transformer(self.net_config)
|
self.model = get_transformer(self.net_config)
|
||||||
|
self.model.set_super_run_type(super_core.SuperRunMode.FullModel)
|
||||||
self.logger.info("model: {:}".format(self.model))
|
self.logger.info("model: {:}".format(self.model))
|
||||||
self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model)))
|
self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model)))
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from xlayers import trunc_normal_
|
|||||||
from xlayers import super_core
|
from xlayers import super_core
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["DefaultSearchSpace"]
|
__all__ = ["DefaultSearchSpace", "DEFAULT_NET_CONFIG", "get_transformer"]
|
||||||
|
|
||||||
|
|
||||||
def _get_mul_specs(candidates, num):
|
def _get_mul_specs(candidates, num):
|
||||||
@ -41,6 +41,7 @@ def _assert_types(x, expected_types):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_NET_CONFIG = None
|
||||||
_default_max_depth = 5
|
_default_max_depth = 5
|
||||||
DefaultSearchSpace = dict(
|
DefaultSearchSpace = dict(
|
||||||
d_feat=6,
|
d_feat=6,
|
||||||
@ -163,7 +164,9 @@ class SuperTransformer(super_core.SuperModule):
|
|||||||
else:
|
else:
|
||||||
stem_dim = spaces.get_determined_value(self._stem_dim)
|
stem_dim = spaces.get_determined_value(self._stem_dim)
|
||||||
cls_tokens = self.cls_token.expand(batch, -1, -1)
|
cls_tokens = self.cls_token.expand(batch, -1, -1)
|
||||||
cls_tokens = F.interpolate(cls_tokens, size=(stem_dim), mode="linear", align_corners=True)
|
cls_tokens = F.interpolate(
|
||||||
|
cls_tokens, size=(stem_dim), mode="linear", align_corners=True
|
||||||
|
)
|
||||||
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
|
feats_w_ct = torch.cat((cls_tokens, feats), dim=1)
|
||||||
feats_w_tp = self.pos_embed(feats_w_ct)
|
feats_w_tp = self.pos_embed(feats_w_ct)
|
||||||
xfeats = self.backbone(feats_w_tp)
|
xfeats = self.backbone(feats_w_tp)
|
||||||
|
Loading…
Reference in New Issue
Block a user