diff --git a/lib/trade_models/quant_transformer.py b/lib/trade_models/quant_transformer.py index 4f5cdd3..fec729e 100644 --- a/lib/trade_models/quant_transformer.py +++ b/lib/trade_models/quant_transformer.py @@ -298,7 +298,7 @@ class QuantTransformer(Model): results_dict=results_dict, start_epoch=iepoch + 1, ) - torch.save(save_info, ckp_path, map_location="cpu") + torch.save(save_info, ckp_path) self.logger.info( "The best score: {:.6f} @ {:02d}-th epoch".format(best_score, best_epoch) )