File size: 12,254 Bytes
747451d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | # /*---------------------------------------------------------------------------------------------
# * Copyright (c) 2025 STMicroelectronics.
# * All rights reserved.
# *
# * This software is licensed under terms that can be found in the LICENSE file in
# * the root directory of this software component.
# * If no LICENSE file comes with this software, it is provided AS-IS.
# *--------------------------------------------------------------------------------------------*/
import collections
import fnmatch
import texttable
import os
# for common registries
from common.registries.dataset_registry import DATASET_WRAPPER_REGISTRY
from common.registries.model_registry import MODEL_WRAPPER_REGISTRY
from common.registries.trainer_registry import TRAINER_WRAPPER_REGISTRY
from common.registries.quantizer_registry import QUANTIZER_WRAPPER_REGISTRY
from common.registries.evaluator_registry import EVALUATOR_WRAPPER_REGISTRY
from common.registries.predictor_registry import PREDICTOR_WRAPPER_REGISTRY
from common.utils import LOGGER
# for tensorflow based image classification
import image_classification.tf.wrappers.datasets # this is done so that registeration happens before get_dataloader is called
import image_classification.tf.wrappers.models # this is done so that registeration happens before get_model is called
import image_classification.tf.wrappers.training # this is done so that registeration happens before get_trainer is called
import image_classification.tf.wrappers.quantization # this is done so that registeration happens before get_quantizer is called
import image_classification.tf.wrappers.evaluation # this is done so that registeration happens before get_evaluator is called
import image_classification.tf.wrappers.prediction # this is done so that registeration happens before get_predictor is called
# for pytorch based image classification
import image_classification.pt.wrappers.datasets # this is done so that registeration happens before get_dataloader is called
import image_classification.pt.wrappers.models # this is done so that registeration happens before get_model is called
import image_classification.pt.wrappers.training # this is done so that registeration happens before get_trainer is called
import image_classification.pt.wrappers.quantization # this is done so that registeration happens before get_quantizer is called
import image_classification.pt.wrappers.evaluation # this is done so that registeration happens before get_evaluator is called
import image_classification.pt.wrappers.prediction # this is done so that registeration happens before get_evaluator is called
# for tensorflow based object detection
import object_detection.tf.wrappers.models.standard_models
import object_detection.tf.wrappers.models.custom_models
import object_detection.tf.wrappers.datasets # this is done so that registeration happens before get_dataloader is called
import object_detection.tf.wrappers.prediction # this is done so that registeration happens before get_predictor is called
import object_detection.tf.wrappers.evaluation # this is done so that registeration happens before get_evaluator is called
import object_detection.tf.wrappers.quantization # this is done so that registeration happens before get_quantizer is called
import object_detection.tf.wrappers.training # this is done so that registeration happens before get_trainer is called
# for pytorch based object detection
import object_detection.pt.wrappers.models
import object_detection.pt.wrappers.datasets
import object_detection.pt.wrappers.training
import object_detection.pt.wrappers.evaluation
from common.model_utils.tf_model_loader import load_model_from_path
from pathlib import Path
import torch
import tensorflow as tf
import onnxruntime
__all__ = ['get_dataloaders',
"get_model",
"get_trainer",
"get_quantizer",
"get_evaluator",
"get_predictor",
"list_models",
"list_models_by_dataset",]
def get_dataloaders(cfg):
"""
Tries to find a matching dataloader creation wrapper function in the registry and uses it to create a new dataloder dict.
returns datasplits in the following format:
{
'train': train_data_loader,
'valid': valid_data_loader,
'quantization': quantization_data_loader,
'test' : test_data_loader
'predict' : predict_data_loader
}
TODO : add other keys like val, quant
"""
#if mode not in ["benchmarking", "deployment"]:
if cfg.dataset.dataset_name != "<unnamed>":
data_split_wrapper_fn = DATASET_WRAPPER_REGISTRY.get(framework=cfg.model.framework,
dataset_name=cfg.dataset.dataset_name,
use_case=cfg.use_case)
# The registered function expects a single cfg object
return data_split_wrapper_fn(cfg)
else:
return {'train': None, 'valid': None, 'quantization': None, 'test': None, 'predict': None,}
def get_model(cfg):
"""
Tries to find a matching model creation wrapper function in the registry and uses it to create a new model object.
"""
allowed_exts = [".keras", ".h5", ".tflite", ".onnx"]
model_path = getattr(cfg.model, "model_path", None)
if model_path:
_, ext = os.path.splitext(model_path)
if ext.lower() in allowed_exts:
LOGGER.info(f"Loading model from {model_path}")
model = load_model_from_path(cfg, model_path)
return model
# LOGGER.info(f"Loading model from {model_path}")
# _, ext = os.path.splitext(model_path)
# if ext.lower() in allowed_exts:
# model = load_model_from_path(cfg, model_path)
#
# Covers cases where there is no model_path provided, or pt checkpoints
model_func = MODEL_WRAPPER_REGISTRY.get(model_name=cfg.model.model_name.lower(),
use_case=cfg.use_case,
framework=cfg.model.framework)
# The registered function expects a single cfg object
model = model_func(cfg)
saved_model_dir = os.path.join(cfg.output_dir, cfg.general.saved_models_dir)
os.makedirs(saved_model_dir, exist_ok=True)
if cfg.model.framework == 'tf':
saved_model = Path(saved_model_dir, cfg.model.model_name, cfg.model.model_name+".keras")
saved_model.parent.mkdir(exist_ok=True)
model.save(saved_model)
setattr(model, 'model_path', saved_model)
cfg.model.model_path = saved_model
return model
def get_trainer(dataloaders, model, cfg):
"""
Returns an instance of the trainer class from registry.
"""
trainer_cls = TRAINER_WRAPPER_REGISTRY.get(
trainer_name=cfg.training.trainer_name,
framework=cfg.model.framework,
use_case=cfg.use_case
)
# The registered function expects dataloaders, model and full config
return trainer_cls(dataloaders=dataloaders, model=model, cfg=cfg)
def get_quantizer(dataloaders, model, cfg):
"""
Returns an instance of the quantizer class from registry.
"""
quantizer_cls = QUANTIZER_WRAPPER_REGISTRY.get(
quantizer_name=cfg.quantization.quantizer.lower(),
framework=cfg.model.framework,
use_case=cfg.use_case
)
# The registered function expects dataloaders, model and full config
return quantizer_cls(dataloaders=dataloaders, model=model, cfg=cfg)
def get_evaluator(dataloaders, model, cfg):
"""
Returns an instance of the evaluator class from registry.
"""
if isinstance(model, tf.keras.Model):
evaluator_name = "keras_evaluator"
elif 'Interpreter' in str(type(model)):
evaluator_name = "tflite_evaluator"
elif isinstance(model, onnxruntime.InferenceSession):
evaluator_name = "onnx_evaluator"
elif isinstance(model, torch.nn.Module):
model_name = cfg.model.model_name.lower()
# Will not work if user is providing custom format.
if "ssd" in model_name:
evaluator_name = "ssd"
# ---- YOLOD----
elif "yolod" in model_name:
evaluator_name = "yolod"
else :
evaluator_name = "torch_evaluator"
else:
raise TypeError("Unsupported model type for evaluation")
evaluator_cls = EVALUATOR_WRAPPER_REGISTRY.get(
evaluator_name=evaluator_name,
framework=cfg.model.framework,
use_case=cfg.use_case,
)
return evaluator_cls(dataloaders=dataloaders, model=model, cfg=cfg)
def get_predictor(dataloaders, model, cfg):
"""
Returns an instance of the predictor class from registry.
"""
if isinstance(model, tf.keras.Model):
predictor_name = "keras_predictor"
elif 'Interpreter' in str(type(model)):
predictor_name = "tflite_predictor"
elif isinstance(model, onnxruntime.InferenceSession):
predictor_name = "onnx_predictor"
else:
raise TypeError("Unsupported model type for predictor")
predictor_cls = PREDICTOR_WRAPPER_REGISTRY.get(
predictor_name=predictor_name,
framework=cfg.model.framework,
use_case=cfg.use_case
)
# The registered function expects dataloaders, model and full config
return predictor_cls(dataloaders=dataloaders, model=model, cfg=cfg)
def list_models(
filter_string='',
match_all=True,
print_table=True,
with_checkpoint=False,
):
"""
A helper function to list all existing models based on text filters
You can provide list of strings like model_name, use_case, framework.
It will print a table of corresponding available models
:param filter: a string or list of strings containing model name, use_case , framework or "model_name_use_case_framework"
to use as a filter
:param print_table: Whether to print a table with matched models (if False, return as a list)
"""
#print(MODEL_WRAPPER_REGISTRY.registry_dict.keys())
if with_checkpoint:
all_model_keys = MODEL_WRAPPER_REGISTRY.pretrained_models.keys()
else:
all_model_keys = MODEL_WRAPPER_REGISTRY.registry_dict.keys()
all_models = {
model_key.model_name + '_' + model_key.use_case + '_' + model_key.framework: model_key
for model_key in all_model_keys
}
models = set()
include_filters = (
filter_string if isinstance(filter_string, (tuple, list)) else [filter_string]
)
matched_sets = []
for keyword in include_filters:
matched = set(fnmatch.filter(all_models.keys(), f'*{keyword}*'))
matched_sets.append(matched) # append always, even if empty
if match_all:
# If ANY matched set is empty → intersection is empty
if any(len(s) == 0 for s in matched_sets):
models = set()
else:
models = set.intersection(*matched_sets)
else:
# match_any behavior (if you need it)
models = set().union(*matched_sets)
found_model_keys = [all_models[model] for model in sorted(models)]
if not print_table:
return found_model_keys
# Build a table with counts per model name (dataset_name removed)
model_counts = collections.Counter(mk.model_name for mk in found_model_keys)
table = texttable.Texttable()
rows = [['Model name', 'Count']]
for model_name, count in model_counts.items():
rows.append([model_name, count])
table.add_rows(rows)
LOGGER.info(table.draw())
return found_model_keys
def list_models_by_dataset(dataset_name, with_checkpoint=False):
return [
model_key.model_name
for model_key in list_models(dataset_name, print_table=False, with_checkpoint=with_checkpoint)
if model_key.dataset_name == dataset_name
]
|