Fix bugs
This commit is contained in:
parent
8358d71cdf
commit
d3371296a7
@ -66,23 +66,24 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
|||||||
|
|
||||||
def find_min(cur, others):
|
def find_min(cur, others):
|
||||||
if cur is None:
|
if cur is None:
|
||||||
return float(others.min())
|
return float(others)
|
||||||
else:
|
else:
|
||||||
return float(min(cur, others.min()))
|
return float(min(cur, others))
|
||||||
|
|
||||||
|
|
||||||
def find_max(cur, others):
|
def find_max(cur, others):
|
||||||
if cur is None:
|
if cur is None:
|
||||||
return float(others.max())
|
return float(others.max())
|
||||||
else:
|
else:
|
||||||
return float(max(cur, others.max()))
|
return float(max(cur, others))
|
||||||
|
|
||||||
|
|
||||||
def compare_cl(save_dir):
|
def compare_cl(save_dir):
|
||||||
save_dir = Path(str(save_dir))
|
save_dir = Path(str(save_dir))
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
dynamic_env, function = create_example_v1(
|
dynamic_env, function = create_example_v1(
|
||||||
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
||||||
|
timestamp_config=None,
|
||||||
num_per_task=1000,
|
num_per_task=1000,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,13 +105,11 @@ def compare_cl(save_dir):
|
|||||||
current_data["lfna_xaxis_all"] = xaxis_all
|
current_data["lfna_xaxis_all"] = xaxis_all
|
||||||
current_data["lfna_yaxis_all"] = yaxis_all
|
current_data["lfna_yaxis_all"] = yaxis_all
|
||||||
|
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
|
|
||||||
# compute cl-min
|
# compute cl-min
|
||||||
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all)
|
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
|
||||||
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1
|
cl_xaxis_max = (
|
||||||
|
find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1
|
||||||
|
)
|
||||||
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
|
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
|
||||||
|
|
||||||
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
|
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
|
||||||
@ -142,8 +141,8 @@ def compare_cl(save_dir):
|
|||||||
"yaxis": cl_yaxis_all,
|
"yaxis": cl_yaxis_all,
|
||||||
"color": "r",
|
"color": "r",
|
||||||
"s": 10,
|
"s": 10,
|
||||||
"xlim": (-6, 6 + timestamp * 0.2),
|
"xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)),
|
||||||
"ylim": (-40, 40),
|
"ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)),
|
||||||
"alpha": 0.99,
|
"alpha": 0.99,
|
||||||
"label": "Continual Learning",
|
"label": "Continual Learning",
|
||||||
}
|
}
|
||||||
@ -151,10 +150,10 @@ def compare_cl(save_dir):
|
|||||||
|
|
||||||
draw_multi_fig(
|
draw_multi_fig(
|
||||||
save_dir,
|
save_dir,
|
||||||
timestamp,
|
idx,
|
||||||
scatter_list,
|
scatter_list,
|
||||||
wh=(2000, 1300),
|
wh=(2000, 1300),
|
||||||
fig_title="Timestamp={:03d}".format(timestamp),
|
fig_title="Timestamp={:03d}".format(idx),
|
||||||
)
|
)
|
||||||
print("Save all figures into {:}".format(save_dir))
|
print("Save all figures into {:}".format(save_dir))
|
||||||
save_dir = save_dir.resolve()
|
save_dir = save_dir.resolve()
|
||||||
|
@ -63,7 +63,7 @@ class SyntheticDEnv(data.Dataset):
|
|||||||
dataset = np.random.multivariate_normal(
|
dataset = np.random.multivariate_normal(
|
||||||
mean_list, cov_matrix, size=self._num_per_task
|
mean_list, cov_matrix, size=self._num_per_task
|
||||||
)
|
)
|
||||||
return index, torch.Tensor(dataset)
|
return timestamp, torch.Tensor(dataset)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._timestamp_generator)
|
return len(self._timestamp_generator)
|
||||||
|
@ -8,9 +8,10 @@ from .synthetic_env import SyntheticDEnv
|
|||||||
|
|
||||||
|
|
||||||
def create_example_v1(
|
def create_example_v1(
|
||||||
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
timestamp_config=None,
|
||||||
num_per_task=5000,
|
num_per_task=5000,
|
||||||
):
|
):
|
||||||
|
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||||
mean_generator = ComposedSinFunc()
|
mean_generator = ComposedSinFunc()
|
||||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||||
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||||
#####################################################
|
#####################################################
|
||||||
|
# To be finished.
|
||||||
|
#
|
||||||
import os, sys, time, torch
|
import os, sys, time, torch
|
||||||
from typing import import Optional, Text, Callable
|
from typing import Optional, Text, Callable
|
||||||
|
|
||||||
# modules in AutoDL
|
# modules in AutoDL
|
||||||
from log_utils import AverageMeter
|
from log_utils import AverageMeter
|
||||||
@ -60,9 +62,10 @@ def procedure(
|
|||||||
network,
|
network,
|
||||||
criterion,
|
criterion,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
eval_metric,
|
||||||
mode: Text,
|
mode: Text,
|
||||||
print_freq: int = 100,
|
print_freq: int = 100,
|
||||||
logger_fn: Callable = None
|
logger_fn: Callable = None,
|
||||||
):
|
):
|
||||||
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
|
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
|
||||||
if mode.lower() == "train":
|
if mode.lower() == "train":
|
||||||
@ -90,7 +93,7 @@ def procedure(
|
|||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# record
|
# record
|
||||||
metrics =
|
metrics = eval_metric(logits.data, targets.data)
|
||||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||||
losses.update(loss.item(), inputs.size(0))
|
losses.update(loss.item(), inputs.size(0))
|
||||||
top1.update(prec1.item(), inputs.size(0))
|
top1.update(prec1.item(), inputs.size(0))
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
#####################################################
|
#####################################################
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
|
|
||||||
def obtain_accuracy(output, target, topk=(1,)):
|
def obtain_accuracy(output, target, topk=(1,)):
|
||||||
"""Computes the precision@k for the specified values of k"""
|
"""Computes the precision@k for the specified values of k"""
|
||||||
maxk = max(topk)
|
maxk = max(topk)
|
||||||
@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)):
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationMetric(abc.ABC):
|
class EvaluationMetric(abc.ABC):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._total_metrics = 0
|
self._total_metrics = 0
|
||||||
|
|
||||||
|
110
notebooks/LFNA/synthetic-data.ipynb
Normal file
110
notebooks/LFNA/synthetic-data.ipynb
Normal file
File diff suppressed because one or more lines are too long
137
notebooks/LFNA/synthetic-env.ipynb
Normal file
137
notebooks/LFNA/synthetic-env.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user