#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import builtins
import decimal
import logging
import os
import sys
import pycls.core.distributed as dist
import simplejson
from pycls.core.config import cfg
# Show filename and line number in logs
_FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s"
# Log file name (for cfg.LOG_DEST = 'file')
_LOG_FILE = "stdout.log"
# Data output with dump_log_data(data, data_type) will be tagged w/ this
_TAG = "json_stats: "
# Data output with dump_log_data(data, data_type) will have data[_TYPE]=data_type
_TYPE = "_type"
def _suppress_print():
"""Suppresses printing from the current process."""
def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False):
builtins.print = ignore
def setup_logging():
"""Sets up the logging."""
# Enable logging only for the master process
if dist.is_master_proc():
# Clear the root logger to prevent any existing logging config
# (e.g. set by another module) from messing with our setup
logging.root.handlers = []
# Construct logging configuration
logging_config = {"level": logging.INFO, "format": _FORMAT}
# Log either to stdout or to a file
if cfg.LOG_DEST == "stdout":
logging_config["stream"] = sys.stdout
logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE)
# Configure logging
def get_logger(name):
"""Retrieves the logger."""
return logging.getLogger(name)
def dump_log_data(data, data_type, prec=4):
"""Covert data (a dictionary) into tagged json string for logging."""
data[_TYPE] = data_type
data = float_to_decimal(data, prec)
data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True)
return "{:s}{:s}".format(_TAG, data_json)
def float_to_decimal(data, prec=4):
"""Convert floats to decimals which allows for fixed width json."""
if isinstance(data, dict):
return {k: float_to_decimal(v, prec) for k, v in data.items()}
if isinstance(data, float):
return decimal.Decimal(("{:." + str(prec) + "f}").format(data))
return data
def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE):
"""Get all log files in directory containing subdirs of trained models."""
names = [n for n in sorted(os.listdir(log_dir)) if name_filter in n]
files = [os.path.join(log_dir, n, log_file) for n in names]
f_n_ps = [(f, n) for (f, n) in zip(files, names) if os.path.exists(f)]
files, names = zip(*f_n_ps) if f_n_ps else ([], [])
return files, names
def load_log_data(log_file, data_types_to_skip=()):
"""Loads log data into a dictionary of the form data[data_type][metric][index]."""
# Load log_file
assert os.path.exists(log_file), "Log file not found: {}".format(log_file)
with open(log_file, "r") as f:
lines = f.readlines()
# Extract and parse lines that start with _TAG and have a type specified
lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l]
lines = [simplejson.loads(l) for l in lines]
lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip]
# Generate data structure accessed by data[data_type][index][metric]
data_types = [l[_TYPE] for l in lines]
data = {t: [] for t in data_types}
for t, line in zip(data_types, lines):
del line[_TYPE]
# Generate data structure accessed by data[data_type][metric][index]
for t in data:
metrics = sorted(data[t][0].keys())
err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics)
assert all(sorted(d.keys()) == metrics for d in data[t]), err_str
data[t] = {m: [d[m] for d in data[t]] for m in metrics}
return data
def sort_log_data(data):
"""Sort each data[data_type][metric] by epoch or keep only first instance."""
for t in data:
if "epoch" in data[t]:
assert "epoch_ind" not in data[t] and "epoch_max" not in data[t]
data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]]
data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]]
epoch = data[t]["epoch_ind"]
if "iter" in data[t]:
assert "iter_ind" not in data[t] and "iter_max" not in data[t]
data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]]
data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]]
itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"])
epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr]
for m in data[t]:
data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))]
data[t] = {m: d[0] for m, d in data[t].items()}
return data