Update the sync data v1
This commit is contained in:
parent
da2575cc6c
commit
3ee0d348af
6
.github/workflows/basic_test.yml
vendored
6
.github/workflows/basic_test.yml
vendored
@ -56,12 +56,16 @@ jobs:
|
|||||||
python -m pytest ./tests/test_basic_space.py -s
|
python -m pytest ./tests/test_basic_space.py -s
|
||||||
shell: bash
|
shell: bash
|
||||||
|
|
||||||
- name: Test Synthetic Data
|
- name: Test Math
|
||||||
run: |
|
run: |
|
||||||
python -m pip install pytest numpy
|
python -m pip install pytest numpy
|
||||||
python -m pip install parameterized
|
python -m pip install parameterized
|
||||||
python -m pip install torch torchvision
|
python -m pip install torch torchvision
|
||||||
python --version
|
python --version
|
||||||
python -m pytest ./tests/test_math*.py -s
|
python -m pytest ./tests/test_math*.py -s
|
||||||
|
shell: bash
|
||||||
|
|
||||||
|
- name: Test Synthetic Data
|
||||||
|
run: |
|
||||||
python -m pytest ./tests/test_synthetic*.py -s
|
python -m pytest ./tests/test_synthetic*.py -s
|
||||||
shell: bash
|
shell: bash
|
||||||
|
@ -222,7 +222,7 @@ def pretrain_v2(base_model, meta_model, criterion, xenv, args, logger):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
logger, env_info, model_kwargs = lfna_setup(args)
|
logger, model_kwargs = lfna_setup(args)
|
||||||
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
train_env = get_synthetic_env(mode="train", version=args.env_version)
|
||||||
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
valid_env = get_synthetic_env(mode="valid", version=args.env_version)
|
||||||
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
all_env = get_synthetic_env(mode=None, version=args.env_version)
|
||||||
|
@ -11,33 +11,6 @@ from xautodl.datasets.synthetic_core import get_synthetic_env
|
|||||||
def lfna_setup(args):
|
def lfna_setup(args):
|
||||||
prepare_seed(args.rand_seed)
|
prepare_seed(args.rand_seed)
|
||||||
logger = prepare_logger(args)
|
logger = prepare_logger(args)
|
||||||
|
|
||||||
cache_path = (
|
|
||||||
logger.path(None) / ".." / "env-{:}-info.pth".format(args.env_version)
|
|
||||||
).resolve()
|
|
||||||
if cache_path.exists():
|
|
||||||
env_info = torch.load(cache_path)
|
|
||||||
else:
|
|
||||||
env_info = dict()
|
|
||||||
dynamic_env = get_synthetic_env(version=args.env_version)
|
|
||||||
env_info["total"] = len(dynamic_env)
|
|
||||||
for idx, (timestamp, (_allx, _ally)) in enumerate(tqdm(dynamic_env)):
|
|
||||||
env_info["{:}-timestamp".format(idx)] = timestamp
|
|
||||||
env_info["{:}-x".format(idx)] = _allx
|
|
||||||
env_info["{:}-y".format(idx)] = _ally
|
|
||||||
env_info["dynamic_env"] = dynamic_env
|
|
||||||
torch.save(env_info, cache_path)
|
|
||||||
|
|
||||||
"""
|
|
||||||
model_kwargs = dict(
|
|
||||||
config=dict(model_type="simple_mlp"),
|
|
||||||
input_dim=1,
|
|
||||||
output_dim=1,
|
|
||||||
hidden_dim=args.hidden_dim,
|
|
||||||
act_cls="leaky_relu",
|
|
||||||
norm_cls="identity",
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
model_kwargs = dict(
|
model_kwargs = dict(
|
||||||
config=dict(model_type="norm_mlp"),
|
config=dict(model_type="norm_mlp"),
|
||||||
input_dim=1,
|
input_dim=1,
|
||||||
@ -46,7 +19,7 @@ def lfna_setup(args):
|
|||||||
act_cls="gelu",
|
act_cls="gelu",
|
||||||
norm_cls="layer_norm_1d",
|
norm_cls="layer_norm_1d",
|
||||||
)
|
)
|
||||||
return logger, env_info, model_kwargs
|
return logger, model_kwargs
|
||||||
|
|
||||||
|
|
||||||
def train_model(model, dataset, lr, epochs):
|
def train_model(model, dataset, lr, epochs):
|
||||||
|
@ -20,14 +20,13 @@ matplotlib.use("agg")
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import matplotlib.ticker as ticker
|
import matplotlib.ticker as ticker
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||||
if str(lib_dir) not in sys.path:
|
if str(lib_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
from models.xcore import get_model
|
from xautodl.models.xcore import get_model
|
||||||
from datasets.synthetic_core import get_synthetic_env
|
from xautodl.datasets.synthetic_core import get_synthetic_env
|
||||||
from utils.temp_sync import optimize_fn, evaluate_fn
|
from xautodl.procedures.metric_utils import MSEMetric
|
||||||
from procedures.metric_utils import MSEMetric
|
|
||||||
|
|
||||||
|
|
||||||
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
|
def plot_scatter(cur_ax, xs, ys, color, alpha, linewidths, label=None):
|
||||||
@ -181,10 +180,17 @@ def compare_cl(save_dir):
|
|||||||
|
|
||||||
def visualize_env(save_dir, version):
|
def visualize_env(save_dir, version):
|
||||||
save_dir = Path(str(save_dir))
|
save_dir = Path(str(save_dir))
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
for substr in ("pdf", "png"):
|
||||||
|
sub_save_dir = save_dir / substr
|
||||||
|
sub_save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
dynamic_env = get_synthetic_env(version=version)
|
dynamic_env = get_synthetic_env(version=version)
|
||||||
min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
|
# min_t, max_t = dynamic_env.min_timestamp, dynamic_env.max_timestamp
|
||||||
|
allxs, allys = [], []
|
||||||
|
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
|
allxs.append(allx)
|
||||||
|
allys.append(ally)
|
||||||
|
allxs, allys = torch.cat(allxs).view(-1), torch.cat(allys).view(-1)
|
||||||
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
for idx, (timestamp, (allx, ally)) in enumerate(tqdm(dynamic_env, ncols=50)):
|
||||||
dpi, width, height = 30, 1800, 1400
|
dpi, width, height = 30, 1800, 1400
|
||||||
figsize = width / float(dpi), height / float(dpi)
|
figsize = width / float(dpi), height / float(dpi)
|
||||||
@ -201,21 +207,18 @@ def visualize_env(save_dir, version):
|
|||||||
tick.label.set_rotation(10)
|
tick.label.set_rotation(10)
|
||||||
for tick in cur_ax.yaxis.get_major_ticks():
|
for tick in cur_ax.yaxis.get_major_ticks():
|
||||||
tick.label.set_fontsize(LabelSize - font_gap)
|
tick.label.set_fontsize(LabelSize - font_gap)
|
||||||
if version == "v1":
|
cur_ax.set_xlim(round(allxs.min().item(), 1), round(allxs.max().item(), 1))
|
||||||
cur_ax.set_xlim(-2, 2)
|
cur_ax.set_ylim(round(allys.min().item(), 1), round(allys.max().item(), 1))
|
||||||
cur_ax.set_ylim(-8, 8)
|
|
||||||
elif version == "v2":
|
|
||||||
cur_ax.set_xlim(-10, 10)
|
|
||||||
cur_ax.set_ylim(-60, 60)
|
|
||||||
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
cur_ax.legend(loc=1, fontsize=LegendFontsize)
|
||||||
|
|
||||||
save_path = save_dir / "v{:}-{:05d}".format(version, idx)
|
pdf_save_path = save_dir / "pdf" / "v{:}-{:05d}.pdf".format(version, idx)
|
||||||
fig.savefig(str(save_path) + ".pdf", dpi=dpi, bbox_inches="tight", format="pdf")
|
fig.savefig(str(pdf_save_path), dpi=dpi, bbox_inches="tight", format="pdf")
|
||||||
fig.savefig(str(save_path) + ".png", dpi=dpi, bbox_inches="tight", format="png")
|
png_save_path = save_dir / "png" / "v{:}-{:05d}.png".format(version, idx)
|
||||||
|
fig.savefig(str(png_save_path), dpi=dpi, bbox_inches="tight", format="png")
|
||||||
plt.close("all")
|
plt.close("all")
|
||||||
save_dir = save_dir.resolve()
|
save_dir = save_dir.resolve()
|
||||||
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
|
base_cmd = "ffmpeg -y -i {xdir}/v{version}-%05d.png -vf scale=1800:1400 -pix_fmt yuv420p -vb 5000k".format(
|
||||||
xdir=save_dir, version=version
|
xdir=save_dir / "png", version=version
|
||||||
)
|
)
|
||||||
print(base_cmd)
|
print(base_cmd)
|
||||||
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
|
os.system("{:} {xdir}/env-{ver}.mp4".format(base_cmd, xdir=save_dir, ver=version))
|
||||||
@ -371,7 +374,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
|
visualize_env(os.path.join(args.save_dir, "vis-env"), "v1")
|
||||||
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
|
# visualize_env(os.path.join(args.save_dir, "vis-env"), "v2")
|
||||||
compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
|
# compare_algs(os.path.join(args.save_dir, "compare-alg"), args.env_version)
|
||||||
# compare_cl(os.path.join(args.save_dir, "compare-cl"))
|
# compare_cl(os.path.join(args.save_dir, "compare-cl"))
|
||||||
|
@ -13,7 +13,10 @@ from xautodl.config_utils import dict2config
|
|||||||
|
|
||||||
# NAS-Bench-201 related module or function
|
# NAS-Bench-201 related module or function
|
||||||
from xautodl.models import CellStructure, get_cell_based_tiny_net
|
from xautodl.models import CellStructure, get_cell_based_tiny_net
|
||||||
from xautodl.procedures import bench_pure_evaluate as pure_evaluate, get_nas_bench_loaders
|
from xautodl.procedures import (
|
||||||
|
bench_pure_evaluate as pure_evaluate,
|
||||||
|
get_nas_bench_loaders,
|
||||||
|
)
|
||||||
from nas_201_api import NASBench201API, ArchResults, ResultsCount
|
from nas_201_api import NASBench201API, ArchResults, ResultsCount
|
||||||
|
|
||||||
api = NASBench201API(
|
api = NASBench201API(
|
||||||
|
21
exps/experimental/test-dynamic.py
Normal file
21
exps/experimental/test-dynamic.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.08 #
|
||||||
|
#####################################################
|
||||||
|
# python test-dynamic.py
|
||||||
|
#####################################################
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||||
|
print("LIB-DIR: {:}".format(lib_dir))
|
||||||
|
if str(lib_dir) not in sys.path:
|
||||||
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
|
from xautodl.datasets.math_core import ConstantFunc
|
||||||
|
from xautodl.datasets.math_core import GaussianDGenerator
|
||||||
|
|
||||||
|
mean_generator = ConstantFunc(0)
|
||||||
|
cov_generator = ConstantFunc(1)
|
||||||
|
|
||||||
|
generator = GaussianDGenerator([mean_generator], [[cov_generator]], (-1, 1))
|
||||||
|
generator(0, 10)
|
@ -19,9 +19,11 @@ import seaborn as sns
|
|||||||
matplotlib.use("agg")
|
matplotlib.use("agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
|
||||||
|
print("LIB-DIR: {:}".format(lib_dir))
|
||||||
if str(lib_dir) not in sys.path:
|
if str(lib_dir) not in sys.path:
|
||||||
sys.path.insert(0, str(lib_dir))
|
sys.path.insert(0, str(lib_dir))
|
||||||
|
|
||||||
from log_utils import time_string
|
from log_utils import time_string
|
||||||
from nats_bench import create
|
from nats_bench import create
|
||||||
from models import get_cell_based_tiny_net
|
from models import get_cell_based_tiny_net
|
||||||
|
@ -3,11 +3,7 @@ from copy import deepcopy
|
|||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
lib_dir = (Path(__file__).parent / ".." / ".." / "lib").resolve()
|
from xautodl.utils import weight_watcher
|
||||||
if str(lib_dir) not in sys.path:
|
|
||||||
sys.path.insert(0, str(lib_dir))
|
|
||||||
|
|
||||||
from utils import weight_watcher
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -17,10 +17,10 @@ from .math_base_funcs import QuarticFunc
|
|||||||
class ConstantFunc(FitFunc):
|
class ConstantFunc(FitFunc):
|
||||||
"""The constant function: f(x) = c."""
|
"""The constant function: f(x) = c."""
|
||||||
|
|
||||||
def __init__(self, constant=None):
|
def __init__(self, constant=None, xstr="x"):
|
||||||
param = dict()
|
param = dict()
|
||||||
param[0] = constant
|
param[0] = constant
|
||||||
super(ConstantFunc, self).__init__(0, None, param)
|
super(ConstantFunc, self).__init__(0, None, param, xstr)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
self.check_valid()
|
self.check_valid()
|
||||||
@ -37,6 +37,34 @@ class ConstantFunc(FitFunc):
|
|||||||
|
|
||||||
|
|
||||||
class ComposedSinFunc(FitFunc):
|
class ComposedSinFunc(FitFunc):
|
||||||
|
"""The composed sin function that outputs:
|
||||||
|
f(x) = a * sin( b*x ) + c
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, params, xstr="x"):
|
||||||
|
super(ComposedSinFunc, self).__init__(3, None, params, xstr)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
self.check_valid()
|
||||||
|
a = self._params[0]
|
||||||
|
b = self._params[1]
|
||||||
|
c = self._params[2]
|
||||||
|
return a * math.sin(b * x) + c
|
||||||
|
|
||||||
|
def _getitem(self, x, weights):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({a} * sin({b} * {x}) + {c})".format(
|
||||||
|
name=self.__class__.__name__,
|
||||||
|
a=self._params[0],
|
||||||
|
b=self._params[1],
|
||||||
|
c=self._params[2],
|
||||||
|
x=self.xstr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ComposedSinFuncV2(FitFunc):
|
||||||
"""The composed sin function that outputs:
|
"""The composed sin function that outputs:
|
||||||
f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
|
f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
|
||||||
- the amplitude scale is a quadratic function of x
|
- the amplitude scale is a quadratic function of x
|
||||||
@ -44,7 +72,7 @@ class ComposedSinFunc(FitFunc):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
super(ComposedSinFunc, self).__init__(0, None)
|
super(ComposedSinFuncV2, self).__init__(0, None)
|
||||||
self.fit(**kwargs)
|
self.fit(**kwargs)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
|
@ -5,15 +5,13 @@ import math
|
|||||||
import abc
|
import abc
|
||||||
import copy
|
import copy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Optional
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data as data
|
|
||||||
|
|
||||||
|
|
||||||
class FitFunc(abc.ABC):
|
class FitFunc(abc.ABC):
|
||||||
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
"""The fit function that outputs f(x) = a * x^2 + b * x + c."""
|
||||||
|
|
||||||
def __init__(self, freedom: int, list_of_points=None, params=None):
|
def __init__(self, freedom: int, list_of_points=None, params=None, xstr="x"):
|
||||||
self._params = dict()
|
self._params = dict()
|
||||||
for i in range(freedom):
|
for i in range(freedom):
|
||||||
self._params[i] = None
|
self._params[i] = None
|
||||||
@ -24,6 +22,7 @@ class FitFunc(abc.ABC):
|
|||||||
self.fit(list_of_points=list_of_points)
|
self.fit(list_of_points=list_of_points)
|
||||||
if params is not None:
|
if params is not None:
|
||||||
self.set(params)
|
self.set(params)
|
||||||
|
self._xstr = str(xstr)
|
||||||
|
|
||||||
def set(self, params):
|
def set(self, params):
|
||||||
self._params = copy.deepcopy(params)
|
self._params = copy.deepcopy(params)
|
||||||
@ -33,6 +32,13 @@ class FitFunc(abc.ABC):
|
|||||||
if value is None:
|
if value is None:
|
||||||
raise ValueError("The {:} is None".format(key))
|
raise ValueError("The {:} is None".format(key))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def xstr(self):
|
||||||
|
return self._xstr
|
||||||
|
|
||||||
|
def reset_xstr(self, xstr):
|
||||||
|
self._xstr = str(xstr)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -106,8 +112,8 @@ class FitFunc(abc.ABC):
|
|||||||
class LinearFunc(FitFunc):
|
class LinearFunc(FitFunc):
|
||||||
"""The linear function that outputs f(x) = a * x + b."""
|
"""The linear function that outputs f(x) = a * x + b."""
|
||||||
|
|
||||||
def __init__(self, list_of_points=None, params=None):
|
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
||||||
super(LinearFunc, self).__init__(2, list_of_points, params)
|
super(LinearFunc, self).__init__(2, list_of_points, params, xstr)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
self.check_valid()
|
self.check_valid()
|
||||||
@ -117,18 +123,19 @@ class LinearFunc(FitFunc):
|
|||||||
return weights[0] * x + weights[1]
|
return weights[0] * x + weights[1]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}({a} * x + {b})".format(
|
return "{name}({a} * {x} + {b})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
a=self._params[0],
|
a=self._params[0],
|
||||||
b=self._params[1],
|
b=self._params[1],
|
||||||
|
x=self.xstr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuadraticFunc(FitFunc):
|
class QuadraticFunc(FitFunc):
|
||||||
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
"""The quadratic function that outputs f(x) = a * x^2 + b * x + c."""
|
||||||
|
|
||||||
def __init__(self, list_of_points=None, params=None):
|
def __init__(self, list_of_points=None, params=None, xstr="x"):
|
||||||
super(QuadraticFunc, self).__init__(3, list_of_points, params)
|
super(QuadraticFunc, self).__init__(3, list_of_points, params, xstr)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
self.check_valid()
|
self.check_valid()
|
||||||
@ -138,11 +145,12 @@ class QuadraticFunc(FitFunc):
|
|||||||
return weights[0] * x * x + weights[1] * x + weights[2]
|
return weights[0] * x * x + weights[1] * x + weights[2]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}({a} * x^2 + {b} * x + {c})".format(
|
return "{name}({a} * {x}^2 + {b} * {x} + {c})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
a=self._params[0],
|
a=self._params[0],
|
||||||
b=self._params[1],
|
b=self._params[1],
|
||||||
c=self._params[2],
|
c=self._params[2],
|
||||||
|
x=self.xstr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -165,12 +173,13 @@ class CubicFunc(FitFunc):
|
|||||||
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
return weights[0] * x ** 3 + weights[1] * x ** 2 + weights[2] * x + weights[3]
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}({a} * x^3 + {b} * x^2 + {c} * x + {d})".format(
|
return "{name}({a} * {x}^3 + {b} * {x}^2 + {c} * {x} + {d})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
a=self._params[0],
|
a=self._params[0],
|
||||||
b=self._params[1],
|
b=self._params[1],
|
||||||
c=self._params[2],
|
c=self._params[2],
|
||||||
d=self._params[3],
|
d=self._params[3],
|
||||||
|
x=self.xstr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,3 +6,4 @@ from .math_dynamic_funcs import DynamicLinearFunc
|
|||||||
from .math_dynamic_funcs import DynamicQuadraticFunc
|
from .math_dynamic_funcs import DynamicQuadraticFunc
|
||||||
from .math_adv_funcs import ConstantFunc
|
from .math_adv_funcs import ConstantFunc
|
||||||
from .math_adv_funcs import ComposedSinFunc
|
from .math_adv_funcs import ComposedSinFunc
|
||||||
|
from .math_dynamic_generator import GaussianDGenerator
|
||||||
|
@ -15,20 +15,19 @@ from .math_base_funcs import FitFunc
|
|||||||
class DynamicFunc(FitFunc):
|
class DynamicFunc(FitFunc):
|
||||||
"""The dynamic quadratic function, where each param is a function."""
|
"""The dynamic quadratic function, where each param is a function."""
|
||||||
|
|
||||||
def __init__(self, freedom: int, params=None):
|
def __init__(self, freedom: int, params=None, xstr="x"):
|
||||||
super(DynamicFunc, self).__init__(freedom, None, params)
|
if params is not None:
|
||||||
self._timestamp = None
|
for param in params:
|
||||||
|
param.reset_xstr("t") if isinstance(param, FitFunc) else None
|
||||||
|
super(DynamicFunc, self).__init__(freedom, None, params, xstr)
|
||||||
|
|
||||||
def __call__(self, x, timestamp=None):
|
def __call__(self, x, timestamp):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _getitem(self, x, weights):
|
def _getitem(self, x, weights):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_timestamp(self, timestamp):
|
def noise_call(self, x, timestamp, std):
|
||||||
self._timestamp = timestamp
|
|
||||||
|
|
||||||
def noise_call(self, x, timestamp=None, std=0.1):
|
|
||||||
clean_y = self.__call__(x, timestamp)
|
clean_y = self.__call__(x, timestamp)
|
||||||
if isinstance(clean_y, np.ndarray):
|
if isinstance(clean_y, np.ndarray):
|
||||||
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
noise_y = clean_y + np.random.normal(scale=std, size=clean_y.shape)
|
||||||
@ -42,13 +41,10 @@ class DynamicLinearFunc(DynamicFunc):
|
|||||||
The a and b is a function of timestamp.
|
The a and b is a function of timestamp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params=None):
|
def __init__(self, params=None, xstr="x"):
|
||||||
super(DynamicLinearFunc, self).__init__(3, params)
|
super(DynamicLinearFunc, self).__init__(3, params, xstr)
|
||||||
|
|
||||||
def __call__(self, x, timestamp=None):
|
def __call__(self, x, timestamp):
|
||||||
self.check_valid()
|
|
||||||
if timestamp is None:
|
|
||||||
timestamp = self._timestamp
|
|
||||||
a = self._params[0](timestamp)
|
a = self._params[0](timestamp)
|
||||||
b = self._params[1](timestamp)
|
b = self._params[1](timestamp)
|
||||||
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
convert_fn = lambda x: x[-1] if isinstance(x, (tuple, list)) else x
|
||||||
@ -56,11 +52,11 @@ class DynamicLinearFunc(DynamicFunc):
|
|||||||
return a * x + b
|
return a * x + b
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}({a} * x + {b}, timestamp={timestamp})".format(
|
return "{name}({a} * {x} + {b})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
a=self._params[0],
|
a=self._params[0],
|
||||||
b=self._params[1],
|
b=self._params[1],
|
||||||
timestamp=self._timestamp,
|
x=self.xstr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
58
xautodl/datasets/math_dynamic_generator.py
Normal file
58
xautodl/datasets/math_dynamic_generator.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.03 #
|
||||||
|
#####################################################
|
||||||
|
import abc
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def assert_list_tuple(x):
|
||||||
|
assert isinstance(x, (list, tuple))
|
||||||
|
return len(x)
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicGenerator(abc.ABC):
|
||||||
|
"""The dynamic quadratic function, where each param is a function."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._ndim = None
|
||||||
|
|
||||||
|
def __call__(self, time, num):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class GaussianDGenerator(DynamicGenerator):
|
||||||
|
def __init__(self, mean_functors, cov_functors, trunc=(-1, 1)):
|
||||||
|
super(GaussianDGenerator, self).__init__()
|
||||||
|
self._ndim = assert_list_tuple(mean_functors)
|
||||||
|
assert self._ndim == len(
|
||||||
|
cov_functors
|
||||||
|
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
|
||||||
|
assert_list_tuple(cov_functors)
|
||||||
|
for cov_functor in cov_functors:
|
||||||
|
assert self._ndim == assert_list_tuple(
|
||||||
|
cov_functor
|
||||||
|
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
|
||||||
|
assert (
|
||||||
|
isinstance(trunc, (list, tuple)) and len(trunc) == 2 and trunc[0] < trunc[1]
|
||||||
|
)
|
||||||
|
self._mean_functors = mean_functors
|
||||||
|
self._cov_functors = cov_functors
|
||||||
|
if trunc is not None:
|
||||||
|
assert assert_list_tuple(trunc) == 2 and trunc[0] < trunc[1]
|
||||||
|
self._trunc = trunc
|
||||||
|
|
||||||
|
def __call__(self, time, num):
|
||||||
|
mean_list = [functor(time) for functor in self._mean_functors]
|
||||||
|
cov_matrix = [
|
||||||
|
[abs(cov_gen(time)) for cov_gen in cov_functor]
|
||||||
|
for cov_functor in self._cov_functors
|
||||||
|
]
|
||||||
|
values = np.random.multivariate_normal(mean_list, cov_matrix, size=num)
|
||||||
|
if self._trunc is not None:
|
||||||
|
np.clip(values, self._trunc[0], self._trunc[1], out=values)
|
||||||
|
return values
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({ndim} dims, trunc={trunc})".format(
|
||||||
|
name=self.__class__.__name__, ndim=self._ndim, trunc=self._trunc
|
||||||
|
)
|
@ -1,13 +1,14 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.05 #
|
||||||
#####################################################
|
#####################################################
|
||||||
|
import math
|
||||||
from .synthetic_utils import TimeStamp
|
from .synthetic_utils import TimeStamp
|
||||||
from .synthetic_env import EnvSampler
|
|
||||||
from .synthetic_env import SyntheticDEnv
|
from .synthetic_env import SyntheticDEnv
|
||||||
from .math_core import LinearFunc
|
from .math_core import LinearFunc
|
||||||
from .math_core import DynamicLinearFunc
|
from .math_core import DynamicLinearFunc
|
||||||
from .math_core import DynamicQuadraticFunc
|
from .math_core import DynamicQuadraticFunc
|
||||||
from .math_core import ConstantFunc, ComposedSinFunc
|
from .math_core import ConstantFunc, ComposedSinFunc
|
||||||
|
from .math_core import GaussianDGenerator
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
__all__ = ["TimeStamp", "SyntheticDEnv", "get_synthetic_env"]
|
||||||
@ -17,42 +18,21 @@ def get_synthetic_env(total_timestamp=1000, num_per_task=1000, mode=None, versio
|
|||||||
if version == "v1":
|
if version == "v1":
|
||||||
mean_generator = ConstantFunc(0)
|
mean_generator = ConstantFunc(0)
|
||||||
std_generator = ConstantFunc(1)
|
std_generator = ConstantFunc(1)
|
||||||
elif version == "v2":
|
data_generator = GaussianDGenerator(
|
||||||
mean_generator = ComposedSinFunc()
|
[mean_generator], [[std_generator]], (-2, 2)
|
||||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=1.5)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unknown version: {:}".format(version))
|
|
||||||
dynamic_env = SyntheticDEnv(
|
|
||||||
[mean_generator],
|
|
||||||
[[std_generator]],
|
|
||||||
num_per_task=num_per_task,
|
|
||||||
timestamp_config=dict(
|
|
||||||
min_timestamp=-0.5, max_timestamp=1.5, num=total_timestamp, mode=mode
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if version == "v1":
|
|
||||||
function = DynamicLinearFunc()
|
|
||||||
function_param = dict()
|
|
||||||
function_param[0] = ComposedSinFunc(
|
|
||||||
amplitude_scale=ConstantFunc(3.0),
|
|
||||||
num_sin_phase=9,
|
|
||||||
sin_speed_use_power=False,
|
|
||||||
)
|
)
|
||||||
function_param[1] = ConstantFunc(constant=0.9)
|
time_generator = TimeStamp(
|
||||||
elif version == "v2":
|
min_timestamp=0, max_timestamp=math.pi * 6, num=total_timestamp, mode=mode
|
||||||
function = DynamicQuadraticFunc()
|
|
||||||
function_param = dict()
|
|
||||||
function_param[0] = ComposedSinFunc(
|
|
||||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
|
||||||
)
|
)
|
||||||
function_param[1] = ConstantFunc(constant=0.9)
|
oracle_map = DynamicLinearFunc(
|
||||||
function_param[2] = ComposedSinFunc(
|
params={
|
||||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
0: ComposedSinFunc(params={0: 2.0, 1: 1.0, 2: 2.2}),
|
||||||
|
1: ComposedSinFunc(params={0: 1.5, 1: 0.4, 2: 2.2}),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
dynamic_env = SyntheticDEnv(
|
||||||
|
data_generator, oracle_map, time_generator, num_per_task
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown version: {:}".format(version))
|
raise ValueError("Unknown version: {:}".format(version))
|
||||||
|
|
||||||
function.set(function_param)
|
|
||||||
# dynamic_env.set_oracle_map(copy.deepcopy(function))
|
|
||||||
dynamic_env.set_oracle_map(function)
|
|
||||||
return dynamic_env
|
return dynamic_env
|
||||||
|
@ -1,15 +1,9 @@
|
|||||||
#####################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
|
||||||
#####################################################
|
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
|
|
||||||
from .synthetic_utils import TimeStamp
|
|
||||||
|
|
||||||
|
|
||||||
def is_list_tuple(x):
|
def is_list_tuple(x):
|
||||||
return isinstance(x, (tuple, list))
|
return isinstance(x, (tuple, list))
|
||||||
@ -38,46 +32,33 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mean_functors: List[data.Dataset],
|
data_generator,
|
||||||
cov_functors: List[List[data.Dataset]],
|
oracle_map,
|
||||||
|
time_generator,
|
||||||
num_per_task: int = 5000,
|
num_per_task: int = 5000,
|
||||||
timestamp_config: Optional[Dict] = None,
|
noise: float = 0.1,
|
||||||
mode: Optional[str] = None,
|
|
||||||
timestamp_noise_scale: float = 0.3,
|
|
||||||
):
|
):
|
||||||
self._ndim = len(mean_functors)
|
self._data_generator = data_generator
|
||||||
assert self._ndim == len(
|
self._time_generator = time_generator
|
||||||
cov_functors
|
self._oracle_map = oracle_map
|
||||||
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functors))
|
|
||||||
for cov_functor in cov_functors:
|
|
||||||
assert self._ndim == len(
|
|
||||||
cov_functor
|
|
||||||
), "length does not match {:} vs. {:}".format(self._ndim, len(cov_functor))
|
|
||||||
self._num_per_task = num_per_task
|
self._num_per_task = num_per_task
|
||||||
if timestamp_config is None:
|
self._noise = noise
|
||||||
timestamp_config = dict(mode=mode)
|
|
||||||
elif "mode" not in timestamp_config:
|
|
||||||
timestamp_config["mode"] = mode
|
|
||||||
|
|
||||||
self._timestamp_generator = TimeStamp(**timestamp_config)
|
|
||||||
self._timestamp_noise_scale = timestamp_noise_scale
|
|
||||||
|
|
||||||
self._mean_functors = mean_functors
|
|
||||||
self._cov_functors = cov_functors
|
|
||||||
|
|
||||||
self._oracle_map = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def min_timestamp(self):
|
def min_timestamp(self):
|
||||||
return self._timestamp_generator.min_timestamp
|
return self._time_generator.min_timestamp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_timestamp(self):
|
def max_timestamp(self):
|
||||||
return self._timestamp_generator.max_timestamp
|
return self._time_generator.max_timestamp
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def timestamp_interval(self):
|
def time_interval(self):
|
||||||
return self._timestamp_generator.interval
|
return self._time_generator.interval
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mode(self):
|
||||||
|
return self._time_generator.mode
|
||||||
|
|
||||||
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
|
def random_timestamp(self, min_timestamp=None, max_timestamp=None):
|
||||||
if min_timestamp is None:
|
if min_timestamp is None:
|
||||||
@ -89,16 +70,13 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
def get_timestamp(self, index):
|
def get_timestamp(self, index):
|
||||||
if index is None:
|
if index is None:
|
||||||
timestamps = []
|
timestamps = []
|
||||||
for index in range(len(self._timestamp_generator)):
|
for index in range(len(self._time_generator)):
|
||||||
timestamps.append(self._timestamp_generator[index][1])
|
timestamps.append(self._time_generator[index][1])
|
||||||
return tuple(timestamps)
|
return tuple(timestamps)
|
||||||
else:
|
else:
|
||||||
index, timestamp = self._timestamp_generator[index]
|
index, timestamp = self._time_generator[index]
|
||||||
return timestamp
|
return timestamp
|
||||||
|
|
||||||
def set_oracle_map(self, functor):
|
|
||||||
self._oracle_map = functor
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
self._iter_num = 0
|
self._iter_num = 0
|
||||||
return self
|
return self
|
||||||
@ -111,7 +89,7 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
assert 0 <= index < len(self), "{:} is not in [0, {:})".format(index, len(self))
|
||||||
index, timestamp = self._timestamp_generator[index]
|
index, timestamp = self._time_generator[index]
|
||||||
return self.__call__(timestamp)
|
return self.__call__(timestamp)
|
||||||
|
|
||||||
def seq_call(self, timestamps):
|
def seq_call(self, timestamps):
|
||||||
@ -122,52 +100,24 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
return zip_sequence(xdata)
|
return zip_sequence(xdata)
|
||||||
|
|
||||||
def __call__(self, timestamp):
|
def __call__(self, timestamp):
|
||||||
mean_list = [functor(timestamp) for functor in self._mean_functors]
|
dataset = self._data_generator(timestamp, self._num_per_task)
|
||||||
cov_matrix = [
|
targets = self._oracle_map.noise_call(dataset, timestamp, self._noise)
|
||||||
[abs(cov_gen(timestamp)) for cov_gen in cov_functor]
|
return torch.Tensor([timestamp]), (
|
||||||
for cov_functor in self._cov_functors
|
torch.Tensor(dataset),
|
||||||
]
|
torch.Tensor(targets),
|
||||||
|
|
||||||
dataset = np.random.multivariate_normal(
|
|
||||||
mean_list, cov_matrix, size=self._num_per_task
|
|
||||||
)
|
)
|
||||||
if self._oracle_map is None:
|
|
||||||
return torch.Tensor([timestamp]), torch.Tensor(dataset)
|
|
||||||
else:
|
|
||||||
targets = self._oracle_map.noise_call(dataset, timestamp)
|
|
||||||
return torch.Tensor([timestamp]), (
|
|
||||||
torch.Tensor(dataset),
|
|
||||||
torch.Tensor(targets),
|
|
||||||
)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._timestamp_generator)
|
return len(self._time_generator)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format(
|
return "{name}({cur_num:}/{total} elements, ndim={ndim}, num_per_task={num_per_task}, range=[{xrange_min:.5f}~{xrange_max:.5f}], mode={mode})".format(
|
||||||
name=self.__class__.__name__,
|
name=self.__class__.__name__,
|
||||||
cur_num=len(self),
|
cur_num=len(self),
|
||||||
total=len(self._timestamp_generator),
|
total=len(self._time_generator),
|
||||||
ndim=self._ndim,
|
ndim=self._ndim,
|
||||||
num_per_task=self._num_per_task,
|
num_per_task=self._num_per_task,
|
||||||
xrange_min=self.min_timestamp,
|
xrange_min=self.min_timestamp,
|
||||||
xrange_max=self.max_timestamp,
|
xrange_max=self.max_timestamp,
|
||||||
mode=self._timestamp_generator.mode,
|
mode=self.mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class EnvSampler:
|
|
||||||
def __init__(self, env, batch, enlarge):
|
|
||||||
indexes = list(range(len(env)))
|
|
||||||
self._indexes = indexes * enlarge
|
|
||||||
self._batch = batch
|
|
||||||
self._iterations = len(self._indexes) // self._batch
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
random.shuffle(self._indexes)
|
|
||||||
for it in range(self._iterations):
|
|
||||||
indexes = self._indexes[it * self._batch : (it + 1) * self._batch]
|
|
||||||
yield indexes
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self._iterations
|
|
||||||
|
@ -1,72 +0,0 @@
|
|||||||
#####################################################
|
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
|
|
||||||
#####################################################
|
|
||||||
import copy
|
|
||||||
|
|
||||||
from .math_dynamic_funcs import DynamicLinearFunc, DynamicQuadraticFunc
|
|
||||||
from .math_adv_funcs import ConstantFunc, ComposedSinFunc
|
|
||||||
from .synthetic_env import SyntheticDEnv
|
|
||||||
|
|
||||||
|
|
||||||
def create_example(timestamp_config=None, num_per_task=5000, indicator="v1"):
|
|
||||||
if indicator == "v1":
|
|
||||||
return create_example_v1(timestamp_config, num_per_task)
|
|
||||||
elif indicator == "v2":
|
|
||||||
return create_example_v2(timestamp_config, num_per_task)
|
|
||||||
else:
|
|
||||||
raise ValueError("Unkonwn indicator: {:}".format(indicator))
|
|
||||||
|
|
||||||
|
|
||||||
def create_example_v1(
|
|
||||||
timestamp_config=None,
|
|
||||||
num_per_task=5000,
|
|
||||||
):
|
|
||||||
mean_generator = ComposedSinFunc()
|
|
||||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
|
||||||
|
|
||||||
dynamic_env = SyntheticDEnv(
|
|
||||||
[mean_generator],
|
|
||||||
[[std_generator]],
|
|
||||||
num_per_task=num_per_task,
|
|
||||||
timestamp_config=timestamp_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
function = DynamicQuadraticFunc()
|
|
||||||
function_param = dict()
|
|
||||||
function_param[0] = ComposedSinFunc(
|
|
||||||
num_sin_phase=4, phase_shift=1.0, max_amplitude=1.0
|
|
||||||
)
|
|
||||||
function_param[1] = ConstantFunc(constant=0.9)
|
|
||||||
function_param[2] = ComposedSinFunc(
|
|
||||||
num_sin_phase=5, phase_shift=0.4, max_amplitude=0.9
|
|
||||||
)
|
|
||||||
function.set(function_param)
|
|
||||||
|
|
||||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
|
||||||
return dynamic_env, function
|
|
||||||
|
|
||||||
|
|
||||||
def create_example_v2(
|
|
||||||
timestamp_config=None,
|
|
||||||
num_per_task=5000,
|
|
||||||
):
|
|
||||||
mean_generator = ConstantFunc(0)
|
|
||||||
std_generator = ConstantFunc(1)
|
|
||||||
|
|
||||||
dynamic_env = SyntheticDEnv(
|
|
||||||
[mean_generator],
|
|
||||||
[[std_generator]],
|
|
||||||
num_per_task=num_per_task,
|
|
||||||
timestamp_config=timestamp_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
function = DynamicLinearFunc()
|
|
||||||
function_param = dict()
|
|
||||||
function_param[0] = ComposedSinFunc(
|
|
||||||
amplitude_scale=ConstantFunc(1.0), period_phase_shift=ConstantFunc(1.0)
|
|
||||||
)
|
|
||||||
function_param[1] = ConstantFunc(constant=0.9)
|
|
||||||
function.set(function_param)
|
|
||||||
|
|
||||||
dynamic_env.set_oracle_map(copy.deepcopy(function))
|
|
||||||
return dynamic_env, function
|
|
@ -13,11 +13,11 @@ class UnifiedSplit:
|
|||||||
"""A class to unify the split strategy."""
|
"""A class to unify the split strategy."""
|
||||||
|
|
||||||
def __init__(self, total_num, mode):
|
def __init__(self, total_num, mode):
|
||||||
# Training Set 60%
|
# Training Set 65%
|
||||||
num_of_train = int(total_num * 0.6)
|
num_of_train = int(total_num * 0.65)
|
||||||
# Validation Set 20%
|
# Validation Set 05%
|
||||||
num_of_valid = int(total_num * 0.2)
|
num_of_valid = int(total_num * 0.05)
|
||||||
# Test Set 20%
|
# Test Set 30%
|
||||||
num_of_set = total_num - num_of_train - num_of_valid
|
num_of_set = total_num - num_of_train - num_of_valid
|
||||||
all_indexes = list(range(total_num))
|
all_indexes = list(range(total_num))
|
||||||
if mode is None:
|
if mode is None:
|
||||||
@ -28,6 +28,8 @@ class UnifiedSplit:
|
|||||||
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
self._indexes = all_indexes[num_of_train : num_of_train + num_of_valid]
|
||||||
elif mode.lower() in ("test", "testing"):
|
elif mode.lower() in ("test", "testing"):
|
||||||
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
self._indexes = all_indexes[num_of_train + num_of_valid :]
|
||||||
|
elif mode.lower() in ("trainval", "trainvalidation"):
|
||||||
|
self._indexes = all_indexes[: num_of_train + num_of_valid]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unkonwn mode of {:}".format(mode))
|
raise ValueError("Unkonwn mode of {:}".format(mode))
|
||||||
self._all_indexes = all_indexes
|
self._all_indexes = all_indexes
|
||||||
|
Loading…
Reference in New Issue
Block a user