| import os |
| from typing import Optional, List, Dict |
| from pydantic import Field |
|
|
| from ..models.base_model import BaseLLM, LLMOutputParser |
| from .action import Action, ActionInput, ActionOutput |
| from ..prompts.code_extraction import CODE_EXTRACTION |
|
|
|
|
| class CodeExtractionInput(ActionInput): |
| """ |
| Input parameters for the CodeExtraction action. |
| """ |
| code_string: str = Field(description="The string containing code blocks to extract") |
| target_directory: str = Field(description="The directory path where extracted code files will be saved") |
| project_name: Optional[str] = Field(default=None, description="Optional name for the project folder") |
|
|
|
|
| class CodeExtractionOutput(ActionOutput): |
| """ |
| Output of the CodeExtraction action. |
| """ |
| extracted_files: Dict[str, str] = Field(description="Map of filename to file path of saved files") |
| main_file: Optional[str] = Field(default=None, description="Path to the main file if identified") |
| error: Optional[str] = Field(default=None, description="Error message if any operation failed") |
|
|
|
|
| class CodeBlockInfo(LLMOutputParser): |
| """ |
| Information about an extracted code block. |
| """ |
| language: str = Field(description="Programming language of the code block") |
| filename: str = Field(description="Suggested filename for the code block") |
| content: str = Field(description="The actual code content") |
|
|
|
|
| class CodeBlockList(LLMOutputParser): |
| """ |
| List of code blocks extracted from text. |
| """ |
| code_blocks: List[CodeBlockInfo] = Field(description="List of code blocks") |
|
|
|
|
| class CodeExtraction(Action): |
| """ |
| An action that extracts and organizes code blocks from text. |
| |
| This action uses an LLM to analyze text containing code blocks, extract them, |
| suggest appropriate filenames, and save them to a specified directory. It can |
| also identify which file is likely the main entry point based on heuristics. |
| |
| Attributes: |
| name: The name of the action. |
| description: A description of what the action does. |
| prompt: The prompt template used by the action. |
| inputs_format: The expected format of inputs to this action. |
| outputs_format: The format of the action's output. |
| """ |
|
|
| def __init__(self, **kwargs): |
| |
| name = kwargs.pop("name") if "name" in kwargs else CODE_EXTRACTION["name"] |
| description = kwargs.pop("description") if "description" in kwargs else CODE_EXTRACTION["description"] |
| prompt = kwargs.pop("prompt") if "prompt" in kwargs else CODE_EXTRACTION["prompt"] |
| |
| |
| inputs_format = kwargs.pop("inputs_format", None) or CodeExtractionInput |
| outputs_format = kwargs.pop("outputs_format", None) or CodeExtractionOutput |
| super().__init__(name=name, description=description, prompt=prompt, inputs_format=inputs_format, outputs_format=outputs_format, **kwargs) |
|
|
| def identify_main_file(self, saved_files: Dict[str, str]) -> Optional[str]: |
| """Identify the main file from the saved files based on content and file type. |
| |
| This method uses a combination of common filename conventions and content |
| analysis to determine which file is likely the main entry point of a project. |
| |
| Args: |
| saved_files: Dictionary mapping filenames to their full paths |
| |
| Returns: |
| Path to the main file if found, None otherwise |
| |
| """ |
| |
| main_file_priorities = [ |
| |
| "index.html", |
| |
| "main.py", |
| "app.py", |
| |
| "index.js", |
| "main.js", |
| "app.js", |
| |
| "Main.java", |
| |
| "main.cpp", |
| "main.c", |
| |
| "main.go", |
| |
| "index.php", |
| "Program.cs" |
| ] |
| |
| |
| for main_file in main_file_priorities: |
| if main_file in saved_files: |
| return saved_files[main_file] |
| |
| |
| |
| |
| html_files = {k: v for k, v in saved_files.items() if k.endswith('.html')} |
| if html_files: |
| return next(iter(html_files.values())) |
| |
| |
| py_files = {k: v for k, v in saved_files.items() if k.endswith('.py')} |
| if py_files: |
| for filename, path in py_files.items(): |
| with open(path, 'r', encoding='utf-8') as f: |
| content = f.read() |
| if "if __name__ == '__main__'" in content or 'if __name__ == "__main__"' in content: |
| return path |
| |
| if py_files: |
| return next(iter(py_files.values())) |
| |
| |
| java_files = {k: v for k, v in saved_files.items() if k.endswith('.java')} |
| if java_files: |
| for filename, path in java_files.items(): |
| with open(path, 'r', encoding='utf-8') as f: |
| content = f.read() |
| if "public static void main" in content: |
| return path |
| |
| if java_files: |
| return next(iter(java_files.values())) |
| |
| |
| js_files = {k: v for k, v in saved_files.items() if k.endswith('.js')} |
| if js_files: |
| return next(iter(js_files.values())) |
| |
| |
| if saved_files: |
| return next(iter(saved_files.values())) |
| |
| |
| return None |
|
|
| def save_code_blocks(self, code_blocks: List[Dict], target_directory: str) -> Dict[str, str]: |
| """Save code blocks to files in the target directory. |
| |
| Creates the target directory if it doesn't exist and saves each code block |
| to a file with an appropriate name, handling filename conflicts. |
| |
| Args: |
| code_blocks: List of dictionaries containing code block information |
| target_directory: Directory path where files should be saved |
| |
| Returns: |
| Dictionary mapping filenames to their full paths |
| """ |
| os.makedirs(target_directory, exist_ok=True) |
| saved_files = {} |
| |
| for block in code_blocks: |
| filename = block.get("filename", "unknown.txt") |
| content = block.get("content", "") |
| |
| |
| if not content.strip(): |
| continue |
| |
| |
| base_filename = filename |
| counter = 1 |
| while filename in saved_files: |
| name_parts = base_filename.split('.') |
| if len(name_parts) > 1: |
| filename = f"{'.'.join(name_parts[:-1])}_{counter}.{name_parts[-1]}" |
| else: |
| filename = f"{base_filename}_{counter}" |
| counter += 1 |
| |
| |
| file_path = os.path.join(target_directory, filename) |
| with open(file_path, 'w', encoding='utf-8') as f: |
| f.write(content) |
| |
| |
| saved_files[filename] = file_path |
| |
| return saved_files |
|
|
| def execute(self, llm: Optional[BaseLLM] = None, inputs: Optional[dict] = None, sys_msg: Optional[str]=None, return_prompt: bool = False, **kwargs) -> CodeExtractionOutput: |
| """Execute the CodeExtraction action. |
| |
| Extracts code blocks from the provided text using the specified LLM, |
| saves them to the target directory, and identifies the main file. |
| |
| Args: |
| llm: The LLM to use for code extraction |
| inputs: Dictionary containing: |
| - code_string: The string with code blocks to extract |
| - target_directory: Where to save the files |
| - project_name: Optional project folder name |
| sys_msg: Optional system message override for the LLM |
| return_prompt: Whether to return the prompt along with the result |
| **kwargs (Any): Additional keyword arguments |
| |
| Returns: |
| CodeExtractionOutput with extracted file information |
| """ |
| if not llm: |
| error_msg = "CodeExtraction action requires an LLM." |
| return CodeExtractionOutput(extracted_files={}, error=error_msg) |
| |
| if not inputs: |
| error_msg = "CodeExtraction action received invalid `inputs`: None or empty." |
| return CodeExtractionOutput(extracted_files={}, error=error_msg) |
| |
| code_string = inputs.get("code_string", "") |
| target_directory = inputs.get("target_directory", "") |
| project_name = inputs.get("project_name", None) |
| |
| if not code_string: |
| error_msg = "No code string provided." |
| return CodeExtractionOutput(extracted_files={}, error=error_msg) |
| |
| if not target_directory: |
| error_msg = "No target directory provided." |
| return CodeExtractionOutput(extracted_files={}, error=error_msg) |
| |
| |
| if project_name: |
| project_dir = os.path.join(target_directory, project_name) |
| else: |
| project_dir = target_directory |
| |
| try: |
| |
| prompt_params = {"code_string": code_string} |
| system_message = CODE_EXTRACTION["system_prompt"] if sys_msg is None else sys_msg |
|
|
| llm_response: CodeBlockList = llm.generate( |
| prompt=self.prompt.format(**prompt_params), |
| system_message=system_message, |
| parser=CodeBlockList, |
| parse_mode="json" |
| ) |
| code_blocks = llm_response.get_structured_data().get("code_blocks", []) |
|
|
| |
| saved_files = self.save_code_blocks(code_blocks, project_dir) |
| |
| |
| main_file = self.identify_main_file(saved_files) |
| |
| result = CodeExtractionOutput( |
| extracted_files=saved_files, |
| main_file=main_file |
| ) |
| |
| if return_prompt: |
| return result, self.prompt.format(**prompt_params) |
| |
| return result |
| |
| except Exception as e: |
| error_msg = f"Error extracting code: {str(e)}" |
| return CodeExtractionOutput(extracted_files={}, error=error_msg) |