This commit is contained in:
D-X-Y 2021-05-26 01:53:44 -07:00
parent 30fb8fad67
commit 299c8a085b
12 changed files with 137 additions and 115 deletions

View File

@ -1,14 +1,18 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2021.04 #
#####################################################
# python exps/LFNA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/LFNA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
# python exps/GeMOSA/basic-prev.py --env_version v1 --prev_time 5 --hidden_dim 16 --epochs 500 --init_lr 0.1
# python exps/GeMOSA/basic-prev.py --env_version v2 --hidden_dim 16 --epochs 1000 --init_lr 0.05
#####################################################
import sys, time, copy, torch, random, argparse
from tqdm import tqdm
from copy import deepcopy
from pathlib import Path
lib_dir = (Path(__file__).parent / ".." / "..").resolve()
if str(lib_dir) not in sys.path:
sys.path.insert(0, str(lib_dir))
from xautodl.procedures import (
prepare_seed,
prepare_logger,
@ -38,9 +42,9 @@ def subsample(historical_x, historical_y, maxn=10000):
def main(args):
logger, env_info, model_kwargs = lfna_setup(args)
logger, model_kwargs = lfna_setup(args)
w_container_per_epoch = dict()
w_containers = dict()
per_timestamp_time, start_time = AverageMeter(), time.time()
for idx in range(args.prev_time, env_info["total"]):
@ -111,7 +115,7 @@ def main(args):
save_path = logger.path(None) / "{:04d}-{:04d}.pth".format(
idx, env_info["total"]
)
w_container_per_epoch[idx] = model.get_w_container().no_grad_clone()
w_containers[idx] = model.get_w_container().no_grad_clone()
save_checkpoint(
{
"model_state_dict": model.state_dict(),
@ -127,7 +131,7 @@ def main(args):
start_time = time.time()
save_checkpoint(
{"w_container_per_epoch": w_container_per_epoch},
{"w_containers": w_containers},
logger.path(None) / "final-ckp.pth",
logger,
)

View File

@ -68,6 +68,8 @@ def main(args):
# build model
model = get_model(**model_kwargs)
model = model.to(args.device)
if idx == 0:
print(model)
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr, amsgrad=True)
criterion = torch.nn.MSELoss()

View File

@ -16,7 +16,7 @@ def lfna_setup(args):
input_dim=1,
output_dim=1,
hidden_dims=[args.hidden_dim] * 2,
act_cls="gelu",
act_cls="relu",
norm_cls="layer_norm_1d",
)
return logger, model_kwargs

View File

@ -23,10 +23,12 @@ if str(lib_dir) not in sys.path:
import qlib
from qlib import config as qconfig
from qlib.workflow import R
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
from utils.qlib_utils import QResult
def filter_finished(recorders):
returned_recorders = dict()
not_finished = 0
@ -41,9 +43,10 @@ def filter_finished(recorders):
def add_to_dict(xdict, timestamp, value):
date = timestamp.date().strftime("%Y-%m-%d")
if date in xdict:
raise ValueError("This date [{:}] is already in the dict".format(date))
raise ValueError("This date [{:}] is already in the dict".format(date))
xdict[date] = value
def query_info(save_dir, verbose, name_filter, key_map):
if isinstance(save_dir, list):
results = []
@ -61,7 +64,10 @@ def query_info(save_dir, verbose, name_filter, key_map):
for idx, (key, experiment) in enumerate(experiments.items()):
if experiment.id == "0":
continue
if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:
if (
name_filter is not None
and re.fullmatch(name_filter, experiment.name) is None
):
continue
recorders = experiment.list_recorders()
recorders, not_finished = filter_finished(recorders)
@ -77,10 +83,10 @@ def query_info(save_dir, verbose, name_filter, key_map):
)
result = QResult(experiment.name)
for recorder_id, recorder in recorders.items():
file_names = ['results-train.pkl', 'results-valid.pkl', 'results-test.pkl']
file_names = ["results-train.pkl", "results-valid.pkl", "results-test.pkl"]
date2IC = OrderedDict()
for file_name in file_names:
xtemp = recorder.load_object(file_name)['all-IC']
xtemp = recorder.load_object(file_name)["all-IC"]
timestamps, values = xtemp.index.tolist(), xtemp.tolist()
for timestamp, value in zip(timestamps, values):
add_to_dict(date2IC, timestamp, value)
@ -104,7 +110,7 @@ def query_info(save_dir, verbose, name_filter, key_map):
##
paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']
paths = [root_dir / "outputs" / "qlib-baselines-csi300"]
paths = [path.resolve() for path in paths]
print(paths)
@ -112,12 +118,12 @@ key_map = dict()
for xset in ("train", "valid", "test"):
key_map["{:}-mean-IC".format(xset)] = "IC ({:})".format(xset)
key_map["{:}-mean-ICIR".format(xset)] = "ICIR ({:})".format(xset)
qresults = query_info(paths, False, 'TSF-2x24-drop0_0s.*-.*-01', key_map)
print('Find {:} results'.format(len(qresults)))
qresults = query_info(paths, False, "TSF-2x24-drop0_0s.*-.*-01", key_map)
print("Find {:} results".format(len(qresults)))
times = []
for qresult in qresults:
times.append(qresult.name.split('0_0s')[-1])
times.append(qresult.name.split("0_0s")[-1])
print(times)
save_path = os.path.join(note_dir, 'temp-time-x.pth')
save_path = os.path.join(note_dir, "temp-time-x.pth")
torch.save(qresults, save_path)
print(save_path)

