Update GeMOSA v4
This commit is contained in:
		| @@ -2,9 +2,9 @@ | |||||||
| # Learning to Generate Model One Step Ahead         # | # Learning to Generate Model One Step Ahead         # | ||||||
| ##################################################### | ##################################################### | ||||||
| # python exps/GeMOSA/main.py --env_version v1 --workers 0 | # python exps/GeMOSA/main.py --env_version v1 --workers 0 | ||||||
| # python exps/GeMOSA/main.py --env_version v1 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256 | # python exps/GeMOSA/main.py --env_version v1 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v2 --device cuda --lr 0.002 --hidden_dim 16 --meta_batch 256 | # python exps/GeMOSA/main.py --env_version v2 --lr 0.002 --hidden_dim 16 --meta_batch 256 --device cuda | ||||||
| # python exps/GeMOSA/main.py --env_version v3 --device cuda --lr 0.002 --hidden_dim 32 --meta_batch 256 | # python exps/GeMOSA/main.py --env_version v3 --lr 0.002 --hidden_dim 32 --time_dim 32 --meta_batch 256 --device cuda | ||||||
| ##################################################### | ##################################################### | ||||||
| import sys, time, copy, torch, random, argparse | import sys, time, copy, torch, random, argparse | ||||||
| from tqdm import tqdm | from tqdm import tqdm | ||||||
|   | |||||||
| @@ -3,7 +3,8 @@ | |||||||
| ############################################################################ | ############################################################################ | ||||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v1                     # | # python exps/GeMOSA/vis-synthetic.py --env_version v1                     # | ||||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | ||||||
| # python exps/GeMOSA/vis-synthetic.py --env_version v2                     # | # python exps/GeMOSA/vis-synthetic.py --env_version v3                     # | ||||||
|  | # python exps/GeMOSA/vis-synthetic.py --env_version v4                     # | ||||||
| ############################################################################ | ############################################################################ | ||||||
| import os, sys, copy, random | import os, sys, copy, random | ||||||
| import torch | import torch | ||||||
| @@ -31,8 +32,8 @@ from xautodl.procedures.metric_utils import MSEMetric | |||||||
|  |  | ||||||
|  |  | ||||||
| def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None): | ||||||
|     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths, label=label) |     cur_ax.scatter([-100], [-100], color=color, linewidths=linewidths[0], label=label) | ||||||
|     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=1.5, label=None) |     cur_ax.scatter(xs, ys, color=color, alpha=alpha, linewidths=linewidths[1], label=None) | ||||||
|  |  | ||||||
|  |  | ||||||
| def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None): | ||||||
| @@ -186,15 +187,23 @@ def visualize_env(save_dir, version): | |||||||
|         sub_save_dir.mkdir(parents=True, exist_ok=True) |         sub_save_dir.mkdir(parents=True, exist_ok=True) | ||||||
|  |  | ||||||
|     dynamic_env = get_synthetic_env(version=version) |     dynamic_env = get_synthetic_env(version=version) | ||||||
|  |     print("env: {:}".format(dynamic_env)) | ||||||
|  |     print("oracle_map: {:}".format(dynamic_env.oracle_map)) | ||||||
|     allxs, allys = [], [] |     allxs, allys = [], [] | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         allxs.append(allx) |         allxs.append(allx) | ||||||
|         allys.append(ally) |         allys.append(ally) | ||||||
|     allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) |     if dynamic_env.meta_info['task'] == 'regression': | ||||||
|     print("env: {:}".format(dynamic_env)) |         allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1) | ||||||
|     print("oracle_map: {:}".format(dynamic_env.oracle_map)) |         print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) | ||||||
|     print("x - min={:.3f}, max={:.3f}".format(allxs.min().item(), allxs.max().item())) |         print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) | ||||||
|     print("y - min={:.3f}, max={:.3f}".format(allys.min().item(), allys.max().item())) |     elif dynamic_env.meta_info['task'] == 'classification': | ||||||
|  |         allxs = torch.cat(allxs) | ||||||
|  |         print("x[0] - min={:.3f}, max={:.3f}".format(allxs[:,0].min().item(), allxs[:,0].max().item())) | ||||||
|  |         print("x[1] - min={:.3f}, max={:.3f}".format(allxs[:,1].min().item(), allxs[:,1].max().item())) | ||||||
|  |     else: | ||||||
|  |         raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||||
|  |  | ||||||
|     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): |     for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)): | ||||||
|         dpi, width, height = 30, 1800, 1400 |         dpi, width, height = 30, 1800, 1400 | ||||||
|         figsize = width / float(dpi), height / float(dpi) |         figsize = width / float(dpi), height / float(dpi) | ||||||
| @@ -202,19 +211,29 @@ def visualize_env(save_dir, version): | |||||||
|         fig = plt.figure(figsize=figsize) |         fig = plt.figure(figsize=figsize) | ||||||
|  |  | ||||||
|         cur_ax = fig.add_subplot(1, 1, 1) |         cur_ax = fig.add_subplot(1, 1, 1) | ||||||
|         allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() |         if dynamic_env.meta_info['task'] == 'regression': | ||||||
|         plot_scatter(cur_ax, allx, ally, "k", 0.99, 15, "timestamp={:05d}".format(idx)) |             allx, ally = allx[:, 0].numpy(), ally[:, 0].numpy() | ||||||
|  |             plot_scatter(cur_ax, allx, ally, "k", 0.99, (15, 1.5), "timestamp={:05d}".format(idx)) | ||||||
|  |             cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) | ||||||
|  |             cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) | ||||||
|  |         elif dynamic_env.meta_info['task'] == 'classification': | ||||||
|  |             positive, negative = ally == 1, ally == 0 | ||||||
|  |             # plot_scatter(cur_ax, [1], [1], "k", 0.1, 1, "timestamp={:05d}".format(idx)) | ||||||
|  |             plot_scatter(cur_ax, allx[positive,0], allx[positive,1], "r", 0.99, (20, 10), "positive") | ||||||
|  |             plot_scatter(cur_ax, allx[negative,0], allx[negative,1], "g", 0.99, (20, 10), "negative") | ||||||
|  |             cur_ax.set_xlim(round(allxs[:,0].min().item(), 1), round(allxs[:,0].max().item(), 1)) | ||||||
|  |             cur_ax.set_ylim(round(allxs[:,1].min().item(), 1), round(allxs[:,1].max().item(), 1)) | ||||||
|  |         else: | ||||||
|  |             raise ValueError("Unknown task".format(dynamic_env.meta_info['task'])) | ||||||
|  |  | ||||||
|         cur_ax.set_xlabel("X", fontsize=LabelSize) |         cur_ax.set_xlabel("X", fontsize=LabelSize) | ||||||
|         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) |         cur_ax.set_ylabel("Y", rotation=0, fontsize=LabelSize) | ||||||
|         for tick in cur_ax.xaxis.get_major_ticks(): |         for tick in cur_ax.xaxis.get_major_ticks(): | ||||||
|             tick.label.set_fontsize(LabelSize - font_gap) |                 tick.label.set_fontsize(LabelSize - font_gap) | ||||||
|             tick.label.set_rotation(10) |                 tick.label.set_rotation(10) | ||||||
|         for tick in cur_ax.yaxis.get_major_ticks(): |         for tick in cur_ax.yaxis.get_major_ticks(): | ||||||
|             tick.label.set_fontsize(LabelSize - font_gap) |                 tick.label.set_fontsize(LabelSize - font_gap) | ||||||
|         cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1)) |         cur_ax.legend(loc=1, fontsize=LegendFontsize)    | ||||||
|         cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1)) |  | ||||||
|         cur_ax.legend(loc=1, fontsize=LegendFontsize) |  | ||||||
|  |  | ||||||
|         pdf_save_path = ( |         pdf_save_path = ( | ||||||
|             save_dir |             save_dir | ||||||
|             / "pdf-{:}".format(version) |             / "pdf-{:}".format(version) | ||||||
| @@ -237,7 +256,7 @@ def visualize_env(save_dir, version): | |||||||
|     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) |     os.system("{:} {xdir}/env-{ver}.webm".format(base_cmd, xdir=save_dir, ver=version)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def compare_algs(save_dir, version, alg_dir="./outputs/lfna-synthetic"): | def compare_algs(save_dir, version, alg_dir="./outputs/GeMOSA-synthetic"): | ||||||
|     save_dir = Path(str(save_dir)) |     save_dir = Path(str(save_dir)) | ||||||
|     for substr in ("pdf", "png"): |     for substr in ("pdf", "png"): | ||||||
|         sub_save_dir = save_dir / substr |         sub_save_dir = save_dir / substr | ||||||
|   | |||||||
| @@ -10,5 +10,10 @@ from .math_static_funcs import ( | |||||||
|     ComposedSinSFunc, |     ComposedSinSFunc, | ||||||
|     ComposedCosSFunc, |     ComposedCosSFunc, | ||||||
| ) | ) | ||||||
| from .math_dynamic_funcs import LinearDFunc, QuadraticDFunc, SinQuadraticDFunc | from .math_dynamic_funcs import ( | ||||||
| from .math_dynamic_generator import GaussianDGenerator |     LinearDFunc, | ||||||
|  |     QuadraticDFunc, | ||||||
|  |     SinQuadraticDFunc, | ||||||
|  |     BinaryQuadraticDFunc, | ||||||
|  | ) | ||||||
|  | from .math_dynamic_generator import UniformDGenerator, GaussianDGenerator | ||||||
|   | |||||||
| @@ -20,7 +20,9 @@ class DynamicFunc(MathFunc): | |||||||
|  |  | ||||||
|     def noise_call(self, x, timestamp, std): |     def noise_call(self, x, timestamp, std): | ||||||
|         clean_y = self.__call__(x, timestamp) |         clean_y = self.__call__(x, timestamp) | ||||||
|         if isinstance(clean_y, np.ndarray): |         if std is None: | ||||||
|  |             noise_y = clean_y | ||||||
|  |         elif isinstance(clean_y, np.ndarray): | ||||||
|             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) |             noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape) | ||||||
|         else: |         else: | ||||||
|             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) |             raise ValueError("Unkonwn type: {:}".format(type(clean_y))) | ||||||
| @@ -43,7 +45,7 @@ class LinearDFunc(DynamicFunc): | |||||||
|         return a * x + b |         return a * x + b | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * {x} + {b})".format( |         return "({a} * {x} + {b})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
| @@ -69,7 +71,7 @@ class QuadraticDFunc(DynamicFunc): | |||||||
|         return a * x * x + b * x + c |         return a * x * x + b * x + c | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( |         return "({a} * {x}^2 + {b} * {x} + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
| @@ -97,6 +99,39 @@ class SinQuadraticDFunc(DynamicFunc): | |||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( |         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( | ||||||
|  |             name="Sin", | ||||||
|  |             a=self._params[0], | ||||||
|  |             b=self._params[1], | ||||||
|  |             c=self._params[2], | ||||||
|  |             x=self.xstr, | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BinaryQuadraticDFunc(DynamicFunc): | ||||||
|  |     """The dynamic quadratic function that outputs f(x) = a * x[0]^2 + b * x[1] + c >= 0. | ||||||
|  |     The a, b, and c is a function of timestamp. | ||||||
|  |     """ | ||||||
|  |  | ||||||
|  |     def __init__(self, params=None): | ||||||
|  |         super(BinaryQuadraticDFunc, self).__init__(3, params) | ||||||
|  |  | ||||||
|  |     def __call__(self, x, timestamp): | ||||||
|  |         self.check_valid() | ||||||
|  |         a = self._params[0](timestamp) | ||||||
|  |         b = self._params[1](timestamp) | ||||||
|  |         c = self._params[2](timestamp) | ||||||
|  |         convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x | ||||||
|  |         a, b, c = convert_fn(a), convert_fn(b), convert_fn(c) | ||||||
|  |         if isinstance(x, np.ndarray) and x.shape[-1] == 2: | ||||||
|  |             results = a * x[..., 0] * x[..., 0] + b * x[..., 1] + c | ||||||
|  |             return (results >= 0).astype(np.int) | ||||||
|  |         else: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "Either the type {:} or the shape is incorrect.".format(type(x)) | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "({a} * {x}[0]^2 + {b} * {x}[1] + {c} >= 0)".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|   | |||||||
| @@ -20,6 +20,37 @@ class DynamicGenerator(abc.ABC): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class UniformDGenerator(DynamicGenerator): | ||||||
|  |     """Generate data from the uniform distribution.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, l_functors, r_functors): | ||||||
|  |         super(UniformDGenerator, self).__init__() | ||||||
|  |         self._ndim = assert_list_tuple(l_functors) | ||||||
|  |         assert self._ndim == assert_list_tuple(r_functors) | ||||||
|  |         self._l_functors = l_functors | ||||||
|  |         self._r_functors = r_functors | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def ndim(self): | ||||||
|  |         return self._ndim | ||||||
|  |  | ||||||
|  |     def output_shape(self): | ||||||
|  |         return (self._ndim,) | ||||||
|  |  | ||||||
|  |     def __call__(self, time, num): | ||||||
|  |         l_list = [functor(time) for functor in self._l_functors] | ||||||
|  |         r_list = [functor(time) for functor in self._r_functors] | ||||||
|  |         values = [] | ||||||
|  |         for l, r in zip(l_list, r_list): | ||||||
|  |             values.append(np.random.uniform(low=l, high=r, size=num)) | ||||||
|  |         return np.stack(values, axis=-1) | ||||||
|  |  | ||||||
|  |     def __repr__(self): | ||||||
|  |         return "{name}({ndim} dims)".format( | ||||||
|  |             name=self.__class__.__name__, ndim=self._ndim | ||||||
|  |         ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class GaussianDGenerator(DynamicGenerator): | class GaussianDGenerator(DynamicGenerator): | ||||||
|     """Generate data from Gaussian distribution.""" |     """Generate data from Gaussian distribution.""" | ||||||
|  |  | ||||||
|   | |||||||
| @@ -47,7 +47,7 @@ class LinearSFunc(StaticFunc): | |||||||
|         return weights[0] * x + weights[1] |         return weights[0] * x + weights[1] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * {x} + {b})".format( |         return "({a} * {x} + {b})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
| @@ -69,7 +69,7 @@ class QuadraticSFunc(StaticFunc): | |||||||
|         return weights[0] * x * x + weights[1] * x + weights[2] |         return weights[0] * x * x + weights[1] * x + weights[2] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * {x}^2 + {b} * {x} + {c})".format( |         return "({a} * {x}^2 + {b} * {x} + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
| @@ -97,7 +97,7 @@ class CubicSFunc(StaticFunc): | |||||||
|         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] |         return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3] | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( |         return "({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
| @@ -166,7 +166,7 @@ class ConstantFunc(StaticFunc): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a})".format(name=self.__class__.__name__, a=self._params[0]) |         return "{a}".format(name=self.__class__.__name__, a=self._params[0]) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ComposedSinSFunc(StaticFunc): | class ComposedSinSFunc(StaticFunc): | ||||||
| @@ -188,7 +188,7 @@ class ComposedSinSFunc(StaticFunc): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( |         return "({a} * sin({b} * {x}) + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
| @@ -216,7 +216,7 @@ class ComposedCosSFunc(StaticFunc): | |||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         return "{name}({a} * sin({b} * {x}) + {c})".format( |         return "({a} * sin({b} * {x}) + {c})".format( | ||||||
|             name=self.__class__.__name__, |             name=self.__class__.__name__, | ||||||
|             a=self._params[0], |             a=self._params[0], | ||||||
|             b=self._params[1], |             b=self._params[1], | ||||||
|   | |||||||
| @@ -3,13 +3,13 @@ from .synthetic_utils import TimeStamp | |||||||
| from .synthetic_env import SyntheticDEnv | from .synthetic_env import SyntheticDEnv | ||||||
| from .math_core import LinearSFunc | from .math_core import LinearSFunc | ||||||
| from .math_core import LinearDFunc | from .math_core import LinearDFunc | ||||||
| from .math_core import QuadraticDFunc, SinQuadraticDFunc | from .math_core import QuadraticDFunc, SinQuadraticDFunc, BinaryQuadraticDFunc | ||||||
| from .math_core import ( | from .math_core import ( | ||||||
|     ConstantFunc, |     ConstantFunc, | ||||||
|     ComposedSinSFunc as SinFunc, |     ComposedSinSFunc as SinFunc, | ||||||
|     ComposedCosSFunc as CosFunc, |     ComposedCosSFunc as CosFunc, | ||||||
| ) | ) | ||||||
| from .math_core import GaussianDGenerator | from .math_core import UniformDGenerator, GaussianDGenerator | ||||||
|  |  | ||||||
|  |  | ||||||
| __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | __all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"] | ||||||
| @@ -77,8 +77,21 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio | |||||||
|         ) |         ) | ||||||
|         dynamic_env.set_regression() |         dynamic_env.set_regression() | ||||||
|     elif version.lower() == "v4": |     elif version.lower() == "v4": | ||||||
|  |         l_generator = ConstantFunc(-2) | ||||||
|  |         r_generator = ConstantFunc(2) | ||||||
|  |         data_generator = UniformDGenerator([l_generator] * 2, [r_generator] * 2) | ||||||
|  |         time_generator = TimeStamp( | ||||||
|  |             min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode | ||||||
|  |         ) | ||||||
|  |         oracle_map = BinaryQuadraticDFunc( | ||||||
|  |             params={ | ||||||
|  |                 0: SinFunc(params={0: 1, 1: 3, 2: 0}),  # sin(3 * t) | ||||||
|  |                 1: CosFunc(params={0: 1, 1: 6, 2: 0}),  # cos(6 * t) | ||||||
|  |                 2: ConstantFunc(0), | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|         dynamic_env = SyntheticDEnv( |         dynamic_env = SyntheticDEnv( | ||||||
|             data_generator, oracle_map, time_generator, num_per_task, noise=0.05 |             data_generator, oracle_map, time_generator, num_per_task, noise=None | ||||||
|         ) |         ) | ||||||
|         dynamic_env.set_classification(2) |         dynamic_env.set_classification(2) | ||||||
|     else: |     else: | ||||||
|   | |||||||
| @@ -119,10 +119,15 @@ class SyntheticDEnv(data.Dataset): | |||||||
|     def __call__(self, timestamp): |     def __call__(self, timestamp): | ||||||
|         dataset = self._data_generator(timestamp, self._num_per_task) |         dataset = self._data_generator(timestamp, self._num_per_task) | ||||||
|         targets = self._oracle_map.noise_call(dataset, timestamp, self._noise) |         targets = self._oracle_map.noise_call(dataset, timestamp, self._noise) | ||||||
|         return torch.Tensor([timestamp]), ( |         if isinstance(dataset, np.ndarray): | ||||||
|             torch.Tensor(dataset), |             dataset = torch.from_numpy(dataset) | ||||||
|             torch.Tensor(targets), |         else: | ||||||
|         ) |             dataset = torch.Tensor(dataset) | ||||||
|  |         if isinstance(targets, np.ndarray): | ||||||
|  |             targets = torch.from_numpy(targets) | ||||||
|  |         else: | ||||||
|  |             targets = torch.Tensor(targets) | ||||||
|  |         return torch.Tensor([timestamp]), (dataset, targets) | ||||||
|  |  | ||||||
|     def __len__(self): |     def __len__(self): | ||||||
|         return len(self._time_generator) |         return len(self._time_generator) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user