| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from typing import Dict, List |
| from munch import DefaultMunch |
| from omegaconf import DictConfig |
| from pathlib import Path |
| import re |
| import numpy as np |
| import requests |
| from hydra.core.hydra_config import HydraConfig |
|
|
|
|
| aspect_ratio_dict = {"fit": "ASPECT_RATIO_FIT", |
| "crop": "ASPECT_RATIO_CROP", |
| "padding": "ASPECT_RATIO_PADDING", |
| "full_screen": "ASPECT_RATIO_FULLSCREEN"} |
| |
| color_mode_n6_dict = {"rgb": "COLOR_RGB", |
| "bgr": "COLOR_BGR"} |
| |
|
|
|
|
| def download_file(url:str, local_path:str): |
| """ |
| Downloads a file from the given URL and saves it to the specified local path. |
| args: |
| url (str): URL of the file to download |
| local_path(str): Local path where the file should be saved |
| """ |
| try: |
| |
| response = requests.get(url, stream=True, timeout=20) |
|
|
| |
| response.raise_for_status() |
|
|
| |
| with open(local_path, 'wb') as file: |
| |
| for chunk in response.iter_content(chunk_size=8192): |
| |
| file.write(chunk) |
|
|
| print(f"[INFO] : Pretrained model file downloaded successfully and saved as :\n\t{local_path}") |
|
|
| except requests.exceptions.HTTPError as http_err: |
| print(f"[ERROR] : HTTP error occurred :\n\t{http_err}") |
| except Exception as err: |
| print(f"[ERROR] : An error occurred while downloading the pretrained model \n\t: {err}") |
|
|
|
|
| def check_attributes(cfg: Dict, |
| expected: List[str] = None, |
| optional: List[str] = [], |
| section: str = None) -> None: |
| ''' |
| Checks that all the expected attributes are present in the configuration dictionary and that |
| there are no unknown attributes. Optional attributes may also be present. |
| args: |
| cfg (dict): The configuration dictionary. |
| expected (list): A list of expected (required) attributes. |
| optional (list, optional): A list of optional attributes. Defaults to []. |
| section (str, optional): The name of the config file section to check. Defaults to None. |
| raises: |
| ValueError: If an unknown or unsupported attribute is found or if a required attribute is missing. |
| ''' |
| if section is not None: |
| message = "\nPlease check the '{}' section of your configuration file.".format(section) |
| else: |
| message = "\nPlease check your configuration file." |
|
|
| if cfg is not None: |
| if type(cfg) != DefaultMunch: |
| raise ValueError("Expecting an attribute. Received {}{}".format(cfg, message)) |
| |
| for attr in cfg.keys(): |
| if (not attr in expected) and (not attr in optional): |
| raise ValueError("\nUnknown or unsupported attribute. Received `{}`{}".format(attr, message)) |
| |
| used = list(cfg.keys()) |
| else: |
| used = [] |
|
|
| |
| for attr in expected: |
| if attr not in used: |
| raise ValueError("\nMissing `{}` attribute{}".format(attr, message)) |
| if cfg[attr] is None: |
| raise ValueError("\nMissing a value for attribute `{}`{}".format(attr, message)) |
| |
|
|
| def collect_callback_args(name, |
| args=None, |
| message=None) -> str: |
| if args: |
| if type(args) != DefaultMunch: |
| raise ValueError(f"\nInvalid syntax for `{name}` callback arguments{message}") |
| text = "(" |
| for k, v in args.items(): |
| if type(v) == str and v[:7] != "lambda ": |
| text += f'{k}=r"{v}", ' |
| else: |
| text += f'{k}={v}, ' |
| text = text[:-2] + ")" |
| else: |
| text = "()" |
| return text |
|
|
|
|
| def get_random_seed(cfg: DictConfig = None): |
| ''' |
| Returns a random seed based on the configuration file. |
| args: |
| cfg (DictConfig): The configuration object. |
| returns: |
| int or None: The random seed. If no seed is set in the configuration file, returns 0. |
| ''' |
| if "global_seed" in cfg.general: |
| seed = cfg.general["global_seed"] |
| if seed == "None": |
| seed = None |
| else: |
| seed = int(seed) |
| else: |
| seed = 0 |
| return seed |
|
|
|
|
| def replace_none_string(dico: dict) -> dict: |
| ''' |
| Replaces None strings in the values of a dictionary with the Python None value. |
| Other values are unchanged. |
| args: |
| dico (dict): any dictionary. |
| ''' |
| for k, v in dico.items(): |
| if v == "None": |
| dico[k] = None |
| return dico |
|
|
|
|
| def expand_env_vars(string: str) -> str: |
| ''' |
| Expands environment variables in a string if any. The syntax for variables |
| is ${variable_name}. An error is thrown if a variable is found in the string |
| but is not set. |
| args: |
| string (str): The string to expand environment variables in. |
| returns: |
| string (str): The original string with expanded variables. |
| ''' |
| for match in re.findall(r'\$\{\w+\}', string): |
| var_name = match[2:-1] |
| |
| var_value = os.environ.get(var_name) |
| if var_value is None: |
| raise OSError("\nCould not find an environment variable named `{}`\n" |
| "Please check your configuration file.".format(var_name)) |
| match = "\\" + match |
| string = re.sub(match, var_value, string, count=1) |
| return string |
|
|
|
|
| def postprocess_config_dict(config: DictConfig, replace_none_string=False) -> None: |
| ''' |
| The YAML loader outputs some attribute values as strings although they |
| are different Python types. This function walks the config dictionary |
| tree and restores some of these types, including booleans, integers, |
| floats and tuples. |
| For example: |
| - "True" is converted to boolean True |
| - "1./255" is converted to a float (equal to 0.00392156) |
| - "(128, 128, 3)" is converted to a tuple. |
| The function also replaces environment variables that appear in strings |
| with their values. |
| args: |
| config (DictConfig): dictionary containing the entire configuration file. |
| ''' |
| |
| for k in config.keys(): |
| v = config[k] |
| if type(v) == dict: |
| postprocess_config_dict(v, replace_none_string=replace_none_string) |
| elif type(v) == str: |
| if replace_none_string and v.lower() == "none": |
| config[k] = None |
| continue |
| |
| v_exp = expand_env_vars(v) |
| if v_exp != v: |
| config[k] = v_exp |
| v = v_exp |
| if v[:7] == "lambda ": |
| |
| |
| |
| v = re.sub("\n", "", v) |
| config[k] = re.sub(" +", " ", v) |
| else: |
| try: |
| v_eval = eval(v) |
| except: |
| v_eval = v |
| if isinstance(v_eval, (bool, int, float, tuple)): |
| config[k] = v_eval |
|
|
|
|
| def check_config_attributes(cfg: DictConfig, |
| specs: Dict = None, |
| section: str = None) -> None: |
| ''' |
| This function checks that the attributes used in a given section |
| of the configuration file comply with specified requirements. |
| args: |
| cfg (DictConfig): dictionary containing the configuration file section to check |
| specs (Dict): dictionary specifying the requirements for attribute usage in the section |
| section (str): name of the section |
| ''' |
| specs = DefaultMunch.fromDict(specs) |
| if section == "top_level": |
| message = f"\nPlease check the top-level of your configuration file." |
| else: |
| message = f"\nPlease check the '{section}' section of your configuration file." |
|
|
| if specs.legal: |
| |
| for attr in cfg.keys(): |
| if attr not in specs.legal: |
| raise ValueError(f"\nUnknown attribute `{attr}`{message}") |
|
|
| if specs.all: |
| |
| for attr in specs.all: |
| if attr not in cfg: |
| if section == "top_level": |
| raise ValueError(f"\nMissing `{attr}` section{message}") |
| else: |
| raise ValueError(f"\nMissing `{attr}` attribute{message}") |
| if cfg[attr] is None: |
| if section == "top_level": |
| raise ValueError(f"\nMissing body of `{attr}` section{message}") |
| else: |
| raise ValueError(f"\nExpecting a value for `{attr}` attribute{message}") |
|
|
| if specs.one_or_more: |
| |
| count = 0 |
| for attr in specs.one_or_more: |
| if attr in cfg and cfg[attr] is not None: |
| count += 1 |
| if count == 0: |
| raise ValueError(f"\nMissing one or more attributes from {specs.one_or_more}{message}") |
|
|
|
|
| def check_model_file_extension(ml_path, mode, mode_groups, field_name): |
| """ |
| Validates the file extension and existence of a model file path according to the current operation mode. |
| |
| Args: |
| ml_path (str): Path to the model file to check. |
| mode (str): The current operation mode (e.g., 'training', 'quantization', etc.). |
| mode_groups (Any): An object with attributes for each mode group, each being a list of mode names. |
| field_name (str): The name of the config field being checked (for error messages). |
| |
| Raises: |
| ValueError: If the file extension is not allowed for the current mode, or if the path is not provided. |
| FileNotFoundError: If the file does not exist at the given path. |
| """ |
| m1 = f"\nExpecting `{field_name}` attribute to be set to a path to " |
| m2 = "\nPlease check the 'model' section of your configuration file." |
| if not ml_path: |
| raise ValueError(m1 + "a valid file path" + m2) |
| file_extension = Path(ml_path).suffix.lower() |
| if mode in mode_groups.training: |
| allowed = [".h5", ".keras"] |
| if file_extension not in allowed: |
| raise ValueError(m1 + ", ".join(allowed) + m2) |
| elif mode in mode_groups.quantization: |
| allowed = [".h5", ".keras", ".onnx"] |
| if file_extension not in allowed: |
| raise ValueError(m1 + ", ".join(allowed) + m2) |
| elif mode in ("evaluation", "prediction"): |
| allowed = [".h5", ".keras", ".tflite", ".onnx"] |
| if file_extension not in allowed: |
| raise ValueError(m1 + ", ".join(allowed) + m2) |
| elif mode in ("benchmarking", "deployment"): |
| allowed = [".h5", ".keras", ".tflite", ".onnx"] |
| if file_extension not in allowed: |
| raise ValueError(m1 + ", ".join(allowed) + m2) |
| if not os.path.isfile(ml_path): |
| raise FileNotFoundError( |
| f"\nUnable to find file {ml_path}\nPlease check the '{field_name}' attribute in your configuration file" |
| ) |
| |
|
|
| def parse_tools_section(cfg: DictConfig, |
| operation_mode: str, |
| hardware_type: str ="MCU") -> None: |
| ''' |
| parses the 'tools' section of the configuration file |
| args: |
| cfg (DictConfig): 'tools' section of configuration file |
| operation_mode (str): service or operation mode used |
| hardware_type (str): type of hardware targetted |
| ''' |
| |
| required = [] |
| if hardware_type == "MCU" and not operation_mode == "evaluation" and not operation_mode == "prediction": |
| required += ["path_to_cubeIDE",] |
| |
| if cfg.stedgeai: |
| legal = ["stedgeai", "path_to_cubeIDE"] |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="tools") |
|
|
| |
| legal = ["optimization", "on_cloud", "path_to_stedgeai"] |
| check_config_attributes(cfg.stedgeai, |
| specs={"legal": legal, "all": []}, section="tools.stedgeai") |
| if not cfg.stedgeai.on_cloud: |
| if not os.path.isfile(cfg.stedgeai.path_to_stedgeai): |
| print(cfg.stedgeai.path_to_stedgeai) |
| raise ValueError("Path for `stedgeai.exe` does not exist.\n" |
| "Please check the cfg.tools.stedgeai section!") |
| |
| |
| |
| cfg["stm32ai"] = cfg.stedgeai |
| cfg.stm32ai["optimization"] = cfg.stedgeai.optimization if cfg.stedgeai.optimization else "balanced" |
| cfg.stm32ai["on_cloud"] = cfg.stedgeai.on_cloud |
| cfg.stm32ai["path_to_stm32ai"] = cfg.stedgeai.path_to_stedgeai if cfg.stedgeai.path_to_stedgeai else None |
| cfg.stm32ai["version"] = Path(cfg.stedgeai.path_to_stedgeai).parts[-4] |
| cfg.stedgeai["version"] = Path(cfg.stedgeai.path_to_stedgeai).parts[-4] |
| |
| if hardware_type == "MCU": |
| if operation_mode == "deployment" and not os.path.isfile(cfg.path_to_cubeIDE): |
| raise ValueError("Path for `path_to_cubeIDE` does not exist.\n" |
| "Please check the cfg.tools section!") |
|
|
|
|
| def parse_benchmarking_section(cfg: DictConfig) -> None: |
| ''' |
| parses the 'benchmarking' section of the configuration file |
| args: |
| cfg (DictConfig): 'benchmarking' section of the configuration file |
| ''' |
| legal = ["board"] |
| check_config_attributes(cfg, specs={"legal": legal, "all": legal}, section="benchmarking") |
|
|
|
|
| def parse_quantization_section(cfg: DictConfig, |
| legal: List) -> None: |
| ''' |
| parses the dictionary containing entire configuration file |
| args: |
| cfg (DictConfig): 'quantization' section of the configuration file |
| legal (List): UC specific usable attributes |
| ''' |
| required = [x for x in legal if x not in ["export_dir", "granularity", "optimize", "target_opset", "operating_mode", |
| "onnx_quant_parameters", "op_types_to_quantize", "onnx_extra_options", |
| "iterative_quant_parameters"]] |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="quantization") |
|
|
| |
| if not cfg.export_dir: |
| cfg.export_dir = "quantized_models" |
| if not cfg.granularity: |
| cfg.granularity = "per_channel" |
| cfg.optimize = cfg.optimize if cfg.optimize is not None else False |
| cfg.target_opset = cfg.target_opset if cfg.target_opset is not None else 17 |
| cfg.operating_mode = cfg.operating_mode if cfg.operating_mode else 'default' |
|
|
| |
| if cfg.quantizer.lower() not in ["tflite_converter", "onnx_quantizer"]: |
| raise ValueError(f"\nUnknown or unsupported quantizer. Received `{cfg.quantizer}`\n" |
| "Supported quantizers are : TFlite_converter or Onnx_quantizer\n" |
| "Please check the 'quantization.quantizer' attribute in your configuration file.") |
|
|
| |
| if cfg.granularity not in ['per_channel', 'per_tensor']: |
| raise ValueError(f"\nUnknown or unsupported granularity value. Received `{cfg.granularity}`\n" |
| "Supported granularity: 'per_channel, or 'per_tensor'\n" |
| "Please check the 'quantization.granularity' attribute in your configuration file.") |
| |
| if cfg.optimize not in [True, False]: |
| raise ValueError(f"\nUnknown or unsupported optimize value. Received `{cfg.optimize}`\n" |
| "Supported optimize parameters: 'True', or 'False'\n" |
| "Please check the 'quantization.optimize' attribute in your configuration file.") |
|
|
| if not isinstance(cfg.target_opset, int): |
| raise ValueError(f"\nUnknown or unsupported target_opset value. Received `{cfg.optimize}`\n" |
| "Supported target_opset parameters: 'int' up to latest onnx_opset\n" |
| "Please check the 'quantization.target_opset' attribute in your configuration file.") |
| |
| if cfg.quantization_type.lower() not in ["ptq"]: |
| raise ValueError(f"\nUnknown or unsupported quantization type. Received `{cfg.quantization_type}`\n" |
| "Supported type: PTQ\n" |
| "Please check the 'quantization.quantization_type' attribute in your configuration file.") |
|
|
|
|
| def parse_evaluation_section(cfg: DictConfig, |
| legal: List) -> None: |
| ''' |
| parses the dictionary containing entire configuration file |
| args: |
| cfg (DictConfig): 'evaluation' section of the configuration file |
| legal (List): UC specific usable attributes |
| ''' |
| required = [] |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="evaluation") |
|
|
| |
| if not cfg.gen_npy_input: |
| cfg.gen_npy_input = False |
| if not cfg.gen_npy_output: |
| cfg.gen_npy_output = False |
| if not cfg.profile: |
| cfg.profile = "profile_O3" |
| if not cfg.input_type: |
| cfg.input_type = "uint8" |
| if not cfg.output_type: |
| cfg.output_type = "int8" |
| if not cfg.input_chpos: |
| cfg.input_chpos = "chlast" |
| if not cfg.output_chpos: |
| cfg.output_chpos = "chlast" |
| if not cfg.target: |
| cfg.target = "host" |
|
|
|
|
| def parse_top_level(cfg: DictConfig, |
| mode_groups: DictConfig = None, |
| mode_choices: List = None, |
| legal: List = None) -> None: |
| ''' |
| parses the dictionary containing entire configuration file |
| args: |
| cfg (DictConfig): configuration dictionary containing entire configuration file |
| mode_groups (DictConfig): configuration about a given mode |
| mode_choices (List): currently supported modes |
| legal (List): UC specific usable attributes |
| ''' |
|
|
| |
| message = "\nPlease check the top-level of your configuration file." |
| if "operation_mode" not in cfg: |
| raise ValueError("\nMissing `operation_mode` attribute\n" |
| f"Supported modes: {mode_choices}{message}") |
| if cfg.operation_mode is None: |
| raise ValueError("\nExpecting a value for `operation_mode` attribute\n" |
| f"Supported modes: {mode_choices}{message}") |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| mode = cfg.operation_mode |
| if mode not in mode_choices: |
| raise ValueError(f"\nUnknown value for `operation_mode` attribute. Received {mode}\n" |
| f"Supported modes: {mode_choices}{message}") |
|
|
| |
| required = ["mlflow"] |
| |
| |
| |
| if mode != "benchmarking": |
| |
| required += ["preprocessing"] |
| if "feature_extraction" in legal: |
| required += ["feature_extraction"] |
| if mode not in ("prediction", "quantization", "benchmarking", "deployment", "chain_qb", "chain_qd"): |
| required += ["dataset",] |
| if mode in mode_groups.training: |
| required += ["training",] |
| if mode in mode_groups.quantization: |
| required += ["quantization",] |
| |
| |
| if mode in mode_groups.benchmarking: |
| required += ["benchmarking", "tools"] |
| if mode in mode_groups.deployment: |
| required += ["deployment", "tools"] |
| if mode in mode_groups.compression: |
| required += ["compression", "training"] |
|
|
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="top_level") |
|
|
|
|
| def parse_general_section(cfg: DictConfig, |
| mode: str = None, |
| mode_groups: str = None, |
| legal: List = None, |
| required: List = None, |
| output_dir: str = '') -> None: |
| ''' |
| parses the general section of configuration file. |
| args: |
| cfg (DictConfig): configuration dictionary |
| mode (str): operation mode |
| mode_groups (str): operation mode group |
| legal (List): UC specific usable attributes |
| required (List): UC specific required attributes |
| output_dir (str): output directory for the current run |
| ''' |
| |
| |
| |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="general") |
|
|
| |
| if not cfg.project_name: |
| cfg.project_name = "<unnamed>" |
| if not cfg.logs_dir: |
| cfg.logs_dir = "logs" |
| if not cfg.saved_models_dir: |
| cfg.saved_models_dir = "saved_models" |
| cfg.deterministic_ops = cfg.deterministic_ops if cfg.deterministic_ops is not None else False |
| cfg.display_figures = cfg.display_figures if cfg.display_figures is not None else True |
| if not cfg.global_seed or cfg.global_seed == 'None': |
| cfg.global_seed = 123 |
|
|
| if not cfg.num_threads_tflite: |
| cfg.num_threads_tflite = 1 |
| |
|
|
|
|
| def parse_random_periodic_resizing(cfg, output_stride): |
|
|
| message = "Please check the 'random_periodic_resizing' section of your configuration file." |
| if "period" not in cfg: |
| raise ValueError(f"\nMissing `period` argument\n{message}") |
| if "image_sizes" not in cfg: |
| raise ValueError(f"\nMissing `image_sizes` argument\n{message}") |
| |
| |
| |
| sizes_str = '[' |
| for size in cfg.image_sizes: |
| if isinstance(size, (list, tuple)): |
| sizes_str += '(' |
| for x in size: |
| sizes_str += str(x) + ',' |
| sizes_str = sizes_str[:-1] + '),' |
| else: |
| sizes_str += str(size) + ',' |
| sizes_str = sizes_str[:-1] + ']' |
|
|
| sizes_message = "\nInvalid syntax for `image_sizes` argument\n" |
| try: |
| x = eval(sizes_str) |
| random_sizes = np.array(x, dtype=np.int32) |
| except: |
| raise ValueError(sizes_message + message) |
| |
| if np.shape(random_sizes)[1] != 2: |
| raise ValueError(sizes_message + message) |
|
|
| |
| for size in random_sizes: |
| if np.shape(output_stride)==(): |
| output_strides = [output_stride] |
| else: |
| output_strides = output_stride |
| for os in output_strides: |
| if (size[0] % os != 0) or (size[1] % os != 0): |
| raise ValueError( |
| f"Image sizes must be multiples of the network stride.\n" |
| f"Network stride: {os}\n" |
| f"Invalid image size: {size}\n" |
| f"{message}") |
|
|
| return random_sizes.tolist() |
| |
|
|
| def parse_compression_section(cfg: DictConfig, |
| legal: List) -> None: |
| ''' |
| parses the dictionary containing entire configuration file |
| args: |
| cfg (DictConfig): 'compression' section of the configuration file |
| legal (List): UC specific usable attributes |
| ''' |
| required = [x for x in legal if x not in ["factor", "strong_optimization"]] |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="compression") |
|
|
| |
| cfg.factor = cfg.factor if cfg.factor else 0.5 |
| cfg.strong_optimization = cfg.strong_optimization if cfg.strong_optimization else False |
|
|
| |
| if not isinstance(cfg.factor, float): |
| raise ValueError(f"\nUnknown or unsupported factor value. Received `{cfg.factor}`\n" |
| "Supported factor parameters: 'float'\n" |
| "Please check the 'compression.factor' attribute in your configuration file.") |
|
|
| |
| if cfg.strong_optimization not in [True, False]: |
| raise ValueError(f"\nUnknown or unsupported strong_optimization value. Received `{cfg.strong_optimization}`\n" |
| "Supported optimize parameters: 'True', or 'False'\n" |
| "Please check the 'compression.strong_optimization' attribute in your configuration file.") |
|
|
|
|
| def parse_training_section(cfg: DictConfig, |
| legal: List = None) -> None: |
| ''' |
| parses the training section of configuration file. |
| args: |
| cfg (DictConfig): 'training' section of the configuration file |
| model_path_used (bool): a flag to tell if the 'model.model_path' parameter is provided |
| model_type_used (bool): a flag to tell if the 'model.model_type' parameter is provided |
| legal (List): usable attributes |
| ''' |
| required = ["batch_size", "epochs", "optimizer"] |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="training") |
|
|
| |
| |
| if type(cfg.optimizer) == str: |
| cfg.optimizer = DefaultMunch.fromDict({cfg.optimizer: None}) |
|
|
| def parse_model_section(cfg: DictConfig, mode: str, mode_groups, legal: list, required: list = None) -> None: |
| """ |
| Checks and parses the root-level 'model' section of the config file. |
| |
| Args: |
| cfg (DictConfig): The model configuration parameters. |
| mode (str): The current operation mode. |
| mode_groups: The mode groups object. |
| legal (list): List of legal model attributes for this use case. |
| required (list, optional): List of required model attributes. If None, no required fields are enforced. |
| |
| Returns: |
| None |
| """ |
| req = [] if required is None else list(required) |
| if cfg.model_name: |
| req.append("input_shape") |
| |
| model_sources = ["model_name", "model_path"] |
| set_sources = [name for name in model_sources if getattr(cfg, name, None)] |
| if len(set_sources) == 0: |
| raise ValueError( |
| "\nExpecting one of the following model source attributes to be set: " |
| f"{', '.join(model_sources)}\nPlease check your configuration file." ) |
|
|
|
|
| check_config_attributes(cfg, specs={"legal": legal, "all": req}, section="model") |
|
|
| if cfg.model_path and cfg.model_path[:4].lower() == "http": |
| print('[INFO] : A URL found for model.model_path variable!') |
| url = cfg.model_path |
| output_dir = HydraConfig.get().runtime.output_dir |
| model_dir = os.path.join(output_dir, os.path.splitext(os.path.basename(cfg.model_path))[0]) |
| os.makedirs(model_dir, exist_ok=True) |
| local_path = os.path.join(model_dir, url.split('/')[-1]) |
| download_file(url, local_path) |
| cfg.model_path = local_path |
|
|
| |
| if cfg.model_path: |
| file_extension = Path(cfg.model_path).suffix.lower() |
| if file_extension in [".h5", ".keras", ".tflite"]: |
| cfg.framework = "tf" |
| check_model_file_extension(cfg.model_path, mode, mode_groups, "model_path") |
| elif file_extension in [".pt", ".pth"]: |
| cfg.framework = "torch" |
| elif file_extension in [".onnx"] and not cfg.framework: |
| cfg.framework = "tf" |
| check_model_file_extension(cfg.model_path, mode, mode_groups, "model_path") |
| else: |
| |
| if cfg.model_name[-3:]=='_pt': |
| cfg.framework = "torch" |
| else: |
| cfg.framework = "tf" |
| |
| if cfg.framework == "tf": |
| if len(set_sources) > 1: |
| raise ValueError( |
| "\nThe following model source attributes are mutually exclusive and more than one is set: " |
| f"{', '.join(set_sources)}\nPlease check your configuration file.") |
|
|
| |
| def parse_prediction_section(cfg: DictConfig) -> None: |
| ''' |
| parses the prediction section of configuration file. |
| args: |
| cfg (DictConfig): 'prediction' section of the configuration file |
| ''' |
| legal = ["seed","target", "reid_distance_metric", |
| "profile", "input_type", "output_type", "input_chpos", "output_chpos"] |
| required = [] |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="prediction") |
|
|
| |
| if not cfg.profile: |
| cfg.profile = "profile_O3" |
| if not cfg.input_type: |
| cfg.input_type = "uint8" |
| if not cfg.output_type: |
| cfg.output_type = "int8" |
| if not cfg.input_chpos: |
| cfg.input_chpos = "chlast" |
| if not cfg.output_chpos: |
| cfg.output_chpos = "chlast" |
| if not cfg.target: |
| cfg.target = "host" |
|
|
| def parse_deployment_section(cfg: DictConfig, |
| legal: List = None, |
| legal_hw: List = None) -> None: |
| ''' |
| parses the training section of configuration file. |
| args: |
| cfg (DictConfig): 'deployment' section of the configuration file |
| legal (List): usable attributes |
| legal_hw (List): usable attributes for the HW setup part |
| ''' |
| required = [x for x in legal] |
| check_config_attributes(cfg, specs={"legal": legal, "all": required}, section="deployment") |
| required = [x for x in legal_hw if x != 'stlink_serial_number'] |
| check_config_attributes(cfg.hardware_setup, specs={"legal": legal_hw, "all": required}, |
| section="deployment.hardware_setup") |
|
|
|
|
| def parse_mlflow_section(cfg: DictConfig) -> None: |
| ''' |
| parses the mlflow section of configuration |
| args: |
| cfg (DictConfig): 'mlflow' section of the configuration |
| ''' |
| legal = ["uri"] |
| check_config_attributes(cfg, specs={"legal": legal, "all": legal}, section="mlflow") |
|
|
|
|
| def check_hardware_type(cfg: DictConfig, |
| mode_groups: DictConfig = None) -> None: |
| ''' |
| parses the mlflow section of configuration |
| args: |
| cfg (DictConfig): dictionary containing the configuration file section to check |
| mode_groups (DictConfig): configuration about a given mode |
| ''' |
| |
| cfg["hardware_type"] = "MCU" |
|
|
| |
| if cfg.operation_mode in mode_groups.benchmarking: |
| if cfg.benchmarking.board is not None: |
| if "STM32MP" in cfg.benchmarking.board: |
| cfg.hardware_type = "MPU" |
| elif cfg.operation_mode in mode_groups.deployment: |
| if cfg.deployment.hardware_setup.serie is not None: |
| if "STM32MP" in cfg.deployment.hardware_setup.serie: |
| cfg.hardware_type = "MPU" |
|
|
|
|
| def get_class_names_from_file(cfg: DictConfig) -> List[str]: |
| if cfg.classes_file_path : |
| with open(cfg.classes_file_path, 'r') as file: |
| class_names = [line.strip() for line in file] |
| return class_names |
|
|
| from omegaconf import DictConfig, OmegaConf |
|
|
| def flatten_config(cfg, preserve_keys=("class_map",)): |
| |
| if isinstance(cfg, DictConfig): |
| cfg = OmegaConf.to_container(cfg, resolve=False) |
| elif "DefaultMunch" in str(type(cfg)): |
| cfg = cfg.toDict() |
| else: |
| cfg = dict(cfg) |
|
|
| preserve_keys = set(preserve_keys) |
| flat_config = {} |
|
|
| def _flatten(d): |
| for k, v in d.items(): |
| if isinstance(v, dict): |
| if k in preserve_keys: |
| flat_config[k] = v |
| else: |
| _flatten(v) |
| else: |
| flat_config[k] = v |
|
|
| _flatten(cfg) |
| return flat_config |