View File

@ -24,38 +24,38 @@ from qlib.model.base import Model
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)
qlib.init(provider_uri="~/.qlib/qlib_data/cn_data", region=qconfig.REG_CN)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
},
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
},
}
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
pprint.pprint(dataset_config)
dataset = init_instance_by_config(dataset_config)
df_train, df_valid, df_test = dataset.prepare(
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
["train", "valid", "test"],
col_set=["feature", "label"],
data_key=DataHandlerLP.DK_L,
)
model = get_transformer(None)
print(model)
@ -72,4 +72,5 @@ label = labels[batch][mask]
loss = torch.nn.functional.mse_loss(pred, label)
from sklearn.metrics import mean_squared_error
mse_loss = mean_squared_error(pred.numpy(), label.numpy())

View File

@ -37,7 +37,9 @@ def read(fname="README.md"):
# What packages are required for this module to be executed?
REQUIRED = ["numpy>=1.16.5,<=1.19.5"]
packages = find_packages(exclude=("tests", "scripts", "scripts-search", "lib*", "exps*"))
packages = find_packages(
exclude=("tests", "scripts", "scripts-search", "lib*", "exps*")
)
print("packages: {:}".format(packages))
setup(

View File

@ -64,65 +64,29 @@ class ComposedSinFunc(FitFunc):
)
class ComposedSinFuncV2(FitFunc):
class ComposedCosFunc(FitFunc):
"""The composed sin function that outputs:
f(x) = amplitude-scale-of(x) * sin( period-phase-shift-of(x) )
- the amplitude scale is a quadratic function of x
- the period-phase-shift is another quadratic function of x
f(x) = a * cos( b*x ) + c
"""
def __init__(self, **kwargs):
super(ComposedSinFuncV2, self).__init__(0, None)
self.fit(**kwargs)
def __init__(self, params, xstr="x"):
super(ComposedCosFunc, self).__init__(3, None, params, xstr)
def __call__(self, x):
self.check_valid()
scale = self._params["amplitude_scale"](x)
period_phase = self._params["period_phase_shift"](x)
return scale * math.sin(period_phase)
def fit(self, **kwargs):
num_sin_phase = kwargs.get("num_sin_phase", 7)
sin_speed_use_power = kwargs.get("sin_speed_use_power", True)
min_amplitude = kwargs.get("min_amplitude", 1)
max_amplitude = kwargs.get("max_amplitude", 4)
phase_shift = kwargs.get("phase_shift", 0.0)
# create parameters
if kwargs.get("amplitude_scale", None) is None:
amplitude_scale = QuadraticFunc(
[(0, min_amplitude), (0.5, max_amplitude), (1, min_amplitude)]
)
else:
amplitude_scale = kwargs.get("amplitude_scale")
if kwargs.get("period_phase_shift", None) is None:
fitting_data = []
if sin_speed_use_power:
temp_max_scalar = 2 ** (num_sin_phase - 1)
else:
temp_max_scalar = num_sin_phase - 1
for i in range(num_sin_phase):
if sin_speed_use_power:
value = (2 ** i) / temp_max_scalar
next_value = (2 ** (i + 1)) / temp_max_scalar
else:
value = i / temp_max_scalar
next_value = (i + 1) / temp_max_scalar
for _phase in (0, 0.25, 0.5, 0.75):
inter_value = value + (next_value - value) * _phase
fitting_data.append((inter_value, math.pi * (2 * i + _phase)))
period_phase_shift = QuarticFunc(fitting_data)
else:
period_phase_shift = kwargs.get("period_phase_shift")
self.set(
dict(amplitude_scale=amplitude_scale, period_phase_shift=period_phase_shift)
)
a = self._params[0]
b = self._params[1]
c = self._params[2]
return a * math.cos(b * x) + c
def _getitem(self, x, weights):
raise NotImplementedError
def __repr__(self):
return "{name}({amplitude_scale} * sin({period_phase_shift}))".format(
return "{name}({a} * sin({b} * {x}) + {c})".format(
name=self.__class__.__name__,
amplitude_scale=self._params["amplitude_scale"],
period_phase_shift=self._params["period_phase_shift"],
a=self._params[0],
b=self._params[1],
c=self._params[2],
x=self.xstr,
)

View File

@ -5,5 +5,5 @@ from .math_base_funcs import LinearFunc, QuadraticFunc, CubicFunc, QuarticFunc
from .math_dynamic_funcs import DynamicLinearFunc
from .math_dynamic_funcs import DynamicQuadraticFunc
from .math_adv_funcs import ConstantFunc
from .math_adv_funcs import ComposedSinFunc
from .math_adv_funcs import ComposedSinFunc, ComposedCosFunc
from .math_dynamic_generator import GaussianDGenerator

View File

@ -4,7 +4,11 @@ from .synthetic_env import SyntheticDEnv
from .math_core import LinearFunc
from .math_core import DynamicLinearFunc
from .math_core import DynamicQuadraticFunc
from .math_core import ConstantFunc, ComposedSinFunc as SinFunc
from .math_core import (
ConstantFunc,
ComposedSinFunc as SinFunc,
ComposedCosFunc as CosFunc,
)
from .math_core import GaussianDGenerator
@ -50,6 +54,25 @@ def get_synthetic_env(total_timestamp=1600, num_per_task=1000, mode=None, versio
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
elif version.lower() == "v3":
mean_generator = SinFunc(params={0: 1, 1: 1, 2: 0}) # sin(t)
std_generator = CosFunc(params={0: 0.5, 1: 1, 2: 1}) # 0.5 cos(t) + 1
data_generator = GaussianDGenerator(
[mean_generator], [[std_generator]], (-2, 2)
)
time_generator = TimeStamp(
min_timestamp=0, max_timestamp=max_time, num=total_timestamp, mode=mode
)
oracle_map = DynamicQuadraticFunc(
params={
0: LinearFunc(params={0: 0.1, 1: 0}), # 0.1 * t
1: SinFunc(params={0: 1, 1: 1, 2: 0}), # sin(t)
2: ConstantFunc(0),
}
)
dynamic_env = SyntheticDEnv(
data_generator, oracle_map, time_generator, num_per_task
)
else:
raise ValueError("Unknown version: {:}".format(version))
return dynamic_env

View File

@ -39,9 +39,9 @@ def get_model(config: Dict[Text, Any], **kwargs):
norm_cls = super_name2norm[kwargs["norm_cls"]]
sub_layers, last_dim = [], kwargs["input_dim"]
for i, hidden_dim in enumerate(kwargs["hidden_dims"]):
sub_layers.append(SuperLinear(last_dim, hidden_dim))
if hidden_dim > 1:
sub_layers.append(norm_cls(hidden_dim, elementwise_affine=False))
sub_layers.append(SuperLinear(last_dim, hidden_dim))
sub_layers.append(act_cls())
last_dim = hidden_dim
sub_layers.append(SuperLinear(last_dim, kwargs["output_dim"]))

View File

@ -1,5 +1,5 @@
# Performance-Aware Template Network for One-Shot Neural Architecture Search
from .CifarNet import NetworkCIFAR as CifarNet
from .ImageNet import NetworkImageNet as ImageNet
from .CifarNet import NetworkCIFAR as CifarNet
from .ImageNet import NetworkImageNet as ImageNet
from .genotypes import Networks
from .genotypes import build_genotype_from_dict

View File

@ -8,24 +8,44 @@
import os, torch
def obtain_nas_infer_model(config, extra_model_path=None):
if config.arch == 'dxys':
from .DXYs import CifarNet, ImageNet, Networks
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError('When genotype in confiig is None, extra_model_path must be set as a path instead of {:}'.format(extra_model_path))
xdata = torch.load(extra_model_path)
current_epoch = xdata['epoch']
genotype_dict = xdata['genotypes'][current_epoch-1]
genotype = build_genotype_from_dict(genotype_dict)
if config.arch == "dxys":
from .DXYs import CifarNet, ImageNet, Networks
from .DXYs import build_genotype_from_dict
if config.genotype is None:
if extra_model_path is not None and not os.path.isfile(extra_model_path):
raise ValueError(
"When genotype in confiig is None, extra_model_path must be set as a path instead of {:}".format(
extra_model_path
)
)
xdata = torch.load(extra_model_path)
current_epoch = xdata["epoch"]
genotype_dict = xdata["genotypes"][current_epoch - 1]
genotype = build_genotype_from_dict(genotype_dict)
else:
genotype = Networks[config.genotype]
if config.dataset == "cifar":
return CifarNet(
config.ichannel,
config.layers,
config.stem_multi,
config.auxiliary,
genotype,
config.class_num,
)
elif config.dataset == "imagenet":
return ImageNet(
config.ichannel,
config.layers,
config.auxiliary,
genotype,
config.class_num,
)
else:
raise ValueError("invalid dataset : {:}".format(config.dataset))
else:
genotype = Networks[config.genotype]
if config.dataset == 'cifar':
return CifarNet(config.ichannel, config.layers, config.stem_multi, config.auxiliary, genotype, config.class_num)
elif config.dataset == 'imagenet':
return ImageNet(config.ichannel, config.layers, config.auxiliary, genotype, config.class_num)
else: raise ValueError('invalid dataset : {:}'.format(config.dataset))
else:
raise ValueError('invalid nas arch type : {:}'.format(config.arch))
raise ValueError("invalid nas arch type : {:}".format(config.arch))