| | import os |
| |
|
| | import click |
| | from huggingface_hub import HfApi |
| | from loguru import logger |
| |
|
| | from src import config |
| | from src import data |
| | from src import loss |
| | from src import models |
| | from src import tokenizer as tk |
| | from src import vision_model |
| | from src import utils |
| | from src.lightning_module import LightningModule |
| |
|
| |
|
| | def _upload_model_to_hub( |
| | vision_encoder: models.TinyCLIPVisionEncoder, |
| | text_encoder: models.TinyCLIPTextEncoder, |
| | debug: bool = False, |
| | ): |
| | vision_encoder.save_pretrained( |
| | str(config.VISION_MODEL_PATH), |
| | safe_serialization=True, |
| | ) |
| | text_encoder.save_pretrained( |
| | str(config.TEXT_MODEL_PATH), |
| | safe_serialization=True, |
| | ) |
| |
|
| | api = HfApi() |
| | if debug: |
| | repo_components = config.REPO_ID.split("/", maxsplit=1) |
| | repo_components[1] = f"debug-{repo_components[1]}" |
| | repo_id = "/".join(repo_components) |
| | else: |
| | repo_id = config.REPO_ID |
| | common_hf_api_params = { |
| | "repo_id": repo_id, |
| | "repo_type": "model", |
| | } |
| | if not api.repo_exists(**common_hf_api_params): |
| | logger.info(f"Creating repo {repo_id} on Hugging Face Hub.") |
| | api.create_repo(**common_hf_api_params) |
| | logger.info(f"Uploading models in {str(config.MODEL_PATH)} to {repo_id}.") |
| | api.upload_folder( |
| | folder_path=config.MODEL_PATH, |
| | **common_hf_api_params, |
| | ) |
| |
|
| |
|
| | @click.group() |
| | def cli(): |
| | pass |
| |
|
| |
|
| | @click.command() |
| | @click.option("--trainer-config-json", required=False, default="{}", type=str) |
| | def train(trainer_config_json: str): |
| | if "HF_TOKEN" not in os.environ: |
| | raise ValueError("Please set the HF_TOKEN environment variable.") |
| | trainer_config = config.TrainerConfig.model_validate_json(trainer_config_json) |
| | transform = vision_model.get_vision_transform(trainer_config._model_config.vision_config) |
| | tokenizer = tk.Tokenizer(trainer_config._model_config.text_config) |
| | train_dl, valid_dl = data.get_dataset( |
| | transform=transform, tokenizer=tokenizer, hyper_parameters=trainer_config |
| | ) |
| | vision_encoder = models.TinyCLIPVisionEncoder(config=trainer_config._model_config.vision_config) |
| | text_encoder = models.TinyCLIPTextEncoder(config=trainer_config._model_config.text_config) |
| |
|
| | lightning_module = LightningModule( |
| | vision_encoder=vision_encoder, |
| | text_encoder=text_encoder, |
| | loss_fn=loss.get_loss(trainer_config._model_config.loss_type), |
| | hyper_parameters=trainer_config, |
| | len_train_dl=len(train_dl), |
| | ) |
| |
|
| | trainer = utils.get_trainer(trainer_config) |
| | trainer.fit(lightning_module, train_dl, valid_dl) |
| |
|
| | _upload_model_to_hub(vision_encoder, text_encoder, trainer_config.debug) |
| |
|
| |
|
| | cli.add_command(train) |
| |
|
| | if __name__ == "__main__": |
| | cli() |
| |
|