Update tests for torch/cuda
This commit is contained in:
		
							
								
								
									
										1
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.github/workflows/basic_test.yml
									
									
									
									
										vendored
									
									
								
							| @@ -37,6 +37,7 @@ jobs: | |||||||
|           python -m black ./lib/trade_models -l 88 --check --diff --verbose |           python -m black ./lib/trade_models -l 88 --check --diff --verbose | ||||||
|           python -m black ./lib/procedures -l 88 --check --diff --verbose |           python -m black ./lib/procedures -l 88 --check --diff --verbose | ||||||
|           python -m black ./lib/config_utils -l 88 --check --diff --verbose |           python -m black ./lib/config_utils -l 88 --check --diff --verbose | ||||||
|  |           python -m black ./lib/log_utils -l 88 --check --diff --verbose | ||||||
|  |  | ||||||
|       - name: Test Search Space |       - name: Test Search Space | ||||||
|         run: | |         run: | | ||||||
|   | |||||||
 Submodule .latent-data/qlib updated: 968930e85f...70c84cbc77
									
								
							| @@ -141,26 +141,25 @@ def retrieve_configs(): | |||||||
|     return alg2configs |     return alg2configs | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(xargs, config): | def main(alg_name, market, config, times, save_dir, gpu): | ||||||
|  |  | ||||||
|     pprint("Run {:}".format(xargs.alg)) |     pprint("Run {:}".format(alg_name)) | ||||||
|     config = update_market(config, xargs.market) |     config = update_market(config, market) | ||||||
|     config = update_gpu(config, xargs.gpu) |     config = update_gpu(config, gpu) | ||||||
|  |  | ||||||
|     qlib.init(**config.get("qlib_init")) |     qlib.init(**config.get("qlib_init")) | ||||||
|     dataset_config = config.get("task").get("dataset") |     dataset_config = config.get("task").get("dataset") | ||||||
|     dataset = init_instance_by_config(dataset_config) |     dataset = init_instance_by_config(dataset_config) | ||||||
|     pprint("args: {:}".format(xargs)) |  | ||||||
|     pprint(dataset_config) |     pprint(dataset_config) | ||||||
|     pprint(dataset) |     pprint(dataset) | ||||||
|  |  | ||||||
|     for irun in range(xargs.times): |     for irun in range(times): | ||||||
|         run_exp( |         run_exp( | ||||||
|             config.get("task"), |             config.get("task"), | ||||||
|             dataset, |             dataset, | ||||||
|             xargs.alg, |             alg_name, | ||||||
|             "recorder-{:02d}-{:02d}".format(irun, xargs.times), |             "recorder-{:02d}-{:02d}".format(irun, times), | ||||||
|             "{:}-{:}".format(xargs.save_dir, xargs.market), |             "{:}-{:}".format(save_dir, market), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -203,6 +202,13 @@ if __name__ == "__main__": | |||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     if len(args.alg) == 1: |     if len(args.alg) == 1: | ||||||
|         main(args, alg2configs[args.alg[0]]) |         main( | ||||||
|  |             args.alg[0], | ||||||
|  |             args.market, | ||||||
|  |             alg2configs[args.alg[0]], | ||||||
|  |             args.times, | ||||||
|  |             args.save_dir, | ||||||
|  |             args.gpu, | ||||||
|  |         ) | ||||||
|     else: |     else: | ||||||
|         print("-") |         print("-") | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # general config related functions | # general config related functions | ||||||
| from .config_utils import load_config, dict2config, configure2str | from .config_utils import load_config, dict2config, configure2str | ||||||
|  |  | ||||||
| # the args setting for different experiments | # the args setting for different experiments | ||||||
| from .basic_args import obtain_basic_args | from .basic_args import obtain_basic_args | ||||||
| from .attention_args import obtain_attention_args | from .attention_args import obtain_attention_args | ||||||
|   | |||||||
| @@ -3,6 +3,14 @@ | |||||||
| ################################################## | ################################################## | ||||||
| # every package does not rely on pytorch or tensorflow | # every package does not rely on pytorch or tensorflow | ||||||
| # I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib | # I tried to list all dependency here: os, sys, time, numpy, (possibly) matplotlib | ||||||
|  | ################################################## | ||||||
| from .logger import Logger, PrintLogger | from .logger import Logger, PrintLogger | ||||||
| from .meter import AverageMeter | from .meter import AverageMeter | ||||||
| from .time_utils   import time_for_file, time_string, time_string_short, time_print, convert_secs2time | from .time_utils import ( | ||||||
|  |     time_for_file, | ||||||
|  |     time_string, | ||||||
|  |     time_string_short, | ||||||
|  |     time_print, | ||||||
|  |     convert_secs2time, | ||||||
|  | ) | ||||||
|  | from .pickle_wrap import pickle_save, pickle_load | ||||||
|   | |||||||
| @@ -4,45 +4,48 @@ | |||||||
| from pathlib import Path | from pathlib import Path | ||||||
| import importlib, warnings | import importlib, warnings | ||||||
| import os, sys, time, numpy as np | import os, sys, time, numpy as np | ||||||
|  |  | ||||||
| if sys.version_info.major == 2:  # Python 2.x | if sys.version_info.major == 2:  # Python 2.x | ||||||
|     from StringIO import StringIO as BIO |     from StringIO import StringIO as BIO | ||||||
| else:  # Python 3.x | else:  # Python 3.x | ||||||
|     from io import BytesIO as BIO |     from io import BytesIO as BIO | ||||||
|  |  | ||||||
| if importlib.util.find_spec('tensorflow'): | if importlib.util.find_spec("tensorflow"): | ||||||
|     import tensorflow as tf |     import tensorflow as tf | ||||||
|  |  | ||||||
|  |  | ||||||
| class PrintLogger(object): | class PrintLogger(object): | ||||||
|    |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         """Create a summary writer logging to log_dir.""" |         """Create a summary writer logging to log_dir.""" | ||||||
|     self.name = 'PrintLogger' |         self.name = "PrintLogger" | ||||||
|  |  | ||||||
|     def log(self, string): |     def log(self, string): | ||||||
|     print (string) |         print(string) | ||||||
|  |  | ||||||
|     def close(self): |     def close(self): | ||||||
|     print ('-'*30 + ' close printer ' + '-'*30) |         print("-" * 30 + " close printer " + "-" * 30) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Logger(object): | class Logger(object): | ||||||
|    |  | ||||||
|     def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): |     def __init__(self, log_dir, seed, create_model_dir=True, use_tf=False): | ||||||
|         """Create a summary writer logging to log_dir.""" |         """Create a summary writer logging to log_dir.""" | ||||||
|         self.seed = int(seed) |         self.seed = int(seed) | ||||||
|         self.log_dir = Path(log_dir) |         self.log_dir = Path(log_dir) | ||||||
|     self.model_dir = Path(log_dir) / 'checkpoint' |         self.model_dir = Path(log_dir) / "checkpoint" | ||||||
|     self.log_dir.mkdir  (parents=True, exist_ok=True) |         self.log_dir.mkdir(parents=True, exist_ok=True) | ||||||
|         if create_model_dir: |         if create_model_dir: | ||||||
|             self.model_dir.mkdir(parents=True, exist_ok=True) |             self.model_dir.mkdir(parents=True, exist_ok=True) | ||||||
|     #self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True) |         # self.meta_dir.mkdir(mode=0o775, parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|         self.use_tf = bool(use_tf) |         self.use_tf = bool(use_tf) | ||||||
|     self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h', time.gmtime(time.time()) ))) |         self.tensorboard_dir = self.log_dir / ( | ||||||
|     #self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) ))) |             "tensorboard-{:}".format(time.strftime("%d-%h", time.gmtime(time.time()))) | ||||||
|     self.logger_path = self.log_dir / 'seed-{:}-T-{:}.log'.format(self.seed, time.strftime('%d-%h-at-%H-%M-%S', time.gmtime(time.time()))) |         ) | ||||||
|     self.logger_file = open(self.logger_path, 'w') |         # self.tensorboard_dir = self.log_dir / ('tensorboard-{:}'.format(time.strftime( '%d-%h-at-%H:%M:%S', time.gmtime(time.time()) ))) | ||||||
|  |         self.logger_path = self.log_dir / "seed-{:}-T-{:}.log".format( | ||||||
|  |             self.seed, time.strftime("%d-%h-at-%H-%M-%S", time.gmtime(time.time())) | ||||||
|  |         ) | ||||||
|  |         self.logger_file = open(self.logger_path, "w") | ||||||
|  |  | ||||||
|         if self.use_tf: |         if self.use_tf: | ||||||
|             self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True) |             self.tensorboard_dir.mkdir(mode=0o775, parents=True, exist_ok=True) | ||||||
| @@ -51,15 +54,22 @@ class Logger(object): | |||||||
|             self.writer = None |             self.writer = None | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|     return ('{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})'.format(name=self.__class__.__name__, **self.__dict__)) |         return "{name}(dir={log_dir}, use-tf={use_tf}, writer={writer})".format( | ||||||
|  |             name=self.__class__.__name__, **self.__dict__ | ||||||
|  |         ) | ||||||
|  |  | ||||||
|     def path(self, mode): |     def path(self, mode): | ||||||
|     valids = ('model', 'best', 'info', 'log') |         valids = ("model", "best", "info", "log") | ||||||
|     if   mode == 'model': return self.model_dir / 'seed-{:}-basic.pth'.format(self.seed) |         if mode == "model": | ||||||
|     elif mode == 'best' : return self.model_dir / 'seed-{:}-best.pth'.format(self.seed) |             return self.model_dir / "seed-{:}-basic.pth".format(self.seed) | ||||||
|     elif mode == 'info' : return self.log_dir / 'seed-{:}-last-info.pth'.format(self.seed) |         elif mode == "best": | ||||||
|     elif mode == 'log'  : return self.log_dir |             return self.model_dir / "seed-{:}-best.pth".format(self.seed) | ||||||
|     else: raise TypeError('Unknow mode = {:}, valid modes = {:}'.format(mode, valids)) |         elif mode == "info": | ||||||
|  |             return self.log_dir / "seed-{:}-last-info.pth".format(self.seed) | ||||||
|  |         elif mode == "log": | ||||||
|  |             return self.log_dir | ||||||
|  |         else: | ||||||
|  |             raise TypeError("Unknow mode = {:}, valid modes = {:}".format(mode, valids)) | ||||||
|  |  | ||||||
|     def extract_log(self): |     def extract_log(self): | ||||||
|         return self.logger_file |         return self.logger_file | ||||||
| @@ -71,31 +81,37 @@ class Logger(object): | |||||||
|  |  | ||||||
|     def log(self, string, save=True, stdout=False): |     def log(self, string, save=True, stdout=False): | ||||||
|         if stdout: |         if stdout: | ||||||
|       sys.stdout.write(string); sys.stdout.flush() |             sys.stdout.write(string) | ||||||
|  |             sys.stdout.flush() | ||||||
|         else: |         else: | ||||||
|       print (string) |             print(string) | ||||||
|         if save: |         if save: | ||||||
|       self.logger_file.write('{:}\n'.format(string)) |             self.logger_file.write("{:}\n".format(string)) | ||||||
|             self.logger_file.flush() |             self.logger_file.flush() | ||||||
|  |  | ||||||
|     def scalar_summary(self, tags, values, step): |     def scalar_summary(self, tags, values, step): | ||||||
|         """Log a scalar variable.""" |         """Log a scalar variable.""" | ||||||
|         if not self.use_tf: |         if not self.use_tf: | ||||||
|       warnings.warn('Do set use-tensorflow installed but call scalar_summary') |             warnings.warn("Do set use-tensorflow installed but call scalar_summary") | ||||||
|         else: |         else: | ||||||
|       assert isinstance(tags, list) == isinstance(values, list), 'Type : {:} vs {:}'.format(type(tags), type(values)) |             assert isinstance(tags, list) == isinstance( | ||||||
|  |                 values, list | ||||||
|  |             ), "Type : {:} vs {:}".format(type(tags), type(values)) | ||||||
|             if not isinstance(tags, list): |             if not isinstance(tags, list): | ||||||
|                 tags, values = [tags], [values] |                 tags, values = [tags], [values] | ||||||
|             for tag, value in zip(tags, values): |             for tag, value in zip(tags, values): | ||||||
|         summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) |                 summary = tf.Summary( | ||||||
|  |                     value=[tf.Summary.Value(tag=tag, simple_value=value)] | ||||||
|  |                 ) | ||||||
|                 self.writer.add_summary(summary, step) |                 self.writer.add_summary(summary, step) | ||||||
|                 self.writer.flush() |                 self.writer.flush() | ||||||
|  |  | ||||||
|     def image_summary(self, tag, images, step): |     def image_summary(self, tag, images, step): | ||||||
|         """Log a list of images.""" |         """Log a list of images.""" | ||||||
|         import scipy |         import scipy | ||||||
|  |  | ||||||
|         if not self.use_tf: |         if not self.use_tf: | ||||||
|       warnings.warn('Do set use-tensorflow installed but call scalar_summary') |             warnings.warn("Do set use-tensorflow installed but call scalar_summary") | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         img_summaries = [] |         img_summaries = [] | ||||||
| @@ -108,11 +124,15 @@ class Logger(object): | |||||||
|             scipy.misc.toimage(img).save(s, format="png") |             scipy.misc.toimage(img).save(s, format="png") | ||||||
|  |  | ||||||
|             # Create an Image object |             # Create an Image object | ||||||
|       img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), |             img_sum = tf.Summary.Image( | ||||||
|  |                 encoded_image_string=s.getvalue(), | ||||||
|                 height=img.shape[0], |                 height=img.shape[0], | ||||||
|                      width=img.shape[1]) |                 width=img.shape[1], | ||||||
|  |             ) | ||||||
|             # Create a Summary value |             # Create a Summary value | ||||||
|       img_summaries.append(tf.Summary.Value(tag='{}/{}'.format(tag, i), image=img_sum)) |             img_summaries.append( | ||||||
|  |                 tf.Summary.Value(tag="{}/{}".format(tag, i), image=img_sum) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|         # Create and write Summary |         # Create and write Summary | ||||||
|         summary = tf.Summary(value=img_summaries) |         summary = tf.Summary(value=img_summaries) | ||||||
| @@ -121,7 +141,8 @@ class Logger(object): | |||||||
|  |  | ||||||
|     def histo_summary(self, tag, values, step, bins=1000): |     def histo_summary(self, tag, values, step, bins=1000): | ||||||
|         """Log a histogram of the tensor of values.""" |         """Log a histogram of the tensor of values.""" | ||||||
|     if not self.use_tf: raise ValueError('Do not have tensorflow') |         if not self.use_tf: | ||||||
|  |             raise ValueError("Do not have tensorflow") | ||||||
|         import tensorflow as tf |         import tensorflow as tf | ||||||
|  |  | ||||||
|         # Create a histogram using numpy |         # Create a histogram using numpy | ||||||
| @@ -133,7 +154,7 @@ class Logger(object): | |||||||
|         hist.max = float(np.max(values)) |         hist.max = float(np.max(values)) | ||||||
|         hist.num = int(np.prod(values.shape)) |         hist.num = int(np.prod(values.shape)) | ||||||
|         hist.sum = float(np.sum(values)) |         hist.sum = float(np.sum(values)) | ||||||
|     hist.sum_squares = float(np.sum(values**2)) |         hist.sum_squares = float(np.sum(values ** 2)) | ||||||
|  |  | ||||||
|         # Drop the start of the first bin |         # Drop the start of the first bin | ||||||
|         bin_edges = bin_edges[1:] |         bin_edges = bin_edges[1:] | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ import numpy as np | |||||||
|  |  | ||||||
| class AverageMeter(object): | class AverageMeter(object): | ||||||
|     """Computes and stores the average and current value""" |     """Computes and stores the average and current value""" | ||||||
|  |  | ||||||
|     def __init__(self): |     def __init__(self): | ||||||
|         self.reset() |         self.reset() | ||||||
|  |  | ||||||
| @@ -19,42 +20,60 @@ class AverageMeter(object): | |||||||
|         self.avg = self.sum / self.count |         self.avg = self.sum / self.count | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|     return ('{name}(val={val}, avg={avg}, count={count})'.format(name=self.__class__.__name__, **self.__dict__)) |         return "{name}(val={val}, avg={avg}, count={count})".format( | ||||||
|  |             name=self.__class__.__name__, **self.__dict__ | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class RecorderMeter(object): | class RecorderMeter(object): | ||||||
|     """Computes and stores the minimum loss value and its epoch index""" |     """Computes and stores the minimum loss value and its epoch index""" | ||||||
|  |  | ||||||
|     def __init__(self, total_epoch): |     def __init__(self, total_epoch): | ||||||
|         self.reset(total_epoch) |         self.reset(total_epoch) | ||||||
|  |  | ||||||
|     def reset(self, total_epoch): |     def reset(self, total_epoch): | ||||||
|     assert total_epoch > 0, 'total_epoch should be greater than 0 vs {:}'.format(total_epoch) |         assert total_epoch > 0, "total_epoch should be greater than 0 vs {:}".format( | ||||||
|  |             total_epoch | ||||||
|  |         ) | ||||||
|         self.total_epoch = total_epoch |         self.total_epoch = total_epoch | ||||||
|         self.current_epoch = 0 |         self.current_epoch = 0 | ||||||
|     self.epoch_losses  = np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] |         self.epoch_losses = np.zeros( | ||||||
|  |             (self.total_epoch, 2), dtype=np.float32 | ||||||
|  |         )  # [epoch, train/val] | ||||||
|         self.epoch_losses = self.epoch_losses - 1 |         self.epoch_losses = self.epoch_losses - 1 | ||||||
|     self.epoch_accuracy= np.zeros((self.total_epoch, 2), dtype=np.float32) # [epoch, train/val] |         self.epoch_accuracy = np.zeros( | ||||||
|     self.epoch_accuracy= self.epoch_accuracy |             (self.total_epoch, 2), dtype=np.float32 | ||||||
|  |         )  # [epoch, train/val] | ||||||
|  |         self.epoch_accuracy = self.epoch_accuracy | ||||||
|  |  | ||||||
|     def update(self, idx, train_loss, train_acc, val_loss, val_acc): |     def update(self, idx, train_loss, train_acc, val_loss, val_acc): | ||||||
|     assert idx >= 0 and idx < self.total_epoch, 'total_epoch : {} , but update with the {} index'.format(self.total_epoch, idx) |         assert ( | ||||||
|     self.epoch_losses  [idx, 0] = train_loss |             idx >= 0 and idx < self.total_epoch | ||||||
|     self.epoch_losses  [idx, 1] = val_loss |         ), "total_epoch : {} , but update with the {} index".format( | ||||||
|  |             self.total_epoch, idx | ||||||
|  |         ) | ||||||
|  |         self.epoch_losses[idx, 0] = train_loss | ||||||
|  |         self.epoch_losses[idx, 1] = val_loss | ||||||
|         self.epoch_accuracy[idx, 0] = train_acc |         self.epoch_accuracy[idx, 0] = train_acc | ||||||
|         self.epoch_accuracy[idx, 1] = val_acc |         self.epoch_accuracy[idx, 1] = val_acc | ||||||
|         self.current_epoch = idx + 1 |         self.current_epoch = idx + 1 | ||||||
|         return self.max_accuracy(False) == self.epoch_accuracy[idx, 1] |         return self.max_accuracy(False) == self.epoch_accuracy[idx, 1] | ||||||
|  |  | ||||||
|     def max_accuracy(self, istrain): |     def max_accuracy(self, istrain): | ||||||
|     if self.current_epoch <= 0: return 0 |         if self.current_epoch <= 0: | ||||||
|     if istrain: return self.epoch_accuracy[:self.current_epoch, 0].max() |             return 0 | ||||||
|     else:       return self.epoch_accuracy[:self.current_epoch, 1].max() |         if istrain: | ||||||
|  |             return self.epoch_accuracy[: self.current_epoch, 0].max() | ||||||
|  |         else: | ||||||
|  |             return self.epoch_accuracy[: self.current_epoch, 1].max() | ||||||
|  |  | ||||||
|     def plot_curve(self, save_path): |     def plot_curve(self, save_path): | ||||||
|         import matplotlib |         import matplotlib | ||||||
|     matplotlib.use('agg') |  | ||||||
|  |         matplotlib.use("agg") | ||||||
|         import matplotlib.pyplot as plt |         import matplotlib.pyplot as plt | ||||||
|     title = 'the accuracy/loss curve of train/val' |  | ||||||
|  |         title = "the accuracy/loss curve of train/val" | ||||||
|         dpi = 100 |         dpi = 100 | ||||||
|         width, height = 1600, 1000 |         width, height = 1600, 1000 | ||||||
|         legend_fontsize = 10 |         legend_fontsize = 10 | ||||||
| @@ -72,27 +91,30 @@ class RecorderMeter(object): | |||||||
|         plt.yticks(np.arange(0, 100 + interval_y, interval_y)) |         plt.yticks(np.arange(0, 100 + interval_y, interval_y)) | ||||||
|         plt.grid() |         plt.grid() | ||||||
|         plt.title(title, fontsize=20) |         plt.title(title, fontsize=20) | ||||||
|     plt.xlabel('the training epoch', fontsize=16) |         plt.xlabel("the training epoch", fontsize=16) | ||||||
|     plt.ylabel('accuracy', fontsize=16) |         plt.ylabel("accuracy", fontsize=16) | ||||||
|  |  | ||||||
|         y_axis[:] = self.epoch_accuracy[:, 0] |         y_axis[:] = self.epoch_accuracy[:, 0] | ||||||
|     plt.plot(x_axis, y_axis, color='g', linestyle='-', label='train-accuracy', lw=2) |         plt.plot(x_axis, y_axis, color="g", linestyle="-", label="train-accuracy", lw=2) | ||||||
|         plt.legend(loc=4, fontsize=legend_fontsize) |         plt.legend(loc=4, fontsize=legend_fontsize) | ||||||
|  |  | ||||||
|         y_axis[:] = self.epoch_accuracy[:, 1] |         y_axis[:] = self.epoch_accuracy[:, 1] | ||||||
|     plt.plot(x_axis, y_axis, color='y', linestyle='-', label='valid-accuracy', lw=2) |         plt.plot(x_axis, y_axis, color="y", linestyle="-", label="valid-accuracy", lw=2) | ||||||
|         plt.legend(loc=4, fontsize=legend_fontsize) |         plt.legend(loc=4, fontsize=legend_fontsize) | ||||||
|  |  | ||||||
|      |  | ||||||
|         y_axis[:] = self.epoch_losses[:, 0] |         y_axis[:] = self.epoch_losses[:, 0] | ||||||
|     plt.plot(x_axis, y_axis*50, color='g', linestyle=':', label='train-loss-x50', lw=2) |         plt.plot( | ||||||
|  |             x_axis, y_axis * 50, color="g", linestyle=":", label="train-loss-x50", lw=2 | ||||||
|  |         ) | ||||||
|         plt.legend(loc=4, fontsize=legend_fontsize) |         plt.legend(loc=4, fontsize=legend_fontsize) | ||||||
|  |  | ||||||
|         y_axis[:] = self.epoch_losses[:, 1] |         y_axis[:] = self.epoch_losses[:, 1] | ||||||
|     plt.plot(x_axis, y_axis*50, color='y', linestyle=':', label='valid-loss-x50', lw=2) |         plt.plot( | ||||||
|  |             x_axis, y_axis * 50, color="y", linestyle=":", label="valid-loss-x50", lw=2 | ||||||
|  |         ) | ||||||
|         plt.legend(loc=4, fontsize=legend_fontsize) |         plt.legend(loc=4, fontsize=legend_fontsize) | ||||||
|  |  | ||||||
|         if save_path is not None: |         if save_path is not None: | ||||||
|       fig.savefig(save_path, dpi=dpi, bbox_inches='tight') |             fig.savefig(save_path, dpi=dpi, bbox_inches="tight") | ||||||
|       print ('---- save figure {} into {}'.format(title, save_path)) |             print("---- save figure {} into {}".format(title, save_path)) | ||||||
|         plt.close(fig) |         plt.close(fig) | ||||||
|   | |||||||
							
								
								
									
										21
									
								
								lib/log_utils/pickle_wrap.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										21
									
								
								lib/log_utils/pickle_wrap.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,21 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 # | ||||||
