PinkAlpaca commited on
Commit
26e30fa
·
verified ·
1 Parent(s): 47384ec

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +8 -10
main.py CHANGED
@@ -6,17 +6,18 @@ from transformers import pipeline, TextGenerationPipeline
6
 
7
  app = FastAPI()
8
 
9
- # Get Gemini API key from environment variable
10
- gemini_api_key = os.environ.get("GEMINI_API_KEY")
11
- if not gemini_api_key:
12
- raise ValueError("GEMINI_API_KEY environment variable is not set.")
13
 
14
  # Load the Gemini model using Hugging Face's pipeline
15
- # Make sure to use a model you have access to
16
  generator: TextGenerationPipeline = pipeline(
17
  "text-generation",
18
- model="google/gemma-2-2b-it", # Replace if needed
19
- ) # IMPORTANT: **DO NOT** set `use_auth_token` here
 
20
 
21
  # Data model for the request body
22
  class Item(BaseModel):
@@ -31,14 +32,11 @@ async def generate_text(item: Item):
31
  if not item.prompt:
32
  raise HTTPException(status_code=400, detail="`prompt` field is required")
33
 
34
- # Set API key in the headers BEFORE calling the pipeline
35
- generator.model.config.use_auth_token = gemini_api_key # Set the API key here
36
  output = generator(
37
  item.prompt,
38
  temperature=item.temperature,
39
  max_length=item.max_new_tokens,
40
  )
41
- generator.model.config.use_auth_token = None # Reset after use
42
 
43
  return {"generated_text": output[0]['generated_text']}
44
 
 
6
 
7
  app = FastAPI()
8
 
9
+ # Get Gemini API key from environment variable (NOT USED FOR HUGGING FACE AUTH)
10
+ # gemini_api_key = os.environ.get("GEMINI_API_KEY")
11
+ # if not gemini_api_key:
12
+ # raise ValueError("GEMINI_API_KEY environment variable is not set.")
13
 
14
  # Load the Gemini model using Hugging Face's pipeline
15
+ # Make sure to use a model you have access to and is available on Hugging Face
16
  generator: TextGenerationPipeline = pipeline(
17
  "text-generation",
18
+ model="google/gemma-2-2b-it", # The model you want to use
19
+ use_auth_token= os.environ.get("TOKEN") # The Hugging Face token you got after login
20
+ )
21
 
22
  # Data model for the request body
23
  class Item(BaseModel):
 
32
  if not item.prompt:
33
  raise HTTPException(status_code=400, detail="`prompt` field is required")
34
 
 
 
35
  output = generator(
36
  item.prompt,
37
  temperature=item.temperature,
38
  max_length=item.max_new_tokens,
39
  )
 
40
 
41
  return {"generated_text": output[0]['generated_text']}
42