目标检测

识别一张图片中的所有物体,比如多个狗或者猫,并且还要用方框标注出每个物体的位置

一个边缘框可以用四个数字来定义:
左上x,右上y, 右下x和右下y (注意,一个一个图片的左上角为原点)
左上x,右上y,宽和高

目标检测数据集:
一个图片中可能有多个类,所以一般用CSV文件来存
一行表示一个物体
所以一张图片可能需要多行来描述
每一行的数据包括:
图片名文件,物体类别和边缘框

常用的目标检测数据集:
COCO数据集 :cocodataset.org 包含了80个常见类别,包含了大概330k图片,有1.5M物体

1 边缘框的实现

读入一张图片

%matplotlib inline

import torch
from d2l import torch as d2l

d2l.set_figsize()
img = d2l.plt.imread('../Jupyter/img/catdog.jpg')
d2l.plt.imshow(img);

Image

框坐标的转换:

# 定义在这两种表示之间进行转换的函数
def box_corner_to_center(boxes):
    """从(左上,右下)转换到(中间,宽度,高度)"""
    x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    cx = (x1 + x2) / 2  # 得到中间点的坐标
    cy = (y1 + y2) / 2

    w = x2 - x1
    h = y2 - y1

    boxes = torch.stack((cx, cy, w, h), axis = -1)
    return boxes

def box_center_to_corner(boxes):
    """从(中间,宽度,高度)转换到(左上,右下)"""
    # 图片左上角为零点,向下是y轴正方向
    cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
    x1 = cx - 0.5 * w
    y1 = cy - 0.5 * h
    x2 = cx + 0.5 * w
    y2 = cy + 0.5 * h 

    boxes = torch.stack((x1, y1, x2, y2), axis = -1)
    return boxes

基于边缘框画出物体的位置:

# bbox是边界框的英文缩写
dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]
boxes = torch.tensor((dog_bbox, cat_bbox))
box_center_to_corner(box_corner_to_center(boxes)) == boxes   # 测试一下转换函数

def bbox_to_rect(bbox, color):
    return d2l.plt.Rectangle(xy = (bbox[0], bbox[1]),
                            width = bbox[2]-bbox[0],
                            height = bbox[3] - bbox[1],
                            edgecolor = color,
                            linewidth = 2,
                            fill = False)

fig = d2l.plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(dog_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(cat_bbox, 'red'));

Image

2 目标检测数据集(手动构造了一个小的数据集)

下载和读取数据集

#包含所有图像和CSV标签文件的香蕉检测数据集可以直接从互联网下载。
%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l

#@save
d2l.DATA_HUB['banana-detection'] = (
    d2l.DATA_URL + 'banana-detection.zip',
    '5de26c8fce5ccdea9f91267273464dc968d20d72')


# 读取数据集(这里的方法不常用,将所有图片读到内存里面,因为图片比较少)
def read_data_bananas(is_train=True):
    """读取香蕉检测数据集中的图像和标签"""
    # 下载并解压数据集,返回数据集根目录
    data_dir = d2l.download_extract('banana-detection') 
    
    # 根据is_train参数选择训练集或验证集的label.csv文件
    csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else 
                                   'bananas_val', 'label.csv') 
    
    csv_data = pd.read_csv(csv_fname)
    csv_data = csv_data.set_index('img_name')
    images, targets = [], []
    for img_name, target in csv_data.iterrows():
        # 从images子目录中读取每张图片,使用torchvision.io.read_image加载为张量。
        images.append(torchvision.io.read_image(
            os.path.join(data_dir, 'bananas_train' if is_train else
                        'bananas_val', 'images', f'{img_name}')
        ))
        # 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
        # 其中所有图像都具有相同的香蕉类(索引为0)
        targets.append(list(target))
    return images, torch.tensor(targets).unsqueeze(1)/256

一个数据集自定义Dataset实例

class BananasDataset(torch.utils.data.Dataset):
    """一个用于加载香蕉检测数据集的自定义数据集"""
    def __init__(self, is_train):
        self.features, self.labels = read_data_bananas(is_train)
        print('read ' + str(len(self.features)) + (f' training examples' if
                    is_train else f' validation examples'))

    # 读取第i个样品
    def __getitem__(self, idx):
        return (self.features[idx].float(), self.labels[idx])

    # 返回读取的数据有多长
    def __len__(self):
        return len(self.features)

基于构建的实例,加载数据集,构建为迭代器

def load_data_bananas(batch_size):
    """加载香蕉检测数据集"""
    train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
                batch_size, shuffle=True)
    val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
                batch_size)
    return train_iter, val_iter

打印一下

batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape
# 数据集中只有一个香蕉,因此标签只有一个种类[32,1,5]

输出:
read 1000 training examples
read 100 validation examples
(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))

显示一下边框:

imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
    d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])

Image