Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| import os | |
| from .config import settings | |
| class SyllabusGenerator: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.is_gpu = settings.DEVICE == "cuda" and torch.cuda.is_available() | |
| self.load_model() | |
| def load_model(self): | |
| model_id = settings.MODEL_ID_GPU if self.is_gpu else settings.MODEL_ID_CPU | |
| print(f"Loading model: {model_id} on {'GPU' if self.is_gpu else 'CPU'}") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=settings.HF_TOKEN) | |
| # Load model kwargs | |
| load_kwargs = { | |
| "torch_dtype": torch.bfloat16 if self.is_gpu else torch.float32, | |
| "token": settings.HF_TOKEN | |
| } | |
| if self.is_gpu: | |
| load_kwargs["device_map"] = "auto" | |
| load_kwargs["load_in_4bit"] = True | |
| else: | |
| # For CPU, we rely on standard loading. | |
| # accelerate's device_map="auto" can also be used on CPU but sometimes | |
| # standard loading is more predictable for smaller models if we don't need offloading. | |
| # We'll explicitly move to CPU after load or just let it default. | |
| pass | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| **load_kwargs | |
| ) | |
| if not self.is_gpu: | |
| self.model.to("cpu") | |
| def generate(self, prompt: str, context: str = "", max_tokens: int = 1000): | |
| full_prompt = f""" | |
| System: You are an expert academic curriculum designer. | |
| Use the following context to generate a detailed syllabus. | |
| Context: {context} | |
| Question: {prompt} | |
| Syllabus Output: | |
| """ | |
| inputs = self.tokenizer(full_prompt, return_tensors="pt") | |
| if self.is_gpu: | |
| inputs = inputs.to("cuda") | |
| else: | |
| inputs = inputs.to("cpu") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.7 | |
| ) | |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| generator = SyllabusGenerator() | |