object-detection / visualize.py
wanlinxing's picture
Upload 3 files
2af19bc verified
Raw
History Blame Contribute Delete
9.59 kB
"""
检测结果可视化模块
在图片上绘制检测框、类别标签和置信度
"""
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
from io import BytesIO
def draw_detections_cv2(image, detections, class_names, colors):
"""
使用 OpenCV 绘制检测结果(适合中文环境用 PIL 方案更好)
Args:
image: numpy array (H, W, 3) BGR
detections: list of [x1, y1, x2, y2, score, class_id]
class_names: list of class name strings
colors: list of (R, G, B) tuples
Returns:
image with boxes drawn (BGR)
"""
result = image.copy()
for det in detections:
x1, y1, x2, y2, score, cls_id = det
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
cls_id = int(cls_id)
color = tuple(colors[cls_id]) if cls_id < len(colors) else (0, 255, 0)
class_name = class_names[cls_id] if cls_id < len(class_names) else f"cls_{cls_id}"
label = f"{class_name}: {score:.2f}"
# 绘制边界框
cv2.rectangle(result, (x1, y1), (x2, y2), color, 2)
# 绘制标签背景
(label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
cv2.rectangle(result, (x1, y1 - label_h - 10), (x1 + label_w, y1), color, -1)
# 绘制标签文字
cv2.putText(result, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
return result
def draw_detections_pil(image, detections, class_names, colors, font_path=None):
"""
使用 PIL 绘制检测结果(支持中文,效果更好)
Args:
image: numpy array (H, W, 3) RGB
detections: list of [x1, y1, x2, y2, score, class_id]
class_names: list of class name strings
colors: list of (R, G, B) tuples
font_path: 中文字体路径(可选)
Returns:
image with boxes drawn (RGB, numpy array)
"""
result = Image.fromarray(image.astype(np.uint8))
draw = ImageDraw.Draw(result)
# 尝试加载字体
try:
if font_path:
font = ImageFont.truetype(font_path, 16)
small_font = ImageFont.truetype(font_path, 12)
else:
# 尝试系统默认字体
font = ImageFont.truetype("arial.ttf", 16)
small_font = ImageFont.truetype("arial.ttf", 12)
except Exception:
font = ImageFont.load_default()
small_font = ImageFont.load_default()
for det in detections:
x1, y1, x2, y2, score, cls_id = det
cls_id = int(cls_id)
color = tuple(colors[cls_id]) if cls_id < len(colors) else (0, 255, 0)
class_name = class_names[cls_id] if cls_id < len(class_names) else f"cls_{cls_id}"
label = f"{class_name}: {score:.2f}"
# 绘制边界框(加粗)
for offset in range(2):
draw.rectangle(
[x1 - offset, y1 - offset, x2 + offset, y2 + offset],
outline=color,
width=1
)
# 绘制标签背景和文字
try:
text_bbox = draw.textbbox((0, 0), label, font=small_font)
text_w = text_bbox[2] - text_bbox[0]
text_h = text_bbox[3] - text_bbox[1]
except Exception:
text_w, text_h = 60, 14
label_y = max(0, y1 - text_h - 6)
draw.rectangle([x1, label_y, x1 + text_w + 4, label_y + text_h + 4], fill=color)
draw.text((x1 + 2, label_y + 2), label, fill=(255, 255, 255), font=small_font)
return np.array(result)
def draw_detections_pil_enhanced(image, detections, class_names, colors):
"""
增强版 PIL 绘制,支持阴影、圆角、半透明等美化效果
Args:
image: numpy array (H, W, 3) RGB
detections: list of [x1, y1, x2, y2, score, class_id]
class_names: list of class name strings
colors: list of (R, G, B) tuples
Returns:
image with beautified boxes drawn (RGB, numpy array)
"""
from PIL import ImageDraw
h, w = image.shape[:2]
result = Image.fromarray(image.astype(np.uint8))
overlay = Image.new('RGBA', (w, h), (0, 0, 0, 0))
overlay_draw = ImageDraw.Draw(overlay)
draw = ImageDraw.Draw(result)
# 尝试加载字体
try:
# Windows 中文字体
for font_path in [
"C:/Windows/Fonts/msyh.ttc",
"C:/Windows/Fonts/simhei.ttf",
"C:/Windows/Fonts/msyhbd.ttc",
"/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc",
]:
try:
font = ImageFont.truetype(font_path, 16)
small_font = ImageFont.truetype(font_path, 12)
break
except Exception:
continue
else:
font = ImageFont.load_default()
small_font = ImageFont.load_default()
except Exception:
font = ImageFont.load_default()
small_font = ImageFont.load_default()
# 按置信度降序绘制(高置信度在上层)
sorted_dets = sorted(detections, key=lambda x: x[4])
for det in sorted_dets:
x1, y1, x2, y2, score, cls_id = det
x1, y1 = max(0, int(x1)), max(0, int(y1))
x2, y2 = min(w, int(x2)), min(h, int(y2))
cls_id = int(cls_id)
if x2 <= x1 or y2 <= y1:
continue
color = tuple(colors[cls_id]) if cls_id < len(colors) else (0, 255, 0)
class_name = class_names[cls_id] if cls_id < len(class_names) else f"cls_{cls_id}"
# 根据置信度调整线宽
if score >= 0.7:
line_width = 4
shadow_offset = 3
elif score >= 0.4:
line_width = 3
shadow_offset = 2
else:
line_width = 2
shadow_offset = 1
# === 绘制阴影 ===
shadow_color = (0, 0, 0, 40)
for dx in range(shadow_offset):
overlay_draw.rectangle(
[x1 + shadow_offset, y1 + shadow_offset, x2 + shadow_offset, y2 + shadow_offset],
outline=shadow_color,
width=line_width + 2,
)
# === 绘制半透明填充 ===
fill_alpha = 25
if score >= 0.7:
fill_alpha = 35
fill_rgba = color + (fill_alpha,)
overlay_draw.rectangle([x1, y1, x2, y2], fill=fill_rgba)
# === 绘制主边界框 ===
# 外层(深色边缘)
dark_color = tuple(max(0, c - 40) for c in color)
overlay_draw.rectangle([x1, y1, x2, y2], outline=dark_color + (220,), width=line_width + 2)
# 内层(亮色)
overlay_draw.rectangle([x1, y1, x2, y2], outline=color + (255,), width=line_width)
# === 绘制角标(四角加粗) ===
corner_len = min(15, (x2 - x1) // 4, (y2 - y1) // 4)
for cx, cy, dx, dy in [
(x1, y1, 1, 1), (x2, y1, -1, 1),
(x1, y2, 1, -1), (x2, y2, -1, -1),
]:
overlay_draw.line(
[(cx, cy), (cx + corner_len * dx, cy), (cx, cy), (cx, cy + corner_len * dy)],
fill=color + (230,),
width=line_width + 1,
)
# === 绘制标签 ===
label_text = f"{class_name} {score:.0%}"
try:
text_bbox = draw.textbbox((0, 0), label_text, font=small_font)
text_w = text_bbox[2] - text_bbox[0]
text_h = text_bbox[3] - text_bbox[1]
except Exception:
text_w, text_h = len(label_text) * 7, 14
padding = 6
label_bg_w = text_w + padding * 2
label_bg_h = text_h + padding
# 标签位置(优先放框内顶部,框太小就放框外)
if y1 > label_bg_h + 8:
label_y = y1 - label_bg_h - 2
label_bg_y1 = label_y
label_bg_y2 = y1
else:
label_y = y1 + 2
label_bg_y1 = y1
label_bg_y2 = y1 + label_bg_h
label_x = x1 + 2
# 标签背景(扩展到框外)
overlay_draw.rectangle(
[label_x, label_bg_y1, label_x + label_bg_w, label_bg_y2],
fill=color + (200,),
)
# 标签文字
draw.text(
(label_x + padding, label_y + padding // 2),
label_text,
fill=(255, 255, 255),
font=small_font,
)
# 合成
result = Image.alpha_composite(result.convert('RGBA'), overlay)
return np.array(result.convert('RGB'))
def draw_detections(image, detections, class_names, colors):
"""
综合绘制检测结果
"""
try:
return draw_detections_pil_enhanced(image, detections, class_names, colors)
except Exception:
return draw_detections_pil(image, detections, class_names, colors)
def create_summary_text(detections, class_names):
"""
生成检测结果摘要文本
Args:
detections: list of [x1, y1, x2, y2, score, class_id]
class_names: list of class name strings
Returns:
summary string
"""
if not detections:
return "未检测到任何目标。"
# 统计各类别数量
class_counts = {}
for det in detections:
cls_id = int(det[5])
class_name = class_names[cls_id] if cls_id < len(class_names) else f"未知_{cls_id}"
class_counts[class_name] = class_counts.get(class_name, 0) + 1
lines = [f"共检测到 {len(detections)} 个目标:"]
for name, count in class_counts.items():
lines.append(f" • {name}: {count} 个")
return "\n".join(lines)