1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
/* SPDX-License-Identifier: GPL-2.0 */
/*
* NH - ε-almost-universal hash function, x86_64 AVX2 accelerated
*
* Copyright 2018 Google LLC
*
* Author: Eric Biggers <ebiggers@google.com>
*/
#include <linux/linkage.h>
#define PASS0_SUMS %ymm0
#define PASS1_SUMS %ymm1
#define PASS2_SUMS %ymm2
#define PASS3_SUMS %ymm3
#define K0 %ymm4
#define K0_XMM %xmm4
#define K1 %ymm5
#define K1_XMM %xmm5
#define K2 %ymm6
#define K2_XMM %xmm6
#define K3 %ymm7
#define K3_XMM %xmm7
#define T0 %ymm8
#define T1 %ymm9
#define T2 %ymm10
#define T2_XMM %xmm10
#define T3 %ymm11
#define T3_XMM %xmm11
#define T4 %ymm12
#define T5 %ymm13
#define T6 %ymm14
#define T7 %ymm15
#define KEY %rdi
#define MESSAGE %rsi
#define MESSAGE_LEN %rdx
#define HASH %rcx
.macro _nh_2xstride k0, k1, k2, k3
// Add message words to key words
vpaddd \k0, T3, T0
vpaddd \k1, T3, T1
vpaddd \k2, T3, T2
vpaddd \k3, T3, T3
// Multiply 32x32 => 64 and accumulate
vpshufd $0x10, T0, T4
vpshufd $0x32, T0, T0
vpshufd $0x10, T1, T5
vpshufd $0x32, T1, T1
vpshufd $0x10, T2, T6
vpshufd $0x32, T2, T2
vpshufd $0x10, T3, T7
vpshufd $0x32, T3, T3
vpmuludq T4, T0, T0
vpmuludq T5, T1, T1
vpmuludq T6, T2, T2
vpmuludq T7, T3, T3
vpaddq T0, PASS0_SUMS, PASS0_SUMS
vpaddq T1, PASS1_SUMS, PASS1_SUMS
vpaddq T2, PASS2_SUMS, PASS2_SUMS
vpaddq T3, PASS3_SUMS, PASS3_SUMS
.endm
/*
* void nh_avx2(const u32 *key, const u8 *message, size_t message_len,
* u8 hash[NH_HASH_BYTES])
*
* It's guaranteed that message_len % 16 == 0.
*/
SYM_FUNC_START(nh_avx2)
vmovdqu 0x00(KEY), K0
vmovdqu 0x10(KEY), K1
add $0x20, KEY
vpxor PASS0_SUMS, PASS0_SUMS, PASS0_SUMS
vpxor PASS1_SUMS, PASS1_SUMS, PASS1_SUMS
vpxor PASS2_SUMS, PASS2_SUMS, PASS2_SUMS
vpxor PASS3_SUMS, PASS3_SUMS, PASS3_SUMS
sub $0x40, MESSAGE_LEN
jl .Lloop4_done
.Lloop4:
vmovdqu (MESSAGE), T3
vmovdqu 0x00(KEY), K2
vmovdqu 0x10(KEY), K3
_nh_2xstride K0, K1, K2, K3
vmovdqu 0x20(MESSAGE), T3
vmovdqu 0x20(KEY), K0
vmovdqu 0x30(KEY), K1
_nh_2xstride K2, K3, K0, K1
add $0x40, MESSAGE
add $0x40, KEY
sub $0x40, MESSAGE_LEN
jge .Lloop4
.Lloop4_done:
and $0x3f, MESSAGE_LEN
jz .Ldone
cmp $0x20, MESSAGE_LEN
jl .Llast
// 2 or 3 strides remain; do 2 more.
vmovdqu (MESSAGE), T3
vmovdqu 0x00(KEY), K2
vmovdqu 0x10(KEY), K3
_nh_2xstride K0, K1, K2, K3
add $0x20, MESSAGE
add $0x20, KEY
sub $0x20, MESSAGE_LEN
jz .Ldone
vmovdqa K2, K0
vmovdqa K3, K1
.Llast:
// Last stride. Zero the high 128 bits of the message and keys so they
// don't affect the result when processing them like 2 strides.
vmovdqu (MESSAGE), T3_XMM
vmovdqa K0_XMM, K0_XMM
vmovdqa K1_XMM, K1_XMM
vmovdqu 0x00(KEY), K2_XMM
vmovdqu 0x10(KEY), K3_XMM
_nh_2xstride K0, K1, K2, K3
.Ldone:
// Sum the accumulators for each pass, then store the sums to 'hash'
// PASS0_SUMS is (0A 0B 0C 0D)
// PASS1_SUMS is (1A 1B 1C 1D)
// PASS2_SUMS is (2A 2B 2C 2D)
// PASS3_SUMS is (3A 3B 3C 3D)
// We need the horizontal sums:
// (0A + 0B + 0C + 0D,
// 1A + 1B + 1C + 1D,
// 2A + 2B + 2C + 2D,
// 3A + 3B + 3C + 3D)
//
vpunpcklqdq PASS1_SUMS, PASS0_SUMS, T0 // T0 = (0A 1A 0C 1C)
vpunpckhqdq PASS1_SUMS, PASS0_SUMS, T1 // T1 = (0B 1B 0D 1D)
vpunpcklqdq PASS3_SUMS, PASS2_SUMS, T2 // T2 = (2A 3A 2C 3C)
vpunpckhqdq PASS3_SUMS, PASS2_SUMS, T3 // T3 = (2B 3B 2D 3D)
vinserti128 $0x1, T2_XMM, T0, T4 // T4 = (0A 1A 2A 3A)
vinserti128 $0x1, T3_XMM, T1, T5 // T5 = (0B 1B 2B 3B)
vperm2i128 $0x31, T2, T0, T0 // T0 = (0C 1C 2C 3C)
vperm2i128 $0x31, T3, T1, T1 // T1 = (0D 1D 2D 3D)
vpaddq T5, T4, T4
vpaddq T1, T0, T0
vpaddq T4, T0, T0
vmovdqu T0, (HASH)
ret
SYM_FUNC_END(nh_avx2)
|