PinkAlpaca commited on
Commit
72bedc9
·
verified ·
1 Parent(s): 6a35b35

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +6 -4
main.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  import uvicorn
@@ -7,15 +8,16 @@ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
7
  app = FastAPI()
8
 
9
  # Define the cache directory path within your project
10
- cache_dir = './cache'
11
 
12
  # Create the cache directory if it doesn't exist
13
- if not os.path.exists(cache_dir):
14
- os.makedirs(cache_dir, exist_ok=True)
15
 
16
  # Set the TRANSFORMERS_CACHE environment variable to the cache directory
17
  os.environ['TRANSFORMERS_CACHE'] = cache_dir
18
 
 
 
19
  # Ensure your Hugging Face token is set as an environment variable
20
  huggingface_token = os.environ.get("TOKEN")
21
  if not huggingface_token:
@@ -61,4 +63,4 @@ async def generate_text(item: Item):
61
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
62
 
63
  if __name__ == "__main__":
64
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  import os
2
+ from pathlib import Path
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  import uvicorn
 
8
  app = FastAPI()
9
 
10
  # Define the cache directory path within your project
11
+ cache_dir = str(Path(__file__).parent.resolve() / 'cache')
12
 
13
  # Create the cache directory if it doesn't exist
14
+ os.makedirs(cache_dir, exist_ok=True)
 
15
 
16
  # Set the TRANSFORMERS_CACHE environment variable to the cache directory
17
  os.environ['TRANSFORMERS_CACHE'] = cache_dir
18
 
19
+ print(f"Transformers cache directory: {os.environ['TRANSFORMERS_CACHE']}")
20
+
21
  # Ensure your Hugging Face token is set as an environment variable
22
  huggingface_token = os.environ.get("TOKEN")
23
  if not huggingface_token:
 
63
  raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
64
 
65
  if __name__ == "__main__":
66
+ uvicorn.run(app, host="0.0.0.0", port=8000)