78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||
|
#
|
||
|
# This source code is licensed under the MIT license found in the
|
||
|
# LICENSE file in the root directory of this source tree.
|
||
|
|
||
|
"""IO utilities (adapted from Detectron)"""
|
||
|
|
||
|
import logging
|
||
|
import os
|
||
|
import re
|
||
|
import sys
|
||
|
from urllib import request as urlrequest
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
_PYCLS_BASE_URL = "https://dl.fbaipublicfiles.com/pycls"
|
||
|
|
||
|
|
||
|
def cache_url(url_or_file, cache_dir):
|
||
|
"""Download the file specified by the URL to the cache_dir and return the path to
|
||
|
the cached file. If the argument is not a URL, simply return it as is.
|
||
|
"""
|
||
|
is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None
|
||
|
if not is_url:
|
||
|
return url_or_file
|
||
|
url = url_or_file
|
||
|
err_str = "pycls only automatically caches URLs in the pycls S3 bucket: {}"
|
||
|
assert url.startswith(_PYCLS_BASE_URL), err_str.format(_PYCLS_BASE_URL)
|
||
|
cache_file_path = url.replace(_PYCLS_BASE_URL, cache_dir)
|
||
|
if os.path.exists(cache_file_path):
|
||
|
return cache_file_path
|
||
|
cache_file_dir = os.path.dirname(cache_file_path)
|
||
|
if not os.path.exists(cache_file_dir):
|
||
|
os.makedirs(cache_file_dir)
|
||
|
logger.info("Downloading remote file {} to {}".format(url, cache_file_path))
|
||
|
download_url(url, cache_file_path)
|
||
|
return cache_file_path
|
||
|
|
||
|
|
||
|
def _progress_bar(count, total):
|
||
|
"""Report download progress. Credit:
|
||
|
https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console/27871113
|
||
|
"""
|
||
|
bar_len = 60
|
||
|
filled_len = int(round(bar_len * count / float(total)))
|
||
|
percents = round(100.0 * count / float(total), 1)
|
||
|
bar = "=" * filled_len + "-" * (bar_len - filled_len)
|
||
|
sys.stdout.write(
|
||
|
" [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024)
|
||
|
)
|
||
|
sys.stdout.flush()
|
||
|
if count >= total:
|
||
|
sys.stdout.write("\n")
|
||
|
|
||
|
|
||
|
def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar):
|
||
|
"""Download url and write it to dst_file_path. Credit:
|
||
|
https://stackoverflow.com/questions/2028517/python-urllib2-progress-hook
|
||
|
"""
|
||
|
req = urlrequest.Request(url)
|
||
|
response = urlrequest.urlopen(req)
|
||
|
total_size = response.info().get("Content-Length").strip()
|
||
|
total_size = int(total_size)
|
||
|
bytes_so_far = 0
|
||
|
with open(dst_file_path, "wb") as f:
|
||
|
while 1:
|
||
|
chunk = response.read(chunk_size)
|
||
|
bytes_so_far += len(chunk)
|
||
|
if not chunk:
|
||
|
break
|
||
|
if progress_hook:
|
||
|
progress_hook(bytes_so_far, total_size)
|
||
|
f.write(chunk)
|
||
|
return bytes_so_far
|