PinkAlpaca commited on
Commit
f2e5081
·
verified ·
1 Parent(s): a87331a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -122
main.py CHANGED
@@ -2,141 +2,40 @@ import os
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  import uvicorn
5
- import requests
6
- from google.auth.transport.requests import Request
7
- from google.oauth2 import id_token
8
 
9
  app = FastAPI()
10
 
11
- # Define the Gemini API endpoint for the primary model
12
- primary_url = "https://us-east4-aiplatform.googleapis.com/v1/projects/gen-lang-client-0770444709/locations/us-east4/publishers/google/models/gemini-1.5-flash-latest:generateContent"
 
 
 
13
 
14
- # Define the data model for the request body
15
  class Item(BaseModel):
16
- input: str = None
17
- system_prompt: str = None
18
- system_output: str = None
19
- history: list = None
20
- templates: list = None
21
- temperature: float = 0.0
22
- max_new_tokens: int = 1048
23
- top_p: float = 0.15
24
- repetition_penalty: float = 1.0
25
- key: str = None
26
-
27
- # Function to obtain an OAuth 2.0 access token
28
- def get_access_token():
29
- auth_req = Request()
30
- target_audience = primary_url.split("models")[0] # Extract the base URL
31
- token = id_token.fetch_id_token(auth_req, target_audience)
32
- return token
33
-
34
- # Function to generate the response JSON
35
- def generate_response_json(item, output, tokens, model_name):
36
- return {
37
- "settings": {
38
- "input": item.input if item.input is not None else "",
39
- "system prompt": item.system_prompt if item.system_prompt is not None else "",
40
- "system output": item.system_output if item.system_output is not None else "",
41
- "temperature": f"{item.temperature}" if item.temperature is not None else "",
42
- "max new tokens": f"{item.max_new_tokens}" if item.max_new_tokens is not None else "",
43
- "top p": f"{item.top_p}" if item.top_p is not None else "",
44
- "repetition penalty": f"{item.repetition_penalty}" if item.repetition_penalty is not None else "",
45
- "do sample": "True",
46
- "seed": "42"
47
- },
48
- "response": {
49
- "output": output.strip().lstrip('\n').rstrip('\n').lstrip('').rstrip('').strip(),
50
- "unstripped": output,
51
- "tokens": tokens,
52
- "model": "primary",
53
- "name": model_name
54
- }
55
- }
56
-
57
- # Function to call the Gemini API
58
- def call_gemini_api(url, input_text, generate_kwargs):
59
- access_token = get_access_token()
60
- headers = {
61
- "Authorization": f"Bearer {access_token}",
62
- "Content-Type": "application/json"
63
- }
64
- data = {
65
- "contents": [{"role": "user", "parts": [{"text": input_text}]}],
66
- "generationConfig": {
67
- "temperature": generate_kwargs['temperature'],
68
- "maxOutputTokens": generate_kwargs['max_new_tokens'],
69
- "topP": generate_kwargs['top_p'],
70
- "repetitionPenalty": generate_kwargs['repetition_penalty'],
71
- }
72
- }
73
- try:
74
- response = requests.post(url, headers=headers, json=data)
75
- response.raise_for_status()
76
- return response.json()
77
- except requests.exceptions.HTTPError as http_err:
78
- raise HTTPException(status_code=response.status_code, detail=f"HTTP error occurred: {http_err}")
79
- except Exception as err:
80
- raise HTTPException(status_code=500, detail=f"An error occurred: {err}")
81
 
82
  # Endpoint for generating text
83
  @app.post("/")
84
- async def generate_text(item: Item = None):
85
  try:
86
- if item is None:
87
- raise HTTPException(status_code=400, detail="JSON body is required.")
88
- if item.input is None and item.system_prompt is None or item.input == "" and item.system_prompt == "":
89
- raise HTTPException(status_code=400, detail="Parameter `input` or `system prompt` is required.")
90
-
91
- input_ = ""
92
- if item.system_prompt is not None and item.system_output is not None:
93
- input_ = f"<s>[INST] {item.system_prompt} [/INST] {item.system_output}</s>"
94
- elif item.system_prompt is not None:
95
- input_ = f"<s>[INST] {item.system_prompt} [/INST]</s>"
96
- elif item.system_output is not None:
97
- input_ = f"<s>{item.system_output}</s>"
98
-
99
- if item.templates is not None:
100
- for num, template in enumerate(item.templates, start=1):
101
- input_ += f"\n<s>[INST] Beginning of archived conversation {num} [/INST]</s>"
102
- for i in range(0, len(template), 2):
103
- input_ += f"\n<s>[INST] {template[i]} [/INST]"
104
- input_ += f"\n{template[i + 1]}</s>"
105
- input_ += f"\n<s>[INST] End of archived conversation {num} [/INST]</s>"
106
-
107
- input_ += f"\n<s>[INST] Beginning of active conversation [/INST]</s>"
108
- if item.history is not None:
109
- for input_text, output_text in item.history:
110
- input_ += f"\n<s>[INST] {input_text} [/INST]"
111
- input_ += f"\n{output_text}"
112
- input_ += f"\n<s>[INST] {item.input} [/INST]"
113
-
114
- temperature = float(item.temperature)
115
- if temperature < 1e-2:
116
- temperature = 1e-2
117
- top_p = float(item.top_p)
118
-
119
- generate_kwargs = dict(
120
- temperature=temperature,
121
- max_new_tokens=item.max_new_tokens,
122
- top_p=top_p,
123
- repetition_penalty=item.repetition_penalty,
124
- do_sample=True,
125
- seed=42,
126
  )
127
 
128
- tokens = 0
129
- response_data = call_gemini_api(primary_url, input_, generate_kwargs)
130
- output = response_data['candidates'][0]['content']['parts'][0]['text']
131
- tokens = response_data['usageMetadata']['totalTokenCount']
132
-
133
- return generate_response_json(item, output, tokens, primary_url)
134
-
135
- except HTTPException as http_error:
136
- raise http_error
137
 
138
  except Exception as e:
139
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
140
 
141
  if __name__ == "__main__":
142
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  import uvicorn
5
+ from transformers import pipeline, TextGenerationPipeline
 
 
6
 
7
  app = FastAPI()
8
 
9
+ # Load the Gemini model using Hugging Face's pipeline
10
+ generator: TextGenerationPipeline = pipeline(
11
+ "text-generation",
12
+ model="google/gemini-1.5-flash" # Or the specific Gemini model you have access to
13
+ )
14
 
15
+ # Define the data model for the request body (simplified for clarity)
16
  class Item(BaseModel):
17
+ prompt: str
18
+ temperature: float = 0.7 # Adjust default as needed
19
+ max_new_tokens: int = 128 # Adjust default as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # Endpoint for generating text
22
  @app.post("/")
23
+ async def generate_text(item: Item):
24
  try:
25
+ if not item.prompt:
26
+ raise HTTPException(status_code=400, detail="`prompt` field is required")
27
+
28
+ # Call the Gemini model through Hugging Face's pipeline
29
+ output = generator(
30
+ item.prompt,
31
+ temperature=item.temperature,
32
+ max_length=item.max_new_tokens,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
 
35
+ return {"generated_text": output[0]['generated_text']}
 
 
 
 
 
 
 
 
36
 
37
  except Exception as e:
38
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
39
 
40
  if __name__ == "__main__":
41
+ uvicorn.run(app, host="0.0.0.0", port=8000)