| """ |
| Utilities for working with the local dataset cache. |
| This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp |
| Copyright by the AllenNLP authors. |
| """ |
| from __future__ import (absolute_import, division, print_function, unicode_literals) |
|
|
| import sys |
| import json |
| import logging |
| import os |
| import shutil |
| import tempfile |
| import fnmatch |
| from functools import wraps |
| from hashlib import sha256 |
| import sys |
| from io import open |
|
|
| import boto3 |
| import requests |
| from botocore.exceptions import ClientError |
| from tqdm import tqdm |
|
|
| try: |
| from urllib.parse import urlparse |
| except ImportError: |
| from urlparse import urlparse |
|
|
| try: |
| from pathlib import Path |
| PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', |
| Path.home() / '.pytorch_pretrained_bert')) |
| except (AttributeError, ImportError): |
| PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', |
| os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) |
|
|
| CONFIG_NAME = "config.json" |
| WEIGHTS_NAME = "pytorch_model.bin" |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def url_to_filename(url, etag=None): |
| """ |
| Convert `url` into a hashed filename in a repeatable way. |
| If `etag` is specified, append its hash to the url's, delimited |
| by a period. |
| """ |
| url_bytes = url.encode('utf-8') |
| url_hash = sha256(url_bytes) |
| filename = url_hash.hexdigest() |
|
|
| if etag: |
| etag_bytes = etag.encode('utf-8') |
| etag_hash = sha256(etag_bytes) |
| filename += '.' + etag_hash.hexdigest() |
|
|
| return filename |
|
|
|
|
| def filename_to_url(filename, cache_dir=None): |
| """ |
| Return the url and etag (which may be ``None``) stored for `filename`. |
| Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. |
| """ |
| if cache_dir is None: |
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| cache_path = os.path.join(cache_dir, filename) |
| if not os.path.exists(cache_path): |
| raise EnvironmentError("file {} not found".format(cache_path)) |
|
|
| meta_path = cache_path + '.json' |
| if not os.path.exists(meta_path): |
| raise EnvironmentError("file {} not found".format(meta_path)) |
|
|
| with open(meta_path, encoding="utf-8") as meta_file: |
| metadata = json.load(meta_file) |
| url = metadata['url'] |
| etag = metadata['etag'] |
|
|
| return url, etag |
|
|
|
|
| def cached_path(url_or_filename, cache_dir=None): |
| """ |
| Given something that might be a URL (or might be a local path), |
| determine which. If it's a URL, download the file and cache it, and |
| return the path to the cached file. If it's already a local path, |
| make sure the file exists and then return the path. |
| """ |
| if cache_dir is None: |
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): |
| url_or_filename = str(url_or_filename) |
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| parsed = urlparse(url_or_filename) |
|
|
| if parsed.scheme in ('http', 'https', 's3'): |
| |
| return get_from_cache(url_or_filename, cache_dir) |
| elif os.path.exists(url_or_filename): |
| |
| return url_or_filename |
| elif parsed.scheme == '': |
| |
| raise EnvironmentError("file {} not found".format(url_or_filename)) |
| else: |
| |
| raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) |
|
|
|
|
| def split_s3_path(url): |
| """Split a full s3 path into the bucket name and path.""" |
| parsed = urlparse(url) |
| if not parsed.netloc or not parsed.path: |
| raise ValueError("bad s3 path {}".format(url)) |
| bucket_name = parsed.netloc |
| s3_path = parsed.path |
| |
| if s3_path.startswith("/"): |
| s3_path = s3_path[1:] |
| return bucket_name, s3_path |
|
|
|
|
| def s3_request(func): |
| """ |
| Wrapper function for s3 requests in order to create more helpful error |
| messages. |
| """ |
|
|
| @wraps(func) |
| def wrapper(url, *args, **kwargs): |
| try: |
| return func(url, *args, **kwargs) |
| except ClientError as exc: |
| if int(exc.response["Error"]["Code"]) == 404: |
| raise EnvironmentError("file {} not found".format(url)) |
| else: |
| raise |
|
|
| return wrapper |
|
|
|
|
| @s3_request |
| def s3_etag(url): |
| """Check ETag on S3 object.""" |
| s3_resource = boto3.resource("s3") |
| bucket_name, s3_path = split_s3_path(url) |
| s3_object = s3_resource.Object(bucket_name, s3_path) |
| return s3_object.e_tag |
|
|
|
|
| @s3_request |
| def s3_get(url, temp_file): |
| """Pull a file directly from S3.""" |
| s3_resource = boto3.resource("s3") |
| bucket_name, s3_path = split_s3_path(url) |
| s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) |
|
|
|
|
| def http_get(url, temp_file): |
| req = requests.get(url, stream=True) |
| content_length = req.headers.get('Content-Length') |
| total = int(content_length) if content_length is not None else None |
| progress = tqdm(unit="B", total=total) |
| for chunk in req.iter_content(chunk_size=1024): |
| if chunk: |
| progress.update(len(chunk)) |
| temp_file.write(chunk) |
| progress.close() |
|
|
|
|
| def get_from_cache(url, cache_dir=None): |
| """ |
| Given a URL, look for the corresponding dataset in the local cache. |
| If it's not there, download it. Then return the path to the cached file. |
| """ |
| if cache_dir is None: |
| cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
| cache_dir = str(cache_dir) |
|
|
| if not os.path.exists(cache_dir): |
| os.makedirs(cache_dir) |
|
|
| |
| if url.startswith("s3://"): |
| etag = s3_etag(url) |
| else: |
| try: |
| response = requests.head(url, allow_redirects=True) |
| if response.status_code != 200: |
| etag = None |
| else: |
| etag = response.headers.get("ETag") |
| except EnvironmentError: |
| etag = None |
|
|
| if sys.version_info[0] == 2 and etag is not None: |
| etag = etag.decode('utf-8') |
| filename = url_to_filename(url, etag) |
|
|
| |
| cache_path = os.path.join(cache_dir, filename) |
|
|
| |
| |
| if not os.path.exists(cache_path) and etag is None: |
| matching_files = fnmatch.filter(os.listdir(cache_dir), filename + '.*') |
| matching_files = list(filter(lambda s: not s.endswith('.json'), matching_files)) |
| if matching_files: |
| cache_path = os.path.join(cache_dir, matching_files[-1]) |
|
|
| if not os.path.exists(cache_path): |
| |
| |
| with tempfile.NamedTemporaryFile() as temp_file: |
| logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
|
|
| |
| if url.startswith("s3://"): |
| s3_get(url, temp_file) |
| else: |
| http_get(url, temp_file) |
|
|
| |
| temp_file.flush() |
| |
| temp_file.seek(0) |
|
|
| logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
| with open(cache_path, 'wb') as cache_file: |
| shutil.copyfileobj(temp_file, cache_file) |
|
|
| logger.info("creating metadata file for %s", cache_path) |
| meta = {'url': url, 'etag': etag} |
| meta_path = cache_path + '.json' |
| with open(meta_path, 'w') as meta_file: |
| output_string = json.dumps(meta) |
| if sys.version_info[0] == 2 and isinstance(output_string, str): |
| output_string = unicode(output_string, 'utf-8') |
| meta_file.write(output_string) |
|
|
| logger.info("removing temp file %s", temp_file.name) |
|
|
| return cache_path |
|
|
|
|
| def read_set_from_file(filename): |
| ''' |
| Extract a de-duped collection (set) of text from a file. |
| Expected file format is one item per line. |
| ''' |
| collection = set() |
| with open(filename, 'r', encoding='utf-8') as file_: |
| for line in file_: |
| collection.add(line.rstrip()) |
| return collection |
|
|
|
|
| def get_file_extension(path, dot=True, lower=True): |
| ext = os.path.splitext(path)[1] |
| ext = ext if dot else ext[1:] |
| return ext.lower() if lower else ext |
|
|