File size: 2,125 Bytes
48d18a7 df74ec2 72bedc9 df74ec2 282048d ff02ce5 df74ec2 22631cc 72bedc9 22631cc 72bedc9 22631cc 72bedc9 8370716 996af39 22631cc ff02ce5 22631cc ff02ce5 df74ec2 996af39 df74ec2 f2e5081 ff02ce5 282048d df74ec2 f2e5081 df74ec2 f2e5081 ff02ce5 282048d df74ec2 f2e5081 df74ec2 75601c4 df74ec2 a87331a 72bedc9 48d18a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
'''
import os
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
app = FastAPI()
# Define the cache directory path within your project
cache_dir = str(Path(__file__).parent.resolve() / 'cache')
# Create the cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)
# Set the TRANSFORMERS_CACHE environment variable to the cache directory
os.environ['TRANSFORMERS_CACHE'] = cache_dir
print(f"Transformers cache directory: {os.environ['TRANSFORMERS_CACHE']}")
# Ensure your Hugging Face token is set as an environment variable
huggingface_token = os.environ.get("TOKEN")
if not huggingface_token:
raise ValueError("TOKEN environment variable is not set.")
# Load the tokenizer and model using Hugging Face's library with the token
try:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
# Initialize the pipeline
generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=0 # Assuming you're using a GPU, otherwise set to -1 for CPU
)
except Exception as e:
raise RuntimeError(f"Failed to load model: {e}")
# Data model for the request body
class Item(BaseModel):
prompt: str
temperature: float = 0.7
max_new_tokens: int = 128
# Endpoint for generating text
@app.post("/")
async def generate_text(item: Item):
try:
if not item.prompt:
raise HTTPException(status_code=400, detail="`prompt` field is required")
output = generator(
item.prompt,
temperature=item.temperature,
max_new_tokens=item.max_new_tokens,
)
return {"generated_text": output[0]['generated_text']}
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
''' |