File size: 2,125 Bytes
48d18a7
df74ec2
72bedc9
df74ec2
 
282048d
ff02ce5
df74ec2
 
 
22631cc
72bedc9
22631cc
 
72bedc9
22631cc
 
 
 
72bedc9
 
8370716
 
 
 
996af39
22631cc
ff02ce5
22631cc
 
ff02ce5
 
 
 
 
 
 
 
 
 
df74ec2
996af39
df74ec2
f2e5081
ff02ce5
 
282048d
df74ec2
 
f2e5081
df74ec2
f2e5081
 
 
 
 
 
ff02ce5
282048d
df74ec2
f2e5081
df74ec2
 
75601c4
df74ec2
a87331a
72bedc9
48d18a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
'''
import os
from pathlib import Path
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM

app = FastAPI()

# Define the cache directory path within your project
cache_dir = str(Path(__file__).parent.resolve() / 'cache')

# Create the cache directory if it doesn't exist
os.makedirs(cache_dir, exist_ok=True)

# Set the TRANSFORMERS_CACHE environment variable to the cache directory
os.environ['TRANSFORMERS_CACHE'] = cache_dir

print(f"Transformers cache directory: {os.environ['TRANSFORMERS_CACHE']}")

# Ensure your Hugging Face token is set as an environment variable
huggingface_token = os.environ.get("TOKEN")
if not huggingface_token:
    raise ValueError("TOKEN environment variable is not set.")

# Load the tokenizer and model using Hugging Face's library with the token
try:
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", token=huggingface_token)

    # Initialize the pipeline
    generator = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        device=0  # Assuming you're using a GPU, otherwise set to -1 for CPU
    )
except Exception as e:
    raise RuntimeError(f"Failed to load model: {e}")

# Data model for the request body
class Item(BaseModel):
    prompt: str
    temperature: float = 0.7
    max_new_tokens: int = 128

# Endpoint for generating text
@app.post("/")
async def generate_text(item: Item):
    try:
        if not item.prompt:
            raise HTTPException(status_code=400, detail="`prompt` field is required")

        output = generator(
            item.prompt,
            temperature=item.temperature,
            max_new_tokens=item.max_new_tokens,
        )

        return {"generated_text": output[0]['generated_text']}

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
'''