| import copy |
| import asyncio |
| from typing import Callable, Dict, Union, Awaitable |
| from pydantic import Field |
| import dspy |
| from ...optimizers.engine.registry import ParamRegistry |
| from typing import List |
| |
| from ...core.logging import logger |
| from ...prompts.template import PromptTemplate |
| from dspy.utils.saving import get_dependency_versions |
| from pathlib import Path |
| import cloudpickle |
| import ujson |
|
|
|
|
| class PromptTuningModule(dspy.Module): |
| """ |
| A prompt tuning module that manages interactions between predictors, |
| parameter registry, and program functions. |
| |
| This module coordinates prompt optimization through: |
| 1. Maintaining a set of predictors for different tasks |
| 2. Synchronizing optimized parameters back to the program |
| 3. Executing the program with updated parameters |
| |
| Parameters |
| ---------- |
| program : Union[Callable[..., dict], Callable[..., Awaitable[dict]]] |
| The main program function to execute. Can be either synchronous or asynchronous. |
| Must return a dictionary containing execution results. |
| signature_dict : Dict[str, dspy.Signature] |
| A mapping of task names to their corresponding DSPy signatures. |
| Each signature defines the input/output structure for a specific task. |
| registry : ParamRegistry |
| A registry that maintains tunable parameters shared between |
| predictors and the program. |
| """ |
|
|
| @classmethod |
| def from_registry( |
| cls, |
| program: Union[Callable[..., dict], Callable[..., Awaitable[dict]]], |
| registry: ParamRegistry, |
| ) -> "PromptTuningModule": |
| """ |
| Factory method to create a PromptTuningModule from a registry and program. |
| |
| This method: |
| 1. Creates signatures for each field in the registry |
| 2. Initializes a PromptTuningModule with the program and signatures |
| 3. Sets up predictors for each signature |
| |
| Parameters |
| ---------- |
| program : Union[Callable[..., dict], Callable[..., Awaitable[dict]]] |
| The main program function to execute |
| registry : ParamRegistry |
| Registry containing tunable parameters |
| |
| Returns |
| ------- |
| PromptTuningModule |
| A configured PromptTuningModule instance |
| |
| Examples |
| -------- |
| >>> registry = ParamRegistry() |
| >>> registry.register("task1", "What is {topic}?") |
| >>> registry.register("task2", PromptTemplate(system="You are helpful.", user="{query}")) |
| >>> def my_program(**kwargs) -> dict: |
| ... return {"result": "done"} |
| >>> module = PromptTuningModule.from_registry(my_program, registry) |
| """ |
| from .signature_utils import signature_from_registry |
|
|
| |
| signature_dict, signature_name2register_name = signature_from_registry( |
| registry=registry, |
| ) |
| |
| |
| return cls(program=program, signature_dict=signature_dict, registry=registry, signature_name2register_name=signature_name2register_name) |
|
|
| def __init__( |
| self, |
| program: Union[Callable[..., dict], Callable[..., Awaitable[dict]]], |
| signature_dict: Dict[str, dspy.Signature], |
| registry: ParamRegistry, |
| signature_name2register_name: Dict[str, str], |
| ): |
| """ |
| Initialize a PromptTuningModule instance. |
| |
| Parameters |
| ---------- |
| program : Union[Callable[..., dict], Callable[..., Awaitable[dict]]] |
| The main program function to execute |
| signature_dict : Dict[str, dspy.Signature] |
| Mapping of task names to signatures |
| registry : ParamRegistry |
| Parameter registry |
| signature_name2register_name : Dict[str, str] |
| Mapping of signature names to register names |
| """ |
| super().__init__() |
| self.program = program |
| self.predicts = [] |
|
|
| seen = set() |
| for name, signature in signature_dict.items(): |
| if name in seen: |
| raise ValueError(f"Duplicate name {name} in signature_dict") |
| seen.add(name) |
| self.predicts.append(dspy.Predict(signature, name=name)) |
| self.registry = registry |
| self.signature_name2register_name = signature_name2register_name |
|
|
| def reset(self): |
| """ |
| Reset the module to its initial state. |
| """ |
| self.registry.reset() |
|
|
| for predict in self.predicts: |
| signature = predict.signature |
|
|
| signature_name = signature.__name__ |
| register_name = self.signature_name2register_name[signature_name] |
|
|
| register_element = self.registry.get(register_name) |
|
|
| if isinstance(register_element, PromptTemplate): |
| predict.signature.instructions = register_element.instruction |
| predict.demos = register_element.demonstrations |
| elif isinstance(register_element, str): |
| predict.signature.instructions = register_element |
| predict.demos = [] |
| else: |
| logger.warning(f"Unsupported register element type: {type(register_element)}") |
| |
| |
| return self |
|
|
| def escape_braces(self, text): |
| """ |
| Escape all braces in the text. |
| |
| Parameters |
| ---------- |
| text : str |
| Text that needs escaping |
| |
| Returns |
| ------- |
| str |
| Escaped text |
| """ |
| def helper(s, start=0): |
| result = '' |
| i = start |
| while i < len(s): |
| if s[i] == '{': |
| inner, new_i = helper(s, i + 1) |
| result += '{{' + inner + '}}' |
| i = new_i |
| elif s[i] == '}': |
| return result, i + 1 |
| else: |
| result += s[i] |
| i += 1 |
| return result, i |
|
|
| escaped, _ = helper(text) |
| return escaped |
| |
| def _validate_prompt(self, prompt: str, input_names: List[str], verbose: bool = True) -> str: |
| """ |
| Validate if the generated prompt is valid. Currently only checks if required inputs are wrapped in braces. |
| |
| Parameters |
| ---------- |
| prompt : str |
| The prompt to validate |
| input_names : List[str] |
| List of required input names |
| verbose : bool, optional |
| Whether to show detailed information, defaults to True |
| |
| Returns |
| ------- |
| str |
| Validated and potentially modified prompt |
| """ |
| modified_messages = [] |
| required_inputs = input_names |
| missing_required_inputs = [name for name in required_inputs if f"{{{name}}}" not in prompt] |
| if missing_required_inputs: |
| input_values = "\n\n".join([f"{name}: {{{name}}}" for name in missing_required_inputs]) |
| prompt += f"\n\nThe followings are some required input values: \n{input_values}" |
| modified_messages.append(f"added missing inputs: {', '.join(missing_required_inputs)}") |
|
|
| prompt = self.escape_braces(prompt) |
| for name in input_names: |
| prompt = prompt.replace(f"{{{{{name}}}}}", f"{{{name}}}") |
| prompt = prompt.replace(r"{{{{", r"{{").replace(r"}}}}", r"}}") |
|
|
| |
| |
| |
| return prompt |
| |
| def get_field_type(self, field: Field) -> str: |
| """ |
| Get the type of the field. |
| |
| Parameters |
| ---------- |
| field : Field |
| The field to get type from |
| |
| Returns |
| ------- |
| str |
| The field type |
| """ |
| return field.json_schema_extra.get('__dspy_field_type') if field.json_schema_extra.get('__dspy_field_type') else None |
|
|
| def is_prompt_template(self, register_name: str) -> bool: |
| """ |
| Check if the register name is a prompt template. |
| |
| Parameters |
| ---------- |
| register_name : str |
| The register name to check |
| |
| Returns |
| ------- |
| bool |
| Whether it is a prompt template |
| """ |
| return self.registry.get(register_name) is not None and isinstance(self.registry.get(register_name), PromptTemplate) |
|
|
| def get_demos(self, demos: list) -> List[dict]: |
| result = [] |
| for demo in demos: |
| if isinstance(demo, dspy.Example): |
| demo = demo.toDict() |
| result.append(demo) |
| return result |
| |
| def _inject_demos_to_string(self, instruction: str, demos: List[dict], input_names: List[str], output_names: List[str]) -> str: |
| """ |
| Inject demos to the instruction. |
| """ |
| if not demos: |
| return instruction |
| |
| def _escape_braces(text: str) -> str: |
| return text.replace("{", "{{").replace("}", "}}") |
| |
| def format_demo(demo: dict) -> str: |
| demo_str = "Inputs:\n" |
| inputs = {name: demo.get(name, "Not provided") for name in input_names} |
| demo_str += "\n".join([f"{name}:\n{_escape_braces(str(value))}" for name, value in inputs.items()]) |
| demo_str += "\n\nOutputs:\n" |
| outputs = {name: demo.get(name, "Not provided") for name in output_names} |
| demo_str += "\n".join([f"{name}:\n{_escape_braces(str(value))}" for name, value in outputs.items()]) |
| return demo_str |
| |
| demos_string = "\n\n".join([f"Example {i+1}:\n{format_demo(demo)}" for i, demo in enumerate(demos)]) |
| prompt = f"{instruction}\n\nThe following are some examples:\n{demos_string}" |
| return prompt |
| |
| def sync_predict_inputs_to_program(self): |
| """ |
| Synchronize current input values from all predictors back to the registry. |
| |
| This method ensures that any optimized parameters in the predictors' configurations |
| are properly reflected in the registry, which in turn affects program execution. |
| |
| Synchronization process: |
| 1. Iterate through all predictors |
| 2. For each predictor, check its signature's input fields |
| 3. If a field has a value in the predictor's config, update the registry |
| |
| Note: Values in predictor configs take precedence as they may contain |
| optimized values from recent tuning iterations. |
| """ |
| for predict in self.predicts: |
| signature = predict.signature |
| instruction = signature.instructions |
| demos = predict.demos |
|
|
| input_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'input'] |
| output_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'output'] |
|
|
| signature_name = signature.__name__ |
| register_name = self.signature_name2register_name[signature_name] |
| |
| if self.is_prompt_template(register_name): |
| prompt_template: PromptTemplate = self.registry.get(register_name) |
| prompt_template.instruction = instruction |
| prompt_template.demonstrations = self.get_demos(demos) |
| self.registry.set(register_name, prompt_template) |
| else: |
| instruction = self._validate_prompt(instruction, input_names) |
| |
| prompt = self._inject_demos_to_string(instruction, self.get_demos(demos), input_names, output_names) |
| self.registry.set(register_name, prompt) |
| |
| def constrcut_trace(self, execution_data: dict) -> dict: |
| """ |
| Construct the trace of the execution. |
| |
| Parameters |
| ---------- |
| execution_data : dict |
| Execution data |
| |
| Returns |
| ------- |
| dict |
| Trace information |
| """ |
| trace: List[dict] = [] |
| for predict in self.predicts: |
| input_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'input'] |
| output_names = [name for name, field in predict.signature.fields.items() if self.get_field_type(field) == 'output'] |
|
|
| input_dict = {} |
| output_dict = {} |
|
|
| |
| for name in input_names: |
| if name not in execution_data: |
| logger.warning(f"Input {name} not found in execution data") |
| for name in output_names: |
| if name not in execution_data: |
| logger.warning(f"Output {name} not found in execution data") |
|
|
| |
| for name in input_names: |
| if name in execution_data: |
| input_dict[name] = execution_data[name] |
| for name in output_names: |
| if name in execution_data: |
| output_dict[name] = execution_data[name] |
| |
| trace_tuple = (predict, input_dict, output_dict) |
| trace.append(trace_tuple) |
| return trace |
|
|
| def forward(self, **kwargs) -> dict: |
| """ |
| Execute the program with synchronized parameters and optional inputs. |
| |
| This method: |
| 1. Synchronizes optimized prompts back to the program via registry |
| 2. Executes the program (handles both sync and async functions) |
| 3. Validates and returns the program's output |
| |
| Parameters |
| ---------- |
| **kwargs : dict |
| Optional keyword arguments to pass to the program function |
| |
| Returns |
| ------- |
| dict |
| The program's execution results |
| |
| Raises |
| ------ |
| ValueError |
| If the program doesn't return a dictionary |
| """ |
| |
| self.sync_predict_inputs_to_program() |
|
|
| |
| if asyncio.iscoroutinefunction(self.program): |
| output, execution_data = asyncio.run(self.program(**kwargs)) if kwargs else asyncio.run(self.program()) |
| else: |
| output, execution_data = self.program(**kwargs) if kwargs else self.program() |
|
|
| trace = self.constrcut_trace(execution_data) |
|
|
| |
| if dspy.settings.trace is not None: |
| dspy_trace = dspy.settings.trace |
| dspy_trace.extend(trace) |
|
|
| return output |
|
|
| def deepcopy(self): |
| """ |
| Deep copy the module. |
| |
| This is a tweak to the default Python deepcopy that only deep copies `self.parameters()`, |
| and for other attributes, we just do a shallow copy. |
| |
| Returns |
| ------- |
| PromptTuningModule |
| A deep copy of the module |
| """ |
| try: |
| |
| new_instance = copy.deepcopy(self) |
| setattr(new_instance, "program", self.program) |
| return new_instance |
| except Exception: |
| pass |
|
|
| |
| new_instance = self.__class__.__new__(self.__class__) |
| |
| for attr, value in self.__dict__.items(): |
| if isinstance(value, dspy.Module): |
| setattr(new_instance, attr, value.deepcopy()) |
| else: |
| try: |
| |
| setattr(new_instance, attr, copy.deepcopy(value)) |
| except Exception: |
| try: |
| |
| setattr(new_instance, attr, copy.copy(value)) |
| except Exception: |
| |
| setattr(new_instance, attr, value) |
| |
| |
| setattr(new_instance, "program", self.program) |
| return new_instance |
| |
| def save(self, path, save_program=False): |
| """Save the module. |
| |
| Save the module to a directory or a file. There are two modes: |
| - `save_program=False`: Save only the state of the module to a json or pickle file, based on the value of |
| the file extension. |
| - `save_program=True`: Save the whole module to a directory via cloudpickle, which contains both the state and |
| architecture of the model. |
| |
| We also save the dependency versions, so that the loaded model can check if there is a version mismatch on |
| critical dependencies or DSPy version. |
| |
| Args: |
| path (str): Path to the saved state file, which should be a .json or .pkl file when `save_program=False`, |
| and a directory when `save_program=True`. |
| save_program (bool): If True, save the whole module to a directory via cloudpickle, otherwise only save |
| the state. |
| """ |
| |
| metadata = {} |
| metadata["dependency_versions"] = get_dependency_versions() |
| path = Path(path) |
|
|
| if not path.is_dir(): |
| |
| if not path.parent.exists(): |
| path.parent.mkdir(parents=True) |
| else: |
| |
| if not path.exists(): |
| |
| if not path.exists(): |
| |
| path.mkdir(parents=True) |
|
|
| if hasattr(self.program, "save"): |
| self.program.save(str(path)) |
| return |
|
|
| if save_program: |
| if path.suffix: |
| raise ValueError( |
| f"`path` must point to a directory without a suffix when `save_program=True`, but received: {path}" |
| ) |
| if path.exists() and not path.is_dir(): |
| raise NotADirectoryError(f"The path '{path}' exists but is not a directory.") |
|
|
| try: |
| with open(path / "program.pkl", "wb") as f: |
| cloudpickle.dump(self, f) |
| except Exception as e: |
| raise RuntimeError( |
| f"Saving failed with error: {e}. Please remove the non-picklable attributes from your DSPy program, " |
| "or consider using state-only saving by setting `save_program=False`." |
| ) |
| with open(path / "metadata.json", "w") as f: |
| ujson.dump(metadata, f, indent=4) |
|
|
| return |
|
|
| state = self.dump_state() |
| state["metadata"] = metadata |
| if path.suffix == ".json": |
| try: |
| with open(path, "w") as f: |
| f.write(ujson.dumps(state, indent=4)) |
| except Exception as e: |
| raise RuntimeError( |
| f"Failed to save state to {path} with error: {e}. Your DSPy program may contain non " |
| "json-serializable objects, please consider saving the state in .pkl by using `path` ending " |
| "with `.pkl`, or saving the whole program by setting `save_program=True`." |
| ) |
| elif path.suffix == ".pkl": |
| with open(path, "wb") as f: |
| cloudpickle.dump(state, f) |
| else: |
| raise ValueError(f"`path` must end with `.json` or `.pkl` when `save_program=False`, but received: {path}") |
|
|
| def load(self, path): |
| """Load the saved module. You may also want to check out dspy.load, if you want to |
| load an entire program, not just the state for an existing program. |
| |
| Args: |
| path (str): Path to the saved state file, which should be a .json or a .pkl file |
| """ |
| path = Path(path) |
|
|
| if hasattr(self.program, "load"): |
| self.program.load(str(path)) |
| |
| return |
|
|
| if path.suffix == ".json": |
| with open(path) as f: |
| state = ujson.loads(f.read()) |
| elif path.suffix == ".pkl": |
| with open(path, "rb") as f: |
| state = cloudpickle.load(f) |
| else: |
| raise ValueError(f"`path` must end with `.json` or `.pkl`, but received: {path}") |
|
|
| dependency_versions = get_dependency_versions() |
| saved_dependency_versions = state["metadata"]["dependency_versions"] |
| for key, saved_version in saved_dependency_versions.items(): |
| if dependency_versions[key] != saved_version: |
| logger.warning( |
| f"There is a mismatch of {key} version between saved model and current environment. " |
| f"You saved with `{key}=={saved_version}`, but now you have " |
| f"`{key}=={dependency_versions[key]}`. This might cause errors or performance downgrade " |
| "on the loaded model, please consider loading the model in the same environment as the " |
| "saving environment." |
| ) |
| self.load_state(state) |
| self.sync_predict_inputs_to_program() |
|
|
| |