0%

移动端算法优化

移动端算法优化是个很庞大的话题。从计算机体系到指令,涉及到非常广而深的东西。本文尝试以常见的算法为例,阐述算法在单线程场景下的加速与优化,多线程是最后的收尾,没啥可说的。而至于具体的场景,如金字塔、滤波、降噪等,优化的思路都是相同的:减少 IO,一次 IO 完成尽可能多的计算。

本文会使用 Neon, OpenCL 来优化算法,如果有可能也会引入 DSP。本文持续更新,整理算法优化相关的经验。额外的,确保打开了 O3 编译选项,打开 release 模式等,否则会影响算法的执行时间。

矩阵乘法

注:本文不考虑数学角度的优化,如修改计算公式得到相同结果什么的。实现的浮点矩阵计算为:

简单起见,$A$ 的维度为 $512\times 128$,矩阵 $B$ 的维度为 $128 \times 256$。在高通骁龙某芯片上,目前的加速结果如下:

版本 时间
常规矩阵乘法 59.84ms
Neon 加速版本 1 12.90 ms
Neon 加速版本 2 3.85ms
Cache 友好的矩阵乘法 2.52ms
Neon 加速版本 3 2.77ms
Neon 加速版本 4 2.01ms
Neon 加速版本 5 1.09ms

为什么没 OpenCL?因为还没来得及写,仿佛欠着好多博客。

常规矩阵乘法

以线性代数中的矩阵乘法为例,目标矩阵的第 $i, j$ 个元素是矩阵 $A$ 的第 $i$ 行和矩阵 $B$ 的第 $j$ 列逐元素相乘相加的结果。根据这一原理写出最直观的代码,耗时 59.84ms:

1
2
3
4
5
6
7
8
9
10
11
12
void sgemm_c(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (col = 0; col < d2; col++) {
for (m = 0; m < d1; m++) {
C[row * d2 + col] += A[row * d1 + m] * B[m * d2 + col];
}
C[row * d2 + col] += bias[row * d2 + col];
}
}
}

我们知道矩阵在计算机中是行朱序存储的,即访问矩阵 $B[i, j]$ 时,会将 $B[i, j+1], B[i, j+2],…$ 等元素也一同取到内存的 cache 中。当需要 $B[i, j+1]$ 时就从 cache 中读取而不是去内存读取,这样会节省很多时间。

所以上述代码的性能瓶颈在于:

1
2
3
for (m = 0; m < d1; m++) {
C[row * d2 + col] += A[row * d1 + m] * B[m * d2 + col];
}

由于最内层的循环中 m 逐渐增加,矩阵 $B$ 的寻址方式为跳行寻址。在我们看不见的地方,cache 缓存的数据无法使用,每次读取 $B$ 矩阵的元素时还需要刷新 cache,这就导致这份代码很耗时。

Neon 加速版本 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
void sgemm_neon1(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (col = 0; col < d2; col+=4) {
float32x4_t sum4 = vdupq_n_f32(0.0f);

float *pa = A + row * d1;
float *pb = B + col;
float *pc = C + row * d2 + col;
float *pd = bias + row * d2 + col;

for (m = 0; m < d1; m+=4) {
float32x4_t a4 = vld1q_f32(pa);
float32x4_t b0 = vld1q_f32(pb + 0 * d2);
float32x4_t b1 = vld1q_f32(pb + 1 * d2);
float32x4_t b2 = vld1q_f32(pb + 2 * d2);
float32x4_t b3 = vld1q_f32(pb + 3 * d2);

sum4 = vmlaq_lane_f32(sum4, b0, vget_low_f32(a4), 0);
sum4 = vmlaq_lane_f32(sum4, b1, vget_low_f32(a4), 1);
sum4 = vmlaq_lane_f32(sum4, b2, vget_high_f32(a4), 0);
sum4 = vmlaq_lane_f32(sum4, b3, vget_high_f32(a4), 1);

pa += 4;
pb += 4 * d2;
}

float32x4_t d4 = vld1q_f32(pd);
sum4 = vaddq_f32(sum4, d4);
vst1q_f32(pc, sum4);
}
}
}

