| | import json |
| | import os |
| | import urllib |
| | from tqdm import tqdm |
| |
|
| | from vlmo.config import config, _loss_names |
| | from vlmo.modules import VLMo |
| | from vlmo.transforms import keys_to_transforms |
| |
|
| | def _download(url: str, root: str): |
| | os.makedirs(root, exist_ok=True) |
| | filename = os.path.basename(url) |
| |
|
| | download_target = os.path.join(root, filename) |
| |
|
| | if os.path.exists(download_target) and not os.path.isfile(download_target): |
| | raise RuntimeError(f"{download_target} exists and is not a regular file") |
| |
|
| | if os.path.isfile(download_target): |
| | return download_target |
| |
|
| | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| | with tqdm( |
| | total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024 |
| | ) as loop: |
| | while True: |
| | buffer = source.read(8192) |
| | if not buffer: |
| | break |
| |
|
| | output.write(buffer) |
| | loop.update(len(buffer)) |
| |
|
| | return download_target |
| |
|
| |
|
| | def config_setting(custom_config: dict): |
| | cfg = eval("config")() |
| | for k, v in custom_config.items(): |
| | cfg[k] = v |
| | return cfg |
| |
|
| |
|
| | def load_from_config(model_config): |
| | if isinstance(model_config, str): |
| | model_config = json.loads(open(model_config, 'r').read()) |
| | else: |
| | assert isinstance(model_config, dict) |
| |
|
| | model_url = model_config.pop('model_url', None) |
| | model_path = model_config.pop('model_path', None) |
| | if model_path and os.path.exists(model_path): |
| | load_path = model_path |
| | elif model_url: |
| | load_path = _download(model_url, os.path.expanduser("~/.cache/m2_encoder")) |
| | else: |
| | from modelscope import snapshot_download |
| | modelscope_cfg = model_config.pop('modelscope', None) |
| | model_dir = snapshot_download(**modelscope_cfg) |
| | load_path = os.path.join(model_dir, model_config.pop('model_file')) |
| |
|
| | cfg = config_setting(model_config) |
| | cfg["load_path"] = load_path |
| |
|
| | if cfg["flash_attn"]: |
| | from vlmo.utils.patch_utils import patch_torch_scale_with_flash_attn |
| | patch_torch_scale_with_flash_attn() |
| |
|
| | model = VLMo(cfg) |
| |
|
| | from vlmo.modules.vlmo_module import get_pretrained_tokenizer |
| | txt_processor = get_pretrained_tokenizer(cfg["tokenizer_type"], from_pretrained=cfg["tokenizer"]) |
| | img_processor = keys_to_transforms(cfg["val_transform_keys"], size=cfg["image_size"])[0] |
| |
|
| | return model, [txt_processor, img_processor] |
| |
|