|  | ##################################################### | ||||||
|  | import pickle | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def pickle_save(obj, path): | ||||||
|  |     file_path = Path(path) | ||||||
|  |     file_dir = file_path.parent | ||||||
|  |     file_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |     with file_path.open("wb") as f: | ||||||
|  |         pickle.dump(obj, f) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def pickle_load(path): | ||||||
|  |     if not Path(path).exists(): | ||||||
|  |         raise ValueError("{:} does not exists".format(path)) | ||||||
|  |     with Path(path).open("rb") as f: | ||||||
|  |         data = pickle.load(f) | ||||||
|  |     return data | ||||||
| @@ -4,39 +4,46 @@ | |||||||
| import time, sys | import time, sys | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
|  |  | ||||||
| def time_for_file(): | def time_for_file(): | ||||||
|   ISOTIMEFORMAT='%d-%h-at-%H-%M-%S' |     ISOTIMEFORMAT = "%d-%h-at-%H-%M-%S" | ||||||
|   return '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) |     return "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||||
|  |  | ||||||
|  |  | ||||||
| def time_string(): | def time_string(): | ||||||
|   ISOTIMEFORMAT='%Y-%m-%d %X' |     ISOTIMEFORMAT = "%Y-%m-%d %X" | ||||||
|   string = '[{:}]'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) |     string = "[{:}]".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||||
|     return string |     return string | ||||||
|  |  | ||||||
|  |  | ||||||
| def time_string_short(): | def time_string_short(): | ||||||
|   ISOTIMEFORMAT='%Y%m%d' |     ISOTIMEFORMAT = "%Y%m%d" | ||||||
|   string = '{:}'.format(time.strftime( ISOTIMEFORMAT, time.gmtime(time.time()) )) |     string = "{:}".format(time.strftime(ISOTIMEFORMAT, time.gmtime(time.time()))) | ||||||
|     return string |     return string | ||||||
|  |  | ||||||
|  |  | ||||||
| def time_print(string, is_print=True): | def time_print(string, is_print=True): | ||||||
|   if (is_print): |     if is_print: | ||||||
|     print('{} : {}'.format(time_string(), string)) |         print("{} : {}".format(time_string(), string)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def convert_secs2time(epoch_time, return_str=False): | def convert_secs2time(epoch_time, return_str=False): | ||||||
|     need_hour = int(epoch_time / 3600) |     need_hour = int(epoch_time / 3600) | ||||||
|   need_mins = int((epoch_time - 3600*need_hour) / 60)   |     need_mins = int((epoch_time - 3600 * need_hour) / 60) | ||||||
|   need_secs = int(epoch_time - 3600*need_hour - 60*need_mins) |     need_secs = int(epoch_time - 3600 * need_hour - 60 * need_mins) | ||||||
|     if return_str: |     if return_str: | ||||||
|     str = '[{:02d}:{:02d}:{:02d}]'.format(need_hour, need_mins, need_secs) |         str = "[{:02d}:{:02d}:{:02d}]".format(need_hour, need_mins, need_secs) | ||||||
|         return str |         return str | ||||||
|     else: |     else: | ||||||
|         return need_hour, need_mins, need_secs |         return need_hour, need_mins, need_secs | ||||||
|  |  | ||||||
|  |  | ||||||
| def print_log(print_string, log): | def print_log(print_string, log): | ||||||
|   #if isinstance(log, Logger): log.log('{:}'.format(print_string)) |     # if isinstance(log, Logger): log.log('{:}'.format(print_string)) | ||||||
|   if hasattr(log, 'log'): log.log('{:}'.format(print_string)) |     if hasattr(log, "log"): | ||||||
|  |         log.log("{:}".format(print_string)) | ||||||
|     else: |     else: | ||||||
|         print("{:}".format(print_string)) |         print("{:}".format(print_string)) | ||||||
|         if log is not None: |         if log is not None: | ||||||
|       log.write('{:}\n'.format(print_string)) |             log.write("{:}\n".format(print_string)) | ||||||
|             log.flush() |             log.flush() | ||||||
|   | |||||||
| @@ -9,15 +9,19 @@ def count_parameters_in_MB(model): | |||||||
|  |  | ||||||
| def count_parameters(model_or_parameters, unit="mb"): | def count_parameters(model_or_parameters, unit="mb"): | ||||||
|     if isinstance(model_or_parameters, nn.Module): |     if isinstance(model_or_parameters, nn.Module): | ||||||
|         counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters()) |         counts = sum(np.prod(v.size()) for v in model_or_parameters.parameters()) | ||||||
|  |     elif isinstance(models_or_parameters, nn.Parameter): | ||||||
|  |         counts = models_or_parameters.numel() | ||||||
|  |     elif isinstance(models_or_parameters, (list, tuple)): | ||||||
|  |         counts = sum(count_parameters(x, None) for x in models_or_parameters) | ||||||
|     else: |     else: | ||||||
|         counts = np.sum(np.prod(v.size()) for v in model_or_parameters) |         counts = sum(np.prod(v.size()) for v in model_or_parameters) | ||||||
|     if unit.lower() == "mb": |     if unit.lower() == "kb" or unit.lower() == "k": | ||||||
|         counts /= 1e6 |         counts /= 2 ** 10  # changed from 1e3 to 2^10 | ||||||
|     elif unit.lower() == "kb": |     elif unit.lower() == "mb" or unit.lower() == "m": | ||||||
|         counts /= 1e3 |         counts /= 2 ** 20  # changed from 1e6 to 2^20 | ||||||
|     elif unit.lower() == "gb": |     elif unit.lower() == "gb" or unit.lower() == "g": | ||||||
|         counts /= 1e9 |         counts /= 2 ** 30  # changed from 1e9 to 2^30 | ||||||
|     elif unit is not None: |     elif unit is not None: | ||||||
|         raise ValueError("Unknow unit: {:}".format(unit)) |         raise ValueError("Unknow unit: {:}".format(unit)) | ||||||
|     return counts |     return counts | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								tests/test_torch.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								tests/test_torch.sh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | # bash ./tests/test_torch.sh | ||||||
|  |  | ||||||
|  | pytest ./tests/test_torch_gpu_bugs.py::test_create -s | ||||||
|  | CUDA_VISIBLE_DEVICES="" pytest ./tests/test_torch_gpu_bugs.py::test_load -s | ||||||
							
								
								
									
										43
									
								
								tests/test_torch_gpu_bugs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								tests/test_torch_gpu_bugs.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | |||||||
