Add int search space
This commit is contained in:
		| @@ -65,7 +65,11 @@ def retrieve_configs(): | ||||
|         path = config_dir / name | ||||
|         assert path.exists(), "{:} does not exist.".format(path) | ||||
|         alg2paths[alg] = str(path) | ||||
|         print("The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format(idx, len(alg2names), alg, path)) | ||||
|         print( | ||||
|             "The {:02d}/{:02d}-th baseline algorithm is {:9s} ({:}).".format( | ||||
|                 idx, len(alg2names), alg, path | ||||
|             ) | ||||
|         ) | ||||
|     return alg2paths | ||||
|  | ||||
|  | ||||
| @@ -100,13 +104,30 @@ if __name__ == "__main__": | ||||
|     alg2paths = retrieve_configs() | ||||
|  | ||||
|     parser = argparse.ArgumentParser("Baselines") | ||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/qlib-baselines", help="The checkpoint directory.") | ||||
|     parser.add_argument( | ||||
|         "--market", type=str, default="all", choices=["csi100", "csi300", "all"], help="The market indicator." | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/qlib-baselines", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--market", | ||||
|         type=str, | ||||
|         default="all", | ||||
|         choices=["csi100", "csi300", "all"], | ||||
|         help="The market indicator.", | ||||
|     ) | ||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||
|     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||
|     parser.add_argument("--alg", type=str, choices=list(alg2paths.keys()), required=True, help="The algorithm name.") | ||||
|     parser.add_argument( | ||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--alg", | ||||
|         type=str, | ||||
|         choices=list(alg2paths.keys()), | ||||
|         required=True, | ||||
|         help="The algorithm name.", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     main(args, alg2paths[args.alg]) | ||||
|   | ||||
| @@ -55,7 +55,13 @@ class QResult: | ||||
|             new_dict[xkey] = values | ||||
|         return new_dict | ||||
|  | ||||
|     def info(self, keys: List[Text], separate: Text = "& ", space: int = 25, verbose: bool = True): | ||||
|     def info( | ||||
|         self, | ||||
|         keys: List[Text], | ||||
|         separate: Text = "& ", | ||||
|         space: int = 25, | ||||
|         verbose: bool = True, | ||||
|     ): | ||||
|         avaliable_keys = [] | ||||
|         for key in keys: | ||||
|             if key not in self.result: | ||||
| @@ -89,7 +95,10 @@ def compare_results(heads, values, names, space=10, verbose=True, sort_key=False | ||||
|     if verbose: | ||||
|         print(info_str_dict["head"]) | ||||
|         if sort_key: | ||||
|             lines = sorted(list(zip(values, info_str_dict["lines"])), key=lambda x: float(x[0].split(" ")[0])) | ||||
|             lines = sorted( | ||||
|                 list(zip(values, info_str_dict["lines"])), | ||||
|                 key=lambda x: float(x[0].split(" ")[0]), | ||||
|             ) | ||||
|             lines = [x[1] for x in lines] | ||||
|         else: | ||||
|             lines = info_str_dict["lines"] | ||||
| @@ -136,7 +145,11 @@ def query_info(save_dir, verbose): | ||||
|         if verbose: | ||||
|             print( | ||||
|                 "====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.".format( | ||||
|                     idx + 1, len(experiments), experiment.name, len(recorders), len(recorders) + not_finished | ||||
|                     idx + 1, | ||||
|                     len(experiments), | ||||
|                     experiment.name, | ||||
|                     len(recorders), | ||||
|                     len(recorders) + not_finished, | ||||
|                 ) | ||||
|             ) | ||||
|         result = QResult() | ||||
| @@ -149,7 +162,9 @@ def query_info(save_dir, verbose): | ||||
|         head_strs.append(head_str) | ||||
|         value_strs.append(value_str) | ||||
|         names.append(experiment.name) | ||||
|     info_str_dict = compare_results(head_strs, value_strs, names, space=10, verbose=verbose) | ||||
|     info_str_dict = compare_results( | ||||
|         head_strs, value_strs, names, space=10, verbose=verbose | ||||
|     ) | ||||
|     info_value_dict = dict(heads=head_strs, values=value_strs, names=names) | ||||
|     return info_str_dict, info_value_dict | ||||
|  | ||||
| @@ -169,9 +184,18 @@ if __name__ == "__main__": | ||||
|             raise argparse.ArgumentTypeError("Boolean value expected.") | ||||
|  | ||||
|     parser.add_argument( | ||||
|         "--save_dir", type=str, nargs="+", default=["./outputs/qlib-baselines"], help="The checkpoint directory." | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         nargs="+", | ||||
|         default=["./outputs/qlib-baselines"], | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--verbose", | ||||
|         type=str2bool, | ||||
|         default=False, | ||||
|         help="Print detailed log information or not.", | ||||
|     ) | ||||
|     parser.add_argument("--verbose", type=str2bool, default=False, help="Print detailed log information or not.") | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     print("Show results of {:}".format(args.save_dir)) | ||||
| @@ -184,4 +208,11 @@ if __name__ == "__main__": | ||||
|         _, info_dict = query_info(save_dir, args.verbose) | ||||
|         all_info_dict.append(info_dict) | ||||
|     info_dict = QResult.merge_dict(all_info_dict) | ||||
|     compare_results(info_dict["heads"], info_dict["values"], info_dict["names"], space=10, verbose=True, sort_key=True) | ||||
|     compare_results( | ||||
|         info_dict["heads"], | ||||
|         info_dict["values"], | ||||
|         info_dict["names"], | ||||
|         space=10, | ||||
|         verbose=True, | ||||
|         sort_key=True, | ||||
|     ) | ||||
|   | ||||
| @@ -39,7 +39,10 @@ def main(xargs): | ||||
|                     "fit_end_time": "2014-12-31", | ||||
|                     "instruments": xargs.market, | ||||
|                     "infer_processors": [ | ||||
|                         {"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature", "clip_outlier": True}}, | ||||
|                         { | ||||
|                             "class": "RobustZScoreNorm", | ||||
|                             "kwargs": {"fields_group": "feature", "clip_outlier": True}, | ||||
|                         }, | ||||
|                         {"class": "Fillna", "kwargs": {"fields_group": "feature"}}, | ||||
|                     ], | ||||
|                     "learn_processors": [ | ||||
| @@ -90,7 +93,11 @@ def main(xargs): | ||||
|     } | ||||
|  | ||||
|     record_config = [ | ||||
|         {"class": "SignalRecord", "module_path": "qlib.workflow.record_temp", "kwargs": dict()}, | ||||
|         { | ||||
|             "class": "SignalRecord", | ||||
|             "module_path": "qlib.workflow.record_temp", | ||||
|             "kwargs": dict(), | ||||
|         }, | ||||
|         { | ||||
|             "class": "SigAnaRecord", | ||||
|             "module_path": "qlib.workflow.record_temp", | ||||
| @@ -111,18 +118,37 @@ def main(xargs): | ||||
|     for irun in range(xargs.times): | ||||
|         xmodel_config = model_config.copy() | ||||
|         xmodel_config = update_gpu(xmodel_config, xargs.gpu) | ||||
|         task_config = dict(model=xmodel_config, dataset=dataset_config, record=record_config) | ||||
|         task_config = dict( | ||||
|             model=xmodel_config, dataset=dataset_config, record=record_config | ||||
|         ) | ||||
|  | ||||
|         run_exp(task_config, dataset, xargs.name, "recorder-{:02d}-{:02d}".format(irun, xargs.times), save_dir) | ||||
|         run_exp( | ||||
|             task_config, | ||||
|             dataset, | ||||
|             xargs.name, | ||||
|             "recorder-{:02d}-{:02d}".format(irun, xargs.times), | ||||
|             save_dir, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser("Vanilla Transformable Transformer") | ||||
|     parser.add_argument("--save_dir", type=str, default="./outputs/vtt-runs", help="The checkpoint directory.") | ||||
|     parser.add_argument("--name", type=str, default="Transformer", help="The experiment name.") | ||||
|     parser.add_argument( | ||||
|         "--save_dir", | ||||
|         type=str, | ||||
|         default="./outputs/vtt-runs", | ||||
|         help="The checkpoint directory.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--name", type=str, default="Transformer", help="The experiment name." | ||||
|     ) | ||||
|     parser.add_argument("--times", type=int, default=10, help="The repeated run times.") | ||||
|     parser.add_argument("--gpu", type=int, default=0, help="The GPU ID used for train / test.") | ||||
|     parser.add_argument("--market", type=str, default="all", help="The market indicator.") | ||||
|     parser.add_argument( | ||||
|         "--gpu", type=int, default=0, help="The GPU ID used for train / test." | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--market", type=str, default="all", help="The market indicator." | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     main(args) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user