This commit is contained in:
D-X-Y 2021-04-26 21:44:03 +08:00
parent 8358d71cdf
commit d3371296a7
10 changed files with 270 additions and 264 deletions

View File

@ -66,23 +66,24 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
def find_min(cur, others): def find_min(cur, others):
if cur is None: if cur is None:
return float(others.min()) return float(others)
else: else:
return float(min(cur, others.min())) return float(min(cur, others))
def find_max(cur, others): def find_max(cur, others):
if cur is None: if cur is None:
return float(others.max()) return float(others.max())
else: else:
return float(max(cur, others.max())) return float(max(cur, others))
def compare_cl(save_dir): def compare_cl(save_dir):
save_dir = Path(str(save_dir)) save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, function = create_example_v1( dynamic_env, function = create_example_v1(
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0), # timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
timestamp_config=None,
num_per_task=1000, num_per_task=1000,
) )
@ -104,13 +105,11 @@ def compare_cl(save_dir):
current_data["lfna_xaxis_all"] = xaxis_all current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all current_data["lfna_yaxis_all"] = yaxis_all
import pdb
pdb.set_trace()
# compute cl-min # compute cl-min
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all) cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1 cl_xaxis_max = (
find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1
)
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05) cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all) cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
@ -142,8 +141,8 @@ def compare_cl(save_dir):
"yaxis": cl_yaxis_all, "yaxis": cl_yaxis_all,
"color": "r", "color": "r",
"s": 10, "s": 10,
"xlim": (-6, 6 + timestamp * 0.2), "xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)),
"ylim": (-40, 40), "ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)),
"alpha": 0.99, "alpha": 0.99,
"label": "Continual Learning", "label": "Continual Learning",
} }
@ -151,10 +150,10 @@ def compare_cl(save_dir):
draw_multi_fig( draw_multi_fig(
save_dir, save_dir,
timestamp, idx,
scatter_list, scatter_list,
wh=(2000, 1300), wh=(2000, 1300),
fig_title="Timestamp={:03d}".format(timestamp), fig_title="Timestamp={:03d}".format(idx),
) )
print("Save all figures into {:}".format(save_dir)) print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve() save_dir = save_dir.resolve()

View File

@ -63,7 +63,7 @@ class SyntheticDEnv(data.Dataset):
dataset = np.random.multivariate_normal( dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task mean_list, cov_matrix, size=self._num_per_task
) )
return index, torch.Tensor(dataset) return timestamp, torch.Tensor(dataset)
def __len__(self): def __len__(self):
return len(self._timestamp_generator) return len(self._timestamp_generator)

View File

@ -8,9 +8,10 @@ from .synthetic_env import SyntheticDEnv
def create_example_v1( def create_example_v1(
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0), timestamp_config=None,
num_per_task=5000, num_per_task=5000,
): ):
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
mean_generator = ComposedSinFunc() mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5) std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)

View File

@ -1,8 +1,10 @@
##################################################### #####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 # # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
##################################################### #####################################################
# To be finished.
#
import os, sys, time, torch import os, sys, time, torch
from typing import import Optional, Text, Callable from typing import Optional, Text, Callable
# modules in AutoDL # modules in AutoDL
from log_utils import AverageMeter from log_utils import AverageMeter
@ -60,9 +62,10 @@ def procedure(
network, network,
criterion, criterion,
optimizer, optimizer,
eval_metric,
mode: Text, mode: Text,
print_freq: int = 100, print_freq: int = 100,
logger_fn: Callable = None logger_fn: Callable = None,
): ):
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter() data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
if mode.lower() == "train": if mode.lower() == "train":
@ -90,7 +93,7 @@ def procedure(
optimizer.step() optimizer.step()
# record # record
metrics = metrics = eval_metric(logits.data, targets.data)
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5)) prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0)) losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0)) top1.update(prec1.item(), inputs.size(0))

View File

@ -3,6 +3,7 @@
##################################################### #####################################################
import abc import abc
def obtain_accuracy(output, target, topk=(1,)): def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k""" """Computes the precision@k for the specified values of k"""
maxk = max(topk) maxk = max(topk)
@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)):
class EvaluationMetric(abc.ABC): class EvaluationMetric(abc.ABC):
def __init__(self): def __init__(self):
self._total_metrics = 0 self._total_metrics = 0

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long