SaniaE commited on
Commit
ed741f2
·
verified ·
1 Parent(s): 8f6f2d9

added app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import random
4
+ from PIL import Image
5
+ from fastapi import FastAPI, UploadFile, File, Query
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from huggingface_hub import snapshot_download, login
8
+ from transformers import (
9
+ BlipProcessor, BlipForConditionalGeneration,
10
+ ViTImageProcessor, AutoProcessor, AutoModelForCausalLM
11
+ )
12
+ from sentence_transformers import SentenceTransformer, util
13
+
14
+ app = FastAPI()
15
+
16
+ # Configuration
17
+ REPO_ID = "SaniaE/Image_Captioning_Ensemble"
18
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
+ MODELS = {}
20
+ SEARCH_MODEL = None
21
+
22
+ # We'll map your local folder names to the specific config
23
+ MODEL_SETTINGS = {
24
+ "blip": {
25
+ "subfolder": "blip",
26
+ "processor": BlipProcessor,
27
+ "pretrained_path": "Salesforce/blip-image-captioning-large",
28
+ "inference_model": BlipForConditionalGeneration
29
+ },
30
+ "vit": {
31
+ "subfolder": "vit",
32
+ "processor": [ViTImageProcessor, AutoProcessor],
33
+ "pretrained_path": ["nlpconnect/vit-gpt2-image-captioning", "microsoft/git-large"],
34
+ "inference_model": AutoModelForCausalLM
35
+ },
36
+ "git": {
37
+ "subfolder": "git",
38
+ "processor": AutoProcessor,
39
+ "pretrained_path": "microsoft/git-base",
40
+ "inference_model": AutoModelForCausalLM
41
+ }
42
+ }
43
+
44
+ @app.on_event("startup")
45
+ async def startup_event():
46
+ global MODELS, SEARCH_MODEL
47
+
48
+ # 1. Authenticate and Download from Private Repo
49
+ token = os.getenv("HF_TOKEN")
50
+ if token:
51
+ login(token=token)
52
+
53
+ print(f"Downloading ensemble models from {REPO_ID}...")
54
+ # This downloads the whole repo into a local 'weights' directory
55
+ local_dir = snapshot_download(repo_id=REPO_ID, token=token, local_dir="weights")
56
+
57
+ # 2. Load Models from the downloaded folders
58
+ for name, cfg in MODEL_SETTINGS.items():
59
+ ckpt_path = os.path.join(local_dir, cfg["subfolder"])
60
+ inf_model = cfg["inference_model"]
61
+ pretrained = cfg["pretrained_path"]
62
+ proc_class = cfg["processor"]
63
+
64
+ print(f"Loading {name} from {ckpt_path}...")
65
+ # from_pretrained handles .safetensors automatically
66
+ model = inf_model.from_pretrained(ckpt_path).to(DEVICE)
67
+
68
+ if name == "vit":
69
+ i_proc = proc_class[0].from_pretrained(pretrained[0])
70
+ t_proc = proc_class[1].from_pretrained(pretrained[1])
71
+ processor = (i_proc, t_proc)
72
+ else:
73
+ processor = proc_class.from_pretrained(pretrained)
74
+
75
+ MODELS[name] = {"model": model, "processor": processor}
76
+
77
+ SEARCH_MODEL = SentenceTransformer('clip-ViT-B-32')
78
+ print("Ensemble is live!")
79
+
80
+ @app.post("/generate")
81
+ async def generate_endpoint(
82
+ file: UploadFile = File(...),
83
+ temp: float = Query(0.8),
84
+ top_k: int = Query(100),
85
+ top_p: float = Query(0.9)
86
+ ):
87
+ image = Image.open(file.file).convert("RGB")
88
+ captions = []
89
+
90
+ # Randomly select which models to use for the 5 slots
91
+ available = list(MODELS.keys())
92
+ model_selection = random.choices(available, k=5)
93
+
94
+ for m_name in model_selection:
95
+ m_data = MODELS[m_name]
96
+ model = m_data["model"]
97
+
98
+ if m_name == "vit":
99
+ i_proc, t_proc = m_data["processor"]
100
+ pixel_values = i_proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
101
+ gen_ids = model.generate(
102
+ pixel_values=pixel_values, max_length=200, do_sample=True,
103
+ temperature=temp, top_k=top_k, top_p=top_p
104
+ )
105
+ cap = t_proc.batch_decode(gen_ids, skip_special_tokens=True)[0]
106
+ else:
107
+ proc = m_data["processor"]
108
+ pixel_values = proc(images=image, return_tensors="pt").pixel_values.to(DEVICE)
109
+ gen_ids = model.generate(
110
+ pixel_values=pixel_values, max_length=200, do_sample=True,
111
+ temperature=temp, top_k=top_k, top_p=top_p
112
+ )
113
+ cap = proc.batch_decode(gen_ids, skip_special_tokens=True)[0]
114
+
115
+ captions.append(cap.strip())
116
+
117
+ return {"captions": captions, "mix": model_selection}
118
+
119
+ @app.post("/ui-tester")
120
+ async def ui_tester(file: UploadFile = File(...), description: str = Query(...)):
121
+ """Matches a user description against an image using CLIP embeddings."""
122
+ image = Image.open(file.file).convert("RGB")
123
+
124
+ img_emb = SEARCH_MODEL.encode(image)
125
+ txt_emb = SEARCH_MODEL.encode(description)
126
+
127
+ # Calculate cosine similarity
128
+ score = util.cos_sim(img_emb, txt_emb).item()
129
+
130
+ return {
131
+ "match_score": round(score, 4),
132
+ "is_match": score > 0.25, # Threshold can be adjusted
133
+ "status": "High correlation" if score > 0.3 else "Low correlation"
134
+ }
135
+
136
+ @app.get("/ui-search")
137
+ async def ui_search(description: str = Query(...)):
138
+ """Returns top image matches from a gallery based on a text description."""
139
+ if not IMAGE_GALLERY_EMBEDDINGS:
140
+ return {"error": "Gallery not initialized"}
141
+
142
+ query_emb = SEARCH_MODEL.encode(description)
143
+ hits = util.semantic_search(query_emb, IMAGE_GALLERY_EMBEDDINGS, top_k=3)
144
+
145
+ results = []
146
+ for hit in hits[0]:
147
+ results.append({
148
+ "image_path": IMAGE_PATHS[hit['corpus_id']],
149
+ "score": round(hit['score'], 4)
150
+ })
151
+
152
+ return {"results": results}