naswot/pycls/core/logging.py
Jack Turner b74255e1f3 v2
2021-02-26 16:12:51 +00:00

139 lines
5.0 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Logging."""
import builtins
import decimal
import logging
import os
import sys
import pycls.core.distributed as dist
import simplejson
from pycls.core.config import cfg
# Show filename and line number in logs
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
# Log file name (for cfg.LOG_DEST = 'file')
_LOG_FILE = "stdout.log"
# Data output with dump_log_data(data, data_type) will be tagged w/ this
_TAG = "json_stats: "
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
_TYPE = "_type"
def _suppress_print():
"""Suppresses printing from the current process."""
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
pass
builtins.print = ignore
def setup_logging():
"""Sets up the logging."""
# Enable logging only for the master process
if dist.is_master_proc():
# Clear the root logger to prevent any existing logging config
# (e.g. set by another module) from messing with our setup
logging.root.handlers = []
# Construct logging configuration
logging_config = {"level": logging.INFO, "format": _FORMAT}
# Log either to stdout or to a file
if cfg.LOG_DEST == "stdout":
logging_config["stream"] = sys.stdout
else:
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
# Configure logging
logging.basicConfig(**logging_config)
else:
_suppress_print()
def get_logger(name):
"""Retrieves the logger."""
return logging.getLogger(name)
def dump_log_data(data, data_type, prec=4):
"""Covert data (a dictionary) into tagged json string for logging."""
data[_TYPE] = data_type
data = float_to_decimal(data, prec)
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
return "{:s}{:s}".format(_TAG, data_json)
def float_to_decimal(data, prec=4):
"""Convert floats to decimals which allows for fixed width json."""
if isinstance(data, dict):
return {k: float_to_decimal(v, prec) for k, v in data.items()}
if isinstance(data, float):
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
else:
return data
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, log_file) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
return files, names
def load_log_data(log_file, data_types_to_skip=()):
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
# Load log_file
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
with open(log_file, "r") as f:
lines = f.readlines()
# Extract and parse lines that start with _TAG and have a type specified
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
lines = [simplejson.loads(l) for l in lines]
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
# Generate data structure accessed by data[data_type][index][metric]
data_types = [l[_TYPE] for l in lines]
data = {t: [] for t in data_types}
for t, line in zip(data_types, lines):
del line[_TYPE]
data[t].append(line)
# Generate data structure accessed by data[data_type][metric][index]
for t in data:
metrics = sorted(data[t][0].keys())
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
return data
def sort_log_data(data):
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
for t in data:
if "epoch" in data[t]:
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
epoch = data[t]["epoch_ind"]
if "iter" in data[t]:
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
for m in data[t]:
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
else:
data[t] = {m: d[0] for m, d in data[t].items()}
return data