Fix CUDA memory issues
This commit is contained in:
		| @@ -15,8 +15,8 @@ | ||||
| # python exps/trading/baselines.py --alg TabNet     # | ||||
| #                                                   # | ||||
| # python exps/trading/baselines.py --alg Transformer# | ||||
| # python exps/trading/baselines.py --alg TSF        # | ||||
| # python exps/trading/baselines.py --alg TSF-4x64-d0 | ||||
| # python exps/trading/baselines.py --alg TSF          | ||||
| # python exps/trading/baselines.py --alg TSF-4x64-drop0_0 | ||||
| ##################################################### | ||||
| import sys | ||||
| import copy | ||||
| @@ -40,10 +40,11 @@ from qlib.workflow import R | ||||
| from qlib.utils import flatten_dict | ||||
|  | ||||
|  | ||||
| def to_pos_drop(config, value): | ||||
| def to_drop(config, pos_drop, other_drop): | ||||
|     config = copy.deepcopy(config) | ||||
|     net = config["task"]["model"]["kwargs"]["net_config"] | ||||
|     net["pos_drop"] = value | ||||
|     net["pos_drop"] = pos_drop | ||||
|     net["other_drop"] = other_drop | ||||
|     return config | ||||
|  | ||||
|  | ||||
| @@ -59,11 +60,12 @@ def to_layer(config, embed_dim, depth): | ||||
| def extend_transformer_settings(alg2configs, name): | ||||
|     config = copy.deepcopy(alg2configs[name]) | ||||
|     for i in range(1, 7): | ||||
|         for j in [6, 12, 24, 32, 48, 64]: | ||||
|             for k in [0, 0.1]: | ||||
|                 alg2configs[name + "-{:}x{:}-d{:}".format(i, j, k)] = to_layer( | ||||
|                     to_pos_drop(config, k), j, i | ||||
|                 ) | ||||
|         for j in (6, 12, 24, 32, 48, 64): | ||||
|             for k1 in (0, 0.1, 0.2): | ||||
|                 for k2 in (0, 0.1): | ||||
|                     alg2configs[ | ||||
|                         name + "-{:}x{:}-drop{:}_{:}".format(i, j, k1, k2) | ||||
|                     ] = to_layer(to_drop(config, k1, k2), j, i) | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user