|
|
""" |
|
|
数据集格式验证脚本 |
|
|
用于验证 train_loader 加载的 input 和 target 格式 |
|
|
特别是验证 target[0] 是否为 [image_idx, class_id, x_center, y_center, width, height] |
|
|
""" |
|
|
import os |
|
|
import sys |
|
|
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
|
|
sys.path.append(BASE_DIR) |
|
|
|
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
from lib.config import cfg |
|
|
import lib.dataset as dataset |
|
|
from lib.utils import DataLoaderX |
|
|
|
|
|
def check_dataset_format(): |
|
|
"""验证数据集加载格式""" |
|
|
|
|
|
print("="*80) |
|
|
print("开始验证数据集加载格式...") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
normalize = transforms.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
|
|
|
|
|
|
print("\n1. 创建数据集...") |
|
|
train_dataset = eval('dataset.' + cfg.DATASET.DATASET)( |
|
|
cfg=cfg, |
|
|
is_train=True, |
|
|
inputsize=cfg.MODEL.IMAGE_SIZE, |
|
|
transform=transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
normalize, |
|
|
]) |
|
|
) |
|
|
print(f" 数据集类型: {cfg.DATASET.DATASET}") |
|
|
print(f" 数据集大小: {len(train_dataset)}") |
|
|
|
|
|
|
|
|
if hasattr(train_dataset, 'names'): |
|
|
print(f" 数据集类别: {train_dataset.names}") |
|
|
print(f" 类别数量: {len(train_dataset.names)}") |
|
|
else: |
|
|
print(" 数据集没有 names 属性") |
|
|
|
|
|
if hasattr(train_dataset, "names"): |
|
|
print(f" 数据集类别数量: {len(train_dataset.names)}") |
|
|
else: |
|
|
print(" 数据集不包含 names 属性,无法统计类别数量。") |
|
|
|
|
|
|
|
|
print("\n2. 创建 DataLoader...") |
|
|
train_loader = DataLoaderX( |
|
|
train_dataset, |
|
|
batch_size=4, |
|
|
shuffle=False, |
|
|
num_workers=0, |
|
|
pin_memory=False, |
|
|
collate_fn=dataset.AutoDriveDataset.collate_fn |
|
|
) |
|
|
print(f" Batch size: ") |
|
|
print(f" Total batches: {len(train_loader)}") |
|
|
|
|
|
|
|
|
print("\n3. 加载第一个 batch...") |
|
|
for i, (input, target, paths, shapes) in enumerate(train_loader): |
|
|
print("\n" + "="*80) |
|
|
print(f"Batch {i} 数据格式分析:") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
print("\n[INPUT - 图像数据]") |
|
|
print(f" 类型: {type(input)}") |
|
|
print(f" 形状: {input.shape}") |
|
|
print(f" dtype: {input.dtype}") |
|
|
print(f" 值范围: [{input.min():.3f}, {input.max():.3f}]") |
|
|
|
|
|
|
|
|
print("\n[TARGET - 标注数据]") |
|
|
print(f" 类型: {type(target)}") |
|
|
print(f" 长度: {len(target)} (包含 3 个元素: det, da_seg, ll_seg)") |
|
|
|
|
|
|
|
|
print(f"\n target[0] - 检测标签 (Detection Labels):") |
|
|
print(f" 类型: {type(target[0])}") |
|
|
print(f" 形状: {target[0].shape}") |
|
|
print(f" dtype: {target[0].dtype}") |
|
|
print(f" 说明: [N, 6] 其中 N 是所有图片的目标总数,6 维度为:") |
|
|
print(f" [image_idx, class_id, x_center, y_center, width, height]") |
|
|
|
|
|
|
|
|
if target[0].shape[0] > 0: |
|
|
print(f"\n 前 5 个目标样本:") |
|
|
print(f" {'索引':<6} {'img_idx':<10} {'class_id':<10} {'x_center':<12} {'y_center':<12} {'width':<12} {'height':<12}") |
|
|
print(f" {'-'*76}") |
|
|
for idx in range(min(5, target[0].shape[0])): |
|
|
obj = target[0][idx] |
|
|
print(f" {idx:<6} {obj[0].item():<10.0f} {obj[1].item():<10.0f} {obj[2].item():<12.6f} {obj[3].item():<12.6f} {obj[4].item():<12.6f} {obj[5].item():<12.6f}") |
|
|
|
|
|
|
|
|
print(f"\n 验证坐标是否归一化到 [0, 1]:") |
|
|
xywh_data = target[0][:, 2:] |
|
|
print(f" x_center 范围: [{xywh_data[:, 0].min():.6f}, {xywh_data[:, 0].max():.6f}]") |
|
|
print(f" y_center 范围: [{xywh_data[:, 1].min():.6f}, {xywh_data[:, 1].max():.6f}]") |
|
|
print(f" width 范围: [{xywh_data[:, 2].min():.6f}, {xywh_data[:, 2].max():.6f}]") |
|
|
print(f" height 范围: [{xywh_data[:, 3].min():.6f}, {xywh_data[:, 3].max():.6f}]") |
|
|
|
|
|
|
|
|
is_normalized = (xywh_data >= 0).all() and (xywh_data <= 1).all() |
|
|
if is_normalized: |
|
|
print(f" ✓ 坐标已归一化到 [0, 1]") |
|
|
else: |
|
|
print(f" ✗ 警告: 坐标未完全归一化!") |
|
|
|
|
|
|
|
|
print(f"\n 每张图片的目标数量:") |
|
|
for img_idx in range(input.shape[0]): |
|
|
count = (target[0][:, 0] == img_idx).sum().item() |
|
|
print(f" 图片 {img_idx}: {count} 个目标") |
|
|
else: |
|
|
print(f" (该 batch 没有检测目标)") |
|
|
|
|
|
|
|
|
print(f"\n target[1] - 驾驶区域分割标签 (Drivable Area Segmentation):") |
|
|
print(f" 类型: {type(target[1])}") |
|
|
print(f" 形状: {target[1].shape}") |
|
|
print(f" dtype: {target[1].dtype}") |
|
|
print(f" 值范围: [{target[1].min():.3f}, {target[1].max():.3f}]") |
|
|
print(f" 说明: [batch_size, num_classes, H, W]") |
|
|
|
|
|
|
|
|
print(f"\n target[2] - 车道线分割标签 (Lane Line Segmentation):") |
|
|
print(f" 类型: {type(target[2])}") |
|
|
print(f" 形状: {target[2].shape}") |
|
|
print(f" dtype: {target[2].dtype}") |
|
|
print(f" 值范围: [{target[2].min():.3f}, {target[2].max():.3f}]") |
|
|
print(f" 说明: [batch_size, num_classes, H, W]") |
|
|
|
|
|
|
|
|
print(f"\n[PATHS - 图像路径]") |
|
|
print(f" 类型: {type(paths)}") |
|
|
print(f" 长度: {len(paths)}") |
|
|
if len(paths) > 0: |
|
|
print(f" 示例路径:") |
|
|
for idx, path in enumerate(paths): |
|
|
print(f" [{idx}] {path}") |
|
|
|
|
|
|
|
|
print(f"\n[SHAPES - 图像尺寸信息]") |
|
|
print(f" 类型: {type(shapes)}") |
|
|
print(f" 长度: {len(shapes)}") |
|
|
if len(shapes) > 0: |
|
|
print(f" 示例 (原始尺寸, ((缩放比例), (padding))):") |
|
|
for idx, shape in enumerate(shapes[:2]): |
|
|
print(f" [{idx}] {shape}") |
|
|
|
|
|
print("\n" + "="*80) |
|
|
print("验证结论:") |
|
|
print("="*80) |
|
|
print("✓ target[0] 格式为: [image_idx, class_id, x_center, y_center, width, height]") |
|
|
print("✓ xywh 坐标已归一化到 [0, 1]") |
|
|
print("✓ image_idx 用于区分 batch 中不同图片的目标") |
|
|
print("✓ class_id 表示目标类别") |
|
|
print("="*80) |
|
|
|
|
|
|
|
|
break |
|
|
|
|
|
print("\n验证完成!") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
check_dataset_format() |
|
|
|