Finalize example vis codes
This commit is contained in:
parent
77cab08d60
commit
5eb18e8adb
@ -1,7 +1,7 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.02 #
|
||||
############################################################################
|
||||
# CUDA_VISIBLE_DEVICES=0 python exps/LFNA/vis-synthetic.py #
|
||||
# python exps/LFNA/vis-synthetic.py #
|
||||
############################################################################
|
||||
import os, sys, copy, random
|
||||
import torch
|
||||
@ -83,7 +83,7 @@ def find_max(cur, others):
|
||||
def compare_cl(save_dir):
|
||||
save_dir = Path(str(save_dir))
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
dynamic_env, function = create_example_v1(
|
||||
dynamic_env, cl_function = create_example_v1(
|
||||
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
||||
timestamp_config=dict(num=200),
|
||||
num_per_task=1000,
|
||||
@ -91,7 +91,6 @@ def compare_cl(save_dir):
|
||||
|
||||
models = dict()
|
||||
|
||||
cl_function = copy.deepcopy(function)
|
||||
cl_function.set_timestamp(0)
|
||||
cl_xaxis_min = None
|
||||
cl_xaxis_max = None
|
||||
@ -99,23 +98,15 @@ def compare_cl(save_dir):
|
||||
all_data = OrderedDict()
|
||||
|
||||
for idx, (timestamp, dataset) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||
xaxis_all = dataset[:, 0].numpy()
|
||||
xaxis_all = dataset[0][:, 0].numpy()
|
||||
yaxis_all = dataset[1][:, 0].numpy()
|
||||
current_data = dict()
|
||||
|
||||
function.set_timestamp(timestamp)
|
||||
yaxis_all = function.noise_call(xaxis_all)
|
||||
current_data["lfna_xaxis_all"] = xaxis_all
|
||||
current_data["lfna_yaxis_all"] = yaxis_all
|
||||
|
||||
# compute cl-min
|
||||
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
|
||||
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std())
|
||||
"""
|
||||
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
|
||||
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
|
||||
current_data["cl_xaxis_all"] = cl_xaxis_all
|
||||
current_data["cl_yaxis_all"] = cl_yaxis_all
|
||||
"""
|
||||
all_data[timestamp] = current_data
|
||||
|
||||
global_cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.1)
|
||||
@ -170,10 +161,12 @@ def compare_cl(save_dir):
|
||||
xdir=save_dir
|
||||
)
|
||||
)
|
||||
video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(base_cmd, xdir=save_dir)
|
||||
video_cmd = "{:} -pix_fmt yuv420p {xdir}/compare-cl.mp4".format(
|
||||
base_cmd, xdir=save_dir
|
||||
)
|
||||
print(video_cmd + "\n")
|
||||
os.system(video_cmd)
|
||||
# os.system("{:} {xdir}/vis.webm".format(base_cmd, xdir=save_dir))
|
||||
os.system("{:} -pix_fmt yuv420p {xdir}/vis.webm".format(base_cmd, xdir=save_dir))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -5,7 +5,8 @@ from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
|
||||
from .SearchDatasetWrap import SearchDataset
|
||||
|
||||
from .math_base_funcs import QuadraticFunc, CubicFunc, QuarticFunc
|
||||
from .math_adv_funcs import DynamicQuadraticFunc, ConstantFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc
|
||||
from .math_adv_funcs import ComposedSinFunc
|
||||
|
||||
from .synthetic_utils import TimeStamp
|
||||
|
@ -14,41 +14,6 @@ from .math_base_funcs import QuadraticFunc
|
||||
from .math_base_funcs import QuarticFunc
|
||||
|
||||
|
||||
class DynamicQuadraticFunc(FitFunc):
|
||||
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
||||
The a, b, and c is a function of timestamp.
|
||||
"""
|
||||
|
||||
def __init__(self, list_of_points=None):
|
||||
super(DynamicQuadraticFunc, self).__init__(3, list_of_points)
|
||||
self._timestamp = None
|
||||
|
||||
def __call__(self, x, timestamp=None):
|
||||
self.check_valid()
|
||||
if timestamp is None:
|
||||
timestamp = self._timestamp
|
||||
a = self._params[0](timestamp)
|
||||
b = self._params[1](timestamp)
|
||||
c = self._params[2](timestamp)
|
||||
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
|
||||
return a * x * x + b * x + c
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_timestamp(self, timestamp):
|
||||
self._timestamp = timestamp
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
)
|
||||
|
||||
|
||||
class ConstantFunc(FitFunc):
|
||||
"""The constant function: f(x) = c."""
|
||||
|
||||
|
@ -13,20 +13,20 @@ import torch.utils.data as data
|
||||
class FitFunc(abc.ABC):
|
||||
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
||||
|
||||
def __init__(self, freedom: int, list_of_points=None, _params=None):
|
||||
def __init__(self, freedom: int, list_of_points=None, params=None):
|
||||
self._params = dict()
|
||||
for i in range(freedom):
|
||||
self._params[i] = None
|
||||
self._freedom = freedom
|
||||
if list_of_points is not None and _params is not None:
|
||||
raise ValueError("list_of_points and _params can not be set simultaneously")
|
||||
if list_of_points is not None and params is not None:
|
||||
raise ValueError("list_of_points and params can not be set simultaneously")
|
||||
if list_of_points is not None:
|
||||
self.fit(list_of_points=list_of_points)
|
||||
if _params is not None:
|
||||
self.set(_params)
|
||||
if params is not None:
|
||||
self.set(params)
|
||||
|
||||
def set(self, _params):
|
||||
self._params = copy.deepcopy(_params)
|
||||
def set(self, params):
|
||||
self._params = copy.deepcopy(params)
|
||||
|
||||
def check_valid(self):
|
||||
for key, value in self._params.items():
|
||||
|
66
lib/datasets/math_dynamic_funcs.py
Normal file
66
lib/datasets/math_dynamic_funcs.py
Normal file
@ -0,0 +1,66 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||
#####################################################
|
||||
import math
|
||||
import abc
|
||||
import copy
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.utils.data as data
|
||||
|
||||
from .math_base_funcs import FitFunc
|
||||
|
||||
|
||||
class DynamicFunc(FitFunc):
|
||||
"""The dynamic quadratic function, where each param is a function."""
|
||||
|
||||
def __init__(self, freedom: int, params=None):
|
||||
super(DynamicFunc, self).__init__(freedom, None, params)
|
||||
self._timestamp = None
|
||||
|
||||
def __call__(self, x, timestamp=None):
|
||||
raise NotImplementedError
|
||||
|
||||
def _getitem(self, x, weights):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_timestamp(self, timestamp):
|
||||
self._timestamp = timestamp
|
||||
|
||||
def noise_call(self, x, timestamp=None, std=0.1):
|
||||
clean_y = self.__call__(x, timestamp)
|
||||
if isinstance(clean_y, np.ndarray):
|
||||
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||
else:
|
||||
raise ValueError("Unkonwn type: {:}".format(type(clean_y)))
|
||||
return noise_y
|
||||
|
||||
|
||||
class DynamicQuadraticFunc(DynamicFunc):
|
||||
"""The dynamic quadratic function that outputs f(x) = a * x^2 + b * x + c.
|
||||
The a, b, and c is a function of timestamp.
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
super(DynamicQuadraticFunc, self).__init__(3, params)
|
||||
|
||||
def __call__(self, x, timestamp=None):
|
||||
self.check_valid()
|
||||
if timestamp is None:
|
||||
timestamp = self._timestamp
|
||||
a = self._params[0](timestamp)
|
||||
b = self._params[1](timestamp)
|
||||
c = self._params[2](timestamp)
|
||||
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||
a, b, c = convert_fn(a), convert_fn(b), convert_fn(c)
|
||||
return a * x * x + b * x + c
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({a} * x^2 + {b} * x + {c}, timestamp={timestamp})".format(
|
||||
name=self.__class__.__name__,
|
||||
a=self._params[0],
|
||||
b=self._params[1],
|
||||
c=self._params[2],
|
||||
timestamp=self._timestamp,
|
||||
)
|
@ -41,6 +41,11 @@ class SyntheticDEnv(data.Dataset):
|
||||
self._mean_functors = mean_functors
|
||||
self._cov_functors = cov_functors
|
||||
|
||||
self._oracle_map = None
|
||||
|
||||
def set_oracle_map(self, functor):
|
||||
self._oracle_map = functor
|
||||
|
||||
def __iter__(self):
|
||||
self._iter_num = 0
|
||||
return self
|
||||
@ -63,7 +68,11 @@ class SyntheticDEnv(data.Dataset):
|
||||
dataset = np.random.multivariate_normal(
|
||||
mean_list, cov_matrix, size=self._num_per_task
|
||||
)
|
||||
return timestamp, torch.Tensor(dataset)
|
||||
if self._oracle_map is None:
|
||||
return timestamp, torch.Tensor(dataset)
|
||||
else:
|
||||
targets = self._oracle_map.noise_call(dataset, timestamp)
|
||||
return timestamp, (torch.Tensor(dataset), torch.Tensor(targets))
|
||||
|
||||
def __len__(self):
|
||||
return len(self._timestamp_generator)
|
||||
|
@ -1,8 +1,9 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
||||
#####################################################
|
||||
import copy
|
||||
|
||||
from .math_adv_funcs import DynamicQuadraticFunc
|
||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
||||
from .synthetic_env import SyntheticDEnv
|
||||
|
||||
@ -11,7 +12,6 @@ def create_example_v1(
|
||||
timestamp_config=None,
|
||||
num_per_task=5000,
|
||||
):
|
||||
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||
|
||||
@ -32,4 +32,6 @@ def create_example_v1(
|
||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
||||
)
|
||||
function.set(function_param)
|
||||
|
||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
||||
return dynamic_env, function
|
||||
|
@ -6,3 +6,4 @@ black ./lib/datasets
|
||||
black ./lib/xlayers
|
||||
black ./exps/LFNA
|
||||
black ./exps/trading
|
||||
black ./lib/procedures
|
||||
|
Loading…
Reference in New Issue
Block a user