|
|
| from ...optimizers.engine.registry import ParamRegistry |
| from typing import Any, Optional, List, Dict |
| |
|
|
| class MiproRegistry(ParamRegistry): |
| """ |
| Extended ParamRegistry that supports storing input_names and output_names |
| for each optimizable field. Compatible with all original track() usages. |
| """ |
|
|
| def track( |
| self, |
| root_or_obj: Any, |
| path_or_attr: str = None, |
| *, |
| name: Optional[str] = None, |
| input_names: Optional[List[str]] = None, |
| output_names: Optional[List[str]] = None, |
| input_descs: Optional[Dict[str, str]] = None, |
| output_descs: Optional[Dict[str, str]] = None, |
| ): |
| |
| if isinstance(root_or_obj, (list, tuple)): |
| for item in root_or_obj: |
| if isinstance(item, dict): |
| self.track(**item) |
| elif isinstance(item, (list, tuple)): |
| if len(item) == 7: |
| self.track( |
| item[0], item[1], |
| name=item[2], |
| input_names=item[3], |
| output_names=item[4], |
| input_descs=item[5], |
| output_descs=item[6] |
| ) |
| else: |
| raise ValueError("Each tuple must be (obj, attr, name, input_names, output_names, input_descs, output_descs)") |
| return self |
|
|
| |
| super().track(root_or_obj, path_or_attr, name=name) |
|
|
| |
| key = name or path_or_attr |
| field = self.fields[key] |
| field.input_names = input_names or [] |
| field.output_names = output_names or [] |
| field.input_descs = input_descs or {} |
| field.output_descs = output_descs or {} |
|
|
| return self |
| |
| def get_input_names(self, name: str) -> List[str]: |
| """Return the input_names for a registered field, or an empty list if not set.""" |
| return getattr(self.fields[name], "input_names", None) or [] |
|
|
| def get_output_names(self, name: str) -> List[str]: |
| """Return the output_names for a registered field, or an empty list if not set.""" |
| return getattr(self.fields[name], "output_names", None) or [] |
| |
| def get_input_desc_dict(self, name: str) -> Dict[str, str]: |
| """Return the input_descs for a registered field, or an empty dict if not set.""" |
| return getattr(self.fields[name], "input_descs", {}) |
|
|
| def get_output_desc_dict(self, name: str) -> Dict[str, str]: |
| """Return the output_descs for a registered field, or an empty dict if not set.""" |
| return getattr(self.fields[name], "output_descs", {}) |
| |
| def get_input_desc(self, name: str, input_name: str) -> str: |
| """Return the input_desc for a registered field, or an empty string if not set.""" |
| return self.get_input_desc_dict(name).get(input_name, "") |
|
|
| def get_output_desc(self, name: str, output_name: str) -> str: |
| """Return the output_desc for a registered field, or an empty string if not set.""" |
| return self.get_output_desc_dict(name).get(output_name, "") |
| |
| def describe(self) -> Dict[str, Dict[str, Any]]: |
| """ |
| Returns a dict of all fields and their metadata, including input/output names if present. |
| """ |
| result = {} |
| for name, field in self.fields.items(): |
| result[name] = { |
| "value": field.get(), |
| "input_names": getattr(field, "input_names", None), |
| "output_names": getattr(field, "output_names", None), |
| "input_descs": getattr(field, "input_descs", {}), |
| "output_descs": getattr(field, "output_descs", {}), |
| } |
| return result |