0%

目标检测篇:自定义 MMDetection 的 Pipeline

经过几天连续的开坑和读源代码,对 mmdetection 的配置流程了解的差不多了。考虑一个实例应用,尝试着将 FGSM 攻击算法的制作的对抗样本植入目标检测中,企图增加网络的鲁棒性,也就是一个自定义输出处理 Pipeline 的实际流程。之后会尝试自定义损失函数,支持简单的对抗训练,如 MART 算法等 1

对抗攻击之植入 mmdetection

众所周知,攻击算法基于原始样本制作一种含有梯度的扰动信息,将扰动信息叠加至原始样本,就得到了对抗样本。因此,对抗样本应该添加在数据处理的 pipeline 中,而不是网络层。首先实现最简单的对抗训练:

\begin{equation}
\min_\theta \biggl( \max \Bigl( l \bigl(T_\theta(x’), y_i \bigl) \Bigr) \biggr)
\end{equation}

内部的对抗样本 $x’$ 企图最大化模型的分类误差,外部的训练企图最小化对抗样本带来的误差。查看训练数据的 pipeline,我们发现是在最后一步转为 tensor,而攻击算法是在tensor上进行操作的。所以,对抗攻击应该加在最后一步之后或之前。我们查看返回 tensor 的源代码。此部分代码在 mmdet/datasets/pipelines 目录下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def __call__(self, results):
"""Call function to transform and format common fields in results.
Args:
results (dict): Result dict contains the data to convert.
Returns:
dict: The result dict contains the data that is formatted with \
default bundle.
"""

if 'img' in results:
img = results['img']
# add default meta keys
results = self._add_default_meta_keys(results)
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
results['img'] = DC(to_tensor(img), stack=True)
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
if key not in results:
continue
results[key] = DC(to_tensor(results[key]))
if 'gt_masks' in results:
results['gt_masks'] = DC(results['gt_masks'], cpu_only=True)
if 'gt_semantic_seg' in results:
results['gt_semantic_seg'] = DC(
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
return results

而对抗攻击只关注图像与标签,所以,只关注图像与标签,而这部分一目了然。所以,开始写自己的对抗样本:

  • 黑盒攻击,基于 resnet18 制作攻击样本,插入到 pipeline 中
  • 我暂时的想法是只攻击目标区域,所以筛选出目标区域的位置与标签,按照梯度上升的方向制作对抗样本。

自定义 backbone

注意这步不做也行。可以直接模改原有的 backbone,只是不推荐。而为了方便调试,我们使用自己定义的网络,网络只打印输入观察数据是否修改成功。因此,自定义的网络如下,并放在 mmdet/models/backbones/mybackbone下,

1
2
3
4
5
6
7
8
9
10
import torch.nn as nn
from ..builder import BACKBONES

@BACKBONES.register_module()
class mynet(nn.Module):
def __init__(self):
pass

def forward(self, x):
print(x)

而后为了导入模块,在 mmdet/models/backbones/__init__.py 中添加 from .mybackbone import mynet。为了适配 faster rcnn 的其它参数,我这里的 backbone 直接使用了 ResNet 里面的内容,只是打印了 x,这里是为了方便观察结果。这里需要注意,那些 super 初始化也要改,防止重名,因为不具备普适性,所以这里就不演示了。而后配置文件中指定以下就可以自由使用,我的配置文件是 adv_faster_rcnn.py

1
2
3
4
5
6
7
8
9
model = dict(
...
backbone=dict(
type='mynet',
arg1=xxx,
arg2=xxx),
...
# 测试用,跑一轮就停
runner = dict(type='EpochBasedRunner', max_epochs=1)

而后执行一下,成功的看到输入被打印了出来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from mmdet.apis import init_detector, inference_detector
# 目标检测配置文件
config_file = 'mmdetection/configs/faster_rcnn/adv_faster_rcnn.py'
# 训练模型
checkpoint_file = 'mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'

# 配置模型
model = init_detector(config=config_file,
checkpoint=checkpoint_file,
device='cuda:0')

img = 'a.png'
# 推理实际调用语句
# results = model(return_loss=False, rescale=True, **data)
result = inference_detector(model=model, imgs=img)

对抗样本代码思想

首先在 pipeline 文件夹下创建自己的 my_pipeline,而后定义自己的处理流程,比如我就看看处理的数据都有哪些:

1
2
3
4
5
6
7
8
9
10
from mmdet.datasets import PIPELINES

@PIPELINES.register_module()
class MyTransform:

def __call__(self, results):
print(type(results))
for key in results:
print(key, type(results[key]))
return results

最后在 __init__.py 中导入自己的东西,from .my_pipeline import MyTransform。综上发现,在 DefaultFormatBundleCollect 处理后,数据类型不是我能驾驭的常见数据类型: mmcv.parallel.data_container.DataContainer

所以选择在 Normalize 后加入对抗攻击,因为常见的攻击算法也都是在标准化之后,此外那里还是 ndarray 等常见类型,我能把握住这些基本的类似那个。Normalize 后的数据类型:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
img_prefix      <class 'str'>
seg_prefix <class 'NoneType'>
proposal_file <class 'NoneType'>
bbox_fields <class 'list'>
mask_fields <class 'list'>
seg_fields <class 'list'>
filename <class 'str'>
ori_filename <class 'str'>
img <class 'numpy.ndarray'>
img_shape <class 'tuple'>
ori_shape <class 'tuple'>
img_fields <class 'list'>
gt_bboxes <class 'numpy.ndarray'>
gt_bboxes_ignore <class 'numpy.ndarray'>
gt_labels <class 'numpy.ndarray'>
flip <class 'bool'>
flip_direction <class 'NoneType'>
img_norm_cfg <class 'dict'>

我利用的信息只有 imggt_labelsgt_bboxes。其实我也不知道以上字段是啥,源码和文档的信息很少,所以我只能自己都打印了一下。所以此时的任务就是,按照 gt_bboxes,截取 img,根据 gt_labels,制作对抗样本。但对抗样本返回的是 tensor,所以最后要转回到 numpy,在覆盖原来的数据。而 PGD、CW 等攻击算法过程会很慢,所以选用攻击强度大且快捷的 FGSM 单步算法。

而为了防止对抗样本带来的过大负担,所以添加了两个额外参数 $p_1$ 和 $p_2$。

  • 如果概率小于 $p_1$,在目标区域叠加高斯噪音
  • 如果概率在 $p_1$ 到 $p_2$ 之间,在目标区域制作对抗样本
  • 如果概率大于 $p_2$,使用原图,不做任何处理

综上,此时的配置文件应该为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
adv_para = dict(mu=0, std=0.1, epsilon=0.1, pro1=0.3, pro2=0.6, adv='fgsm')

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.0001),
dict(type='Normalize', **img_norm_cfg_trian),
# 自己的 pipeline
dict(type='advTransform', **adv_para),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]

