FROMLIST: crypto: x86/sha256-ni - add support for finup_mb

Add an implementation of finup_mb to sha256-ni, using an interleaving
factor of 2.  It interleaves a finup operation for two equal-length
messages that share a common prefix.  dm-verity and fs-verity will take
advantage of this for greatly improved performance on capable CPUs.

This increases the throughput of SHA-256 hashing 4096-byte messages by
the following amounts on the following CPUs:

    AMD Zen 1:                  84%
    AMD Zen 4:                  98%
    Intel Ice Lake:              4%
    Intel Sapphire Rapids:      20%

For now, this seems to benefit AMD much more than Intel.  This seems to
be because current AMD CPUs support concurrent execution of the SHA-NI
instructions, but unfortunately current Intel CPUs don't, except for the
sha256msg2 instruction.  Hopefully future Intel CPUs will support SHA-NI
on more execution ports.  Zen 1 supports 2 concurrent sha256rnds2, and
Zen 4 supports 4 concurrent sha256rnds2, which suggests that even better
performance may be achievable on Zen 4 by interleaving more than two
hashes; however, doing so poses a number of trade-offs.

It's been reported that the method that achieves the highest SHA-256
throughput on Intel CPUs is actually computing 16 hashes simultaneously
using AVX512.  That method would be quite different to the SHA-NI method
used in this patch.  However, such a high interleaving factor isn't
practical for the use cases being targeted in the kernel.

Reviewed-by: Sami Tolvanen <samitolvanen@google.com>
Acked-by: Ard Biesheuvel <ardb@kernel.org>
Signed-off-by: Eric Biggers <ebiggers@google.com>

Bug: 330611177
Link: https://lore.kernel.org/r/20240621165922.77672-5-ebiggers@kernel.org
Change-Id: I67204992677a80826c61e29ee3ca3c8be477d2f3
Signed-off-by: Eric Biggers <ebiggers@google.com>
This commit is contained in:
Eric Biggers 2024-06-21 10:37:49 -07:00 committed by William McVicker
parent a2372f602d
commit 16e22de481
2 changed files with 407 additions and 0 deletions

View File

