Update models
This commit is contained in:
parent
e637cddc39
commit
b51320dfb1
@ -1 +1 @@
|
|||||||
Subproject commit 253378a44e88a9fcff17d23b589e2d4832f587aa
|
Subproject commit 968930e85f4958d16dfc2c5740c02f5c91745b97
|
@ -105,11 +105,11 @@ to download this repo with submodules.
|
|||||||
|
|
||||||
If you find that this project helps your research, please consider citing the related paper:
|
If you find that this project helps your research, please consider citing the related paper:
|
||||||
```
|
```
|
||||||
@article{dong2020autohas,
|
@inproceedings{dong2021autohas,
|
||||||
title={{AutoHAS}: Efficient Hyperparameter and Architecture Search},
|
title={{AutoHAS}: Efficient Hyperparameter and Architecture Search},
|
||||||
author={Dong, Xuanyi and Tan, Mingxing and Yu, Adams Wei and Peng, Daiyi and Gabrys, Bogdan and Le, Quoc V},
|
author={Dong, Xuanyi and Tan, Mingxing and Yu, Adams Wei and Peng, Daiyi and Gabrys, Bogdan and Le, Quoc V},
|
||||||
journal={arXiv preprint arXiv:2006.03656},
|
booktitle = {International Conference on Learning Representations (ICLR) Workshop on Neural Architecture Search},
|
||||||
year={2020}
|
year={2021}
|
||||||
}
|
}
|
||||||
@article{dong2021nats,
|
@article{dong2021nats,
|
||||||
title = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
title = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
||||||
|
@ -62,13 +62,13 @@
|
|||||||
<tr> <!-- (6-th row) -->
|
<tr> <!-- (6-th row) -->
|
||||||
<td align="center" valign="middle"> NATS-Bench </td>
|
<td align="center" valign="middle"> NATS-Bench </td>
|
||||||
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
|
<td align="center" valign="middle"> <a href="https://xuanyidong.com/assets/projects/NATS-Bench"> NATS-Bench: Benchmarking NAS Algorithms for Architecture Topology and Size</a> </td>
|
||||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NATS-Bench.md">NATS-Bench.md</a> </td>
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr> <!-- (7-th row) -->
|
<tr> <!-- (7-th row) -->
|
||||||
<td align="center" valign="middle"> ... </td>
|
<td align="center" valign="middle"> ... </td>
|
||||||
<td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td>
|
<td align="center" valign="middle"> ENAS / REA / REINFORCE / BOHB </td>
|
||||||
<td align="center" valign="middle"> Please check the original papers. </td>
|
<td align="center" valign="middle"> Please check the original papers. </td>
|
||||||
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NATS-Bench.md">NATS-Bench.md</a> </td>
|
<td align="center" valign="middle"> <a href="https://github.com/D-X-Y/AutoDL-Projects/tree/main/docs/NAS-Bench-201.md">NAS-Bench-201.md</a> <a href="https://github.com/D-X-Y/NATS-Bench">NATS-Bench.md</a> </td>
|
||||||
</tr>
|
</tr>
|
||||||
<tr> <!-- (start second block) -->
|
<tr> <!-- (start second block) -->
|
||||||
<td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td>
|
<td rowspan="1" align="center" valign="middle" halign="middle"> HPO </td>
|
||||||
@ -98,6 +98,12 @@ Some methods use knowledge distillation (KD), which require pre-trained models.
|
|||||||
|
|
||||||
如果您发现该项目对您的科研或工程有帮助,请考虑引用下列的某些文献:
|
如果您发现该项目对您的科研或工程有帮助,请考虑引用下列的某些文献:
|
||||||
```
|
```
|
||||||
|
@inproceedings{dong2021autohas,
|
||||||
|
title={{AutoHAS}: Efficient Hyperparameter and Architecture Search},
|
||||||
|
author={Dong, Xuanyi and Tan, Mingxing and Yu, Adams Wei and Peng, Daiyi and Gabrys, Bogdan and Le, Quoc V},
|
||||||
|
booktitle = {International Conference on Learning Representations (ICLR) Workshop on Neural Architecture Search},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
@article{dong2021nats,
|
@article{dong2021nats,
|
||||||
title = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
title = {{NATS-Bench}: Benchmarking NAS Algorithms for Architecture Topology and Size},
|
||||||
author = {Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
|
author = {Dong, Xuanyi and Liu, Lu and Musial, Katarzyna and Gabrys, Bogdan},
|
||||||
|
@ -67,14 +67,26 @@ def extend_transformer_settings(alg2configs, name):
|
|||||||
return alg2configs
|
return alg2configs
|
||||||
|
|
||||||
|
|
||||||
def remove_PortAnaRecord(alg2configs):
|
def refresh_record(alg2configs):
|
||||||
alg2configs = copy.deepcopy(alg2configs)
|
alg2configs = copy.deepcopy(alg2configs)
|
||||||
for key, config in alg2configs.items():
|
for key, config in alg2configs.items():
|
||||||
xlist = config["task"]["record"]
|
xlist = config["task"]["record"]
|
||||||
new_list = []
|
new_list = []
|
||||||
for x in xlist:
|
for x in xlist:
|
||||||
if x["class"] != "PortAnaRecord":
|
# remove PortAnaRecord and SignalMseRecord
|
||||||
|
if x["class"] != "PortAnaRecord" and x["class"] != "SignalMseRecord":
|
||||||
new_list.append(x)
|
new_list.append(x)
|
||||||
|
## add MultiSegRecord
|
||||||
|
new_list.append(
|
||||||
|
{
|
||||||
|
"class": "MultiSegRecord",
|
||||||
|
"module_path": "qlib.contrib.workflow",
|
||||||
|
"generate_kwargs": {
|
||||||
|
"segments": {"train": "train", "valid": "valid", "test": "test"},
|
||||||
|
"save": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
config["task"]["record"] = new_list
|
config["task"]["record"] = new_list
|
||||||
return alg2configs
|
return alg2configs
|
||||||
|
|
||||||
@ -117,7 +129,7 @@ def retrieve_configs():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
alg2configs = extend_transformer_settings(alg2configs, "TSF")
|
||||||
alg2configs = remove_PortAnaRecord(alg2configs)
|
alg2configs = refresh_record(alg2configs)
|
||||||
print(
|
print(
|
||||||
"There are {:} algorithms : {:}".format(
|
"There are {:} algorithms : {:}".format(
|
||||||
len(alg2configs), list(alg2configs.keys())
|
len(alg2configs), list(alg2configs.keys())
|
||||||
|
@ -99,7 +99,12 @@ def run_exp(
|
|||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
try:
|
try:
|
||||||
model = R.load_object(model_obj_name)
|
if hasattr(model, "to"): # Recoverable model
|
||||||
|
device = model.device
|
||||||
|
model = R.load_object(model_obj_name)
|
||||||
|
model.to(device)
|
||||||
|
else:
|
||||||
|
model = R.load_object(model_obj_name)
|
||||||
logger.info("[Find existing object from {:}]".format(model_obj_name))
|
logger.info("[Find existing object from {:}]".format(model_obj_name))
|
||||||
except OSError:
|
except OSError:
|
||||||
R.log_params(**flatten_dict(task_config))
|
R.log_params(**flatten_dict(task_config))
|
||||||
@ -112,16 +117,29 @@ def run_exp(
|
|||||||
recorder_root_dir, "model-ckps"
|
recorder_root_dir, "model-ckps"
|
||||||
)
|
)
|
||||||
model.fit(**model_fit_kwargs)
|
model.fit(**model_fit_kwargs)
|
||||||
R.save_objects(**{model_obj_name: model})
|
# remove model to CPU for saving
|
||||||
except:
|
if hasattr(model, "to"):
|
||||||
raise ValueError("Something wrong.")
|
model.to("cpu")
|
||||||
|
R.save_objects(**{model_obj_name: model})
|
||||||
|
model.to()
|
||||||
|
else:
|
||||||
|
R.save_objects(**{model_obj_name: model})
|
||||||
|
except Exception as e:
|
||||||
|
import pdb
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
|
raise ValueError("Something wrong: {:}".format(e))
|
||||||
# Get the recorder
|
# Get the recorder
|
||||||
recorder = R.get_recorder()
|
recorder = R.get_recorder()
|
||||||
|
|
||||||
# Generate records: prediction, backtest, and analysis
|
# Generate records: prediction, backtest, and analysis
|
||||||
for record in task_config["record"]:
|
for record in task_config["record"]:
|
||||||
record = deepcopy(record)
|
record = deepcopy(record)
|
||||||
if record["class"] == "SignalRecord":
|
if record["class"] == "MultiSegRecord":
|
||||||
|
record["kwargs"] = dict(model=model, dataset=dataset, recorder=recorder)
|
||||||
|
sr = init_instance_by_config(record)
|
||||||
|
sr.generate(**record["generate_kwargs"])
|
||||||
|
elif record["class"] == "SignalRecord":
|
||||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||||
record["kwargs"].update(srconf)
|
record["kwargs"].update(srconf)
|
||||||
sr = init_instance_by_config(record)
|
sr = init_instance_by_config(record)
|
||||||
|
@ -112,6 +112,12 @@ class QuantTransformer(Model):
|
|||||||
def use_gpu(self):
|
def use_gpu(self):
|
||||||
return self.device != torch.device("cpu")
|
return self.device != torch.device("cpu")
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
if device is None:
|
||||||
|
self.model.to(self.device)
|
||||||
|
else:
|
||||||
|
self.model.to("cpu")
|
||||||
|
|
||||||
def loss_fn(self, pred, label):
|
def loss_fn(self, pred, label):
|
||||||
mask = ~torch.isnan(label)
|
mask = ~torch.isnan(label)
|
||||||
if self.opt_config["loss"] == "mse":
|
if self.opt_config["loss"] == "mse":
|
||||||
|
29
scripts/trade/tsf-all.sh
Normal file
29
scripts/trade/tsf-all.sh
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# bash scripts/trade/tsf-all.sh 0 csi300 0
|
||||||
|
# bash scripts/trade/tsf-all.sh 0 csi300 0.1
|
||||||
|
# bash scripts/trade/tsf-all.sh 1 all
|
||||||
|
#
|
||||||
|
set -e
|
||||||
|
echo script name: $0
|
||||||
|
echo $# arguments
|
||||||
|
|
||||||
|
if [ "$#" -ne 3 ] ;then
|
||||||
|
echo "Input illegal number of parameters " $#
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
gpu=$1
|
||||||
|
market=$2
|
||||||
|
drop=$3
|
||||||
|
|
||||||
|
channels="6 12 24 32 48 64"
|
||||||
|
depths="1 2 3 4 5 6"
|
||||||
|
|
||||||
|
for channel in ${channels}
|
||||||
|
do
|
||||||
|
for depth in ${depths}
|
||||||
|
do
|
||||||
|
python exps/trading/baselines.py --alg TSF-${depth}x${channel}-d${drop} --gpu ${gpu} --market ${market}
|
||||||
|
done
|
||||||
|
done
|
Loading…
Reference in New Issue
Block a user