Small speedup in modular exponentation routine (still not side channel resiliant)

git-svn-id: svn://kolibrios.org@9985 a494cfbc-eb01-0410-851d-a64ba20cac60
This commit is contained in:
hidnplayr 2024-03-05 19:57:16 +00:00
parent eb7e44a0e0
commit 199ad2d9a4

View File

@ -1,6 +1,6 @@
; mpint.inc - Multi precision integer procedures ; mpint.inc - Multi precision integer procedures
; ;
; Copyright (C) 2015-2021 Jeffrey Amelynck ; Copyright (C) 2015-2024 Jeffrey Amelynck
; ;
; This program is free software: you can redistribute it and/or modify ; This program is free software: you can redistribute it and/or modify
; it under the terms of the GNU General Public License as published by ; it under the terms of the GNU General Public License as published by
@ -905,7 +905,7 @@ proc mpint_grow uses eax edi ecx, dst, length ;//////////////////////////////;;
endp endp
;;===========================================================================;; ;;===========================================================================;;
proc mpint_mul uses eax ebx ecx edx esi edi, dst, a, b ;///////////////////////;; proc mpint_mul uses eax ebx ecx edx esi edi, dst, a, b ;/////////////////////;;
;;---------------------------------------------------------------------------;; ;;---------------------------------------------------------------------------;;
;? Multiply a little endian MPINT with another little endian MPINT and store ;; ;? Multiply a little endian MPINT with another little endian MPINT and store ;;
;? in a third one. ;; ;? in a third one. ;;
@ -937,34 +937,34 @@ endl
jnz .adjust_needed jnz .adjust_needed
.length_ok: .length_ok:
; Must have Asize >= Bsize. ; Must have a size >= b size.
cmp ebx, ecx cmp ebx, ecx
ja .swap_a_b ja .swap_a_b
.conditions_ok: .conditions_ok:
; D size will be A size + B size ; dst size will be a size + b size
lea eax, [ebx + ecx] lea eax, [ebx + ecx]
cmp eax, MPINT_MAX_LEN cmp eax, MPINT_MAX_LEN
ja .ovf ja .ovf
; [Asize] = number of dwords in x ; [asize] = number of dwords in a
shr ecx, 2 shr ecx, 2
jz .zero jz .zero
mov [asize], ecx mov [asize], ecx
; esi = x ptr ; esi = x ptr
add esi, 4 add esi, 4
; [Bsize] = number of dwords in y ; [bsize] = number of dwords in b
shr ebx, 2 shr ebx, 2
jz .zero jz .zero
mov [bsize], ebx mov [bsize], ebx
; edx = y ptr (temporarily) ; edx = b ptr (temporarily)
add edx, 4 add edx, 4
; store D size ; store dst size
mov edi, [dst] mov edi, [dst]
mov [edi], eax mov [edi], eax
; edi = D ptr ; edi = dst ptr
add edi, 4 add edi, 4
; Use esp as frame pointer instead of ebp ; Use esp as frame pointer instead of ebp
@ -972,30 +972,30 @@ endl
mov [esp_], esp mov [esp_], esp
mov esp, ebp mov esp, ebp
; ebp = B ptr ; ebp = b ptr
mov ebp, edx mov ebp, edx
; Do the first multiplication ; Do the first multiplication
mov eax, [esi] ; load A[0] mov eax, [esi] ; load a[0]
mul dword[ebp] ; multiply by B[0] mul dword[ebp] ; multiply by b[0]
mov [edi], eax ; store to D[0] mov [edi], eax ; store to dest[0]
; mov ecx, [Asize] ; Asize ; mov ecx, [asize] ; asize
dec ecx ; if Asize = 1, Bsize = 1 too dec ecx ; if asize = 1, bsize = 1 too
jz .done jz .done
; Prepare to enter loop1 ; Prepare to enter loop1
mov eax, [asize-ebp+esp] mov eax, [asize-ebp+esp]
mov ebx, edx mov ebx, edx
lea esi, [esi + eax * 4] ; make A ptr point at end lea esi, [esi + eax * 4] ; make a ptr point at end
lea edi, [edi + eax * 4] ; offset D ptr by Asize lea edi, [edi + eax * 4] ; offset dst ptr by asize
neg ecx ; negate j size/index for inner loop neg ecx ; negate j size/index for inner loop
xor eax, eax ; clear carry xor eax, eax ; clear carry
align 8 align 8
.loop1: .loop1:
adc ebx, 0 adc ebx, 0
mov eax, [esi + ecx * 4] ; load next dword at A[j] mov eax, [esi + ecx * 4] ; load next dword at a[j]
mul dword[ebp] mul dword[ebp]
add eax, ebx add eax, ebx
mov [edi + ecx * 4], eax mov [edi + ecx * 4], eax
@ -1009,10 +1009,10 @@ align 8
add edi, 4 ; increment dst add edi, 4 ; increment dst
dec eax dec eax
jz .skip jz .skip
mov [counter-ebp+esp], eax ; set index i to Bsize mov [counter-ebp+esp], eax ; set index i to bsize
.outer: .outer:
add ebp, 4 ; make ebp point to next B dword add ebp, 4 ; make ebp point to next b dword
mov ecx, [asize-ebp+esp] mov ecx, [asize-ebp+esp]
neg ecx neg ecx
xor ebx, ebx xor ebx, ebx
@ -1047,7 +1047,7 @@ align 8
ret ret
.done: .done:
mov [edi+4], edx ; store to D[1] mov [edi+4], edx ; store to dst[1]
; restore esp, ebp ; restore esp, ebp
mov ebp, esp mov ebp, esp
mov esp, [esp_] mov esp, [esp_]
@ -1151,7 +1151,7 @@ endl
endp endp
;;===========================================================================;; ;;===========================================================================;;
proc mpint_modexp uses edi eax ebx ecx edx, dst, b, e, m ;///////////////////;; proc mpint_modexp uses edi eax ebx ecx edx, dest, b, e, m ;//////////////////;;
;;---------------------------------------------------------------------------;; ;;---------------------------------------------------------------------------;;
;? Find the modulo (remainder after division) of dst by mod. ;; ;? Find the modulo (remainder after division) of dst by mod. ;;
;;---------------------------------------------------------------------------;; ;;---------------------------------------------------------------------------;;
@ -1164,10 +1164,12 @@ proc mpint_modexp uses edi eax ebx ecx edx, dst, b, e, m ;///////////////////;;
;;===========================================================================;; ;;===========================================================================;;
locals locals
mpint_tmp rb MPINT_MAX_LEN+4 dst1 dd ?
dst2 dd ?
tmp rb MPINT_MAX_LEN+4
endl endl
DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dst], [b], [e], [m] DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dest], [b], [e], [m]
; If mod is zero, return ; If mod is zero, return
stdcall mpint_bytes, [m] stdcall mpint_bytes, [m]
@ -1185,6 +1187,12 @@ endl
mov edi, [e] mov edi, [e]
lea edi, [edi + 4 + ecx - 1] lea edi, [edi + 4 + ecx - 1]
; Set up temp variables
lea eax, [tmp]
mov edx, [dest]
mov [dst1], eax
mov [dst2], edx
; Find the highest order bit in this byte ; Find the highest order bit in this byte
mov al, [edi] mov al, [edi]
test al, al test al, al
@ -1195,28 +1203,32 @@ endl
shl al, 1 shl al, 1
jnc @r jnc @r
; Make pointer to tmp mpint for convenient access
lea edx, [mpint_tmp]
; Initialise result to base, to take care of the highest order bit ; Initialise result to base, to take care of the highest order bit
stdcall mpint_mov, [dst], [b] stdcall mpint_mov, [dst1], [b]
dec bl dec bl
jz .next_byte jz .next_byte
.bit_loop: .bit_loop:
; For each bit, square result ; For each bit, square result
stdcall mpint_mov, edx, [dst] stdcall mpint_mul, [dst2], [dst1], [dst1]
stdcall mpint_mul, [dst], edx, edx stdcall mpint_mod, [dst2], [m]
stdcall mpint_mod, [dst], [m]
; If the bit is set, multiply result by the base ; If the bit is set, multiply result by the base
shl al, 1 shl al, 1
jnc .next_bit jnc .bit_zero
stdcall mpint_mov, edx, [dst] stdcall mpint_mul, [dst1], [b], [dst2]
stdcall mpint_mul, [dst], [b], edx stdcall mpint_mod, [dst1], [m]
stdcall mpint_mod, [dst], [m]
.next_bit:
dec bl dec bl
jnz .bit_loop jnz .bit_loop
jmp .next_byte
.bit_zero:
mov edx, [dst1]
mov esi, [dst2]
mov [dst2], edx
mov [dst1], esi
dec bl
jnz .bit_loop
.next_byte: .next_byte:
dec ecx dec ecx
jz .done jz .done
@ -1225,19 +1237,25 @@ endl
mov bl, 8 mov bl, 8
jmp .bit_loop jmp .bit_loop
.done: .done:
mov edx, [dest]
cmp edx, [dst1]
je @f
stdcall mpint_mov, [dest], [dst1]
@@:
ret ret
.mod_zero: .mod_zero:
DEBUGF 3, "modexp with modulo 0\n" DEBUGF 3, "modexp with modulo 0\n"
; if mod is zero, result = 0 ; if mod is zero, result = 0
mov eax, [dst] mov eax, [dest]
mov dword[eax], 0 mov dword[eax], 0
ret ret
.exp_zero: .exp_zero:
DEBUGF 3, "modexp with exponent 0\n" DEBUGF 3, "modexp with exponent 0\n"
; if exponent is zero, result = 1 ; if exponent is zero, result = 1
mov eax, [dst] mov eax, [dest]
mov dword[eax], 1 mov dword[eax], 1
mov byte[eax+4], 1 mov byte[eax+4], 1
ret ret
@ -1254,3 +1272,4 @@ endl
endp endp