目标检测
识别一张图片中的所有物体,比如多个狗或者猫,并且还要用方框标注出每个物体的位置
一个边缘框可以用四个数字来定义:
左上x,右上y, 右下x和右下y (注意,一个一个图片的左上角为原点)
左上x,右上y,宽和高
目标检测数据集:
一个图片中可能有多个类,所以一般用CSV文件来存
一行表示一个物体
所以一张图片可能需要多行来描述
每一行的数据包括:
图片名文件,物体类别和边缘框
常用的目标检测数据集:
COCO数据集 :cocodataset.org 包含了80个常见类别,包含了大概330k图片,有1.5M物体
1 边缘框的实现
读入一张图片
%matplotlib inline
import torch
from d2l import torch as d2l
d2l.set_figsize()
img = d2l.plt.imread('../Jupyter/img/catdog.jpg')
d2l.plt.imshow(img);
框坐标的转换:
# 定义在这两种表示之间进行转换的函数
def box_corner_to_center(boxes):
"""从(左上,右下)转换到(中间,宽度,高度)"""
x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
cx = (x1 + x2) / 2 # 得到中间点的坐标
cy = (y1 + y2) / 2
w = x2 - x1
h = y2 - y1
boxes = torch.stack((cx, cy, w, h), axis = -1)
return boxes
def box_center_to_corner(boxes):
"""从(中间,宽度,高度)转换到(左上,右下)"""
# 图片左上角为零点,向下是y轴正方向
cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]
x1 = cx - 0.5 * w
y1 = cy - 0.5 * h
x2 = cx + 0.5 * w
y2 = cy + 0.5 * h
boxes = torch.stack((x1, y1, x2, y2), axis = -1)
return boxes
基于边缘框画出物体的位置:
# bbox是边界框的英文缩写
dog_bbox, cat_bbox = [60.0, 45.0, 378.0, 516.0], [400.0, 112.0, 655.0, 493.0]
boxes = torch.tensor((dog_bbox, cat_bbox))
box_center_to_corner(box_corner_to_center(boxes)) == boxes # 测试一下转换函数
def bbox_to_rect(bbox, color):
return d2l.plt.Rectangle(xy = (bbox[0], bbox[1]),
width = bbox[2]-bbox[0],
height = bbox[3] - bbox[1],
edgecolor = color,
linewidth = 2,
fill = False)
fig = d2l.plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(dog_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(cat_bbox, 'red'));
2 目标检测数据集(手动构造了一个小的数据集)
下载和读取数据集
#包含所有图像和CSV标签文件的香蕉检测数据集可以直接从互联网下载。
%matplotlib inline
import os
import pandas as pd
import torch
import torchvision
from d2l import torch as d2l
#@save
d2l.DATA_HUB['banana-detection'] = (
d2l.DATA_URL + 'banana-detection.zip',
'5de26c8fce5ccdea9f91267273464dc968d20d72')
# 读取数据集(这里的方法不常用,将所有图片读到内存里面,因为图片比较少)
def read_data_bananas(is_train=True):
"""读取香蕉检测数据集中的图像和标签"""
# 下载并解压数据集,返回数据集根目录
data_dir = d2l.download_extract('banana-detection')
# 根据is_train参数选择训练集或验证集的label.csv文件
csv_fname = os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'label.csv')
csv_data = pd.read_csv(csv_fname)
csv_data = csv_data.set_index('img_name')
images, targets = [], []
for img_name, target in csv_data.iterrows():
# 从images子目录中读取每张图片,使用torchvision.io.read_image加载为张量。
images.append(torchvision.io.read_image(
os.path.join(data_dir, 'bananas_train' if is_train else
'bananas_val', 'images', f'{img_name}')
))
# 这里的target包含(类别,左上角x,左上角y,右下角x,右下角y),
# 其中所有图像都具有相同的香蕉类(索引为0)
targets.append(list(target))
return images, torch.tensor(targets).unsqueeze(1)/256
一个数据集自定义Dataset实例
class BananasDataset(torch.utils.data.Dataset):
"""一个用于加载香蕉检测数据集的自定义数据集"""
def __init__(self, is_train):
self.features, self.labels = read_data_bananas(is_train)
print('read ' + str(len(self.features)) + (f' training examples' if
is_train else f' validation examples'))
# 读取第i个样品
def __getitem__(self, idx):
return (self.features[idx].float(), self.labels[idx])
# 返回读取的数据有多长
def __len__(self):
return len(self.features)
基于构建的实例,加载数据集,构建为迭代器
def load_data_bananas(batch_size):
"""加载香蕉检测数据集"""
train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),
batch_size, shuffle=True)
val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),
batch_size)
return train_iter, val_iter
打印一下
batch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
batch[0].shape, batch[1].shape
# 数据集中只有一个香蕉,因此标签只有一个种类[32,1,5]
输出:
read 1000 training examples
read 100 validation examples
(torch.Size([32, 3, 256, 256]), torch.Size([32, 1, 5]))
显示一下边框:
imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = d2l.show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):
d2l.show_bboxes(ax, [label[0][1:5] * edge_size], colors=['w'])