139 lines
5.0 KiB
Python
139 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""Logging."""
|
|
|
|
import builtins
|
|
import decimal
|
|
import logging
|
|
import os
|
|
import sys
|
|
|
|
import pycls.core.distributed as dist
|
|
import simplejson
|
|
from pycls.core.config import cfg
|
|
|
|
|
|
# Show filename and line number in logs
|
|
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
|
|
|
|
# Log file name (for cfg.LOG_DEST = 'file')
|
|
_LOG_FILE = "stdout.log"
|
|
|
|
# Data output with dump_log_data(data, data_type) will be tagged w/ this
|
|
_TAG = "json_stats: "
|
|
|
|
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
|
|
_TYPE = "_type"
|
|
|
|
|
|
def _suppress_print():
|
|
"""Suppresses printing from the current process."""
|
|
|
|
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
|
|
pass
|
|
|
|
builtins.print = ignore
|
|
|
|
|
|
def setup_logging():
|
|
"""Sets up the logging."""
|
|
# Enable logging only for the master process
|
|
if dist.is_master_proc():
|
|
# Clear the root logger to prevent any existing logging config
|
|
# (e.g. set by another module) from messing with our setup
|
|
logging.root.handlers = []
|
|
# Construct logging configuration
|
|
logging_config = {"level": logging.INFO, "format": _FORMAT}
|
|
# Log either to stdout or to a file
|
|
if cfg.LOG_DEST == "stdout":
|
|
logging_config["stream"] = sys.stdout
|
|
else:
|
|
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
|
|
# Configure logging
|
|
logging.basicConfig(**logging_config)
|
|
else:
|
|
_suppress_print()
|
|
|
|
|
|
def get_logger(name):
|
|
"""Retrieves the logger."""
|
|
return logging.getLogger(name)
|
|
|
|
|
|
def dump_log_data(data, data_type, prec=4):
|
|
"""Covert data (a dictionary) into tagged json string for logging."""
|
|
data[_TYPE] = data_type
|
|
data = float_to_decimal(data, prec)
|
|
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
|
|
return "{:s}{:s}".format(_TAG, data_json)
|
|
|
|
|
|
def float_to_decimal(data, prec=4):
|
|
"""Convert floats to decimals which allows for fixed width json."""
|
|
if isinstance(data, dict):
|
|
return {k: float_to_decimal(v, prec) for k, v in data.items()}
|
|
if isinstance(data, float):
|
|
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
|
|
else:
|
|
return data
|
|
|
|
|
|
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
|
|
"""Get all log files in directory containing subdirs of trained models."""
|
|
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
|
|
files = [os.path.join(log_dir, n, log_file) for n in names]
|
|
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
|
|
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
|
|
return files, names
|
|
|
|
|
|
def load_log_data(log_file, data_types_to_skip=()):
|
|
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
|
|
# Load log_file
|
|
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
|
|
with open(log_file, "r") as f:
|
|
lines = f.readlines()
|
|
# Extract and parse lines that start with _TAG and have a type specified
|
|
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
|
|
lines = [simplejson.loads(l) for l in lines]
|
|
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
|
|
# Generate data structure accessed by data[data_type][index][metric]
|
|
data_types = [l[_TYPE] for l in lines]
|
|
data = {t: [] for t in data_types}
|
|
for t, line in zip(data_types, lines):
|
|
del line[_TYPE]
|
|
data[t].append(line)
|
|
# Generate data structure accessed by data[data_type][metric][index]
|
|
for t in data:
|
|
metrics = sorted(data[t][0].keys())
|
|
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
|
|
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
|
|
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
|
|
return data
|
|
|
|
|
|
def sort_log_data(data):
|
|
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
|
|
for t in data:
|
|
if "epoch" in data[t]:
|
|
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
|
|
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
|
|
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
|
|
epoch = data[t]["epoch_ind"]
|
|
if "iter" in data[t]:
|
|
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
|
|
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
|
|
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
|
|
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
|
|
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
|
|
for m in data[t]:
|
|
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
|
|
else:
|
|
data[t] = {m: d[0] for m, d in data[t].items()}
|
|
return data
|