|  | ##################################################### | ||||||
|  | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 # | ||||||
|  | ##################################################### | ||||||
|  | # pytest ./tests/test_torch_gpu_bugs.py::test_create | ||||||
|  | # | ||||||
|  | # CUDA_VISIBLE_DEVICES="" pytest ./tests/test_torch_gpu_bugs.py::test_load | ||||||
|  | ##################################################### | ||||||
|  | import os, sys, time, torch | ||||||
|  | import pickle | ||||||
|  | import tempfile | ||||||
|  | from pathlib import Path | ||||||
|  |  | ||||||
|  | lib_dir = (Path(__file__).parent / ".." / "lib").resolve() | ||||||
|  | print("library path: {:}".format(lib_dir)) | ||||||
|  | if str(lib_dir) not in sys.path: | ||||||
|  |     sys.path.insert(0, str(lib_dir)) | ||||||
|  |  | ||||||
|  | from trade_models.quant_transformer import QuantTransformer | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_create(): | ||||||
|  |     """Test the basic quant-model.""" | ||||||
|  |     if not torch.cuda.is_available(): | ||||||
|  |         return | ||||||
|  |     quant_model = QuantTransformer(GPU=0) | ||||||
|  |     temp_dir = lib_dir / ".." / "tests" / ".pytest_cache" | ||||||
|  |     temp_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |     temp_file = temp_dir / "quant-model.pkl" | ||||||
|  |     with temp_file.open("wb") as f: | ||||||
|  |         # quant_model.to(None) | ||||||
|  |         quant_model.to("cpu") | ||||||
|  |         # del quant_model.model | ||||||
|  |         # del quant_model.train_optimizer | ||||||
|  |         pickle.dump(quant_model, f) | ||||||
|  |     print("save into {:}".format(temp_file)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_load(): | ||||||
|  |     temp_file = lib_dir / ".." / "tests" / ".pytest_cache" / "quant-model.pkl" | ||||||
|  |     with temp_file.open("rb") as f: | ||||||
|  |         model = pickle.load(f) | ||||||
|  |         print(model.model) | ||||||
|  |         print(model.train_optimizer) | ||||||
		Reference in New Issue
	
	Block a user