| | --- |
| | license: wtfpl |
| | datasets: |
| | - HuggingFaceH4/CodeAlpaca_20K |
| | pipeline_tag: text-generation |
| | thumbnail: https://huggingface.co/mrm8488/mamba-coder/resolve/main/mamba-coder-no-bg.png |
| | language: |
| | - en |
| | - code |
| | --- |
| | |
| | # Mamba-Coder |
| | ## MAMBA (2.8B) 🐍 fine-tuned on CodeAlpaca_20k for code generation |
| | |
| | <div style="text-align:center;width:250px;height:250px;"> |
| | <img src="https://huggingface.co/mrm8488/mamba-coder/resolve/main/mamba-coder-no-bg.png" alt="mamba-coder logo""> |
| | </div> |
| | |
| | |
| | ## Base model info |
| | |
| | Mamba is a new state space model architecture showing promising performance on information-dense data such as language modeling, where previous subquadratic models fall short of Transformers. |
| | It is based on the line of progress on [structured state space models](https://github.com/state-spaces/s4), |
| | with an efficient hardware-aware design and implementation in the spirit of [FlashAttention](https://github.com/Dao-AILab/flash-attention). |
| | |
| | ## Dataset info |
| | |
| | [CodeAlpaca_20K](https://huggingface.co/datasets/HuggingFaceH4/CodeAlpaca_20K): contains 20K instruction-following data used for fine-tuning the Code Alpaca model. |
| | |
| | ## Usage |
| | |
| | ```sh |
| | pip install torch==2.1.0 transformers==4.35.0 causal-conv1d==1.0.0 mamba-ssm==1.0.1 |
| | ``` |
| | |
| | ```py |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel |
| |
|
| | CHAT_TEMPLATE_ID = "HuggingFaceH4/zephyr-7b-beta" |
| |
|
| | device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| | model_name = "mrm8488/mamba-coder" |
| |
|
| | eos_token = "<|endoftext|>" |
| | tokenizer = AutoTokenizer.from_pretrained(model_name) |
| | tokenizer.eos_token = eos_token |
| | tokenizer.pad_token = tokenizer.eos_token |
| | tokenizer.chat_template = AutoTokenizer.from_pretrained(CHAT_TEMPLATE_ID).chat_template |
| |
|
| | model = MambaLMHeadModel.from_pretrained( |
| | model_name, device=device, dtype=torch.float16) |
| |
|
| | messages = [] |
| | prompt = "Write a bash script to remove .tmp files" |
| | messages.append(dict(role="user", content=prompt)) |
| |
|
| | input_ids = tokenizer.apply_chat_template( |
| | messages, return_tensors="pt", add_generation_prompt=True |
| | ).to(device) |
| |
|
| | out = model.generate( |
| | input_ids=input_ids, |
| | max_length=2000, |
| | temperature=0.9, |
| | top_p=0.7, |
| | eos_token_id=tokenizer.eos_token_id, |
| | ) |
| | |
| | decoded = tokenizer.batch_decode(out) |
| | assistant_message = ( |
| | decoded[0].split("<|assistant|>\n")[-1].replace(eos_token, "") |
| | ) |
| | |
| | print(assistant_message) |
| | ``` |
| | |
| | |
| | ## Gradio Demo |
| | |
| | ```sh |
| | git clone https://github.com/mrm8488/mamba-chat.git |
| | cd mamba-chat |
| | |
| | pip install -r requirements.txt |
| | pip install -q gradio==4.8.0 |
| | |
| | python app.py \ |
| | --model mrm8488/mamba-coder \ |
| | --share |
| | ``` |
| | ## Evaluations |
| | |
| | Coming soon! |
| | |
| | |
| | ## Citation |
| | ```Bibtext |
| | @misc {manuel_romero_2024, |
| | author = { {Manuel Romero} }, |
| | title = { mamba-coder (Revision 214a13a) }, |
| | year = 2024, |
| | url = { https://huggingface.co/mrm8488/mamba-coder }, |
| | doi = { 10.57967/hf/1673 }, |
| | publisher = { Hugging Face } |
| | } |
| | ``` |
| | |
| | |
| | ## Acknowledgments |
| | |
| | Thanks to [mamba-chat](https://github.com/havenhq/mamba-chat/tree/main) for heavily inspiring our work |