@ -329,6 +329,374 @@ SYM_TYPED_FUNC_START(sha256_ni_transform)
RET
SYM_FUNC_END(sha256_ni_transform)
#undef DIGEST_PTR
#undef DATA_PTR
#undef NUM_BLKS
#undef SHA256CONSTANTS
#undef MSG
#undef STATE0
#undef STATE1
#undef MSG0
#undef MSG1
#undef MSG2
#undef MSG3
#undef TMP
#undef SHUF_MASK
#undef ABEF_SAVE
#undef CDGH_SAVE
// parameters for __sha256_ni_finup2x()
#define SCTX %rdi
#define DATA1 %rsi
#define DATA2 %rdx
#define LEN %ecx
#define LEN8 %cl
#define LEN64 %rcx
#define OUT1 %r8
#define OUT2 %r9
// other scalar variables
#define SHA256CONSTANTS %rax
#define COUNT %r10
#define COUNT32 %r10d
#define FINAL_STEP %r11d
// rbx is used as a temporary.
#define MSG %xmm0 // sha256rnds2 implicit operand
#define STATE0_A %xmm1
#define STATE1_A %xmm2
#define STATE0_B %xmm3
#define STATE1_B %xmm4
#define TMP_A %xmm5
#define TMP_B %xmm6
#define MSG0_A %xmm7
#define MSG1_A %xmm8
#define MSG2_A %xmm9
#define MSG3_A %xmm10
#define MSG0_B %xmm11
#define MSG1_B %xmm12
#define MSG2_B %xmm13
#define MSG3_B %xmm14
#define SHUF_MASK %xmm15
#define OFFSETOF_STATE 0 // offsetof(struct sha256_state, state)
#define OFFSETOF_COUNT 32 // offsetof(struct sha256_state, count)
#define OFFSETOF_BUF 40 // offsetof(struct sha256_state, buf)
// Do 4 rounds of SHA-256 for each of two messages (interleaved). m0_a and m0_b
// contain the current 4 message schedule words for the first and second message
// respectively.
//
// If not all the message schedule words have been computed yet, then this also
// computes 4 more message schedule words for each message. m1_a-m3_a contain
// the next 3 groups of 4 message schedule words for the first message, and
// likewise m1_b-m3_b for the second. After consuming the current value of
// m0_a, this macro computes the group after m3_a and writes it to m0_a, and
// likewise for *_b. This means that the next (m0_a, m1_a, m2_a, m3_a) is the
// current (m1_a, m2_a, m3_a, m0_a), and likewise for *_b, so the caller must
// cycle through the registers accordingly.
.macro do_4rounds_2x i, m0_a, m1_a, m2_a, m3_a, m0_b, m1_b, m2_b, m3_b
movdqa (\i-32)*4(SHA256CONSTANTS), TMP_A
movdqa TMP_A, TMP_B
paddd \m0_a, TMP_A
paddd \m0_b, TMP_B
.if \i < 48
sha256msg1 \m1_a, \m0_a
sha256msg1 \m1_b, \m0_b
.endif
movdqa TMP_A, MSG
sha256rnds2 STATE0_A, STATE1_A
movdqa TMP_B, MSG
sha256rnds2 STATE0_B, STATE1_B
pshufd $0x0E, TMP_A, MSG
sha256rnds2 STATE1_A, STATE0_A
pshufd $0x0E, TMP_B, MSG
sha256rnds2 STATE1_B, STATE0_B
.if \i < 48
movdqa \m3_a, TMP_A
movdqa \m3_b, TMP_B
palignr $4, \m2_a, TMP_A
palignr $4, \m2_b, TMP_B
paddd TMP_A, \m0_a
paddd TMP_B, \m0_b
sha256msg2 \m3_a, \m0_a
sha256msg2 \m3_b, \m0_b
.endif
.endm
//
// void __sha256_ni_finup2x(const struct sha256_state *sctx,
// const u8 *data1, const u8 *data2, int len,
// u8 out1[SHA256_DIGEST_SIZE],
// u8 out2[SHA256_DIGEST_SIZE]);
//
// This function computes the SHA-256 digests of two messages |data1| and
// |data2| that are both |len| bytes long, starting from the initial state
// |sctx|. |len| must be at least SHA256_BLOCK_SIZE.
//
// The instructions for the two SHA-256 operations are interleaved. On many
// CPUs, this is almost twice as fast as hashing each message individually due
// to taking better advantage of the CPU's SHA-256 and SIMD throughput.
//
SYM_FUNC_START(__sha256_ni_finup2x)
// Allocate 128 bytes of stack space, 16-byte aligned.
push %rbx
push %rbp
mov %rsp, %rbp
sub $128, %rsp
and $~15, %rsp
// Load the shuffle mask for swapping the endianness of 32-bit words.
movdqa PSHUFFLE_BYTE_FLIP_MASK(%rip), SHUF_MASK
// Set up pointer to the round constants.
lea K256+32*4(%rip), SHA256CONSTANTS
// Initially we're not processing the final blocks.
xor FINAL_STEP, FINAL_STEP
// Load the initial state from sctx->state.
movdqu OFFSETOF_STATE+0*16(SCTX), STATE0_A // DCBA
movdqu OFFSETOF_STATE+1*16(SCTX), STATE1_A // HGFE
movdqa STATE0_A, TMP_A
punpcklqdq STATE1_A, STATE0_A // FEBA
punpckhqdq TMP_A, STATE1_A // DCHG
pshufd $0x1B, STATE0_A, STATE0_A // ABEF
pshufd $0xB1, STATE1_A, STATE1_A // CDGH
// Load sctx->count. Take the mod 64 of it to get the number of bytes
// that are buffered in sctx->buf. Also save it in a register with LEN
// added to it.
mov LEN, LEN
mov OFFSETOF_COUNT(SCTX), %rbx
lea (%rbx, LEN64, 1), COUNT
and $63, %ebx
jz .Lfinup2x_enter_loop // No bytes buffered?
// %ebx bytes (1 to 63) are currently buffered in sctx->buf. Load them
// followed by the first 64 - %ebx bytes of data. Since LEN >= 64, we
// just load 64 bytes from each of sctx->buf, DATA1, and DATA2
// unconditionally and rearrange the data as needed.
movdqu OFFSETOF_BUF+0*16(SCTX), MSG0_A
movdqu OFFSETOF_BUF+1*16(SCTX), MSG1_A
movdqu OFFSETOF_BUF+2*16(SCTX), MSG2_A
movdqu OFFSETOF_BUF+3*16(SCTX), MSG3_A
movdqa MSG0_A, 0*16(%rsp)
movdqa MSG1_A, 1*16(%rsp)
movdqa MSG2_A, 2*16(%rsp)
movdqa MSG3_A, 3*16(%rsp)
movdqu 0*16(DATA1), MSG0_A
movdqu 1*16(DATA1), MSG1_A
movdqu 2*16(DATA1), MSG2_A
movdqu 3*16(DATA1), MSG3_A
movdqu MSG0_A, 0*16(%rsp,%rbx)
movdqu MSG1_A, 1*16(%rsp,%rbx)
movdqu MSG2_A, 2*16(%rsp,%rbx)
movdqu MSG3_A, 3*16(%rsp,%rbx)
movdqa 0*16(%rsp), MSG0_A
movdqa 1*16(%rsp), MSG1_A
movdqa 2*16(%rsp), MSG2_A
movdqa 3*16(%rsp), MSG3_A
movdqu 0*16(DATA2), MSG0_B
movdqu 1*16(DATA2), MSG1_B
movdqu 2*16(DATA2), MSG2_B
movdqu 3*16(DATA2), MSG3_B
movdqu MSG0_B, 0*16(%rsp,%rbx)
movdqu MSG1_B, 1*16(%rsp,%rbx)
movdqu MSG2_B, 2*16(%rsp,%rbx)
movdqu MSG3_B, 3*16(%rsp,%rbx)
movdqa 0*16(%rsp), MSG0_B
movdqa 1*16(%rsp), MSG1_B
movdqa 2*16(%rsp), MSG2_B
movdqa 3*16(%rsp), MSG3_B
sub $64, %rbx // rbx = buffered - 64
sub %rbx, DATA1 // DATA1 += 64 - buffered
sub %rbx, DATA2 // DATA2 += 64 - buffered
add %ebx, LEN // LEN += buffered - 64
movdqa STATE0_A, STATE0_B
movdqa STATE1_A, STATE1_B
jmp .Lfinup2x_loop_have_data
.Lfinup2x_enter_loop:
sub $64, LEN
movdqa STATE0_A, STATE0_B
movdqa STATE1_A, STATE1_B
.Lfinup2x_loop:
// Load the next two data blocks.
movdqu 0*16(DATA1), MSG0_A
movdqu 0*16(DATA2), MSG0_B
movdqu 1*16(DATA1), MSG1_A
movdqu 1*16(DATA2), MSG1_B
movdqu 2*16(DATA1), MSG2_A
movdqu 2*16(DATA2), MSG2_B
movdqu 3*16(DATA1), MSG3_A
movdqu 3*16(DATA2), MSG3_B
add $64, DATA1
add $64, DATA2
.Lfinup2x_loop_have_data:
// Convert the words of the data blocks from big endian.
pshufb SHUF_MASK, MSG0_A
pshufb SHUF_MASK, MSG0_B
pshufb SHUF_MASK, MSG1_A
pshufb SHUF_MASK, MSG1_B
pshufb SHUF_MASK, MSG2_A
pshufb SHUF_MASK, MSG2_B
pshufb SHUF_MASK, MSG3_A
pshufb SHUF_MASK, MSG3_B
.Lfinup2x_loop_have_bswapped_data:
// Save the original state for each block.
movdqa STATE0_A, 0*16(%rsp)
movdqa STATE0_B, 1*16(%rsp)
movdqa STATE1_A, 2*16(%rsp)
movdqa STATE1_B, 3*16(%rsp)
// Do the SHA-256 rounds on each block.
.irp i, 0, 16, 32, 48
do_4rounds_2x (\i + 0), MSG0_A, MSG1_A, MSG2_A, MSG3_A, \
MSG0_B, MSG1_B, MSG2_B, MSG3_B
do_4rounds_2x (\i + 4), MSG1_A, MSG2_A, MSG3_A, MSG0_A, \
MSG1_B, MSG2_B, MSG3_B, MSG0_B
do_4rounds_2x (\i + 8), MSG2_A, MSG3_A, MSG0_A, MSG1_A, \
MSG2_B, MSG3_B, MSG0_B, MSG1_B
do_4rounds_2x (\i + 12), MSG3_A, MSG0_A, MSG1_A, MSG2_A, \
MSG3_B, MSG0_B, MSG1_B, MSG2_B
.endr
// Add the original state for each block.
paddd 0*16(%rsp), STATE0_A
paddd 1*16(%rsp), STATE0_B
paddd 2*16(%rsp), STATE1_A
paddd 3*16(%rsp), STATE1_B
// Update LEN and loop back if more blocks remain.
sub $64, LEN
jge .Lfinup2x_loop
// Check if any final blocks need to be handled.
// FINAL_STEP = 2: all done
// FINAL_STEP = 1: need to do count-only padding block
// FINAL_STEP = 0: need to do the block with 0x80 padding byte
cmp $1, FINAL_STEP
jg .Lfinup2x_done
je .Lfinup2x_finalize_countonly
add $64, LEN
jz .Lfinup2x_finalize_blockaligned
// Not block-aligned; 1 <= LEN <= 63 data bytes remain. Pad the block.
// To do this, write the padding starting with the 0x80 byte to
// &sp[64]. Then for each message, copy the last 64 data bytes to sp
// and load from &sp[64 - LEN] to get the needed padding block. This
// code relies on the data buffers being >= 64 bytes in length.
mov $64, %ebx
sub LEN, %ebx // ebx = 64 - LEN
sub %rbx, DATA1 // DATA1 -= 64 - LEN
sub %rbx, DATA2 // DATA2 -= 64 - LEN
mov $0x80, FINAL_STEP // using FINAL_STEP as a temporary
movd FINAL_STEP, MSG0_A
pxor MSG1_A, MSG1_A
movdqa MSG0_A, 4*16(%rsp)
movdqa MSG1_A, 5*16(%rsp)
movdqa MSG1_A, 6*16(%rsp)
movdqa MSG1_A, 7*16(%rsp)
cmp $56, LEN
jge 1f // will COUNT spill into its own block?
shl $3, COUNT
bswap COUNT
mov COUNT, 56(%rsp,%rbx)
mov $2, FINAL_STEP // won't need count-only block
jmp 2f
1:
mov $1, FINAL_STEP // will need count-only block
2:
movdqu 0*16(DATA1), MSG0_A
movdqu 1*16(DATA1), MSG1_A
movdqu 2*16(DATA1), MSG2_A
movdqu 3*16(DATA1), MSG3_A
movdqa MSG0_A, 0*16(%rsp)
movdqa MSG1_A, 1*16(%rsp)
movdqa MSG2_A, 2*16(%rsp)
movdqa MSG3_A, 3*16(%rsp)
movdqu 0*16(%rsp,%rbx), MSG0_A
movdqu 1*16(%rsp,%rbx), MSG1_A
movdqu 2*16(%rsp,%rbx), MSG2_A
movdqu 3*16(%rsp,%rbx), MSG3_A
movdqu 0*16(DATA2), MSG0_B
movdqu 1*16(DATA2), MSG1_B
movdqu 2*16(DATA2), MSG2_B
movdqu 3*16(DATA2), MSG3_B
movdqa MSG0_B, 0*16(%rsp)
movdqa MSG1_B, 1*16(%rsp)
movdqa MSG2_B, 2*16(%rsp)
movdqa MSG3_B, 3*16(%rsp)
movdqu 0*16(%rsp,%rbx), MSG0_B
movdqu 1*16(%rsp,%rbx), MSG1_B
movdqu 2*16(%rsp,%rbx), MSG2_B
movdqu 3*16(%rsp,%rbx), MSG3_B
jmp .Lfinup2x_loop_have_data
// Prepare a padding block, either:
//
// {0x80, 0, 0, 0, ..., count (as __be64)}
// This is for a block aligned message.
//
// { 0, 0, 0, 0, ..., count (as __be64)}
// This is for a message whose length mod 64 is >= 56.
//
// Pre-swap the endianness of the words.
.Lfinup2x_finalize_countonly:
pxor MSG0_A, MSG0_A
jmp 1f
.Lfinup2x_finalize_blockaligned:
mov $0x80000000, %ebx
movd %ebx, MSG0_A
1:
pxor MSG1_A, MSG1_A
pxor MSG2_A, MSG2_A
ror $29, COUNT
movq COUNT, MSG3_A
pslldq $8, MSG3_A
movdqa MSG0_A, MSG0_B
pxor MSG1_B, MSG1_B
pxor MSG2_B, MSG2_B
movdqa MSG3_A, MSG3_B
mov $2, FINAL_STEP
jmp .Lfinup2x_loop_have_bswapped_data
.Lfinup2x_done:
// Write the two digests with all bytes in the correct order.
movdqa STATE0_A, TMP_A
movdqa STATE0_B, TMP_B
punpcklqdq STATE1_A, STATE0_A // GHEF
punpcklqdq STATE1_B, STATE0_B
punpckhqdq TMP_A, STATE1_A // ABCD
punpckhqdq TMP_B, STATE1_B
pshufd $0xB1, STATE0_A, STATE0_A // HGFE
pshufd $0xB1, STATE0_B, STATE0_B
pshufd $0x1B, STATE1_A, STATE1_A // DCBA
pshufd $0x1B, STATE1_B, STATE1_B
pshufb SHUF_MASK, STATE0_A
pshufb SHUF_MASK, STATE0_B
pshufb SHUF_MASK, STATE1_A
pshufb SHUF_MASK, STATE1_B
movdqu STATE0_A, 1*16(OUT1)
movdqu STATE0_B, 1*16(OUT2)
movdqu STATE1_A, 0*16(OUT1)
movdqu STATE1_B, 0*16(OUT2)
mov %rbp, %rsp
pop %rbp
pop %rbx
RET
SYM_FUNC_END(__sha256_ni_finup2x)
.section .rodata.cst256.K256, "aM", @progbits, 256
.align 64
K256:

