| def cleanup_code( |
| code: str, |
| language_type: str = None, |
| dataset: str = None, |
| issft: bool = False, |
| stop_words = [] |
| ): |
| """ |
| Cleans up the generated code. |
| """ |
|
|
| if language_type.lower() == "python": |
| if issft: |
| code = _clean_python_code_for_sft(code) |
| stop_words = ["\ndef", "\nclass", "\nif", "\n#", "\nprint"] |
| code = _truncate_code_at_stopwords(code, stop_words) |
| elif language_type.lower() == "ts": |
| code = _truncate_code_at_stopwords(code, stop_words + ["\nexport", "\nimport", "\nexport default", "\nimport default", "\nconsole.log"]) |
| else: |
| code = _truncate_code_at_stopwords(code, stop_words) |
|
|
| return code |
|
|
| def _clean_python_code_for_sft(code): |
| code = code.replace("\r", "") |
| if "```python" in code: |
| code_start_idx = code.index("```python") |
| code = code[code_start_idx:].replace("```python", "").strip() |
| end_idx = code.find("```") if "```" in code else len(code) |
| code = code[:end_idx].strip() |
|
|
| return code |
|
|
| def _truncate_code_at_stopwords(code, stop_words): |
| min_stop_idx = len(code) |
| for stop_word in stop_words: |
| stop_index = code.find(stop_word) |
| if 0 <= stop_index < min_stop_idx: |
| min_stop_idx = stop_index |
| return code[:min_stop_idx] |