This commit is contained in:
D-X-Y 2021-04-26 21:44:03 +08:00
parent 8358d71cdf
commit d3371296a7
10 changed files with 270 additions and 264 deletions

View File

@ -66,23 +66,24 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
def find_min(cur, others):
if cur is None:
return float(others.min())
return float(others)
else:
return float(min(cur, others.min()))
return float(min(cur, others))
def find_max(cur, others):
if cur is None:
return float(others.max())
else:
return float(max(cur, others.max()))
return float(max(cur, others))
def compare_cl(save_dir):
save_dir = Path(str(save_dir))
save_dir.mkdir(parents=True, exist_ok=True)
dynamic_env, function = create_example_v1(
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
timestamp_config=None,
num_per_task=1000,
)
@ -104,13 +105,11 @@ def compare_cl(save_dir):
current_data["lfna_xaxis_all"] = xaxis_all
current_data["lfna_yaxis_all"] = yaxis_all
import pdb
pdb.set_trace()
# compute cl-min
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all)
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
cl_xaxis_max = (
find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1
)
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
@ -142,8 +141,8 @@ def compare_cl(save_dir):
"yaxis": cl_yaxis_all,
"color": "r",
"s": 10,
"xlim": (-6, 6 + timestamp * 0.2),
"ylim": (-40, 40),
"xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)),
"ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)),
"alpha": 0.99,
"label": "Continual Learning",
}
@ -151,10 +150,10 @@ def compare_cl(save_dir):
draw_multi_fig(
save_dir,
timestamp,
idx,
scatter_list,
wh=(2000, 1300),
fig_title="Timestamp={:03d}".format(timestamp),
fig_title="Timestamp={:03d}".format(idx),
)
print("Save all figures into {:}".format(save_dir))
save_dir = save_dir.resolve()

View File

@ -63,7 +63,7 @@ class SyntheticDEnv(data.Dataset):
dataset = np.random.multivariate_normal(
mean_list, cov_matrix, size=self._num_per_task
)
return index, torch.Tensor(dataset)
return timestamp, torch.Tensor(dataset)
def __len__(self):
return len(self._timestamp_generator)

View File

@ -8,9 +8,10 @@ from .synthetic_env import SyntheticDEnv
def create_example_v1(
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
timestamp_config=None,
num_per_task=5000,
):
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
mean_generator = ComposedSinFunc()
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)

View File

@ -1,8 +1,10 @@
#####################################################
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
#####################################################
# To be finished.
#
import os, sys, time, torch
from typing import import Optional, Text, Callable
from typing import Optional, Text, Callable
# modules in AutoDL
from log_utils import AverageMeter
@ -60,9 +62,10 @@ def procedure(
network,
criterion,
optimizer,
eval_metric,
mode: Text,
print_freq: int = 100,
logger_fn: Callable = None
logger_fn: Callable = None,
):
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
if mode.lower() == "train":
@ -90,7 +93,7 @@ def procedure(
optimizer.step()
# record
metrics =
metrics = eval_metric(logits.data, targets.data)
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))

View File

@ -3,6 +3,7 @@
#####################################################
import abc
def obtain_accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)):
class EvaluationMetric(abc.ABC):
def __init__(self):
self._total_metrics = 0

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long