Update tests for torch/cuda
This commit is contained in:
		| @@ -141,26 +141,25 @@ def retrieve_configs(): | ||||
|     return alg2configs | ||||
|  | ||||
|  | ||||
| def main(xargs, config): | ||||
| def main(alg_name, market, config, times, save_dir, gpu): | ||||
|  | ||||
|     pprint("Run {:}".format(xargs.alg)) | ||||
|     config = update_market(config, xargs.market) | ||||
|     config = update_gpu(config, xargs.gpu) | ||||
|     pprint("Run {:}".format(alg_name)) | ||||
|     config = update_market(config, market) | ||||
|     config = update_gpu(config, gpu) | ||||
|  | ||||
|     qlib.init(**config.get("qlib_init")) | ||||
|     dataset_config = config.get("task").get("dataset") | ||||
|     dataset = init_instance_by_config(dataset_config) | ||||
|     pprint("args: {:}".format(xargs)) | ||||
|     pprint(dataset_config) | ||||
|     pprint(dataset) | ||||
|  | ||||
|     for irun in range(xargs.times): | ||||
|     for irun in range(times): | ||||
|         run_exp( | ||||
|             config.get("task"), | ||||
|             dataset, | ||||
|             xargs.alg, | ||||
|             "recorder-{:02d}-{:02d}".format(irun, xargs.times), | ||||
|             "{:}-{:}".format(xargs.save_dir, xargs.market), | ||||
|             alg_name, | ||||
|             "recorder-{:02d}-{:02d}".format(irun, times), | ||||
|             "{:}-{:}".format(save_dir, market), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| @@ -203,6 +202,13 @@ if __name__ == "__main__": | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     if len(args.alg) == 1: | ||||
|         main(args, alg2configs[args.alg[0]]) | ||||
|         main( | ||||
|             args.alg[0], | ||||
|             args.market, | ||||
|             alg2configs[args.alg[0]], | ||||
|             args.times, | ||||
|             args.save_dir, | ||||
|             args.gpu, | ||||
|         ) | ||||
|     else: | ||||
|         print("-") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user