| | import streamlit as st |
| | import torch |
| | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
| |
|
| | |
| | model_name = "starcoder2" |
| | model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
|
| | def code_complete(prompt, max_length=256): |
| | """ |
| | Generate code completion suggestions for the given prompt. |
| | |
| | Args: |
| | prompt (str): The incomplete code snippet. |
| | max_length (int, optional): The maximum length of the generated code. Defaults to 256. |
| | |
| | Returns: |
| | list: A list of code completion suggestions. |
| | """ |
| | |
| | inputs = tokenizer.encode_plus(prompt, |
| | add_special_tokens=True, |
| | max_length=max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_attention_mask=True, |
| | return_tensors="pt") |
| |
|
| | |
| | outputs = model.generate(inputs["input_ids"], |
| | attention_mask=inputs["attention_mask"], |
| | max_length=max_length) |
| |
|
| | |
| | suggestions = [] |
| | for output in outputs: |
| | decoded_code = tokenizer.decode(output, skip_special_tokens=True) |
| | suggestions.append(decoded_code) |
| |
|
| | return suggestions |
| |
|
| | def code_fix(code): |
| | """ |
| | Fix errors in the given code snippet. |
| | |
| | Args: |
| | code (str): The code snippet with errors. |
| | |
| | Returns: |
| | str: The corrected code snippet. |
| | """ |
| | |
| | inputs = tokenizer.encode_plus(code, |
| | add_special_tokens=True, |
| | max_length=512, |
| | padding="max_length", |
| | truncation=True, |
| | return_attention_mask=True, |
| | return_tensors="pt") |
| |
|
| | |
| | outputs = model.generate(inputs["input_ids"], |
| | attention_mask=inputs["attention_mask"], |
| | max_length=512) |
| |
|
| | |
| | corrected_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | return corrected_code |
| |
|
| | def text_to_code(text, max_length=256): |
| | """ |
| | Generate code from a natural language description. |
| | |
| | Args: |
| | text (str): The natural language description of the code. |
| | max_length (int, optional): The maximum length of the generated code. Defaults to 256. |
| | |
| | Returns: |
| | str: The generated code. |
| | """ |
| | |
| | inputs = tokenizer.encode_plus(text, |
| | add_special_tokens=True, |
| | max_length=max_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_attention_mask=True, |
| | return_tensors="pt") |
| |
|
| | |
| | outputs = model.generate(inputs["input_ids"], |
| | attention_mask=inputs["attention_mask"], |
| | max_length=max_length) |
| |
|
| | |
| | generated_code = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
|
| | return generated_code |
| |
|
| | |
| | st.title("Codebot") |
| | st.write("Welcome to the Codebot! You can use this app to generate code completions, fix errors in your code, or generate code from a natural language description.") |
| |
|
| | |
| | code_completion_tab = st.tab("Code Completion") |
| |
|
| | with code_completion_tab: |
| | st.write("Enter an incomplete code snippet:") |
| | prompt_input = st.text_input("Prompt:", value="") |
| | generate_button = st.button("Generate Completions") |
| |
|
| | if generate_button: |
| | completions = code_complete(prompt_input) |
| | st.write("Code completions:") |
| | for i, completion in enumerate(completions): |
| | st.write(f"{i+1}. {completion}") |
| |
|
| | |
| | code_fixing_tab = st.tab("Code Fixing") |
| |
|
| | with code_fixing_tab: |
| | st.write("Enter a code snippet with errors:") |
| | code_input = st.text_area("Code:", height=300) |
| | fix_button = st.button("Fix Errors") |
| |
|
| | if fix_button: |
| | corrected_code = code_fix(code_input) |
| | st.write("Corrected code:") |
| | st.code(corrected_code) |
| |
|
| | |
| | text_to_code_tab = st.tab("Text-to-Code") |
| |
|
| | with text_to_code_tab: |
| | st.write("Enter a natural language description of the code:") |
| | text_input = st.text_input("Description:", value="") |
| | generate_button = st.button("Generate Code") |
| |
|
| | if generate_button: |
| | generated_code = text_to_code(text_input) |
| | st.write("Generated code:") |
| | st.code(generated_code) |
| |
|
| | |
| | if __name__ == "__main__": |
| | st.run() |