开始制作

为了使代码易于维护和扩展,我尽力使额外添加的程序符合设计模式,对扩展开放,对修改封闭,针对接口编程。

  • mmdet/datasets/pipelines/ 目录下增加 adv_example pipeline,生成对抗样本。
  • 至于攻击算法,依据设计模式,应使用额外的类来实现。位于 datasets/ 目录下,命名为 attack 文件夹。

对抗样本 pipeline

这里有几个点需要注意下:

  • 通过获得的 results,图像的维度是:高、宽、通道
  • bboxes 中的信息是 x1,x2,y1,y2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from mmdet.datasets import PIPELINES
# 导入攻击算法
from mmdet.datasets import attack
import random
import numpy as np


@PIPELINES.register_module()
class advTransform:

def __init__(self,
mu=0,
std=0.1,
epsilon=0.1,
pro1=0.3,
pro2=0.6,
adv='fgsm'):
# 高斯噪音的均值和方差
if isinstance(mu, int) or isinstance(mu, float):
self.mu = mu
else:
self.mu = 0

if isinstance(mu, int) or isinstance(mu, float):
self.std = std
else:
self.std = 0.1

# 对抗扰动值
if isinstance(mu, int) or isinstance(mu, float):
self.epsilon = epsilon
else:
self.epsilon = 0.5

# 概率 p1 和 p2
if isinstance(pro1, float):
self.pro1 = pro1
else:
self.pro1 = 0.3

if isinstance(pro2, float):
self.pro2 = pro2
else:
self.pro2 = 0.6

# 使用的攻击算法
if isinstance(adv, str):
self.adv = adv
else:
self.adv = 'fgsm'

# 产生一个随机数
# 如果位于区间 [0, pro1), 目标区域添加噪音
# 如果位于区间 [pro1, pro2), 目标区域用 adv 算法攻击
# 如果区间位于 [pro2, 1) 不做任何操作,返回原图

assert 0 < pro1 < pro2 < 1

self.pro1 = pro1
self.pro2 = pro2

