| --- |
| library_name: gemma_torch |
| license: gemma |
| license_link: https://ai.google.dev/gemma/terms |
| pipeline_tag: text-generation |
| tags: |
| - pytorch |
| extra_gated_heading: Access CodeGemma on Hugging Face |
| extra_gated_prompt: To access CodeGemma on Hugging Face, you’re required to review |
| and agree to Google’s usage license. To do this, please ensure you’re logged-in |
| to Hugging Face and click below. Requests are processed immediately. |
| extra_gated_button_content: Acknowledge license |
| --- |
| |
| # CodeGemma Model Card |
|
|
| > [!IMPORTANT] |
| > |
| > This repository corresponds to the CodeGemma 7B IT checkpoint for use with [Gemma PyTorch](https://github.com/google/gemma_pytorch). If you're looking for the `transformers` implementation, or more detailed model card, visit https://huggingface.co/google/codegemma-7b-it. |
|
|
| **Model Page**: [CodeGemma](https://ai.google.dev/gemma/docs/codegemma) |
|
|
| **Resources and Technical Documentation**: |
|
|
| * [Technical Report](https://goo.gle/codegemma) |
| * [Responsible Generative AI Toolkit](https://ai.google.dev/responsible) |
|
|
| **Terms of Use**: [Terms](https://www.kaggle.com/models/google/codegemma/license/consent/verify/huggingface?returnModelRepoId=google/codegemma-7b-it-pytorch) |
|
|
| **Authors**: Google |
|
|
| # Sample Usage |
|
|
| ```python |
| from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b |
| from gemma.model import GemmaForCausalLM |
| from gemma.tokenizer import Tokenizer |
| import contextlib |
| import os |
| import torch |
| |
| VARIANT = "7b-it" |
| MACHINE_TYPE = "cpu" |
| weights_dir = 'codegemma-7b-it-pytorch' |
| |
| @contextlib.contextmanager |
| def _set_default_tensor_type(dtype: torch.dtype): |
| """Sets the default torch dtype to the given dtype.""" |
| torch.set_default_dtype(dtype) |
| yield |
| torch.set_default_dtype(torch.float) |
| |
| model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b() |
| model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model") |
| |
| device = torch.device(MACHINE_TYPE) |
| with _set_default_tensor_type(model_config.get_dtype()): |
| model = GemmaForCausalLM(model_config) |
| ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt') |
| model.load_weights(ckpt_path) |
| model = model.to(device).eval() |
| |
| PROMPT = """<start_of_turn>user |
| Write a Python function to calculate the nth fibonacci number.<end_of_turn> |
| <start_of_turn>model |
| """ |
| |
| model.generate( |
| PROMPT, |
| device=device, |
| output_len=100, |
| ) |
| ``` |