Reformulate Q-Transformer
This commit is contained in:
		| @@ -4,7 +4,7 @@ | |||||||
| # Refer to: | # Refer to: | ||||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.ipynb | ||||||
| # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | # - https://github.com/microsoft/qlib/blob/main/examples/workflow_by_code.py | ||||||
| # python exps/trading/workflow_tt.py --market all --gpu 1 | # python exps/trading/workflow_tt.py --gpu 1 --market csi300 | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, argparse | import sys, argparse | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -63,7 +63,8 @@ def main(xargs): | |||||||
|         "class": "QuantTransformer", |         "class": "QuantTransformer", | ||||||
|         "module_path": "trade_models", |         "module_path": "trade_models", | ||||||
|         "kwargs": { |         "kwargs": { | ||||||
|             "loss": "mse", |             "net_config": None, | ||||||
|  |             "opt_config": None, | ||||||
|             "GPU": "0", |             "GPU": "0", | ||||||
|             "metric": "loss", |             "metric": "loss", | ||||||
|         }, |         }, | ||||||
| @@ -107,20 +108,23 @@ def main(xargs): | |||||||
|     provider_uri = "~/.qlib/qlib_data/cn_data" |     provider_uri = "~/.qlib/qlib_data/cn_data" | ||||||
|     qlib.init(provider_uri=provider_uri, region=REG_CN) |     qlib.init(provider_uri=provider_uri, region=REG_CN) | ||||||
|  |  | ||||||
|  |     save_dir = "{:}-{:}".format(xargs.save_dir, xargs.market) | ||||||
|     dataset = init_instance_by_config(dataset_config) |     dataset = init_instance_by_config(dataset_config) | ||||||
|     for irun in range(xargs.times): |     for irun in range(xargs.times): | ||||||
|         xmodel_config = model_config.copy() |         xmodel_config = model_config.copy() | ||||||
|         xmodel_config = update_gpu(xmodel_config, xags.gpu) |         xmodel_config = update_gpu(xmodel_config, xargs.gpu) | ||||||
|         task = dict(model=xmodel_config, dataset=dataset_config, record=record_config) |         task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config) | ||||||
|         run_exp(task_config, dataset, "Transformer", "recorder-{:02d}-{:02d}".format(irun, xargs.times), xargs.save_dir) |  | ||||||
|  |         run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") |     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") | ||||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/tt-ml-runs", help="The checkpoint directory.") |     parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.") | ||||||
|  |     parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.") | ||||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") |     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||||
|     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") |     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||||
|     parser.add_argument("--market", type=str, default="csi300", help="The market indicator.") |     parser.add_argument("--market", type=str, default="all", help="The market indicator.") | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     main(args) |     main(args) | ||||||
|   | |||||||
| @@ -25,6 +25,7 @@ import torch | |||||||
| import torch.nn as nn | import torch.nn as nn | ||||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||||
| import torch.optim as optim | import torch.optim as optim | ||||||
|  | import torch.utils.data as th_data | ||||||
|  |  | ||||||
| import layers as xlayers | import layers as xlayers | ||||||
| from utils import count_parameters | from utils import count_parameters | ||||||
| @@ -34,78 +35,38 @@ from qlib.data.dataset import DatasetH | |||||||
| from qlib.data.dataset.handler import DataHandlerLP | from qlib.data.dataset.handler import DataHandlerLP | ||||||
|  |  | ||||||
|  |  | ||||||
|  | default_net_config = dict(d_feat=6, hidden_size=48, depth=5, pos_drop=0.1) | ||||||
|  |  | ||||||
|  | default_opt_config = dict(epochs=200, lr=0.001, batch_size=2000, early_stop=20, loss="mse", optimizer="adam") | ||||||
|  |  | ||||||
|  |  | ||||||
| class QuantTransformer(Model): | class QuantTransformer(Model): | ||||||
|   """Transformer-based Quant Model |     """Transformer-based Quant Model""" | ||||||
|  |  | ||||||
|   """ |     def __init__(self, net_config=None, opt_config=None, metric="", GPU=0, seed=None, **kwargs): | ||||||
|  |  | ||||||
|   def __init__( |  | ||||||
|     self, |  | ||||||
|     d_feat=6, |  | ||||||
|     hidden_size=48, |  | ||||||
|     depth=5, |  | ||||||
|     pos_dropout=0.1, |  | ||||||
|     n_epochs=200, |  | ||||||
|     lr=0.001, |  | ||||||
|     metric="", |  | ||||||
|     batch_size=2000, |  | ||||||
|     early_stop=20, |  | ||||||
|     loss="mse", |  | ||||||
|     optimizer="adam", |  | ||||||
|     GPU=0, |  | ||||||
|     seed=None, |  | ||||||
|     **kwargs |  | ||||||
|   ): |  | ||||||
|         # Set logger. |         # Set logger. | ||||||
|         self.logger = get_module_logger("QuantTransformer") |         self.logger = get_module_logger("QuantTransformer") | ||||||
|         self.logger.info("QuantTransformer pytorch version...") |         self.logger.info("QuantTransformer pytorch version...") | ||||||
|  |  | ||||||
|         # set hyper-parameters. |         # set hyper-parameters. | ||||||
|     self.d_feat = d_feat |         self.net_config = net_config or default_net_config | ||||||
|     self.hidden_size = hidden_size |         self.opt_config = opt_config or default_opt_config | ||||||
|     self.depth = depth |  | ||||||
|     self.pos_dropout = pos_dropout |  | ||||||
|     self.n_epochs = n_epochs |  | ||||||
|     self.lr = lr |  | ||||||
|         self.metric = metric |         self.metric = metric | ||||||
|     self.batch_size = batch_size |         self.device = torch.device("cuda:{:}".format(GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") | ||||||
|     self.early_stop = early_stop |  | ||||||
|     self.optimizer = optimizer.lower() |  | ||||||
|     self.loss = loss |  | ||||||
|     self.device = torch.device("cuda:{:}".format(GPU) if torch.cuda.is_available() else "cpu") |  | ||||||
|     self.use_gpu = torch.cuda.is_available() |  | ||||||
|         self.seed = seed |         self.seed = seed | ||||||
|  |  | ||||||
|         self.logger.info( |         self.logger.info( | ||||||
|             "Transformer parameters setting:" |             "Transformer parameters setting:" | ||||||
|       "\nd_feat : {}" |             "\nnet_config : {:}" | ||||||
|       "\nhidden_size : {}" |             "\nopt_config : {:}" | ||||||
|       "\ndepth : {}" |             "\nmetric     : {:}" | ||||||
|       "\ndropout : {}" |             "\ndevice     : {:}" | ||||||
|       "\nn_epochs : {}" |             "\nseed       : {:}".format( | ||||||
|       "\nlr : {}" |                 self.net_config, | ||||||
|       "\nmetric : {}" |                 self.opt_config, | ||||||
|       "\nbatch_size : {}" |                 self.metric, | ||||||
|       "\nearly_stop : {}" |                 self.device, | ||||||
|       "\noptimizer : {}" |                 self.seed, | ||||||
|       "\nloss_type : {}" |  | ||||||
|       "\nvisible_GPU : {}" |  | ||||||
|       "\nuse_GPU : {}" |  | ||||||
|       "\nseed : {}".format( |  | ||||||
|         d_feat, |  | ||||||
|         hidden_size, |  | ||||||
|         depth, |  | ||||||
|         pos_dropout, |  | ||||||
|         n_epochs, |  | ||||||
|         lr, |  | ||||||
|         metric, |  | ||||||
|         batch_size, |  | ||||||
|         early_stop, |  | ||||||
|         optimizer.lower(), |  | ||||||
|         loss, |  | ||||||
|         GPU, |  | ||||||
|         self.use_gpu, |  | ||||||
|         seed, |  | ||||||
|             ) |             ) | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
| @@ -113,94 +74,76 @@ class QuantTransformer(Model): | |||||||
|             np.random.seed(self.seed) |             np.random.seed(self.seed) | ||||||
|             torch.manual_seed(self.seed) |             torch.manual_seed(self.seed) | ||||||
|  |  | ||||||
|     self.model = TransformerModel(d_feat=self.d_feat, |         self.model = TransformerModel( | ||||||
|                                   embed_dim=self.hidden_size, |             d_feat=self.net_config["d_feat"], | ||||||
|                                   depth=self.depth, |             embed_dim=self.net_config["hidden_size"], | ||||||
|                                   pos_dropout=pos_dropout) |             depth=self.net_config["depth"], | ||||||
|     self.logger.info('model: {:}'.format(self.model)) |             pos_drop=self.net_config["pos_drop"], | ||||||
|     self.logger.info('model size: {:.3f} MB'.format(count_parameters(self.model))) |         ) | ||||||
|  |         self.logger.info("model: {:}".format(self.model)) | ||||||
|  |         self.logger.info("model size: {:.3f} MB".format(count_parameters(self.model))) | ||||||
|  |  | ||||||
|      |         if self.opt_config["optimizer"] == "adam": | ||||||
|     if optimizer.lower() == "adam": |             self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.opt_config["lr"]) | ||||||
|       self.train_optimizer = optim.Adam(self.model.parameters(), lr=self.lr) |         elif self.opt_config["optimizer"] == "adam": | ||||||
|     elif optimizer.lower() == "gd": |             self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.opt_config["lr"]) | ||||||
|       self.train_optimizer = optim.SGD(self.model.parameters(), lr=self.lr) |  | ||||||
|         else: |         else: | ||||||
|             raise NotImplementedError("optimizer {:} is not supported!".format(optimizer)) |             raise NotImplementedError("optimizer {:} is not supported!".format(optimizer)) | ||||||
|  |  | ||||||
|         self.fitted = False |         self.fitted = False | ||||||
|         self.model.to(self.device) |         self.model.to(self.device) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def use_gpu(self): | ||||||
|  |         self.device == torch.device("cpu") | ||||||
|  |  | ||||||
|     def loss_fn(self, pred, label): |     def loss_fn(self, pred, label): | ||||||
|         mask = ~torch.isnan(label) |         mask = ~torch.isnan(label) | ||||||
|     if self.loss == "mse": |         if self.opt_config["loss"] == "mse": | ||||||
|             return F.mse_loss(pred[mask], label[mask]) |             return F.mse_loss(pred[mask], label[mask]) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("unknown loss `{:}`".format(self.loss)) |             raise ValueError("unknown loss `{:}`".format(self.loss)) | ||||||
|  |  | ||||||
|     def metric_fn(self, pred, label): |     def metric_fn(self, pred, label): | ||||||
|  |  | ||||||
|         mask = torch.isfinite(label) |         mask = torch.isfinite(label) | ||||||
|  |  | ||||||
|         if self.metric == "" or self.metric == "loss": |         if self.metric == "" or self.metric == "loss": | ||||||
|             return -self.loss_fn(pred[mask], label[mask]) |             return -self.loss_fn(pred[mask], label[mask]) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("unknown metric `{:}`".format(self.metric)) |             raise ValueError("unknown metric `{:}`".format(self.metric)) | ||||||
|  |  | ||||||
|   def train_epoch(self, x_train, y_train): |     def train_epoch(self, xloader, model, loss_fn, optimizer): | ||||||
|  |         model.train() | ||||||
|     x_train_values = x_train.values |         scores, losses = [], [] | ||||||
|     y_train_values = np.squeeze(y_train.values) |         for ibatch, (feats, labels) in enumerate(xloader): | ||||||
|  |             feats = feats.to(self.device, non_blocking=True) | ||||||
|     self.model.train() |             labels = labels.to(self.device, non_blocking=True) | ||||||
|  |             # forward the network | ||||||
|     indices = np.arange(len(x_train_values)) |             preds = model(feats) | ||||||
|     np.random.shuffle(indices) |             loss = loss_fn(preds, labels) | ||||||
|  |             with torch.no_grad(): | ||||||
|     for i in range(len(indices))[:: self.batch_size]: |                 score = self.metric_fn(preds, labels) | ||||||
|  |  | ||||||
|       if len(indices) - i < self.batch_size: |  | ||||||
|         break |  | ||||||
|  |  | ||||||
|       feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device) |  | ||||||
|       label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device) |  | ||||||
|  |  | ||||||
|       pred = self.model(feature) |  | ||||||
|       loss = self.loss_fn(pred, label) |  | ||||||
|  |  | ||||||
|       self.train_optimizer.zero_grad() |  | ||||||
|       loss.backward() |  | ||||||
|       torch.nn.utils.clip_grad_value_(self.model.parameters(), 3.0) |  | ||||||
|       self.train_optimizer.step() |  | ||||||
|  |  | ||||||
|   def test_epoch(self, data_x, data_y): |  | ||||||
|  |  | ||||||
|     # prepare training data |  | ||||||
|     x_values = data_x.values |  | ||||||
|     y_values = np.squeeze(data_y.values) |  | ||||||
|  |  | ||||||
|     self.model.eval() |  | ||||||
|  |  | ||||||
|     scores = [] |  | ||||||
|     losses = [] |  | ||||||
|  |  | ||||||
|     indices = np.arange(len(x_values)) |  | ||||||
|  |  | ||||||
|     for i in range(len(indices))[:: self.batch_size]: |  | ||||||
|  |  | ||||||
|       if len(indices) - i < self.batch_size: |  | ||||||
|         break |  | ||||||
|  |  | ||||||
|       feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device) |  | ||||||
|       label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device) |  | ||||||
|  |  | ||||||
|       pred = self.model(feature) |  | ||||||
|       loss = self.loss_fn(pred, label) |  | ||||||
|                 losses.append(loss.item()) |                 losses.append(loss.item()) | ||||||
|  |                 scores.append(loss.item()) | ||||||
|  |             # optimize the network | ||||||
|  |             optimizer.zero_grad() | ||||||
|  |             loss.backward() | ||||||
|  |             torch.nn.utils.clip_grad_value_(model.parameters(), 3.0) | ||||||
|  |             optimizer.step() | ||||||
|  |         return np.mean(losses), np.mean(scores) | ||||||
|  |  | ||||||
|       score = self.metric_fn(pred, label) |     def test_epoch(self, xloader, model, loss_fn, metric_fn): | ||||||
|       scores.append(score.item()) |         model.eval() | ||||||
|  |         scores, losses = [], [] | ||||||
|  |         with torch.no_grad(): | ||||||
|  |             for ibatch, (feats, labels) in enumerate(xloader): | ||||||
|  |                 feats = feats.to(self.device, non_blocking=True) | ||||||
|  |                 labels = labels.to(self.device, non_blocking=True) | ||||||
|  |                 # forward the network | ||||||
|  |                 preds = model(feats) | ||||||
|  |                 loss = loss_fn(preds, labels) | ||||||
|  |                 score = self.metric_fn(preds, labels) | ||||||
|  |                 losses.append(loss.item()) | ||||||
|  |                 scores.append(loss.item()) | ||||||
|         return np.mean(losses), np.mean(scores) |         return np.mean(losses), np.mean(scores) | ||||||
|  |  | ||||||
|     def fit( |     def fit( | ||||||
| @@ -210,57 +153,81 @@ class QuantTransformer(Model): | |||||||
|         verbose=True, |         verbose=True, | ||||||
|         save_path=None, |         save_path=None, | ||||||
|     ): |     ): | ||||||
|  |         def _prepare_dataset(df_data): | ||||||
|  |             return th_data.TensorDataset( | ||||||
|  |                 torch.from_numpy(df_data["feature"].values).float(), | ||||||
|  |                 torch.from_numpy(df_data["label"].values).squeeze().float(), | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         df_train, df_valid, df_test = dataset.prepare( |         df_train, df_valid, df_test = dataset.prepare( | ||||||
|             ["train", "valid", "test"], |             ["train", "valid", "test"], | ||||||
|             col_set=["feature", "label"], |             col_set=["feature", "label"], | ||||||
|             data_key=DataHandlerLP.DK_L, |             data_key=DataHandlerLP.DK_L, | ||||||
|         ) |         ) | ||||||
|  |         train_dataset, valid_dataset, test_dataset = ( | ||||||
|  |             _prepare_dataset(df_train), | ||||||
|  |             _prepare_dataset(df_valid), | ||||||
|  |             _prepare_dataset(df_test), | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     x_train, y_train = df_train["feature"], df_train["label"] |         train_loader = th_data.DataLoader( | ||||||
|     x_valid, y_valid = df_valid["feature"], df_valid["label"] |             train_dataset, batch_size=self.opt_config["batch_size"], shuffle=True, drop_last=False, pin_memory=True | ||||||
|  |         ) | ||||||
|  |         valid_loader = th_data.DataLoader( | ||||||
|  |             valid_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True | ||||||
|  |         ) | ||||||
|  |         test_loader = th_data.DataLoader( | ||||||
|  |             test_dataset, batch_size=self.opt_config["batch_size"], shuffle=False, drop_last=False, pin_memory=True | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         if save_path == None: |         if save_path == None: | ||||||
|             save_path = create_save_path(save_path) |             save_path = create_save_path(save_path) | ||||||
|     stop_steps = 0 |         stop_steps, best_score, best_epoch = 0, -np.inf, 0 | ||||||
|         train_loss = 0 |         train_loss = 0 | ||||||
|     best_score = -np.inf |  | ||||||
|     best_epoch = 0 |  | ||||||
|         evals_result["train"] = [] |         evals_result["train"] = [] | ||||||
|         evals_result["valid"] = [] |         evals_result["valid"] = [] | ||||||
|  |  | ||||||
|         # train |         # train | ||||||
|     self.logger.info("training...") |         self.logger.info("Fit procedure for [{:}] with save path={:}".format(self.__class__.__name__, save_path)) | ||||||
|     self.fitted = True |  | ||||||
|  |  | ||||||
|     for step in range(self.n_epochs): |         def _internal_test(): | ||||||
|       self.logger.info("Epoch%d:", step) |             train_loss, train_score = self.test_epoch(train_loader, self.model, self.loss_fn, self.metric_fn) | ||||||
|       self.logger.info("training...") |             valid_loss, valid_score = self.test_epoch(valid_loader, self.model, self.loss_fn, self.metric_fn) | ||||||
|       self.train_epoch(x_train, y_train) |             test_loss, test_score = self.test_epoch(test_loader, self.model, self.loss_fn, self.metric_fn) | ||||||
|       self.logger.info("evaluating...") |             xstr = "train-score={:.6f}, valid-score={:.6f}, test-score={:.6f}".format( | ||||||
|       train_loss, train_score = self.test_epoch(x_train, y_train) |                 train_score, valid_score, test_score | ||||||
|       val_loss, val_score = self.test_epoch(x_valid, y_valid) |             ) | ||||||
|       self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) |             return dict(train=train_score, valid=valid_score, test=test_score), xstr | ||||||
|       evals_result["train"].append(train_score) |  | ||||||
|       evals_result["valid"].append(val_score) |  | ||||||
|  |  | ||||||
|       if val_score > best_score: |         _, eval_str = _internal_test() | ||||||
|         best_score = val_score |         self.logger.info("Before Training: {:}".format(eval_str)) | ||||||
|         stop_steps = 0 |         for iepoch in range(self.opt_config["epochs"]): | ||||||
|         best_epoch = step |             self.logger.info("Epoch={:03d}/{:03d} ::==>>".format(iepoch, self.opt_config["epochs"])) | ||||||
|  |  | ||||||
|  |             train_loss, train_score = self.train_epoch(train_loader, self.model, self.loss_fn, self.train_optimizer) | ||||||
|  |             self.logger.info("Training :: loss={:.6f}, score={:.6f}".format(train_loss, train_score)) | ||||||
|  |  | ||||||
|  |             eval_score_dict, eval_str = _internal_test() | ||||||
|  |             self.logger.info("Evaluating :: {:}".format(eval_str)) | ||||||
|  |             evals_result["train"].append(eval_score_dict["train"]) | ||||||
|  |             evals_result["valid"].append(eval_score_dict["valid"]) | ||||||
|  |  | ||||||
|  |             if eval_score_dict["valid"] > best_score: | ||||||
|  |                 stop_steps, best_epoch, best_score = 0, iepoch, eval_score_dict["valid"] | ||||||
|                 best_param = copy.deepcopy(self.model.state_dict()) |                 best_param = copy.deepcopy(self.model.state_dict()) | ||||||
|             else: |             else: | ||||||
|                 stop_steps += 1 |                 stop_steps += 1 | ||||||
|         if stop_steps >= self.early_stop: |                 if stop_steps >= self.opt_config["early_stop"]: | ||||||
|           self.logger.info("early stop") |                     self.logger.info("early stop at {:}-th epoch, where the best is @{:}".format(iepoch, best_epoch)) | ||||||
|                     break |                     break | ||||||
|  |  | ||||||
|     self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) |         self.logger.info("The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch)) | ||||||
|         self.model.load_state_dict(best_param) |         self.model.load_state_dict(best_param) | ||||||
|         torch.save(best_param, save_path) |         torch.save(best_param, save_path) | ||||||
|  |  | ||||||
|         if self.use_gpu: |         if self.use_gpu: | ||||||
|             torch.cuda.empty_cache() |             torch.cuda.empty_cache() | ||||||
|  |         self.fitted = True | ||||||
|  |  | ||||||
|     def predict(self, dataset): |     def predict(self, dataset): | ||||||
|  |  | ||||||
| @@ -298,8 +265,7 @@ class QuantTransformer(Model): | |||||||
|  |  | ||||||
|  |  | ||||||
| class Attention(nn.Module): | class Attention(nn.Module): | ||||||
|  |     def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0): | ||||||
|   def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |  | ||||||
|         super(Attention, self).__init__() |         super(Attention, self).__init__() | ||||||
|         self.num_heads = num_heads |         self.num_heads = num_heads | ||||||
|         head_dim = dim // num_heads |         head_dim = dim // num_heads | ||||||
| @@ -327,16 +293,26 @@ class Attention(nn.Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class Block(nn.Module): | class Block(nn.Module): | ||||||
|  |     def __init__( | ||||||
|   def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, |         self, | ||||||
|                attn_drop=0., mlp_drop=0., drop_path=0., |         dim, | ||||||
|                act_layer=nn.GELU, norm_layer=nn.LayerNorm): |         num_heads, | ||||||
|  |         mlp_ratio=4.0, | ||||||
|  |         qkv_bias=False, | ||||||
|  |         qk_scale=None, | ||||||
|  |         attn_drop=0.0, | ||||||
|  |         mlp_drop=0.0, | ||||||
|  |         drop_path=0.0, | ||||||
|  |         act_layer=nn.GELU, | ||||||
|  |         norm_layer=nn.LayerNorm, | ||||||
|  |     ): | ||||||
|         super(Block, self).__init__() |         super(Block, self).__init__() | ||||||
|         self.norm1 = norm_layer(dim) |         self.norm1 = norm_layer(dim) | ||||||
|         self.attn = Attention( |         self.attn = Attention( | ||||||
|       dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=mlp_drop) |             dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=mlp_drop | ||||||
|  |         ) | ||||||
|         # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |         # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | ||||||
|     self.drop_path = xlayers.DropPath(drop_path) if drop_path > 0. else nn.Identity() |         self.drop_path = xlayers.DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | ||||||
|         self.norm2 = norm_layer(dim) |         self.norm2 = norm_layer(dim) | ||||||
|         mlp_hidden_dim = int(dim * mlp_ratio) |         mlp_hidden_dim = int(dim * mlp_ratio) | ||||||
|         self.mlp = xlayers.MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=mlp_drop) |         self.mlp = xlayers.MLP(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=mlp_drop) | ||||||
| @@ -348,7 +324,6 @@ class Block(nn.Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class SimpleEmbed(nn.Module): | class SimpleEmbed(nn.Module): | ||||||
|  |  | ||||||
|     def __init__(self, d_feat, embed_dim): |     def __init__(self, d_feat, embed_dim): | ||||||
|         super(SimpleEmbed, self).__init__() |         super(SimpleEmbed, self).__init__() | ||||||
|         self.d_feat = d_feat |         self.d_feat = d_feat | ||||||
| @@ -363,16 +338,21 @@ class SimpleEmbed(nn.Module): | |||||||
|  |  | ||||||
|  |  | ||||||
| class TransformerModel(nn.Module): | class TransformerModel(nn.Module): | ||||||
|  |     def __init__( | ||||||
|   def __init__(self, |         self, | ||||||
|         d_feat: int, |         d_feat: int, | ||||||
|         embed_dim: int = 64, |         embed_dim: int = 64, | ||||||
|         depth: int = 4, |         depth: int = 4, | ||||||
|         num_heads: int = 4, |         num_heads: int = 4, | ||||||
|          mlp_ratio: float = 4., |         mlp_ratio: float = 4.0, | ||||||
|         qkv_bias: bool = True, |         qkv_bias: bool = True, | ||||||
|         qk_scale: Optional[float] = None, |         qk_scale: Optional[float] = None, | ||||||
|          pos_dropout=0., mlp_drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None): |         pos_drop=0.0, | ||||||
|  |         mlp_drop_rate=0.0, | ||||||
|  |         attn_drop_rate=0.0, | ||||||
|  |         drop_path_rate=0.0, | ||||||
|  |         norm_layer=None, | ||||||
|  |     ): | ||||||
|         """ |         """ | ||||||
|         Args: |         Args: | ||||||
|           d_feat (int, tuple): input image size |           d_feat (int, tuple): input image size | ||||||
| @@ -382,7 +362,7 @@ class TransformerModel(nn.Module): | |||||||
|           mlp_ratio (int): ratio of mlp hidden dim to embedding dim |           mlp_ratio (int): ratio of mlp hidden dim to embedding dim | ||||||
|           qkv_bias (bool): enable bias for qkv if True |           qkv_bias (bool): enable bias for qkv if True | ||||||
|           qk_scale (float): override default qk scale of head_dim ** -0.5 if set |           qk_scale (float): override default qk scale of head_dim ** -0.5 if set | ||||||
|       pos_dropout (float): dropout rate for the positional embedding |           pos_drop (float): dropout rate for the positional embedding | ||||||
|           mlp_drop_rate (float): the dropout rate for MLP layers in a block |           mlp_drop_rate (float): the dropout rate for MLP layers in a block | ||||||
|           attn_drop_rate (float): attention dropout rate |           attn_drop_rate (float): attention dropout rate | ||||||
|           drop_path_rate (float): stochastic depth rate |           drop_path_rate (float): stochastic depth rate | ||||||
| @@ -396,25 +376,36 @@ class TransformerModel(nn.Module): | |||||||
|         self.input_embed = SimpleEmbed(d_feat, embed_dim=embed_dim) |         self.input_embed = SimpleEmbed(d_feat, embed_dim=embed_dim) | ||||||
|  |  | ||||||
|         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | ||||||
|     self.pos_embed = xlayers.PositionalEncoder(d_model=embed_dim, max_seq_len=65, dropout=pos_dropout) |         self.pos_embed = xlayers.PositionalEncoder(d_model=embed_dim, max_seq_len=65, dropout=pos_drop) | ||||||
|  |  | ||||||
|         dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule |         dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule | ||||||
|     self.blocks = nn.ModuleList([ |         self.blocks = nn.ModuleList( | ||||||
|  |             [ | ||||||
|                 Block( |                 Block( | ||||||
|         dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |                     dim=embed_dim, | ||||||
|         attn_drop=attn_drop_rate, mlp_drop=mlp_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) |                     num_heads=num_heads, | ||||||
|       for i in range(depth)]) |                     mlp_ratio=mlp_ratio, | ||||||
|  |                     qkv_bias=qkv_bias, | ||||||
|  |                     qk_scale=qk_scale, | ||||||
|  |                     attn_drop=attn_drop_rate, | ||||||
|  |                     mlp_drop=mlp_drop_rate, | ||||||
|  |                     drop_path=dpr[i], | ||||||
|  |                     norm_layer=norm_layer, | ||||||
|  |                 ) | ||||||
|  |                 for i in range(depth) | ||||||
|  |             ] | ||||||
|  |         ) | ||||||
|         self.norm = norm_layer(embed_dim) |         self.norm = norm_layer(embed_dim) | ||||||
|  |  | ||||||
|         # regression head |         # regression head | ||||||
|         self.head = nn.Linear(self.num_features, 1) |         self.head = nn.Linear(self.num_features, 1) | ||||||
|  |  | ||||||
|     xlayers.trunc_normal_(self.cls_token, std=.02) |         xlayers.trunc_normal_(self.cls_token, std=0.02) | ||||||
|         self.apply(self._init_weights) |         self.apply(self._init_weights) | ||||||
|  |  | ||||||
|     def _init_weights(self, m): |     def _init_weights(self, m): | ||||||
|         if isinstance(m, nn.Linear): |         if isinstance(m, nn.Linear): | ||||||
|       xlayers.trunc_normal_(m.weight, std=.02) |             xlayers.trunc_normal_(m.weight, std=0.02) | ||||||
|             if isinstance(m, nn.Linear) and m.bias is not None: |             if isinstance(m, nn.Linear) and m.bias is not None: | ||||||
|                 nn.init.constant_(m.bias, 0) |                 nn.init.constant_(m.bias, 0) | ||||||
|         elif isinstance(m, nn.LayerNorm): |         elif isinstance(m, nn.LayerNorm): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user