self.rand_ = random.random()

# 复写这个方法
def __call__(self, results):

# print(len(results['gt_bboxes']))
# print(len(results['gt_labels']))

# 返回原图
if self.rand_ > self.pro2:
return results

# 目标区域叠加高斯噪音
elif self.rand_ < self.pro1:
# 可能有好几个盒子
bboxes = results['gt_bboxes']
img = results['img']

for box in bboxes:
box = box.tolist()
box = [int(i) for i in box]
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
noise = np.random.normal(
self.mu, self.std, size=(y2 - y1, x2 - x1, 3))
img[y1:y2, x1:x2] += noise

results['img'] = img
return results

# 对抗攻击
else:
# labels = results['ann_info']['labels']
# bboxes = results['ann_info']['bboxes']
labels = results['gt_labels']
bboxes = results['gt_bboxes']
img = results['img']
# 针对接口编程,只需要给攻击算法提供 图像、位置、标签和扰动值
img = attack.fgsm.fgsm_attack(img, bboxes, labels, self.epsilon)
results['img'] = img
return results

攻击算法

攻击算法位于 datasets/attack/ 文件夹下。此外,为了导入包,需要添加 __init__.py,并按以下格式添加内容:

1
2
3
from .fgsm import fgsm_attack

__all__ = ['fgsm_attack']

同样这里也有一些需要注意的地方:

  • 针对接口编程,只需要给攻击算法提供图像、目标区域、标签和扰动值,至于内部自己如何实现,不重要
  • 攻击算法返回的就是图像,内部如何实现,不重要,减少两个模块的耦合

fgsm.py 的内容为:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def fgsm_attack(img, bboxes, labels, epsilon):
# data_grad 转 tensor
import torch
import torch.nn.functional as F
model = torch.hub.load(
'pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
model.eval()

# 增加 batch 维度,然后 channel first
tmp_img = torch.from_numpy(img).clone().unsqueeze(dim=0).transpose(
1, 3).transpose(2, 3)

# tenor 转 numpy
for box, target in zip(bboxes, labels):
# numpy 2 list
box = box.tolist()
box = [int(i) for i in box]
# print(box, target)
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]

# deep copy
input_ = tmp_img[:, :, y1:y2, x1:x2].clone()
input_.requires_grad = True
label = torch.tensor([int(target)])

output = model(input_)

loss = F.nll_loss(output, label)
# model.zero_grad()
loss.backward()
sign_data = input_.grad.data.sign()
# 脱离计算图
perturbed_image = input_.detach() + epsilon * sign_data
# 指定区域生成对抗样本
tmp_img[:, :, y1:y2, x1:x2] += perturbed_image

# 删除 batch 维度
tmp_img = tmp_img.squeeze().detach().cpu().numpy().transpose(1, 2, 0)

return tmp_img

踩坑记录

  • 也许你看过一些对抗攻击的算法,知道最后的对抗样本应该 torch.clamp 到 $[0,1]$。可那是论文里用的玩具数据集才会做的事情,那些数据集知道均值和方差。现实世界的真实数据,又怎么会知道均值和方差,怎么去标准化到 $[0,1]$ 之间呢?
  • 我遇到了一个佷头疼的 bug,是多线程导致的反向传播异常,大概错误信息是:terminate called after throwing an instance of ‘c10::Error what(): CUDA error: initialization error。网上翻阅了无数 bug issue,才找到一篇有用的,也不知道那些 github 仓库的作者咋想的,问题还没解决就 close 掉,冲业绩还是图仓库 bug 少?这里把每个 GPU 的线程数量设为 0 就可以了 2 。而至于写多线程下的梯度反向传播,我貌似还没这个本事。但是,我服务器上没发现有这个 bug,应该是版本问题。

程序

官方的程序如何使用,我的就怎么使用,毕竟是直接 fork 过来的,从安装一步步来就行。此外,为了方便使用,我把配置文件放到了 configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py 文件中,可以进去参观下。

https://github.com/muyuuuu/mmdetection

此外,我在 colab 也创建了一份能用的,在不会用就真没救了。

https://github.com/muyuuuu/open-mmlab-colab/blob/main/Detection/mmdet_adv.ipynb

reference


  1. 1.https://openreview.net/forum?id=rklOg6EFwS
  2. 2.https://github.com/pytorch/pytorch/issues/1355#issuecomment-299094286
感谢上学期间打赏我的朋友们。赛博乞讨:我,秦始皇,打钱。

欢迎订阅我的文章