File size: 7,187 Bytes
9545fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""
数据集格式验证脚本
用于验证 train_loader 加载的 input 和 target 格式
特别是验证 target[0] 是否为 [image_idx, class_id, x_center, y_center, width, height]
"""
import os
import sys
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(BASE_DIR)

import torch
import torchvision.transforms as transforms
from lib.config import cfg
import lib.dataset as dataset
from lib.utils import DataLoaderX

def check_dataset_format():
    """验证数据集加载格式"""
    
    print("="*80)
    print("开始验证数据集加载格式...")
    print("="*80)
    
    # 数据预处理
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
    
    # 创建训练数据集
    print("\n1. 创建数据集...")
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg=cfg,
        is_train=True,
        inputsize=cfg.MODEL.IMAGE_SIZE,
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )
    print(f"   数据集类型: {cfg.DATASET.DATASET}")
    print(f"   数据集大小: {len(train_dataset)}")
    
    # 打印类别数量
    if hasattr(train_dataset, 'names'):
        print(f"   数据集类别: {train_dataset.names}")
        print(f"   类别数量: {len(train_dataset.names)}")
    else:
        print("   数据集没有 names 属性")
    # 打印类别数量
    if hasattr(train_dataset, "names"):
        print(f"   数据集类别数量: {len(train_dataset.names)}")
    else:
        print("   数据集不包含 names 属性,无法统计类别数量。")
    
    # 创建 DataLoader
    print("\n2. 创建 DataLoader...")
    train_loader = DataLoaderX(
        train_dataset,
        batch_size=4,  # 使用小 batch_size 方便查看
        shuffle=False,
        num_workers=0,  # Windows 上使用 0
        pin_memory=False,
        collate_fn=dataset.AutoDriveDataset.collate_fn
    )
    print(f"   Batch size: ")
    print(f"   Total batches: {len(train_loader)}")
    
    # 获取第一个 batch
    print("\n3. 加载第一个 batch...")
    for i, (input, target, paths, shapes) in enumerate(train_loader):
        print("\n" + "="*80)
        print(f"Batch {i} 数据格式分析:")
        print("="*80)
        
        # 分析 input
        print("\n[INPUT - 图像数据]")
        print(f"  类型: {type(input)}")
        print(f"  形状: {input.shape}")
        print(f"  dtype: {input.dtype}")
        print(f"  值范围: [{input.min():.3f}, {input.max():.3f}]")
        
        # 分析 target
        print("\n[TARGET - 标注数据]")
        print(f"  类型: {type(target)}")
        print(f"  长度: {len(target)} (包含 3 个元素: det, da_seg, ll_seg)")
        
        # target[0] - 检测标签 (最重要)
        print(f"\n  target[0] - 检测标签 (Detection Labels):")
        print(f"    类型: {type(target[0])}")
        print(f"    形状: {target[0].shape}")
        print(f"    dtype: {target[0].dtype}")
        print(f"    说明: [N, 6] 其中 N 是所有图片的目标总数,6 维度为:")
        print(f"          [image_idx, class_id, x_center, y_center, width, height]")
        
        # 打印前几个样本
        if target[0].shape[0] > 0:
            print(f"\n    前 5 个目标样本:")
            print(f"    {'索引':<6} {'img_idx':<10} {'class_id':<10} {'x_center':<12} {'y_center':<12} {'width':<12} {'height':<12}")
            print(f"    {'-'*76}")
            for idx in range(min(5, target[0].shape[0])):
                obj = target[0][idx]
                print(f"    {idx:<6} {obj[0].item():<10.0f} {obj[1].item():<10.0f} {obj[2].item():<12.6f} {obj[3].item():<12.6f} {obj[4].item():<12.6f} {obj[5].item():<12.6f}")
            
            # 验证归一化
            print(f"\n    验证坐标是否归一化到 [0, 1]:")
            xywh_data = target[0][:, 2:]  # 提取 xywh 坐标
            print(f"      x_center 范围: [{xywh_data[:, 0].min():.6f}, {xywh_data[:, 0].max():.6f}]")
            print(f"      y_center 范围: [{xywh_data[:, 1].min():.6f}, {xywh_data[:, 1].max():.6f}]")
            print(f"      width    范围: [{xywh_data[:, 2].min():.6f}, {xywh_data[:, 2].max():.6f}]")
            print(f"      height   范围: [{xywh_data[:, 3].min():.6f}, {xywh_data[:, 3].max():.6f}]")
            
            # 检查是否归一化
            is_normalized = (xywh_data >= 0).all() and (xywh_data <= 1).all()
            if is_normalized:
                print(f"      ✓ 坐标已归一化到 [0, 1]")
            else:
                print(f"      ✗ 警告: 坐标未完全归一化!")
            
            # 统计每张图片的目标数量
            print(f"\n    每张图片的目标数量:")
            for img_idx in range(input.shape[0]):
                count = (target[0][:, 0] == img_idx).sum().item()
                print(f"      图片 {img_idx}: {count} 个目标")
        else:
            print(f"    (该 batch 没有检测目标)")
        
        # target[1] - 驾驶区域分割标签
        print(f"\n  target[1] - 驾驶区域分割标签 (Drivable Area Segmentation):")
        print(f"    类型: {type(target[1])}")
        print(f"    形状: {target[1].shape}")
        print(f"    dtype: {target[1].dtype}")
        print(f"    值范围: [{target[1].min():.3f}, {target[1].max():.3f}]")
        print(f"    说明: [batch_size, num_classes, H, W]")
        
        # target[2] - 车道线分割标签
        print(f"\n  target[2] - 车道线分割标签 (Lane Line Segmentation):")
        print(f"    类型: {type(target[2])}")
        print(f"    形状: {target[2].shape}")
        print(f"    dtype: {target[2].dtype}")
        print(f"    值范围: [{target[2].min():.3f}, {target[2].max():.3f}]")
        print(f"    说明: [batch_size, num_classes, H, W]")
        
        # 分析 paths
        print(f"\n[PATHS - 图像路径]")
        print(f"  类型: {type(paths)}")
        print(f"  长度: {len(paths)}")
        if len(paths) > 0:
            print(f"  示例路径:")
            for idx, path in enumerate(paths):
                print(f"    [{idx}] {path}")
        
        # 分析 shapes
        print(f"\n[SHAPES - 图像尺寸信息]")
        print(f"  类型: {type(shapes)}")
        print(f"  长度: {len(shapes)}")
        if len(shapes) > 0:
            print(f"  示例 (原始尺寸, ((缩放比例), (padding))):")
            for idx, shape in enumerate(shapes[:2]):  # 只显示前2个
                print(f"    [{idx}] {shape}")
        
        print("\n" + "="*80)
        print("验证结论:")
        print("="*80)
        print("✓ target[0] 格式为: [image_idx, class_id, x_center, y_center, width, height]")
        print("✓ xywh 坐标已归一化到 [0, 1]")
        print("✓ image_idx 用于区分 batch 中不同图片的目标")
        print("✓ class_id 表示目标类别")
        print("="*80)
        
        # 只查看第一个 batch
        break
    
    print("\n验证完成!")


if __name__ == '__main__':
    check_dataset_format()