Neon 加速版本 2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
void sgemm_neon2(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row+=4) {
for (col = 0; col < d2; col+=4) {

float *pa = A + row * d1;
float *pb = B + col;
float *pc = C + row * d2 + col;
float *pd = bias + row * d2 + col;

float32x4_t sum0 = vld1q_f32(pd + 0 * d2);
float32x4_t sum1 = vld1q_f32(pd + 1 * d2);
float32x4_t sum2 = vld1q_f32(pd + 2 * d2);
float32x4_t sum3 = vld1q_f32(pd + 3 * d2);

for (m = 0; m < d1; m+=4) {
float32x4_t b0 = vld1q_f32(pb + 0 * d2);
float32x4_t b1 = vld1q_f32(pb + 1 * d2);
float32x4_t b2 = vld1q_f32(pb + 2 * d2);
float32x4_t b3 = vld1q_f32(pb + 3 * d2);

float32x4_t a0 = vld1q_f32(pa + 0 * d1);
float32x4_t a1 = vld1q_f32(pa + 1 * d1);
float32x4_t a2 = vld1q_f32(pa + 2 * d1);
float32x4_t a3 = vld1q_f32(pa + 3 * d1);

sum0 = vmlaq_lane_f32(sum0, b0, vget_low_f32(a0), 0);
sum0 = vmlaq_lane_f32(sum0, b1, vget_low_f32(a0), 1);
sum0 = vmlaq_lane_f32(sum0, b2, vget_high_f32(a0), 0);
sum0 = vmlaq_lane_f32(sum0, b3, vget_high_f32(a0), 1);

sum1 = vmlaq_lane_f32(sum1, b0, vget_low_f32(a1), 0);
sum1 = vmlaq_lane_f32(sum1, b1, vget_low_f32(a1), 1);
sum1 = vmlaq_lane_f32(sum1, b2, vget_high_f32(a1), 0);
sum1 = vmlaq_lane_f32(sum1, b3, vget_high_f32(a1), 1);

sum2 = vmlaq_lane_f32(sum2, b0, vget_low_f32(a2), 0);
sum2 = vmlaq_lane_f32(sum2, b1, vget_low_f32(a2), 1);
sum2 = vmlaq_lane_f32(sum2, b2, vget_high_f32(a2), 0);
sum2 = vmlaq_lane_f32(sum2, b3, vget_high_f32(a2), 1);

sum3 = vmlaq_lane_f32(sum3, b0, vget_low_f32(a3), 0);
sum3 = vmlaq_lane_f32(sum3, b1, vget_low_f32(a3), 1);
sum3 = vmlaq_lane_f32(sum3, b2, vget_high_f32(a3), 0);
sum3 = vmlaq_lane_f32(sum3, b3, vget_high_f32(a3), 1);

pa += 4;
pb += 4 * d2;
}

vst1q_f32(pc + 0 * d2, sum0);
vst1q_f32(pc + 1 * d2, sum1);
vst1q_f32(pc + 2 * d2, sum2);
vst1q_f32(pc + 3 * d2, sum3);
}
}
}

Cache 友好的矩阵乘法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
void rsgemm_c(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for(row = 0; row < d0; row++) {
for(m = 0; m < d1; m++) {
for(col = 0; col < d2; col++) {
C[row * d2 + col] += A[row * d1 + m] * B[m * d2 + col];
if (0 == m) {
C[row * d2 + col] += bias[row * d2 + col];
}
}
}
}
}

Neon 加速版本 3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
void rsgemm_neon1(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (m = 0; m < d1; m++) {

float32x4_t a4 = vdupq_n_f32(A[row * d1 + m]);
float *pb = B + m * d2;
float *pc = C + row * d2;
float *pd = bias + row * d2;

for (col = 0; col < d2; col+=4) {
float32x4_t b4 = vld1q_f32(pb);
float32x4_t c4 = vld1q_f32(pc);
float32x4_t val = vmulq_f32(a4, b4);
val = vaddq_f32(c4, val);

if (0 == m) {
val = vaddq_f32(vld1q_f32(pd), val);
}

vst1q_f32(pc, val);

pb += 4;
pc += 4;
pd += 4;
}
}
}
}

Neon 加速版本 4

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
void rsgemm_neon2(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row++) {
for (m = 0; m < d1; m+=4) {

float *pb0 = B + (m + 0) * d2;
float *pb1 = B + (m + 1) * d2;
float *pb2 = B + (m + 2) * d2;
float *pb3 = B + (m + 3) * d2;

float *pc = C + row * d2;
float *pd = bias + row * d2;

float32x4_t a4 = vld1q_f32(A + row * d1 + m);
float32x4_t a0 = vdupq_n_f32(vgetq_lane_f32(a4, 0));
float32x4_t a1 = vdupq_n_f32(vgetq_lane_f32(a4, 1));
float32x4_t a2 = vdupq_n_f32(vgetq_lane_f32(a4, 2));
float32x4_t a3 = vdupq_n_f32(vgetq_lane_f32(a4, 3));

for (col = 0; col < d2; col+=4) {
float32x4_t c4 = vld1q_f32(pc);

c4 = vaddq_f32(c4, vmulq_f32(a0, vld1q_f32(pb0)));
c4 = vaddq_f32(c4, vmulq_f32(a1, vld1q_f32(pb1)));
c4 = vaddq_f32(c4, vmulq_f32(a2, vld1q_f32(pb2)));
c4 = vaddq_f32(c4, vmulq_f32(a3, vld1q_f32(pb3)));

if (0 == m) {
c4 = vaddq_f32(vld1q_f32(pd), c4);
}

vst1q_f32(pc, c4);

pb0 += 4;
pb1 += 4;
pb2 += 4;
pb3 += 4;

pc += 4;
pd += 4;
}
}
}
}