View File

@ -330,6 +330,11 @@ static void unregister_sha256_avx2(void)
asmlinkage void sha256_ni_transform(struct sha256_state *digest,
const u8 *data, int rounds);
asmlinkage void __sha256_ni_finup2x(const struct sha256_state *sctx,
const u8 *data1, const u8 *data2, int len,
u8 out1[SHA256_DIGEST_SIZE],
u8 out2[SHA256_DIGEST_SIZE]);
static int sha256_ni_update(struct shash_desc *desc, const u8 *data,
unsigned int len)
{
@ -354,6 +359,38 @@ static int sha256_ni_digest(struct shash_desc *desc, const u8 *data,
sha256_ni_finup(desc, data, len, out);
}
static int sha256_ni_finup_mb(struct shash_desc *desc,
const u8 * const data[], unsigned int len,
u8 * const outs[], unsigned int num_msgs)
{
struct sha256_state *sctx = shash_desc_ctx(desc);
/*
* num_msgs != 2 should not happen here, since this algorithm sets
* mb_max_msgs=2, and the crypto API handles num_msgs <= 1 before
* calling into the algorithm's finup_mb method.
*/
if (WARN_ON_ONCE(num_msgs != 2))
return -EOPNOTSUPP;
if (unlikely(!crypto_simd_usable()))
return -EOPNOTSUPP;
/* __sha256_ni_finup2x() assumes SHA256_BLOCK_SIZE <= len <= INT_MAX. */
if (unlikely(len < SHA256_BLOCK_SIZE || len > INT_MAX))
return -EOPNOTSUPP;
/* __sha256_ni_finup2x() assumes the following offsets. */
BUILD_BUG_ON(offsetof(struct sha256_state, state) != 0);
BUILD_BUG_ON(offsetof(struct sha256_state, count) != 32);
BUILD_BUG_ON(offsetof(struct sha256_state, buf) != 40);
kernel_fpu_begin();
__sha256_ni_finup2x(sctx, data[0], data[1], len, outs[0], outs[1]);
kernel_fpu_end();
return 0;
}
static struct shash_alg sha256_ni_algs[] = { {
.digestsize = SHA256_DIGEST_SIZE,
.init = sha256_base_init,
@ -361,7 +398,9 @@ static struct shash_alg sha256_ni_algs[] = { {
.final = sha256_ni_final,
.finup = sha256_ni_finup,
.digest = sha256_ni_digest,
.finup_mb = sha256_ni_finup_mb,
.descsize = sizeof(struct sha256_state),
.mb_max_msgs = 2,
.base = {
.cra_name = "sha256",
.cra_driver_name = "sha256-ni",