《计算机视觉之PyTorch数字图像处理》--目标检测之Torchvision预训练模型实现
[复制链接]
目标检测及检测图像中目标情况,其既要进行分类,还需要估计目标在图像中的位置。本篇就目标检测相关知识进行展开学习与实践。
目标检测常用术语
感受野(Receptive Field):图像经过不同卷积核和级连的卷积层,输出的特征图中的每个元素,都对应于原图中的特定区域。
锚框(Anchor):在目标检测中,目标的定位使用具有明确坐标的矩形框,用于标示目标的分布区域。
交并比(IoU,Intersection over Union):在目标检测中,指两个锚框,相交的面积比相并的面积的值,通常用于描述锚框的定位精度。
AP50、AP75、AP95和mAP(mean Average Precision):在目标检测中,以上APxx指标都用于描述模型的检测性能。
Torchvision中的预训练模型
在torchvision中提供了Faster R-CNN,FCOS,RetinaNet,SSD和SSDlite等多个成熟的目标检测模型,每个模型包含一个或多个改进版本,在速度和精度上有所区别,并且所有的模型都提供在COCO数据集上的训练的参数,可以在创建时加载,可直接用于目标检测。所有的目标检测模型位于torchvision.models.detection子包中。
Torchvision中的预训练模型都带有相应的预训练参数,可以在创建时进行加载,直接用于图像中目标的检测。
Torchvision中的预训练模型 直接观察输出结果并不直观,需要进行一些后处理来更好的展示。后处理主要包括:
①过滤掉一些‘score’置信度较底的检测结果,一般取0.5或0.7作为阈值;
②将检测结果中的类别转换为类别标签,方便观看和理解;
③将检测结果绘制到图像上以可视化的方式进行展示。
下面是Torchvision中的预训练模型完整目标检测代码:
import torchvision.transforms.functional as F
from torchvision.models import detection
import numpy as np
from torchvision.io import read_image,ImageReadMode
import torch as tc
import visdom
from PIL import Image, ImageDraw,ImageFont
COCO_INSTANCE_CATEGORY_NAMES = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]
colormap=[[ 0, 0, 0], # 0=background
[128, 0, 0], # 1=aeroplane
[ 0, 128, 0], #2=bicycle
[128, 128, 0], #3=bird
[ 0, 0, 128], #4=boat
[128, 0, 128], #5=bottle
[ 0, 128, 128], # 6=bus
[128, 128, 128], #7=car
[ 64, 0, 0], #8=cat
[192, 0, 0], #9=chair
[ 64, 128, 0], #10=cow
[192, 128, 0], #11=dining table
[ 64, 0, 128], #12=dog
[192, 0, 128], #13=horse
[ 64, 128, 128], #14=motorbike
[192, 128, 128], #15=person
[ 0, 64, 0], #16=potted plant
[128, 64, 0], #17=sheep
[ 0, 192, 0], #18=sofa
[128, 192, 0], #19=train
[ 0, 64, 128], #20=tv/monitor
[128, 64, 128], #nodefined
[ 0, 192, 128],
[128, 192, 128],
[ 64, 64, 0],
[192, 64, 0],
[ 64, 192, 0],
[192, 192, 0],
[ 64, 64, 128],
[192, 64, 128],
[ 64, 192, 128],
[192, 192, 128],
[ 0, 0, 64],
[128, 0, 64],
[ 0, 128, 64],
[128, 128, 64],
[ 0, 0, 192],
[128, 0, 192],
[ 0, 128, 192],
[128, 128, 192],
[ 64, 0, 64],
[192, 0, 64],
[ 64, 128, 64],
[192, 128, 64],
[ 64, 0, 192],
[192, 0, 192],
[ 64, 128, 192],
[192, 128, 192],
[ 0, 64, 64],
[128, 64, 64],
[ 0, 192, 64],
[128, 192, 64],
[ 0, 64, 192],
[128, 64, 192],
[ 0, 192, 192],
[128, 192, 192],
[ 64, 64, 64],
[192, 64, 64],
[ 64, 192, 64],
[192, 192, 64],
[ 64, 64, 192],
[192, 64, 192],
[ 64, 192, 192],
[192, 192, 192],
[ 32, 0, 0],
[160, 0, 0],
[ 32, 128, 0],
[160, 128, 0],
[ 32, 0, 128],
[160, 0, 128],
[ 32, 128, 128],
[160, 128, 128],
[ 96, 0, 0],
[224, 0, 0],
[ 96, 128, 0],
[224, 128, 0],
[ 96, 0, 128],
[224, 0, 128],
[ 96, 128, 128],
[224, 128, 128],
[ 32, 64, 0],
[160, 64, 0],
[ 32, 192, 0],
[160, 192, 0],
[ 32, 64, 128],
[160, 64, 128],
[ 32, 192, 128],
[160, 192, 128],
[ 96, 64, 0],
[224, 64, 0],
[ 96, 192, 0],
[224, 192, 0],
[ 96, 64, 128],
[224, 64, 128],
[ 96, 192, 128],
[224, 192, 128],
[ 32, 0, 64],
[160, 0, 64],
[ 32, 128, 64],
[160, 128, 64],
[ 32, 0, 192],
[160, 0, 192],
[ 32, 128, 192],
[160, 128, 192],
[ 96, 0, 64],
[224, 0, 64],
[ 96, 128, 64],
[224, 128, 64],
[ 96, 0, 192],
[224, 0, 192],
[ 96, 128, 192],
[224, 128, 192],
[ 32, 64, 64],
[160, 64, 64],
[ 32, 192, 64],
[160, 192, 64],
[ 32, 64, 192],
[160, 64, 192],
[ 32, 192, 192],
[160, 192, 192],
[ 96, 64, 64],
[224, 64, 64],
[ 96, 192, 64],
[224, 192, 64],
[ 96, 64, 192],
[224, 64, 192],
[ 96, 192, 192],
[224, 192, 192],
[ 0, 32, 0],
[128, 32, 0],
[ 0, 160, 0],
[128, 160, 0],
[ 0, 32, 128],
[128, 32, 128],
[ 0, 160, 128],
[128, 160, 128],
[ 64, 32, 0],
[192, 32, 0],
[ 64, 160, 0],
[192, 160, 0],
[ 64, 32, 128],
[192, 32, 128],
[ 64, 160, 128],
[192, 160, 128],
[ 0, 96, 0],
[128, 96, 0],
[ 0, 224, 0],
[128, 224, 0],
[ 0, 96, 128],
[128, 96, 128],
[ 0, 224, 128],
[128, 224, 128],
[ 64, 96, 0],
[192, 96, 0],
[ 64, 224, 0],
[192, 224, 0],
[ 64, 96, 128],
[192, 96, 128],
[ 64, 224, 128],
[192, 224, 128],
[ 0, 32, 64],
[128, 32, 64],
[ 0, 160, 64],
[128, 160, 64],
[ 0, 32, 192],
[128, 32, 192],
[ 0, 160, 192],
[128, 160, 192],
[ 64, 32, 64],
[192, 32, 64],
[ 64, 160, 64],
[192, 160, 64],
[ 64, 32, 192],
[192, 32, 192],
[ 64, 160, 192],
[192, 160, 192],
[ 0, 96, 64],
[128, 96, 64],
[ 0, 224, 64],
[128, 224, 64],
[ 0, 96, 192],
[128, 96, 192],
[ 0, 224, 192],
[128, 224, 192],
[ 64, 96, 64],
[192, 96, 64],
[ 64, 224, 64],
[192, 224, 64],
[ 64, 96, 192],
[192, 96, 192],
[ 64, 224, 192],
[192, 224, 192],
[ 32, 32, 0],
[160, 32, 0],
[ 32, 160, 0],
[160, 160, 0],
[ 32, 32, 128],
[160, 32, 128],
[ 32, 160, 128],
[160, 160, 128],
[ 96, 32, 0],
[224, 32, 0],
[ 96, 160, 0],
[224, 160, 0],
[ 96, 32, 128],
[224, 32, 128],
[ 96, 160, 128],
[224, 160, 128],
[ 32, 96, 0],
[160, 96, 0],
[ 32, 224, 0],
[160, 224, 0],
[ 32, 96, 128],
[160, 96, 128],
[ 32, 224, 128],
[160, 224, 128],
[ 96, 96, 0],
[224, 96, 0],
[ 96, 224, 0],
[224, 224, 0],
[ 96, 96, 128],
[224, 96, 128],
[ 96, 224, 128],
[224, 224, 128],
[ 32, 32, 64],
[160, 32, 64],
[ 32, 160, 64],
[160, 160, 64],
[ 32, 32, 192],
[160, 32, 192],
[ 32, 160, 192],
[160, 160, 192],
[ 96, 32, 64],
[224, 32, 64],
[ 96, 160, 64],
[224, 160, 64],
[ 96, 32, 192],
[224, 32, 192],
[ 96, 160, 192],
[224, 160, 192],
[ 32, 96, 64],
[160, 96, 64],
[ 32, 224, 64],
[160, 224, 64],
[ 32, 96, 192],
[160, 96, 192],
[ 32, 224, 192],
[160, 224, 192],
[ 96, 96, 64],
[224, 96, 64],
[ 96, 224, 64],
[224, 224, 64],
[ 96, 96, 192],
[224, 96, 192],
[ 96, 224, 192],
[224, 224, 192]]
#colormap=colormap=[[0,0,0],[128,0,0],[0,128,0],[128,128,0],[0,0,128],[128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],[64,128,0],[192,128,0],[64,0,128],[192,0,128],[64,128,128],[192,128,128],[0,64,0],[128,64,0],[0,192,0],[128,192,0],[0,64,128],[128,64,128],[0,192,128],[128,192,128],[64,64,0],[192,64,0],[64,192,0],[192,192,0],[64,64,128],[192,64,128],[64,192,128],[192,192,128],[0,0,64],[128,0,64],[0,128,64],[128,128,64],[0,0,192],[128,0,192],[0,128,192],[128,128,192],[64,0,64],[192,0,64],[64,128,64],[192,128,64],[64,0,192],[192,0,192],[64,128,192],[192,128,192],[0,64,64],[128,64,64],[0,192,64],[128,192,64],[0,64,192],[128,64,192],[0,192,192],[128,192,192],[64,64,64],[192,64,64],[64,192,64],[192,192,64],[64,64,192],[192,64,192],[64,192,192],[192,192,192],[32,0,0],[160,0,0],[32,128,0],[160,128,0],[32,0,128],[160,0,128],[32,128,128],[160,128,128],[96,0,0],[224,0,0],[96,128,0],[224,128,0],[96,0,128],[224,0,128],[96,128,128],[224,128,128],[32,64,0],[160,64,0],[32,192,0],[160,192,0],[32,64,128],[160,64,128],[32,192,128],[160,192,128],[96,64,0],[224,64,0],[96,192,0],[224,192,0],[96,64,128],[224,64,128],[96,192,128],[224,192,128],[32,0,64],[160,0,64],[32,128,64],[160,128,64],[32,0,192],[160,0,192],[32,128,192],[160,128,192],[96,0,64],[224,0,64],[96,128,64],[224,128,64],[96,0,192],[224,0,192],[96,128,192],[224,128,192],[32,64,64],[160,64,64],[32,192,64],[160,192,64],[32,64,192],[160,64,192],[32,192,192],[160,192,192],[96,64,64],[224,64,64],[96,192,64],[224,192,64],[96,64,192],[224,64,192],[96,192,192],[224,192,192],[0,32,0],[128,32,0],[0,160,0],[128,160,0],[0,32,128],[128,32,128],[0,160,128],[128,160,128],[64,32,0],[192,32,0],[64,160,0],[192,160,0],[64,32,128],[192,32,128],[64,160,128],[192,160,128],[0,96,0],[128,96,0],[0,224,0],[128,224,0],[0,96,128],[128,96,128],[0,224,128],[128,224,128],[64,96,0],[192,96,0],[64,224,0],[192,224,0],[64,96,128],[192,96,128],[64,224,128],[192,224,128],[0,32,64],[128,32,64],[0,160,64],[128,160,64],[0,32,192],[128,32,192],[0,160,192],[128,160,192],[64,32,64],[192,32,64],[64,160,64],[192,160,64],[64,32,192],[192,32,192],[64,160,192],[192,160,192],[0,96,64],[128,96,64],[0,224,64],[128,224,64],[0,96,192],[128,96,192],[0,224,192],[128,224,192],[64,96,64],[192,96,64],[64,224,64],[192,224,64],[64,96,192],[192,96,192],[64,224,192],[192,224,192],[32,32,0],[160,32,0],[32,160,0],[160,160,0],[32,32,128],[160,32,128],[32,160,128],[160,160,128],[96,32,0],[224,32,0],[96,160,0],[224,160,0],[96,32,128],[224,32,128],[96,160,128],[224,160,128],[32,96,0],[160,96,0],[32,224,0],[160,224,0],[32,96,128],[160,96,128],[32,224,128],[160,224,128],[96,96,0],[224,96,0],[96,224,0],[224,224,0],[96,96,128],[224,96,128],[96,224,128],[224,224,128],[32,32,64],[160,32,64],[32,160,64],[160,160,64],[32,32,192],[160,32,192],[32,160,192],[160,160,192],[96,32,64],[224,32,64],[96,160,64],[224,160,64],[96,32,192],[224,32,192],[96,160,192],[224,160,192],[32,96,64],[160,96,64],[32,224,64],[160,224,64],[32,96,192],[160,96,192],[32,224,192],[160,224,192],[96,96,64],[224,96,64],[96,224,64],[224,224,64],[96,96,192],[224,96,192],[96,224,192],[224,224,192]]
def draw_bounding_boxes(imgchwtensor, boxestensor, labels=['Object'], colors= ['red'], fill= False, width=1, font= 'simhei', font_size= 12):
#imgchwtensor is a cxhxw tensor represent a image , boxestensor is nx4 represent the box on image lx,ly,rx,ry
x = (imgchwtensor*(255 if imgchwtensor.max()<=1 else 1)).byte()
image = Image.fromarray(x.permute(1,2,0).numpy())
draw = ImageDraw.Draw(image)
try:
font=ImageFont.truetype(font,font_size)
except Exception:
font = ImageFont.load_default()
boxes=boxestensor
numofbox=len(boxes)
colors=colors*numofbox if len(colors)==1 else colors
labels=labels*numofbox if len(labels)==1 else labels
textcolor='white'
boxes=boxes.numpy()
for i in range(numofbox):
box=boxes[i]
print(box)
draw.rectangle(xy=box, fill=None ,outline=colors[i],width=width) # box
label=labels[i]
if label != '':
draw.rectangle(box, width=width, outline=colors[i]) # box
_,_, w, h = font.getbbox(label) # text width, height
outside = box[1] - h >= 0 # label fits outside box
draw.rectangle(
(box[0], box[1] - h if outside else box[1], box[0] + w + 1,
box[1] + 1 if outside else box[1] + h + 1),
fill=colors[i],
)
draw.text((box[0], box[1] - h if outside else box[1]), label, fill=textcolor, font=font)
return np.asarray(image).transpose((2,0,1))
#weights=FasterRCNN_ResNet50_FPN_Weights.DEFAULT
#model=fasterrcnn_resnet50_fpn(weights=weights)
#model = retinanet_resnet50_fpn(weights=RetinaNet_ResNet50_FPN_Weights.DEFAULT)
#ssdlite
model=detection.ssdlite320_mobilenet_v3_large(weights=detection.SSDLite320_MobileNet_V3_Large_Weights)
imgpath='PennPed00036.png'
imgorg=read_image(imgpath,ImageReadMode.RGB)
img = F.convert_image_dtype(imgorg, tc.float)
model=model.eval()
with tc.no_grad():
detection_outputs = model([img,])
print(detection_outputs)
res=detection_outputs[0]
boxes=res['boxes']
scores=res['scores']
labels=res['labels']
#阈值过滤
threhold=0.7 #保留类别概率大于0.7的检测结果
mask=scores>threhold
scores=scores[mask]
labels=labels[mask]
boxes=boxes[mask]
labelnames=[ COCO_INSTANCE_CATEGORY_NAMES[label]+f'({scores[idx]:.2f})' for idx,label in enumerate(labels) ]
colors=[ tuple(colormap[i]) for i in labels]
img=draw_bounding_boxes(imgorg,boxes,labels=labelnames,colors=colors,width=3,font_size=30,font='simhei')
vis=visdom.Visdom()
print(img.shape)
vis.image(img)
COCO_INSTANCE_CATEGORY_NAMES定义了具有含义的类别标签。根据预测类别的值,使用了与VOC图像分割相关的颜色映射方法,以便在绘制锚框时,使不同的类具有不同颜色的锚框。使用了draw_bounding_boxes()函数把准备好的锚框,类别标签等信息按照指定位置绘制到图像中。
下面展示了上述Torchvision中的预训练模型目标检测结果,可以看到两个人、一辆汽车、一辆公交车被检出。
上述实现了使用Torchvision预训练模型实现目标检测。后面继续阅读尝试其他类别的目标检测方法。
|