PinkAlpaca commited on
Commit
a87331a
·
verified ·
1 Parent(s): c4aa949

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +29 -20
main.py CHANGED
@@ -2,12 +2,14 @@ import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  import uvicorn
5
- import requests # Use requests for HTTP calls to the Gemini API
 
 
6
 
7
  app = FastAPI()
8
 
9
  # Define the Gemini API endpoint for the primary model
10
- primary_url = "https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash-latest:generateContent"
11
 
12
  # Define the data model for the request body
13
  class Item(BaseModel):
@@ -22,6 +24,13 @@ class Item(BaseModel):
22
  repetition_penalty: float = 1.0
23
  key: str = None
24
 
 
 
 
 
 
 
 
25
  # Function to generate the response JSON
26
  def generate_response_json(item, output, tokens, model_name):
27
  return {
@@ -47,18 +56,19 @@ def generate_response_json(item, output, tokens, model_name):
47
 
48
  # Function to call the Gemini API
49
  def call_gemini_api(url, input_text, generate_kwargs):
 
50
  headers = {
51
- "Authorization": f"Bearer {os.getenv('GEMINI_API_KEY')}", # Ensure the API key is set in the environment
52
  "Content-Type": "application/json"
53
  }
54
  data = {
55
- "prompt": input_text,
56
- "temperature": generate_kwargs['temperature'],
57
- "max_output_tokens": generate_kwargs['max_new_tokens'],
58
- "top_p": generate_kwargs['top_p'],
59
- "repetition_penalty": generate_kwargs['repetition_penalty'],
60
- "do_sample": generate_kwargs['do_sample'],
61
- "seed": generate_kwargs['seed'],
62
  }
63
  try:
64
  response = requests.post(url, headers=headers, json=data)
@@ -79,14 +89,14 @@ async def generate_text(item: Item = None):
79
  raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
80
 
81
  input_ = ""
82
- if item.system_prompt != None and item.system_output != None:
83
  input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
84
- elif item.system_prompt != None:
85
  input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
86
- elif item.system_output != None:
87
  input_ = f"<s>{item.system_output}</s>"
88
 
89
- if item.templates != None:
90
  for num, template in enumerate(item.templates, start=1):
91
  input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
92
  for i in range(0, len(template), 2):
@@ -95,7 +105,7 @@ async def generate_text(item: Item = None):
95
  input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"
96
 
97
  input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
98
- if item.history != None:
99
  for input_text, output_text in item.history:
100
  input_ += f"\n<s>[INST] {input_text} [/INST]"
101
  input_ += f"\n{output_text}"
@@ -117,8 +127,8 @@ async def generate_text(item: Item = None):
117
 
118
  tokens = 0
119
  response_data = call_gemini_api(primary_url, input_, generate_kwargs)
120
- output = response_data['responses'][0]['content']
121
- tokens = len(response_data['responses'][0]['tokens'])
122
 
123
  return generate_response_json(item, output, tokens, primary_url)
124
 
@@ -128,6 +138,5 @@ async def generate_text(item: Item = None):
128
  except Exception as e:
129
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
130
 
131
- if "KEY" in os.environ:
132
- if item.key != os.environ["KEY"]:
133
- raise HTTPException(status_code=401, detail="Valid key is required.")
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  import uvicorn
5
+ import requests
6
+ from google.auth.transport.requests import Request
7
+ from google.oauth2 import id_token
8
 
9
  app = FastAPI()
10
 
11
  # Define the Gemini API endpoint for the primary model
12
+ primary_url = "https://us-east4-aiplatform.googleapis.com/v1/projects/gen-lang-client-0770444709/locations/us-east4/publishers/google/models/gemini-1.5-flash-latest:generateContent"
13
 
14
  # Define the data model for the request body
15
  class Item(BaseModel):
 
24
  repetition_penalty: float = 1.0
25
  key: str = None
26
 
27
+ # Function to obtain an OAuth 2.0 access token
28
+ def get_access_token():
29
+ auth_req = Request()
30
+ target_audience = primary_url.split("models")[0] # Extract the base URL
31
+ token = id_token.fetch_id_token(auth_req, target_audience)
32
+ return token
33
+
34
  # Function to generate the response JSON
35
  def generate_response_json(item, output, tokens, model_name):
36
  return {
 
56
 
57
  # Function to call the Gemini API
58
  def call_gemini_api(url, input_text, generate_kwargs):
59
+ access_token = get_access_token()
60
  headers = {
61
+ "Authorization": f"Bearer {access_token}",
62
  "Content-Type": "application/json"
63
  }
64
  data = {
65
+ "contents": [{"role": "user", "parts": [{"text": input_text}]}],
66
+ "generationConfig": {
67
+ "temperature": generate_kwargs['temperature'],
68
+ "maxOutputTokens": generate_kwargs['max_new_tokens'],
69
+ "topP": generate_kwargs['top_p'],
70
+ "repetitionPenalty": generate_kwargs['repetition_penalty'],
71
+ }
72
  }
73
  try:
74
  response = requests.post(url, headers=headers, json=data)
 
89
  raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
90
 
91
  input_ = ""
92
+ if item.system_prompt is not None and item.system_output is not None:
93
  input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
94
+ elif item.system_prompt is not None:
95
  input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
96
+ elif item.system_output is not None:
97
  input_ = f"<s>{item.system_output}</s>"
98
 
99
+ if item.templates is not None:
100
  for num, template in enumerate(item.templates, start=1):
101
  input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
102
  for i in range(0, len(template), 2):
 
105
  input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"
106
 
107
  input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
108
+ if item.history is not None:
109
  for input_text, output_text in item.history:
110
  input_ += f"\n<s>[INST] {input_text} [/INST]"
111
  input_ += f"\n{output_text}"
 
127
 
128
  tokens = 0
129
  response_data = call_gemini_api(primary_url, input_, generate_kwargs)
130
+ output = response_data['candidates'][0]['content']['parts'][0]['text']
131
+ tokens = response_data['usageMetadata']['totalTokenCount']
132
 
133
  return generate_response_json(item, output, tokens, primary_url)
134
 
 
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
140
 
141
+ if __name__ == "__main__":
142
+ uvicorn.run(app, host="0.0.0.0", port=8000)