ek-5 commited on
Commit
d8360cc
·
verified ·
1 Parent(s): 18d635b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -38
app.py CHANGED
@@ -8,33 +8,28 @@ from PIL import Image
8
  import uvicorn
9
 
10
  # --- 1. إعداد التطبيق والموديلات ---
11
- app = FastAPI(title="YOLO + GIT Large Captioning API")
12
 
13
- # تحديد الجهاز
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- # مسار الموديل الخاص بك
17
  MY_MODEL_PATH = 'best.pt'
18
 
19
- print(f"🔄 جاري تحميل الموديلات على جهاز: {device}... يرجى الانتظار")
20
 
21
- # تحميل موديل YOLO
22
  try:
23
  detection_model = YOLO(MY_MODEL_PATH)
24
- print("✅ تم تحميل موديل YOLO الخاص بك بنجاح")
25
  except Exception as e:
26
- print(f"⚠️ فشل تحميل {MY_MODEL_PATH}، سيتم استخدام الموديل الافتراضي: {e}")
27
  detection_model = YOLO("yolov8n.pt")
28
 
29
- # --- التغيير هنا: استخدام microsoft/git-large ---
30
  model_name = "microsoft/git-large"
31
  processor = AutoProcessor.from_pretrained(model_name)
32
  caption_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
33
- print(f"✅ تم تحميل موديل {model_name} بنجاح")
34
 
35
  @app.get("/")
36
  def home():
37
- return {"status": "Online", "model": "GIT-Large", "instruction": "Add /docs to the URL to test"}
38
 
39
  # --- 2. وظيفة المعالجة ---
40
 
@@ -43,7 +38,6 @@ async def analyze_image(file: UploadFile = File(...)):
43
  data = await file.read()
44
  original_image = Image.open(io.BytesIO(data)).convert("RGB")
45
 
46
- # 1. الكشف باستخدام YOLO
47
  results = detection_model(original_image, conf=0.25)
48
  integrated_results = []
49
 
@@ -53,43 +47,48 @@ async def analyze_image(file: UploadFile = File(...)):
53
  conf_score = float(box.conf[0])
54
  coords = box.xyxy[0].tolist()
55
 
56
- # 2. عملية القص (Cropping)
57
- # إضافة هامش بسيط (Padding) للقص يحسن أحياناً من وصف الموديل
58
- cropped_img = original_image.crop((coords[0], coords[1], coords[2], coords[3]))
59
-
60
- # 3. وصف الجزء المقصوص عبر موديل GIT Large
61
- inputs = processor(images=cropped_img, return_tensors="pt").to(device)
 
 
 
 
 
 
 
62
 
63
- # ضبط البارامترات للحصول على أفضل وصف من نسخة Large
64
  generated_ids = caption_model.generate(
65
- pixel_values=inputs.pixel_values,
66
- max_length=50,
67
- num_beams=4 # استخدام beam search يحسن الجودة في نسخة Large
 
 
68
  )
69
- detailed_desc = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
70
 
71
  integrated_results.append({
72
  "object_id": i + 1,
73
  "label": label,
74
  "confidence": f"{conf_score:.2f}",
75
- "description": detailed_desc
76
  })
77
 
78
- # إذا لم يتم اكتشاف أجسام، وصف الصورة كاملة
79
  if not integrated_results:
80
  inputs = processor(images=original_image, return_tensors="pt").to(device)
81
  generated_ids = caption_model.generate(pixel_values=inputs.pixel_values, max_length=50)
82
- general_desc = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
- return {
84
- "message": "No specific objects detected. General description provided.",
85
- "general_description": general_desc
86
- }
87
-
88
- return {
89
- "detected_count": len(integrated_results),
90
- "results": integrated_results
91
- }
92
-
93
- # --- 3. تشغيل السيرفر (تصحيح الشرطات السفلية) ---
94
- if name == "__main__":
95
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
8
  import uvicorn
9
 
10
  # --- 1. إعداد التطبيق والموديلات ---
11
+ app = FastAPI(title="YOLO + GIT Large (Color & Shape) API")
12
 
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
14
  MY_MODEL_PATH = 'best.pt'
15
 
16
+ print(f"🔄 جاري تحميل الموديلات على: {device}...")
17
 
 
18
  try:
19
  detection_model = YOLO(MY_MODEL_PATH)
20
+ print("✅ تم تحميل YOLO بنجاح")
21
  except Exception as e:
22
+ print(f"⚠️ فشل تحميل الموديل الخاص، استخدام الافتراضي: {e}")
23
  detection_model = YOLO("yolov8n.pt")
24
 
25
+ # تحميل موديل GIT-large
26
  model_name = "microsoft/git-large"
27
  processor = AutoProcessor.from_pretrained(model_name)
28
  caption_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
 
29
 
30
  @app.get("/")
31
  def home():
32
+ return {"status": "Online", "mode": "Detailed Description (Color/Shape)"}
33
 
34
  # --- 2. وظيفة المعالجة ---
35
 
 
38
  data = await file.read()
39
  original_image = Image.open(io.BytesIO(data)).convert("RGB")
40
 
 
41
  results = detection_model(original_image, conf=0.25)
42
  integrated_results = []
43
 
 
47
  conf_score = float(box.conf[0])
48
  coords = box.xyxy[0].tolist()
49
 
50
+ # قص العنصر مع إضافة هامش بسيط (10 بكسل) لرؤية الألوان المحيطة والحواف بشكل أفضل
51
+ pad = 10
52
+ cropped_img = original_image.crop((
53
+ max(0, coords[0]-pad),
54
+ max(0, coords[1]-pad),
55
+ min(original_image.width, coords[2]+pad),
56
+ min(original_image.height, coords[3]+pad)
57
+ ))
58
+
59
+ # --- التعديل الجوهري هنا: توجيه الموديل لوصف اللون والشكل ---
60
+ # نضع نصاً توجيهياً (Prompt) ليقوم الموديل بتكملته
61
+ prompt = f"a photo of a {label}, describing its color and shape:"
62
+ inputs = processor(images=cropped_img, text=prompt, return_tensors="pt").to(device)
63
 
 
64
  generated_ids = caption_model.generate(
65
+ pixel_values=inputs.pixel_values,
66
+ input_ids=inputs.input_ids, # تمرير البرومبت للموديل
67
+ max_length=60,
68
+ num_beams=5,
69
+ repetition_penalty=1.2 # لمنع تكرار الكلمات
70
  )
71
+
72
+ # فك التشفير
73
+ full_desc = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
74
+
75
+ # تنظيف الوصف (إزالة البرومبت من البداية إذا ظهر)
76
+ clean_desc = full_desc.replace(prompt, "").strip()
77
 
78
  integrated_results.append({
79
  "object_id": i + 1,
80
  "label": label,
81
  "confidence": f"{conf_score:.2f}",
82
+ "visual_description": clean_desc
83
  })
84
 
 
85
  if not integrated_results:
86
  inputs = processor(images=original_image, return_tensors="pt").to(device)
87
  generated_ids = caption_model.generate(pixel_values=inputs.pixel_values, max_length=50)
88
+ return {"message": "No objects detected", "general_description": processor.batch_decode(generated_ids, skip_special_tokens=True)[0]}
89
+
90
+ return {"detected_count": len(integrated_results), "results": integrated_results}
91
+
92
+ # --- 3. تشغيل السيرفر (تم تصحيح الـ Syntax) ---
93
+ if __name__ == "__main__":
 
 
 
 
 
 
 
94
  uvicorn.run(app, host="0.0.0.0", port=7860)