File size: 1,623 Bytes
0d321f1
 
9a71624
 
4d6298c
0d321f1
0fc632b
 
 
9a71624
4d6298c
 
 
 
9a71624
 
 
0d321f1
 
 
9a71624
0d321f1
9a71624
0d321f1
9a71624
 
 
0d321f1
8d28be7
0d321f1
 
 
 
9a71624
 
 
 
0fc632b
 
 
 
 
9a71624
0fc632b
 
 
 
 
9a71624
 
 
 
 
 
 
 
 
0d321f1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import shutil
import logging
from huggingface_hub import snapshot_download
from config import Config

os.environ.setdefault("CUDA_VISIBLE_DEVICES", "-1")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")

# Model config
REPO_ID = Config.IMAGE_CLASSIFIER_REPO_ID
MODEL_DIR = Config.IMAGE_CLASSIFIER_MODEL_DIR
WEIGHTS_PATH = os.path.join(MODEL_DIR, Config.IMAGE_CLASSIFIER_WEIGHTS_FILE)
HF_TOKEN = Config.HF_TOKEN

# Global model reference
_model_img = None

def warmup():
    global _model_img
    download_model_repo()
    _model_img = load_model()
    logging.info("Image model is ready.")

def download_model_repo():
    if os.path.exists(MODEL_DIR) and os.path.isdir(MODEL_DIR):
        logging.info("Image model already exists, skipping download.")
        return
    snapshot_path = snapshot_download(repo_id=REPO_ID, token=HF_TOKEN)
    os.makedirs(MODEL_DIR, exist_ok=True)
    shutil.copytree(snapshot_path, MODEL_DIR, dirs_exist_ok=True)

def load_model():
    global _model_img
    if _model_img is not None:
        return _model_img

    import tensorflow as tf

    class Cast(tf.keras.layers.Layer):
        def call(self, inputs):
            return tf.cast(inputs, tf.float32)

    print("Loading image model on CPU.")
    with tf.device("/CPU:0"):
        _model_img = tf.keras.models.load_model(
            WEIGHTS_PATH, custom_objects={"Cast": Cast}
        )
    print("Model input shape:", _model_img.input_shape)
    return _model_img

def get_model():
    global _model_img
    if _model_img is None:
        download_model_repo()
        _model_img = load_model()
    return _model_img