Small speedup in modular exponentation routine (still not side channel resiliant)

git-svn-id: svn://kolibrios.org@9985 a494cfbc-eb01-0410-851d-a64ba20cac60
This commit is contained in:
hidnplayr 2024-03-05 19:57:16 +00:00
parent eb7e44a0e0
commit 199ad2d9a4

View File

@ -1,6 +1,6 @@
; mpint.inc - Multi precision integer procedures
;
; Copyright (C) 2015-2021 Jeffrey Amelynck
; Copyright (C) 2015-2024 Jeffrey Amelynck
;
; This program is free software: you can redistribute it and/or modify
; it under the terms of the GNU General Public License as published by
@ -905,7 +905,7 @@ proc mpint_grow uses eax edi ecx, dst, length ;//////////////////////////////;;
endp
;;===========================================================================;;
proc mpint_mul uses eax ebx ecx edx esi edi, dst, a, b ;///////////////////////;;
proc mpint_mul uses eax ebx ecx edx esi edi, dst, a, b ;/////////////////////;;
;;---------------------------------------------------------------------------;;
;? Multiply a little endian MPINT with another little endian MPINT and store ;;
;? in a third one. ;;
@ -937,34 +937,34 @@ endl
jnz .adjust_needed
.length_ok:
; Must have Asize >= Bsize.
; Must have a size >= b size.
cmp ebx, ecx
ja .swap_a_b
.conditions_ok:
; D size will be A size + B size
; dst size will be a size + b size
lea eax, [ebx + ecx]
cmp eax, MPINT_MAX_LEN
ja .ovf
; [Asize] = number of dwords in x
; [asize] = number of dwords in a
shr ecx, 2
jz .zero
mov [asize], ecx
; esi = x ptr
add esi, 4
; [Bsize] = number of dwords in y
; [bsize] = number of dwords in b
shr ebx, 2
jz .zero
mov [bsize], ebx
; edx = y ptr (temporarily)
; edx = b ptr (temporarily)
add edx, 4
; store D size
; store dst size
mov edi, [dst]
mov [edi], eax
; edi = D ptr
; edi = dst ptr
add edi, 4
; Use esp as frame pointer instead of ebp
@ -972,30 +972,30 @@ endl
mov [esp_], esp
mov esp, ebp
; ebp = B ptr
; ebp = b ptr
mov ebp, edx
; Do the first multiplication
mov eax, [esi] ; load A[0]
mul dword[ebp] ; multiply by B[0]
mov [edi], eax ; store to D[0]
; mov ecx, [Asize] ; Asize
dec ecx ; if Asize = 1, Bsize = 1 too
mov eax, [esi] ; load a[0]
mul dword[ebp] ; multiply by b[0]
mov [edi], eax ; store to dest[0]
; mov ecx, [asize] ; asize
dec ecx ; if asize = 1, bsize = 1 too
jz .done
; Prepare to enter loop1
mov eax, [asize-ebp+esp]
mov ebx, edx
lea esi, [esi + eax * 4] ; make A ptr point at end
lea edi, [edi + eax * 4] ; offset D ptr by Asize
lea esi, [esi + eax * 4] ; make a ptr point at end
lea edi, [edi + eax * 4] ; offset dst ptr by asize
neg ecx ; negate j size/index for inner loop
xor eax, eax ; clear carry
align 8
.loop1:
adc ebx, 0
mov eax, [esi + ecx * 4] ; load next dword at A[j]
mov eax, [esi + ecx * 4] ; load next dword at a[j]
mul dword[ebp]
add eax, ebx
mov [edi + ecx * 4], eax
@ -1009,10 +1009,10 @@ align 8
add edi, 4 ; increment dst
dec eax
jz .skip
mov [counter-ebp+esp], eax ; set index i to Bsize
mov [counter-ebp+esp], eax ; set index i to bsize
.outer:
add ebp, 4 ; make ebp point to next B dword
add ebp, 4 ; make ebp point to next b dword
mov ecx, [asize-ebp+esp]
neg ecx
xor ebx, ebx
@ -1047,7 +1047,7 @@ align 8
ret
.done:
mov [edi+4], edx ; store to D[1]
mov [edi+4], edx ; store to dst[1]
; restore esp, ebp
mov ebp, esp
mov esp, [esp_]
@ -1151,7 +1151,7 @@ endl
endp
;;===========================================================================;;
proc mpint_modexp uses edi eax ebx ecx edx, dst, b, e, m ;///////////////////;;
proc mpint_modexp uses edi eax ebx ecx edx, dest, b, e, m ;//////////////////;;
;;---------------------------------------------------------------------------;;
;? Find the modulo (remainder after division) of dst by mod. ;;
;;---------------------------------------------------------------------------;;
@ -1164,10 +1164,12 @@ proc mpint_modexp uses edi eax ebx ecx edx, dst, b, e, m ;///////////////////;;
;;===========================================================================;;
locals
mpint_tmp rb MPINT_MAX_LEN+4
dst1 dd ?
dst2 dd ?
tmp rb MPINT_MAX_LEN+4
endl
DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dst], [b], [e], [m]
DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dest], [b], [e], [m]
; If mod is zero, return
stdcall mpint_bytes, [m]
@ -1185,6 +1187,12 @@ endl
mov edi, [e]
lea edi, [edi + 4 + ecx - 1]
; Set up temp variables
lea eax, [tmp]
mov edx, [dest]
mov [dst1], eax
mov [dst2], edx
; Find the highest order bit in this byte
mov al, [edi]
test al, al
@ -1195,28 +1203,32 @@ endl
shl al, 1
jnc @r
; Make pointer to tmp mpint for convenient access
lea edx, [mpint_tmp]
; Initialise result to base, to take care of the highest order bit
stdcall mpint_mov, [dst], [b]
stdcall mpint_mov, [dst1], [b]
dec bl
jz .next_byte
.bit_loop:
; For each bit, square result
stdcall mpint_mov, edx, [dst]
stdcall mpint_mul, [dst], edx, edx
stdcall mpint_mod, [dst], [m]
stdcall mpint_mul, [dst2], [dst1], [dst1]
stdcall mpint_mod, [dst2], [m]
; If the bit is set, multiply result by the base
shl al, 1
jnc .next_bit
stdcall mpint_mov, edx, [dst]
stdcall mpint_mul, [dst], [b], edx
stdcall mpint_mod, [dst], [m]
.next_bit:
jnc .bit_zero
stdcall mpint_mul, [dst1], [b], [dst2]
stdcall mpint_mod, [dst1], [m]
dec bl
jnz .bit_loop
jmp .next_byte
.bit_zero:
mov edx, [dst1]
mov esi, [dst2]
mov [dst2], edx
mov [dst1], esi
dec bl
jnz .bit_loop
.next_byte:
dec ecx
jz .done
@ -1225,19 +1237,25 @@ endl
mov bl, 8
jmp .bit_loop
.done:
mov edx, [dest]
cmp edx, [dst1]
je @f
stdcall mpint_mov, [dest], [dst1]
@@:
ret
.mod_zero:
DEBUGF 3, "modexp with modulo 0\n"
; if mod is zero, result = 0
mov eax, [dst]
mov eax, [dest]
mov dword[eax], 0
ret
.exp_zero:
DEBUGF 3, "modexp with exponent 0\n"
; if exponent is zero, result = 1
mov eax, [dst]
mov eax, [dest]
mov dword[eax], 1
mov byte[eax+4], 1
ret
@ -1254,3 +1272,4 @@ endl
endp