Neon 加速版本 5

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
void rsgemm_neon3(float *C, float *A, float *B, float *bias, int d0, int d1, int d2)
{
int row, col, m;
for (row = 0; row < d0; row+=4) {
for (m = 0; m < d1; m+=4) {

float *pb0 = B + (m + 0) * d2;
float *pb1 = B + (m + 1) * d2;
float *pb2 = B + (m + 2) * d2;
float *pb3 = B + (m + 3) * d2;

float *pc0 = C + (0 + row) * d2;
float *pc1 = C + (1 + row) * d2;
float *pc2 = C + (2 + row) * d2;
float *pc3 = C + (3 + row) * d2;

float *pd0 = bias + (0 + row) * d2;
float *pd1 = bias + (1 + row) * d2;
float *pd2 = bias + (2 + row) * d2;
float *pd3 = bias + (3 + row) * d2;

float32x4_t a0 = vld1q_f32(A + (row + 0) * d1 + m);
float32x4_t a1 = vld1q_f32(A + (row + 1) * d1 + m);
float32x4_t a2 = vld1q_f32(A + (row + 2) * d1 + m);
float32x4_t a3 = vld1q_f32(A + (row + 3) * d1 + m);

float32x4_t a00 = vdupq_n_f32(vgetq_lane_f32(a0, 0));
float32x4_t a01 = vdupq_n_f32(vgetq_lane_f32(a0, 1));
float32x4_t a02 = vdupq_n_f32(vgetq_lane_f32(a0, 2));
float32x4_t a03 = vdupq_n_f32(vgetq_lane_f32(a0, 3));

float32x4_t a10 = vdupq_n_f32(vgetq_lane_f32(a1, 0));
float32x4_t a11 = vdupq_n_f32(vgetq_lane_f32(a1, 1));
float32x4_t a12 = vdupq_n_f32(vgetq_lane_f32(a1, 2));
float32x4_t a13 = vdupq_n_f32(vgetq_lane_f32(a1, 3));

float32x4_t a20 = vdupq_n_f32(vgetq_lane_f32(a2, 0));
float32x4_t a21 = vdupq_n_f32(vgetq_lane_f32(a2, 1));
float32x4_t a22 = vdupq_n_f32(vgetq_lane_f32(a2, 2));
float32x4_t a23 = vdupq_n_f32(vgetq_lane_f32(a2, 3));

float32x4_t a30 = vdupq_n_f32(vgetq_lane_f32(a3, 0));
float32x4_t a31 = vdupq_n_f32(vgetq_lane_f32(a3, 1));
float32x4_t a32 = vdupq_n_f32(vgetq_lane_f32(a3, 2));
float32x4_t a33 = vdupq_n_f32(vgetq_lane_f32(a3, 3));

for (col = 0; col < d2; col+=4) {
float32x4_t c04 = vld1q_f32(pc0);
float32x4_t c14 = vld1q_f32(pc1);
float32x4_t c24 = vld1q_f32(pc2);
float32x4_t c34 = vld1q_f32(pc3);

float32x4_t b0 = vld1q_f32(pb0);
float32x4_t b1 = vld1q_f32(pb1);
float32x4_t b2 = vld1q_f32(pb2);
float32x4_t b3 = vld1q_f32(pb3);

c04 = vaddq_f32(c04, vmulq_f32(a00, b0));
c04 = vaddq_f32(c04, vmulq_f32(a01, b1));
c04 = vaddq_f32(c04, vmulq_f32(a02, b2));
c04 = vaddq_f32(c04, vmulq_f32(a03, b3));

c14 = vaddq_f32(c14, vmulq_f32(a10, b0));
c14 = vaddq_f32(c14, vmulq_f32(a11, b1));
c14 = vaddq_f32(c14, vmulq_f32(a12, b2));
c14 = vaddq_f32(c14, vmulq_f32(a13, b3));

c24 = vaddq_f32(c24, vmulq_f32(a20, b0));
c24 = vaddq_f32(c24, vmulq_f32(a21, b1));
c24 = vaddq_f32(c24, vmulq_f32(a22, b2));
c24 = vaddq_f32(c24, vmulq_f32(a23, b3));

c34 = vaddq_f32(c34, vmulq_f32(a30, b0));
c34 = vaddq_f32(c34, vmulq_f32(a31, b1));
c34 = vaddq_f32(c34, vmulq_f32(a32, b2));
c34 = vaddq_f32(c34, vmulq_f32(a33, b3));

if (0 == m) {
c04 = vaddq_f32(vld1q_f32(pd0), c04);
c14 = vaddq_f32(vld1q_f32(pd1), c14);
c24 = vaddq_f32(vld1q_f32(pd2), c24);
c34 = vaddq_f32(vld1q_f32(pd3), c34);
}

vst1q_f32(pc0, c04);
vst1q_f32(pc1, c14);
vst1q_f32(pc2, c24);
vst1q_f32(pc3, c34);

pb0 += 4;
pb1 += 4;
pb2 += 4;
pb3 += 4;

pc0 += 4;
pc1 += 4;
pc2 += 4;
pc3 += 4;

pd0 += 4;
pd1 += 4;
pd2 += 4;
pd3 += 4;
}
}
}
}
感谢上学期间打赏我的朋友们。赛博乞讨:我,秦始皇,打钱。

欢迎订阅我的文章