Add visualize codes for Q
This commit is contained in:
parent
e777f38233
commit
0e2dd13762
@ -20,87 +20,7 @@ import qlib
|
|||||||
from qlib.config import REG_CN
|
from qlib.config import REG_CN
|
||||||
from qlib.workflow import R
|
from qlib.workflow import R
|
||||||
|
|
||||||
|
from utils.qlib_utils import QResult
|
||||||
class QResult:
|
|
||||||
"""A class to maintain the results of a qlib experiment."""
|
|
||||||
|
|
||||||
def __init__(self, name):
|
|
||||||
self._result = defaultdict(list)
|
|
||||||
self._name = name
|
|
||||||
self._recorder_paths = []
|
|
||||||
|
|
||||||
def append(self, key, value):
|
|
||||||
self._result[key].append(value)
|
|
||||||
|
|
||||||
def append_path(self, xpath):
|
|
||||||
self._recorder_paths.append(xpath)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def paths(self):
|
|
||||||
return self._recorder_paths
|
|
||||||
|
|
||||||
@property
|
|
||||||
def result(self):
|
|
||||||
return self._result
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self._result)
|
|
||||||
|
|
||||||
def update(self, metrics, filter_keys=None):
|
|
||||||
for key, value in metrics.items():
|
|
||||||
if filter_keys is not None and key in filter_keys:
|
|
||||||
key = filter_keys[key]
|
|
||||||
elif filter_keys is not None:
|
|
||||||
continue
|
|
||||||
self.append(key, value)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def full_str(xstr, space):
|
|
||||||
xformat = "{:" + str(space) + "s}"
|
|
||||||
return xformat.format(str(xstr))
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def merge_dict(dict_list):
|
|
||||||
new_dict = dict()
|
|
||||||
for xkey in dict_list[0].keys():
|
|
||||||
values = [x for xdict in dict_list for x in xdict[xkey]]
|
|
||||||
new_dict[xkey] = values
|
|
||||||
return new_dict
|
|
||||||
|
|
||||||
def info(
|
|
||||||
self,
|
|
||||||
keys: List[Text],
|
|
||||||
separate: Text = "& ",
|
|
||||||
space: int = 20,
|
|
||||||
verbose: bool = True,
|
|
||||||
):
|
|
||||||
avaliable_keys = []
|
|
||||||
for key in keys:
|
|
||||||
if key not in self.result:
|
|
||||||
print("There are invalid key [{:}].".format(key))
|
|
||||||
else:
|
|
||||||
avaliable_keys.append(key)
|
|
||||||
head_str = separate.join([self.full_str(x, space) for x in avaliable_keys])
|
|
||||||
values = []
|
|
||||||
for key in avaliable_keys:
|
|
||||||
if "IR" in key:
|
|
||||||
current_values = [x * 100 for x in self._result[key]]
|
|
||||||
else:
|
|
||||||
current_values = self._result[key]
|
|
||||||
mean = np.mean(current_values)
|
|
||||||
std = np.std(current_values)
|
|
||||||
# values.append("{:.4f} $\pm$ {:.4f}".format(mean, std))
|
|
||||||
values.append("{:.2f} $\pm$ {:.2f}".format(mean, std))
|
|
||||||
value_str = separate.join([self.full_str(x, space) for x in values])
|
|
||||||
if verbose:
|
|
||||||
print(head_str)
|
|
||||||
print(value_str)
|
|
||||||
return head_str, value_str
|
|
||||||
|
|
||||||
|
|
||||||
def compare_results(
|
def compare_results(
|
||||||
heads, values, names, space=10, separate="& ", verbose=True, sort_key=False
|
heads, values, names, space=10, separate="& ", verbose=True, sort_key=False
|
||||||
@ -149,7 +69,7 @@ def query_info(save_dir, verbose, name_filter, key_map):
|
|||||||
for idx, (key, experiment) in enumerate(experiments.items()):
|
for idx, (key, experiment) in enumerate(experiments.items()):
|
||||||
if experiment.id == "0":
|
if experiment.id == "0":
|
||||||
continue
|
continue
|
||||||
if name_filter is not None and re.match(name_filter, experiment.name) is None:
|
if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:
|
||||||
continue
|
continue
|
||||||
recorders = experiment.list_recorders()
|
recorders = experiment.list_recorders()
|
||||||
recorders, not_finished = filter_finished(recorders)
|
recorders, not_finished = filter_finished(recorders)
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
|
|
||||||
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter
|
from log_utils import AverageMeter
|
||||||
from log_utils import time_string
|
from log_utils import time_string
|
||||||
from utils import obtain_accuracy
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
def basic_train(
|
def basic_train(
|
||||||
|
14
lib/procedures/eval_funcs.py
Normal file
14
lib/procedures/eval_funcs.py
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
def obtain_accuracy(output, target, topk=(1,)):
|
||||||
|
"""Computes the precision@k for the specified values of k"""
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
return res
|
@ -3,12 +3,14 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
import os, time, copy, torch, pathlib
|
import os, time, copy, torch, pathlib
|
||||||
|
|
||||||
|
# modules in AutoDL
|
||||||
import datasets
|
import datasets
|
||||||
from config_utils import load_config
|
from config_utils import load_config
|
||||||
from procedures import prepare_seed, get_optim_scheduler
|
from procedures import prepare_seed, get_optim_scheduler
|
||||||
from utils import get_model_infos, obtain_accuracy
|
|
||||||
from log_utils import AverageMeter, time_string, convert_secs2time
|
from log_utils import AverageMeter, time_string, convert_secs2time
|
||||||
from models import get_cell_based_tiny_net
|
from models import get_cell_based_tiny_net
|
||||||
|
from utils import get_model_infos
|
||||||
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
|
__all__ = ["evaluate_for_seed", "pure_evaluate", "get_nas_bench_loaders"]
|
||||||
|
@ -3,9 +3,10 @@
|
|||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
from log_utils import AverageMeter, time_string
|
from log_utils import AverageMeter, time_string
|
||||||
from utils import obtain_accuracy
|
|
||||||
from models import change_key
|
from models import change_key
|
||||||
|
|
||||||
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||||
expected_flop = torch.mean(expected_flop)
|
expected_flop = torch.mean(expected_flop)
|
||||||
|
@ -2,9 +2,11 @@
|
|||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
|
||||||
##################################################
|
##################################################
|
||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
|
|
||||||
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter, time_string
|
from log_utils import AverageMeter, time_string
|
||||||
from utils import obtain_accuracy
|
|
||||||
from models import change_key
|
from models import change_key
|
||||||
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
def get_flop_loss(expected_flop, flop_cur, flop_need, flop_tolerant):
|
||||||
|
@ -4,9 +4,9 @@
|
|||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# our modules
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter, time_string
|
from log_utils import AverageMeter, time_string
|
||||||
from utils import obtain_accuracy
|
from .eval_funcs import obtain_accuracy
|
||||||
|
|
||||||
|
|
||||||
def simple_KD_train(
|
def simple_KD_train(
|
||||||
|
@ -1,3 +1,9 @@
|
|||||||
|
#####################################################
|
||||||
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019.01 #
|
||||||
|
#####################################################
|
||||||
|
# This directory contains some ad-hoc functions, classes, etc.
|
||||||
|
# It will be re-formulated in the future.
|
||||||
|
#####################################################
|
||||||
from .evaluation_utils import obtain_accuracy
|
from .evaluation_utils import obtain_accuracy
|
||||||
from .gpu_manager import GPUManager
|
from .gpu_manager import GPUManager
|
||||||
from .flop_benchmark import get_model_infos, count_parameters, count_parameters_in_MB
|
from .flop_benchmark import get_model_infos, count_parameters, count_parameters_in_MB
|
||||||
|
@ -76,10 +76,12 @@ def rotate2affine(degree):
|
|||||||
|
|
||||||
# shape is a tuple [H, W]
|
# shape is a tuple [H, W]
|
||||||
def normalize_points(shape, points):
|
def normalize_points(shape, points):
|
||||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
|
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||||
shape
|
shape
|
||||||
)
|
) == 2, "invalid shape : {:}".format(shape)
|
||||||
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), "points are wrong : {:}".format(points.shape)
|
assert isinstance(points, torch.Tensor) and (
|
||||||
|
points.shape[0] == 2
|
||||||
|
), "points are wrong : {:}".format(points.shape)
|
||||||
(H, W), points = shape, points.clone()
|
(H, W), points = shape, points.clone()
|
||||||
points[0, :] = normalize_L(points[0, :], W)
|
points[0, :] = normalize_L(points[0, :], W)
|
||||||
points[1, :] = normalize_L(points[1, :], H)
|
points[1, :] = normalize_L(points[1, :], H)
|
||||||
@ -88,10 +90,12 @@ def normalize_points(shape, points):
|
|||||||
|
|
||||||
# shape is a tuple [H, W]
|
# shape is a tuple [H, W]
|
||||||
def normalize_points_batch(shape, points):
|
def normalize_points_batch(shape, points):
|
||||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
|
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||||
shape
|
shape
|
||||||
)
|
) == 2, "invalid shape : {:}".format(shape)
|
||||||
assert isinstance(points, torch.Tensor) and (points.size(-1) == 2), "points are wrong : {:}".format(points.shape)
|
assert isinstance(points, torch.Tensor) and (
|
||||||
|
points.size(-1) == 2
|
||||||
|
), "points are wrong : {:}".format(points.shape)
|
||||||
(H, W), points = shape, points.clone()
|
(H, W), points = shape, points.clone()
|
||||||
x = normalize_L(points[..., 0], W)
|
x = normalize_L(points[..., 0], W)
|
||||||
y = normalize_L(points[..., 1], H)
|
y = normalize_L(points[..., 1], H)
|
||||||
@ -100,10 +104,12 @@ def normalize_points_batch(shape, points):
|
|||||||
|
|
||||||
# shape is a tuple [H, W]
|
# shape is a tuple [H, W]
|
||||||
def denormalize_points(shape, points):
|
def denormalize_points(shape, points):
|
||||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
|
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||||
shape
|
shape
|
||||||
)
|
) == 2, "invalid shape : {:}".format(shape)
|
||||||
assert isinstance(points, torch.Tensor) and (points.shape[0] == 2), "points are wrong : {:}".format(points.shape)
|
assert isinstance(points, torch.Tensor) and (
|
||||||
|
points.shape[0] == 2
|
||||||
|
), "points are wrong : {:}".format(points.shape)
|
||||||
(H, W), points = shape, points.clone()
|
(H, W), points = shape, points.clone()
|
||||||
points[0, :] = denormalize_L(points[0, :], W)
|
points[0, :] = denormalize_L(points[0, :], W)
|
||||||
points[1, :] = denormalize_L(points[1, :], H)
|
points[1, :] = denormalize_L(points[1, :], H)
|
||||||
@ -112,10 +118,12 @@ def denormalize_points(shape, points):
|
|||||||
|
|
||||||
# shape is a tuple [H, W]
|
# shape is a tuple [H, W]
|
||||||
def denormalize_points_batch(shape, points):
|
def denormalize_points_batch(shape, points):
|
||||||
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(shape) == 2, "invalid shape : {:}".format(
|
assert (isinstance(shape, tuple) or isinstance(shape, list)) and len(
|
||||||
shape
|
shape
|
||||||
)
|
) == 2, "invalid shape : {:}".format(shape)
|
||||||
assert isinstance(points, torch.Tensor) and (points.shape[-1] == 2), "points are wrong : {:}".format(points.shape)
|
assert isinstance(points, torch.Tensor) and (
|
||||||
|
points.shape[-1] == 2
|
||||||
|
), "points are wrong : {:}".format(points.shape)
|
||||||
(H, W), points = shape, points.clone()
|
(H, W), points = shape, points.clone()
|
||||||
x = denormalize_L(points[..., 0], W)
|
x = denormalize_L(points[..., 0], W)
|
||||||
y = denormalize_L(points[..., 1], H)
|
y = denormalize_L(points[..., 1], H)
|
||||||
@ -145,5 +153,7 @@ def affine2image(image, theta, shape):
|
|||||||
theta = theta[:2, :].unsqueeze(0)
|
theta = theta[:2, :].unsqueeze(0)
|
||||||
grid_size = torch.Size([1, C, shape[0], shape[1]])
|
grid_size = torch.Size([1, C, shape[0], shape[1]])
|
||||||
grid = F.affine_grid(theta, grid_size)
|
grid = F.affine_grid(theta, grid_size)
|
||||||
affI = F.grid_sample(image.unsqueeze(0), grid, mode="bilinear", padding_mode="border")
|
affI = F.grid_sample(
|
||||||
|
image.unsqueeze(0), grid, mode="bilinear", padding_mode="border"
|
||||||
|
)
|
||||||
return affI.squeeze(0)
|
return affI.squeeze(0)
|
||||||
|
@ -48,7 +48,11 @@ def get_model_infos(model, shape):
|
|||||||
if hasattr(model, "auxiliary_param"):
|
if hasattr(model, "auxiliary_param"):
|
||||||
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
aux_params = count_parameters_in_MB(model.auxiliary_param())
|
||||||
print("The auxiliary params of this model is : {:}".format(aux_params))
|
print("The auxiliary params of this model is : {:}".format(aux_params))
|
||||||
print("We remove the auxiliary params from the total params ({:}) when counting".format(Param))
|
print(
|
||||||
|
"We remove the auxiliary params from the total params ({:}) when counting".format(
|
||||||
|
Param
|
||||||
|
)
|
||||||
|
)
|
||||||
Param = Param - aux_params
|
Param = Param - aux_params
|
||||||
|
|
||||||
# print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
# print_log('FLOPs : {:} MB'.format(FLOPs), log)
|
||||||
@ -92,7 +96,9 @@ def pool_flops_counter_hook(pool_module, inputs, output):
|
|||||||
out_C, output_height, output_width = output.shape[1:]
|
out_C, output_height, output_width = output.shape[1:]
|
||||||
assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
|
assert out_C == inputs[0].size(1), "{:} vs. {:}".format(out_C, inputs[0].size())
|
||||||
|
|
||||||
overall_flops = batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
overall_flops = (
|
||||||
|
batch_size * out_C * output_height * output_width * kernel_size * kernel_size
|
||||||
|
)
|
||||||
pool_module.__flops__ += overall_flops
|
pool_module.__flops__ += overall_flops
|
||||||
|
|
||||||
|
|
||||||
@ -104,7 +110,9 @@ def self_calculate_flops_counter_hook(self_module, inputs, output):
|
|||||||
def fc_flops_counter_hook(fc_module, inputs, output):
|
def fc_flops_counter_hook(fc_module, inputs, output):
|
||||||
batch_size = inputs[0].size(0)
|
batch_size = inputs[0].size(0)
|
||||||
xin, xout = fc_module.in_features, fc_module.out_features
|
xin, xout = fc_module.in_features, fc_module.out_features
|
||||||
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(xin, xout)
|
assert xin == inputs[0].size(1) and xout == output.size(1), "IO=({:}, {:})".format(
|
||||||
|
xin, xout
|
||||||
|
)
|
||||||
overall_flops = batch_size * xin * xout
|
overall_flops = batch_size * xin * xout
|
||||||
if fc_module.bias is not None:
|
if fc_module.bias is not None:
|
||||||
overall_flops += batch_size * xout
|
overall_flops += batch_size * xout
|
||||||
@ -136,7 +144,9 @@ def conv2d_flops_counter_hook(conv_module, inputs, output):
|
|||||||
in_channels = conv_module.in_channels
|
in_channels = conv_module.in_channels
|
||||||
out_channels = conv_module.out_channels
|
out_channels = conv_module.out_channels
|
||||||
groups = conv_module.groups
|
groups = conv_module.groups
|
||||||
conv_per_position_flops = kernel_height * kernel_width * in_channels * out_channels / groups
|
conv_per_position_flops = (
|
||||||
|
kernel_height * kernel_width * in_channels * out_channels / groups
|
||||||
|
)
|
||||||
|
|
||||||
active_elements_count = batch_size * output_height * output_width
|
active_elements_count = batch_size * output_height * output_width
|
||||||
overall_flops = conv_per_position_flops * active_elements_count
|
overall_flops = conv_per_position_flops * active_elements_count
|
||||||
@ -184,7 +194,9 @@ def add_flops_counter_hook_function(module):
|
|||||||
if not hasattr(module, "__flops_handle__"):
|
if not hasattr(module, "__flops_handle__"):
|
||||||
handle = module.register_forward_hook(fc_flops_counter_hook)
|
handle = module.register_forward_hook(fc_flops_counter_hook)
|
||||||
module.__flops_handle__ = handle
|
module.__flops_handle__ = handle
|
||||||
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(module, torch.nn.MaxPool2d):
|
elif isinstance(module, torch.nn.AvgPool2d) or isinstance(
|
||||||
|
module, torch.nn.MaxPool2d
|
||||||
|
):
|
||||||
if not hasattr(module, "__flops_handle__"):
|
if not hasattr(module, "__flops_handle__"):
|
||||||
handle = module.register_forward_hook(pool_flops_counter_hook)
|
handle = module.register_forward_hook(pool_flops_counter_hook)
|
||||||
module.__flops_handle__ = handle
|
module.__flops_handle__ = handle
|
||||||
|
@ -2,7 +2,15 @@ import os
|
|||||||
|
|
||||||
|
|
||||||
class GPUManager:
|
class GPUManager:
|
||||||
queries = ("index", "gpu_name", "memory.free", "memory.used", "memory.total", "power.draw", "power.limit")
|
queries = (
|
||||||
|
"index",
|
||||||
|
"gpu_name",
|
||||||
|
"memory.free",
|
||||||
|
"memory.used",
|
||||||
|
"memory.total",
|
||||||
|
"power.draw",
|
||||||
|
"power.limit",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
all_gpus = self.query_gpu(False)
|
all_gpus = self.query_gpu(False)
|
||||||
@ -28,7 +36,9 @@ class GPUManager:
|
|||||||
find = False
|
find = False
|
||||||
for gpu in all_gpus:
|
for gpu in all_gpus:
|
||||||
if gpu["index"] == CUDA_VISIBLE_DEVICE:
|
if gpu["index"] == CUDA_VISIBLE_DEVICE:
|
||||||
assert not find, "Duplicate cuda device index : {}".format(CUDA_VISIBLE_DEVICE)
|
assert not find, "Duplicate cuda device index : {}".format(
|
||||||
|
CUDA_VISIBLE_DEVICE
|
||||||
|
)
|
||||||
find = True
|
find = True
|
||||||
selected_gpus.append(gpu.copy())
|
selected_gpus.append(gpu.copy())
|
||||||
selected_gpus[-1]["index"] = "{}".format(idx)
|
selected_gpus[-1]["index"] = "{}".format(idx)
|
||||||
@ -52,7 +62,9 @@ class GPUManager:
|
|||||||
|
|
||||||
def select_by_memory(self, numbers=1):
|
def select_by_memory(self, numbers=1):
|
||||||
all_gpus = self.query_gpu(False)
|
all_gpus = self.query_gpu(False)
|
||||||
assert numbers <= len(all_gpus), "Require {} gpus more than you have".format(numbers)
|
assert numbers <= len(all_gpus), "Require {} gpus more than you have".format(
|
||||||
|
numbers
|
||||||
|
)
|
||||||
alls = []
|
alls = []
|
||||||
for idx, gpu in enumerate(all_gpus):
|
for idx, gpu in enumerate(all_gpus):
|
||||||
free_memory = gpu["memory.free"]
|
free_memory = gpu["memory.free"]
|
||||||
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
# from utils import obtain_accuracy
|
# modules in AutoDL
|
||||||
from models import CellStructure
|
from models import CellStructure
|
||||||
from log_utils import time_string
|
from log_utils import time_string
|
||||||
|
|
||||||
@ -56,11 +56,20 @@ def evaluate_one_shot(model, xloader, api, cal_mode, seed=111):
|
|||||||
correct = (preds == targets.cuda()).float()
|
correct = (preds == targets.cuda()).float()
|
||||||
accuracies.append(correct.mean().item())
|
accuracies.append(correct.mean().item())
|
||||||
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
|
if idx != 0 and (idx % 500 == 0 or idx + 1 == len(archs)):
|
||||||
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[0, 1]
|
cor_accs_valid = np.corrcoef(accuracies, gt_accs_10_valid[: idx + 1])[
|
||||||
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[0, 1]
|
0, 1
|
||||||
|
]
|
||||||
|
cor_accs_test = np.corrcoef(accuracies, gt_accs_10_test[: idx + 1])[
|
||||||
|
0, 1
|
||||||
|
]
|
||||||
print(
|
print(
|
||||||
"{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.".format(
|
"{:} {:05d}/{:05d} mode={:5s}, correlation : accs={:.5f} for CIFAR-10 valid, {:.5f} for CIFAR-10 test.".format(
|
||||||
time_string(), idx, len(archs), "Train" if cal_mode else "Eval", cor_accs_valid, cor_accs_test
|
time_string(),
|
||||||
|
idx,
|
||||||
|
len(archs),
|
||||||
|
"Train" if cal_mode else "Eval",
|
||||||
|
cor_accs_valid,
|
||||||
|
cor_accs_test,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
model.load_state_dict(weights)
|
model.load_state_dict(weights)
|
||||||
|
101
lib/utils/qlib_utils.py
Normal file
101
lib/utils/qlib_utils.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import List, Text
|
||||||
|
from collections import defaultdict, OrderedDict
|
||||||
|
|
||||||
|
|
||||||
|
class QResult:
|
||||||
|
"""A class to maintain the results of a qlib experiment."""
|
||||||
|
|
||||||
|
def __init__(self, name):
|
||||||
|
self._result = defaultdict(list)
|
||||||
|
self._name = name
|
||||||
|
self._recorder_paths = []
|
||||||
|
|
||||||
|
def append(self, key, value):
|
||||||
|
self._result[key].append(value)
|
||||||
|
|
||||||
|
def append_path(self, xpath):
|
||||||
|
self._recorder_paths.append(xpath)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def paths(self):
|
||||||
|
return self._recorder_paths
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self):
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def keys(self):
|
||||||
|
return list(self._result.keys())
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._result)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{name}({xname}, {num} metrics)".format(
|
||||||
|
name=self.__class__.__name__, xname=self.name, num=len(self.result)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if key not in self._result:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid key {:}, please use one of {:}".format(key, self.keys)
|
||||||
|
)
|
||||||
|
values = self._result[key]
|
||||||
|
return float(np.mean(values))
|
||||||
|
|
||||||
|
def update(self, metrics, filter_keys=None):
|
||||||
|
for key, value in metrics.items():
|
||||||
|
if filter_keys is not None and key in filter_keys:
|
||||||
|
key = filter_keys[key]
|
||||||
|
elif filter_keys is not None:
|
||||||
|
continue
|
||||||
|
self.append(key, value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def full_str(xstr, space):
|
||||||
|
xformat = "{:" + str(space) + "s}"
|
||||||
|
return xformat.format(str(xstr))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_dict(dict_list):
|
||||||
|
new_dict = dict()
|
||||||
|
for xkey in dict_list[0].keys():
|
||||||
|
values = [x for xdict in dict_list for x in xdict[xkey]]
|
||||||
|
new_dict[xkey] = values
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
def info(
|
||||||
|
self,
|
||||||
|
keys: List[Text],
|
||||||
|
separate: Text = "& ",
|
||||||
|
space: int = 20,
|
||||||
|
verbose: bool = True,
|
||||||
|
):
|
||||||
|
avaliable_keys = []
|
||||||
|
for key in keys:
|
||||||
|
if key not in self.result:
|
||||||
|
print("There are invalid key [{:}].".format(key))
|
||||||
|
else:
|
||||||
|
avaliable_keys.append(key)
|
||||||
|
head_str = separate.join([self.full_str(x, space) for x in avaliable_keys])
|
||||||
|
values = []
|
||||||
|
for key in avaliable_keys:
|
||||||
|
if "IR" in key:
|
||||||
|
current_values = [x * 100 for x in self._result[key]]
|
||||||
|
else:
|
||||||
|
current_values = self._result[key]
|
||||||
|
mean = np.mean(current_values)
|
||||||
|
std = np.std(current_values)
|
||||||
|
# values.append("{:.4f} $\pm$ {:.4f}".format(mean, std))
|
||||||
|
values.append("{:.2f} $\pm$ {:.2f}".format(mean, std))
|
||||||
|
value_str = separate.join([self.full_str(x, space) for x in values])
|
||||||
|
if verbose:
|
||||||
|
print(head_str)
|
||||||
|
print(value_str)
|
||||||
|
return head_str, value_str
|
@ -8,10 +8,14 @@ def split_str2indexes(string: str, max_check: int, length_limit=5):
|
|||||||
if len(srange) != 2:
|
if len(srange) != 2:
|
||||||
raise ValueError("invalid srange : {:}".format(srange))
|
raise ValueError("invalid srange : {:}".format(srange))
|
||||||
if length_limit is not None:
|
if length_limit is not None:
|
||||||
assert len(srange[0]) == len(srange[1]) == length_limit, "invalid srange : {:}".format(srange)
|
assert (
|
||||||
|
len(srange[0]) == len(srange[1]) == length_limit
|
||||||
|
), "invalid srange : {:}".format(srange)
|
||||||
srange = (int(srange[0]), int(srange[1]))
|
srange = (int(srange[0]), int(srange[1]))
|
||||||
if not (0 <= srange[0] <= srange[1] < max_check):
|
if not (0 <= srange[0] <= srange[1] < max_check):
|
||||||
raise ValueError("{:} vs {:} vs {:}".format(srange[0], srange[1], max_check))
|
raise ValueError(
|
||||||
|
"{:} vs {:} vs {:}".format(srange[0], srange[1], max_check)
|
||||||
|
)
|
||||||
for i in range(srange[0], srange[1] + 1):
|
for i in range(srange[0], srange[1] + 1):
|
||||||
indexes.add(i)
|
indexes.add(i)
|
||||||
return indexes
|
return indexes
|
||||||
|
@ -21,7 +21,11 @@ def get_conv2D_Wmats(tensor: np.ndarray) -> List[np.ndarray]:
|
|||||||
"""
|
"""
|
||||||
mats = []
|
mats = []
|
||||||
N, M, imax, jmax = tensor.shape
|
N, M, imax, jmax = tensor.shape
|
||||||
assert N + M >= imax + jmax, "invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)".format(N, M, imax, jmax)
|
assert (
|
||||||
|
N + M >= imax + jmax
|
||||||
|
), "invalid tensor shape detected: {}x{} (NxM), {}x{} (i,j)".format(
|
||||||
|
N, M, imax, jmax
|
||||||
|
)
|
||||||
for i in range(imax):
|
for i in range(imax):
|
||||||
for j in range(jmax):
|
for j in range(jmax):
|
||||||
w = tensor[:, :, i, j]
|
w = tensor[:, :, i, j]
|
||||||
@ -58,7 +62,17 @@ def glorot_norm_fix(w, n, m, rf_size):
|
|||||||
return w
|
return w
|
||||||
|
|
||||||
|
|
||||||
def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix):
|
def analyze_weights(
|
||||||
|
weights,
|
||||||
|
min_size,
|
||||||
|
max_size,
|
||||||
|
alphas,
|
||||||
|
lognorms,
|
||||||
|
spectralnorms,
|
||||||
|
softranks,
|
||||||
|
normalize,
|
||||||
|
glorot_fix,
|
||||||
|
):
|
||||||
results = OrderedDict()
|
results = OrderedDict()
|
||||||
count = len(weights)
|
count = len(weights)
|
||||||
if count == 0:
|
if count == 0:
|
||||||
@ -94,12 +108,16 @@ def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms
|
|||||||
lambda0 = None
|
lambda0 = None
|
||||||
|
|
||||||
if M < min_size:
|
if M < min_size:
|
||||||
summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(i + 1, count, M, N, min_size)
|
summary = "Weight matrix {}/{} ({},{}): Skipping: too small (<{})".format(
|
||||||
|
i + 1, count, M, N, min_size
|
||||||
|
)
|
||||||
cur_res["summary"] = summary
|
cur_res["summary"] = summary
|
||||||
continue
|
continue
|
||||||
elif max_size > 0 and M > max_size:
|
elif max_size > 0 and M > max_size:
|
||||||
summary = "Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(
|
summary = (
|
||||||
i + 1, count, M, N, max_size
|
"Weight matrix {}/{} ({},{}): Skipping: too big (testing) (>{})".format(
|
||||||
|
i + 1, count, M, N, max_size
|
||||||
|
)
|
||||||
)
|
)
|
||||||
cur_res["summary"] = summary
|
cur_res["summary"] = summary
|
||||||
continue
|
continue
|
||||||
@ -153,7 +171,9 @@ def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms
|
|||||||
cur_res["lognormX"] = lognormX
|
cur_res["lognormX"] = lognormX
|
||||||
|
|
||||||
summary.append(
|
summary.append(
|
||||||
"Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(i + 1, count, M, N, lognorm, lognormX)
|
"Weight matrix {}/{} ({},{}): LogNorm: {} ; LogNormX: {}".format(
|
||||||
|
i + 1, count, M, N, lognorm, lognormX
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if softranks:
|
if softranks:
|
||||||
@ -163,8 +183,10 @@ def analyze_weights(weights, min_size, max_size, alphas, lognorms, spectralnorms
|
|||||||
cur_res["softrank"] = softrank
|
cur_res["softrank"] = softrank
|
||||||
cur_res["softranklog"] = softranklog
|
cur_res["softranklog"] = softranklog
|
||||||
cur_res["softranklogratio"] = softranklogratio
|
cur_res["softranklogratio"] = softranklogratio
|
||||||
summary += "{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(
|
summary += (
|
||||||
summary, softrank, softranklog, softranklogratio
|
"{}. Softrank: {}. Softrank log: {}. Softrank log ratio: {}".format(
|
||||||
|
summary, softrank, softranklog, softranklogratio
|
||||||
|
)
|
||||||
)
|
)
|
||||||
cur_res["summary"] = "\n".join(summary)
|
cur_res["summary"] = "\n".join(summary)
|
||||||
return results
|
return results
|
||||||
@ -209,7 +231,17 @@ def compute_details(results):
|
|||||||
metrics_stats.append("{}_compound_avg".format(metric))
|
metrics_stats.append("{}_compound_avg".format(metric))
|
||||||
|
|
||||||
columns = (
|
columns = (
|
||||||
["layer_id", "layer_type", "N", "M", "layer_count", "slice", "slice_count", "level", "comment"]
|
[
|
||||||
|
"layer_id",
|
||||||
|
"layer_type",
|
||||||
|
"N",
|
||||||
|
"M",
|
||||||
|
"layer_count",
|
||||||
|
"slice",
|
||||||
|
"slice_count",
|
||||||
|
"level",
|
||||||
|
"comment",
|
||||||
|
]
|
||||||
+ [*metrics]
|
+ [*metrics]
|
||||||
+ metrics_stats
|
+ metrics_stats
|
||||||
)
|
)
|
||||||
@ -351,7 +383,15 @@ def analyze(
|
|||||||
else:
|
else:
|
||||||
weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy())
|
weights = get_conv2D_Wmats(module.weight.cpu().detach().numpy())
|
||||||
results = analyze_weights(
|
results = analyze_weights(
|
||||||
weights, min_size, max_size, alphas, lognorms, spectralnorms, softranks, normalize, glorot_fix
|
weights,
|
||||||
|
min_size,
|
||||||
|
max_size,
|
||||||
|
alphas,
|
||||||
|
lognorms,
|
||||||
|
spectralnorms,
|
||||||
|
softranks,
|
||||||
|
normalize,
|
||||||
|
glorot_fix,
|
||||||
)
|
)
|
||||||
results["id"] = index
|
results["id"] = index
|
||||||
results["type"] = type(module)
|
results["type"] = type(module)
|
||||||
|
310
notebooks/TOT/ES-Model.ipynb
Normal file
310
notebooks/TOT/ES-Model.ipynb
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"id": "afraid-minutes",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The root path: /Users/xuanyidong/Desktop/AutoDL-Projects\n",
|
||||||
|
"The library path: /Users/xuanyidong/Desktop/AutoDL-Projects/lib\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[61765:MainThread](2021-04-11 21:23:06,638) INFO - qlib.Initialization - [config.py:276] - default_conf: client.\n",
|
||||||
|
"[61765:MainThread](2021-04-11 21:23:06,641) WARNING - qlib.Initialization - [config.py:291] - redis connection failed(host=127.0.0.1 port=6379), cache will not be used!\n",
|
||||||
|
"[61765:MainThread](2021-04-11 21:23:06,643) INFO - qlib.Initialization - [__init__.py:46] - qlib successfully initialized based on client settings.\n",
|
||||||
|
"[61765:MainThread](2021-04-11 21:23:06,644) INFO - qlib.Initialization - [__init__.py:47] - data_path=/Users/xuanyidong/.qlib/qlib_data/cn_data\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"#\n",
|
||||||
|
"# Exhaustive Search Results\n",
|
||||||
|
"#\n",
|
||||||
|
"import os\n",
|
||||||
|
"import re\n",
|
||||||
|
"import sys\n",
|
||||||
|
"import qlib\n",
|
||||||
|
"import pprint\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"from pathlib import Path\n",
|
||||||
|
"\n",
|
||||||
|
"__file__ = os.path.dirname(os.path.realpath(\"__file__\"))\n",
|
||||||
|
"root_dir = (Path(__file__).parent / \"..\").resolve()\n",
|
||||||
|
"lib_dir = (root_dir / \"lib\").resolve()\n",
|
||||||
|
"print(\"The root path: {:}\".format(root_dir))\n",
|
||||||
|
"print(\"The library path: {:}\".format(lib_dir))\n",
|
||||||
|
"assert lib_dir.exists(), \"{:} does not exist\".format(lib_dir)\n",
|
||||||
|
"if str(lib_dir) not in sys.path:\n",
|
||||||
|
" sys.path.insert(0, str(lib_dir))\n",
|
||||||
|
"\n",
|
||||||
|
"import qlib\n",
|
||||||
|
"from qlib import config as qconfig\n",
|
||||||
|
"from qlib.workflow import R\n",
|
||||||
|
"qlib.init(provider_uri='~/.qlib/qlib_data/cn_data', region=qconfig.REG_CN)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"id": "hidden-exemption",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from utils.qlib_utils import QResult"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"id": "continental-drain",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def filter_finished(recorders):\n",
|
||||||
|
" returned_recorders = dict()\n",
|
||||||
|
" not_finished = 0\n",
|
||||||
|
" for key, recorder in recorders.items():\n",
|
||||||
|
" if recorder.status == \"FINISHED\":\n",
|
||||||
|
" returned_recorders[key] = recorder\n",
|
||||||
|
" else:\n",
|
||||||
|
" not_finished += 1\n",
|
||||||
|
" return returned_recorders, not_finished\n",
|
||||||
|
"\n",
|
||||||
|
"def query_info(save_dir, verbose, name_filter, key_map):\n",
|
||||||
|
" if isinstance(save_dir, list):\n",
|
||||||
|
" results = []\n",
|
||||||
|
" for x in save_dir:\n",
|
||||||
|
" x = query_info(x, verbose, name_filter, key_map)\n",
|
||||||
|
" results.extend(x)\n",
|
||||||
|
" return results\n",
|
||||||
|
" # Here, the save_dir must be a string\n",
|
||||||
|
" R.set_uri(str(save_dir))\n",
|
||||||
|
" experiments = R.list_experiments()\n",
|
||||||
|
"\n",
|
||||||
|
" if verbose:\n",
|
||||||
|
" print(\"There are {:} experiments.\".format(len(experiments)))\n",
|
||||||
|
" qresults = []\n",
|
||||||
|
" for idx, (key, experiment) in enumerate(experiments.items()):\n",
|
||||||
|
" if experiment.id == \"0\":\n",
|
||||||
|
" continue\n",
|
||||||
|
" if name_filter is not None and re.fullmatch(name_filter, experiment.name) is None:\n",
|
||||||
|
" continue\n",
|
||||||
|
" recorders = experiment.list_recorders()\n",
|
||||||
|
" recorders, not_finished = filter_finished(recorders)\n",
|
||||||
|
" if verbose:\n",
|
||||||
|
" print(\n",
|
||||||
|
" \"====>>>> {:02d}/{:02d}-th experiment {:9s} has {:02d}/{:02d} finished recorders.\".format(\n",
|
||||||
|
" idx + 1,\n",
|
||||||
|
" len(experiments),\n",
|
||||||
|
" experiment.name,\n",
|
||||||
|
" len(recorders),\n",
|
||||||
|
" len(recorders) + not_finished,\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" result = QResult(experiment.name)\n",
|
||||||
|
" for recorder_id, recorder in recorders.items():\n",
|
||||||
|
" result.update(recorder.list_metrics(), key_map)\n",
|
||||||
|
" result.append_path(\n",
|
||||||
|
" os.path.join(recorder.uri, recorder.experiment_id, recorder.id)\n",
|
||||||
|
" )\n",
|
||||||
|
" if not len(result):\n",
|
||||||
|
" print(\"There are no valid recorders for {:}\".format(experiment))\n",
|
||||||
|
" continue\n",
|
||||||
|
" else:\n",
|
||||||
|
" if verbose:\n",
|
||||||
|
" print(\n",
|
||||||
|
" \"There are {:} valid recorders for {:}\".format(\n",
|
||||||
|
" len(recorders), experiment.name\n",
|
||||||
|
" )\n",
|
||||||
|
" )\n",
|
||||||
|
" qresults.append(result)\n",
|
||||||
|
" return qresults"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"id": "filled-multiple",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[61765:MainThread](2021-04-11 21:23:07,182) INFO - qlib.workflow - [expm.py:290] - <mlflow.tracking.client.MlflowClient object at 0x7fabbfe8aeb0>\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[PosixPath('/Users/xuanyidong/Desktop/AutoDL-Projects/outputs/qlib-baselines-csi300')]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"paths = [root_dir / 'outputs' / 'qlib-baselines-csi300']\n",
|
||||||
|
"paths = [path.resolve() for path in paths]\n",
|
||||||
|
"print(paths)\n",
|
||||||
|
"\n",
|
||||||
|
"key_map = dict()\n",
|
||||||
|
"for xset in (\"train\", \"valid\", \"test\"):\n",
|
||||||
|
" key_map[\"{:}-mean-IC\".format(xset)] = \"IC ({:})\".format(xset)\n",
|
||||||
|
" key_map[\"{:}-mean-ICIR\".format(xset)] = \"ICIR ({:})\".format(xset)\n",
|
||||||
|
"qresults = query_info(paths, False, 'TSF-.*-drop0_0', key_map)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"id": "intimate-approval",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import matplotlib\n",
|
||||||
|
"from matplotlib import cm\n",
|
||||||
|
"matplotlib.use(\"agg\")\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import matplotlib.ticker as ticker"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 40,
|
||||||
|
"id": "supreme-basis",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def vis_depth_channel(qresults, save_path):\n",
|
||||||
|
" save_dir = (save_path / '..').resolve()\n",
|
||||||
|
" save_dir.mkdir(parents=True, exist_ok=True)\n",
|
||||||
|
" print('There are {:} qlib-results'.format(len(qresults)))\n",
|
||||||
|
" \n",
|
||||||
|
" dpi, width, height = 200, 4000, 2000\n",
|
||||||
|
" figsize = width / float(dpi), height / float(dpi)\n",
|
||||||
|
" LabelSize, LegendFontsize = 22, 12\n",
|
||||||
|
" font_gap = 5\n",
|
||||||
|
" \n",
|
||||||
|
" fig = plt.figure(figsize=figsize)\n",
|
||||||
|
" # fig, axs = plt.subplots(1, 2, figsize=figsize, projection='3d')\n",
|
||||||
|
" \n",
|
||||||
|
" def plot_ax(cur_ax, train_or_test):\n",
|
||||||
|
" depths, channels = [], []\n",
|
||||||
|
" ic_values, xmaps = [], dict()\n",
|
||||||
|
" for qresult in qresults:\n",
|
||||||
|
" name = qresult.name.split('-')[1]\n",
|
||||||
|
" depths.append(float(name.split('x')[0]))\n",
|
||||||
|
" channels.append(float(name.split('x')[1]))\n",
|
||||||
|
" if train_or_test:\n",
|
||||||
|
" ic_values.append(qresult['ICIR (train)'] * 100)\n",
|
||||||
|
" else:\n",
|
||||||
|
" ic_values.append(qresult['ICIR (valid)'] * 100)\n",
|
||||||
|
" xmaps[(depths[-1], channels[-1])] = ic_values[-1]\n",
|
||||||
|
" # cur_ax.scatter(depths, channels, ic_values, marker='o', c=\"tab:orange\")\n",
|
||||||
|
" raw_depths = np.arange(1, 9, dtype=np.int32)\n",
|
||||||
|
" raw_channels = np.array([6, 12, 24, 32, 48, 64], dtype=np.int32)\n",
|
||||||
|
" depths, channels = np.meshgrid(raw_depths, raw_channels)\n",
|
||||||
|
" ic_values = np.sin(depths) # initialize\n",
|
||||||
|
" # print(ic_values.shape)\n",
|
||||||
|
" num_x, num_y = ic_values.shape\n",
|
||||||
|
" for i in range(num_x):\n",
|
||||||
|
" for j in range(num_y):\n",
|
||||||
|
" xkey = (int(depths[i][j]), int(channels[i][j]))\n",
|
||||||
|
" if xkey not in xmaps:\n",
|
||||||
|
" raise ValueError(\"Did not find {:}\".format(xkey))\n",
|
||||||
|
" ic_values[i][j] = xmaps[xkey]\n",
|
||||||
|
" #print(sorted(list(xmaps.keys())))\n",
|
||||||
|
" #surf = cur_ax.plot_surface(\n",
|
||||||
|
" # np.array(depths), np.array(channels), np.array(ic_values),\n",
|
||||||
|
" # cmap=cm.coolwarm, linewidth=0, antialiased=False)\n",
|
||||||
|
" surf = cur_ax.plot_surface(\n",
|
||||||
|
" depths, channels, ic_values,\n",
|
||||||
|
" cmap=cm.Spectral, linewidth=0.2, antialiased=True)\n",
|
||||||
|
" cur_ax.set_xticks(raw_depths)\n",
|
||||||
|
" cur_ax.set_yticks(raw_channels)\n",
|
||||||
|
" cur_ax.set_xlabel(\"#depth\", fontsize=LabelSize)\n",
|
||||||
|
" cur_ax.set_ylabel(\"#channels\", fontsize=LabelSize)\n",
|
||||||
|
" cur_ax.set_zlabel(\"{:} IC (%)\".format('training' if train_or_test else 'validation'), fontsize=LabelSize)\n",
|
||||||
|
" for tick in cur_ax.xaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" for tick in cur_ax.yaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" for tick in cur_ax.zaxis.get_major_ticks():\n",
|
||||||
|
" tick.label.set_fontsize(LabelSize - font_gap)\n",
|
||||||
|
" # Add a color bar which maps values to colors.\n",
|
||||||
|
"# cax = fig.add_axes([cur_ax.get_position().x1 + 0.01,\n",
|
||||||
|
"# cur_ax.get_position().y0,\n",
|
||||||
|
"# 0.01,\n",
|
||||||
|
"# cur_ax.get_position().height * 0.9])\n",
|
||||||
|
" # fig.colorbar(surf, cax=cax)\n",
|
||||||
|
" # fig.colorbar(surf, shrink=0.5, aspect=5)\n",
|
||||||
|
" # import pdb; pdb.set_trace()\n",
|
||||||
|
" # ax1.legend(loc=4, fontsize=LegendFontsize)\n",
|
||||||
|
" ax = fig.add_subplot(1, 2, 1, projection='3d')\n",
|
||||||
|
" plot_ax(ax, True)\n",
|
||||||
|
" ax = fig.add_subplot(1, 2, 2, projection='3d')\n",
|
||||||
|
" plot_ax(ax, False)\n",
|
||||||
|
" # fig.tight_layout()\n",
|
||||||
|
" plt.subplots_adjust(wspace=0.05)#, hspace=0.4)\n",
|
||||||
|
" fig.savefig(save_path, dpi=dpi, bbox_inches=\"tight\", format=\"pdf\")\n",
|
||||||
|
" plt.close(\"all\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 41,
|
||||||
|
"id": "shared-envelope",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The Desktop is at: /Users/xuanyidong/Desktop\n",
|
||||||
|
"There are 48 qlib-results\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Visualization\n",
|
||||||
|
"home_dir = Path.home()\n",
|
||||||
|
"desktop_dir = home_dir / 'Desktop'\n",
|
||||||
|
"print('The Desktop is at: {:}'.format(desktop_dir))\n",
|
||||||
|
"\n",
|
||||||
|
"vis_depth_channel(qresults, desktop_dir / 'es_csi300_d_vs_c.pdf')"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.8.8"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user