Fix bugs
This commit is contained in:
parent
8358d71cdf
commit
d3371296a7
@ -66,23 +66,24 @@ def draw_multi_fig(save_dir, timestamp, scatter_list, wh, fig_title=None):
|
||||
|
||||
def find_min(cur, others):
|
||||
if cur is None:
|
||||
return float(others.min())
|
||||
return float(others)
|
||||
else:
|
||||
return float(min(cur, others.min()))
|
||||
return float(min(cur, others))
|
||||
|
||||
|
||||
def find_max(cur, others):
|
||||
if cur is None:
|
||||
return float(others.max())
|
||||
else:
|
||||
return float(max(cur, others.max()))
|
||||
return float(max(cur, others))
|
||||
|
||||
|
||||
def compare_cl(save_dir):
|
||||
save_dir = Path(str(save_dir))
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
dynamic_env, function = create_example_v1(
|
||||
timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
||||
# timestamp_config=dict(num=200, min_timestamp=-1, max_timestamp=1.0),
|
||||
timestamp_config=None,
|
||||
num_per_task=1000,
|
||||
)
|
||||
|
||||
@ -104,13 +105,11 @@ def compare_cl(save_dir):
|
||||
current_data["lfna_xaxis_all"] = xaxis_all
|
||||
current_data["lfna_yaxis_all"] = yaxis_all
|
||||
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
# compute cl-min
|
||||
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all)
|
||||
cl_xaxis_max = find_max(cl_xaxis_max, xaxis_all) + idx * 0.1
|
||||
cl_xaxis_min = find_min(cl_xaxis_min, xaxis_all.mean() - xaxis_all.std())
|
||||
cl_xaxis_max = (
|
||||
find_max(cl_xaxis_max, xaxis_all.mean() + xaxis_all.std()) + idx * 0.1
|
||||
)
|
||||
cl_xaxis_all = np.arange(cl_xaxis_min, cl_xaxis_max, step=0.05)
|
||||
|
||||
cl_yaxis_all = cl_function.noise_call(cl_xaxis_all)
|
||||
@ -142,8 +141,8 @@ def compare_cl(save_dir):
|
||||
"yaxis": cl_yaxis_all,
|
||||
"color": "r",
|
||||
"s": 10,
|
||||
"xlim": (-6, 6 + timestamp * 0.2),
|
||||
"ylim": (-40, 40),
|
||||
"xlim": (round(cl_xaxis_all.min(), 1), round(cl_xaxis_all.max(), 1)),
|
||||
"ylim": (round(cl_xaxis_all.min(), 1), round(cl_yaxis_all.max(), 1)),
|
||||
"alpha": 0.99,
|
||||
"label": "Continual Learning",
|
||||
}
|
||||
@ -151,10 +150,10 @@ def compare_cl(save_dir):
|
||||
|
||||
draw_multi_fig(
|
||||
save_dir,
|
||||
timestamp,
|
||||
idx,
|
||||
scatter_list,
|
||||
wh=(2000, 1300),
|
||||
fig_title="Timestamp={:03d}".format(timestamp),
|
||||
fig_title="Timestamp={:03d}".format(idx),
|
||||
)
|
||||
print("Save all figures into {:}".format(save_dir))
|
||||
save_dir = save_dir.resolve()
|
||||
|
@ -63,7 +63,7 @@ class SyntheticDEnv(data.Dataset):
|
||||
dataset = np.random.multivariate_normal(
|
||||
mean_list, cov_matrix, size=self._num_per_task
|
||||
)
|
||||
return index, torch.Tensor(dataset)
|
||||
return timestamp, torch.Tensor(dataset)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._timestamp_generator)
|
||||
|
@ -8,9 +8,10 @@ from .synthetic_env import SyntheticDEnv
|
||||
|
||||
|
||||
def create_example_v1(
|
||||
timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
timestamp_config=None,
|
||||
num_per_task=5000,
|
||||
):
|
||||
# timestamp_config=dict(num=100, min_timestamp=0.0, max_timestamp=1.0),
|
||||
mean_generator = ComposedSinFunc()
|
||||
std_generator = ComposedSinFunc(min_amplitude=0.5, max_amplitude=0.5)
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
#####################################################
|
||||
# Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2020.04 #
|
||||
#####################################################
|
||||
# To be finished.
|
||||
#
|
||||
import os, sys, time, torch
|
||||
from typing import import Optional, Text, Callable
|
||||
from typing import Optional, Text, Callable
|
||||
|
||||
# modules in AutoDL
|
||||
from log_utils import AverageMeter
|
||||
@ -60,9 +62,10 @@ def procedure(
|
||||
network,
|
||||
criterion,
|
||||
optimizer,
|
||||
eval_metric,
|
||||
mode: Text,
|
||||
print_freq: int = 100,
|
||||
logger_fn: Callable = None
|
||||
logger_fn: Callable = None,
|
||||
):
|
||||
data_time, batch_time, losses = AverageMeter(), AverageMeter(), AverageMeter()
|
||||
if mode.lower() == "train":
|
||||
@ -90,7 +93,7 @@ def procedure(
|
||||
optimizer.step()
|
||||
|
||||
# record
|
||||
metrics =
|
||||
metrics = eval_metric(logits.data, targets.data)
|
||||
prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
|
||||
losses.update(loss.item(), inputs.size(0))
|
||||
top1.update(prec1.item(), inputs.size(0))
|
||||
|
@ -3,6 +3,7 @@
|
||||
#####################################################
|
||||
import abc
|
||||
|
||||
|
||||
def obtain_accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
@ -20,7 +21,6 @@ def obtain_accuracy(output, target, topk=(1,)):
|
||||
|
||||
|
||||
class EvaluationMetric(abc.ABC):
|
||||
|
||||
def __init__(self):
|
||||
self._total_metrics = 0
|
||||
|
||||
|
110
notebooks/LFNA/synthetic-data.ipynb
Normal file
110
notebooks/LFNA/synthetic-data.ipynb
Normal file
File diff suppressed because one or more lines are too long
137
notebooks/LFNA/synthetic-env.ipynb
Normal file
137
notebooks/LFNA/synthetic-env.ipynb
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user