From 199ad2d9a494fc6641250b28d66d314b632666c2 Mon Sep 17 00:00:00 2001 From: hidnplayr Date: Tue, 5 Mar 2024 19:57:16 +0000 Subject: [PATCH] Small speedup in modular exponentation routine (still not side channel resiliant) git-svn-id: svn://kolibrios.org@9985 a494cfbc-eb01-0410-851d-a64ba20cac60 --- programs/network/ssh/mpint.inc | 95 ++++++++++++++++++++-------------- 1 file changed, 57 insertions(+), 38 deletions(-) diff --git a/programs/network/ssh/mpint.inc b/programs/network/ssh/mpint.inc index c56168094d..153aac4c34 100644 --- a/programs/network/ssh/mpint.inc +++ b/programs/network/ssh/mpint.inc @@ -1,6 +1,6 @@ ; mpint.inc - Multi precision integer procedures ; -; Copyright (C) 2015-2021 Jeffrey Amelynck +; Copyright (C) 2015-2024 Jeffrey Amelynck ; ; This program is free software: you can redistribute it and/or modify ; it under the terms of the GNU General Public License as published by @@ -905,7 +905,7 @@ proc mpint_grow uses eax edi ecx, dst, length ;//////////////////////////////;; endp ;;===========================================================================;; -proc mpint_mul uses eax ebx ecx edx esi edi, dst, a, b ;///////////////////////;; +proc mpint_mul uses eax ebx ecx edx esi edi, dst, a, b ;/////////////////////;; ;;---------------------------------------------------------------------------;; ;? Multiply a little endian MPINT with another little endian MPINT and store ;; ;? in a third one. ;; @@ -937,34 +937,34 @@ endl jnz .adjust_needed .length_ok: -; Must have Asize >= Bsize. +; Must have a size >= b size. cmp ebx, ecx ja .swap_a_b .conditions_ok: -; D size will be A size + B size +; dst size will be a size + b size lea eax, [ebx + ecx] cmp eax, MPINT_MAX_LEN ja .ovf -; [Asize] = number of dwords in x +; [asize] = number of dwords in a shr ecx, 2 jz .zero mov [asize], ecx ; esi = x ptr add esi, 4 -; [Bsize] = number of dwords in y +; [bsize] = number of dwords in b shr ebx, 2 jz .zero mov [bsize], ebx -; edx = y ptr (temporarily) +; edx = b ptr (temporarily) add edx, 4 -; store D size +; store dst size mov edi, [dst] mov [edi], eax -; edi = D ptr +; edi = dst ptr add edi, 4 ; Use esp as frame pointer instead of ebp @@ -972,30 +972,30 @@ endl mov [esp_], esp mov esp, ebp -; ebp = B ptr +; ebp = b ptr mov ebp, edx ; Do the first multiplication - mov eax, [esi] ; load A[0] - mul dword[ebp] ; multiply by B[0] - mov [edi], eax ; store to D[0] -; mov ecx, [Asize] ; Asize - dec ecx ; if Asize = 1, Bsize = 1 too + mov eax, [esi] ; load a[0] + mul dword[ebp] ; multiply by b[0] + mov [edi], eax ; store to dest[0] +; mov ecx, [asize] ; asize + dec ecx ; if asize = 1, bsize = 1 too jz .done ; Prepare to enter loop1 mov eax, [asize-ebp+esp] mov ebx, edx - lea esi, [esi + eax * 4] ; make A ptr point at end - lea edi, [edi + eax * 4] ; offset D ptr by Asize + lea esi, [esi + eax * 4] ; make a ptr point at end + lea edi, [edi + eax * 4] ; offset dst ptr by asize neg ecx ; negate j size/index for inner loop xor eax, eax ; clear carry align 8 .loop1: adc ebx, 0 - mov eax, [esi + ecx * 4] ; load next dword at A[j] + mov eax, [esi + ecx * 4] ; load next dword at a[j] mul dword[ebp] add eax, ebx mov [edi + ecx * 4], eax @@ -1009,10 +1009,10 @@ align 8 add edi, 4 ; increment dst dec eax jz .skip - mov [counter-ebp+esp], eax ; set index i to Bsize + mov [counter-ebp+esp], eax ; set index i to bsize .outer: - add ebp, 4 ; make ebp point to next B dword + add ebp, 4 ; make ebp point to next b dword mov ecx, [asize-ebp+esp] neg ecx xor ebx, ebx @@ -1047,7 +1047,7 @@ align 8 ret .done: - mov [edi+4], edx ; store to D[1] + mov [edi+4], edx ; store to dst[1] ; restore esp, ebp mov ebp, esp mov esp, [esp_] @@ -1151,7 +1151,7 @@ endl endp ;;===========================================================================;; -proc mpint_modexp uses edi eax ebx ecx edx, dst, b, e, m ;///////////////////;; +proc mpint_modexp uses edi eax ebx ecx edx, dest, b, e, m ;//////////////////;; ;;---------------------------------------------------------------------------;; ;? Find the modulo (remainder after division) of dst by mod. ;; ;;---------------------------------------------------------------------------;; @@ -1164,10 +1164,12 @@ proc mpint_modexp uses edi eax ebx ecx edx, dst, b, e, m ;///////////////////;; ;;===========================================================================;; locals - mpint_tmp rb MPINT_MAX_LEN+4 + dst1 dd ? + dst2 dd ? + tmp rb MPINT_MAX_LEN+4 endl - DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dst], [b], [e], [m] + DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dest], [b], [e], [m] ; If mod is zero, return stdcall mpint_bytes, [m] @@ -1185,6 +1187,12 @@ endl mov edi, [e] lea edi, [edi + 4 + ecx - 1] + ; Set up temp variables + lea eax, [tmp] + mov edx, [dest] + mov [dst1], eax + mov [dst2], edx + ; Find the highest order bit in this byte mov al, [edi] test al, al @@ -1195,28 +1203,32 @@ endl shl al, 1 jnc @r - ; Make pointer to tmp mpint for convenient access - lea edx, [mpint_tmp] - ; Initialise result to base, to take care of the highest order bit - stdcall mpint_mov, [dst], [b] + stdcall mpint_mov, [dst1], [b] dec bl jz .next_byte .bit_loop: ; For each bit, square result - stdcall mpint_mov, edx, [dst] - stdcall mpint_mul, [dst], edx, edx - stdcall mpint_mod, [dst], [m] + stdcall mpint_mul, [dst2], [dst1], [dst1] + stdcall mpint_mod, [dst2], [m] ; If the bit is set, multiply result by the base shl al, 1 - jnc .next_bit - stdcall mpint_mov, edx, [dst] - stdcall mpint_mul, [dst], [b], edx - stdcall mpint_mod, [dst], [m] - .next_bit: + jnc .bit_zero + stdcall mpint_mul, [dst1], [b], [dst2] + stdcall mpint_mod, [dst1], [m] dec bl jnz .bit_loop + jmp .next_byte + + .bit_zero: + mov edx, [dst1] + mov esi, [dst2] + mov [dst2], edx + mov [dst1], esi + dec bl + jnz .bit_loop + .next_byte: dec ecx jz .done @@ -1225,19 +1237,25 @@ endl mov bl, 8 jmp .bit_loop .done: + mov edx, [dest] + cmp edx, [dst1] + je @f + stdcall mpint_mov, [dest], [dst1] + @@: + ret .mod_zero: DEBUGF 3, "modexp with modulo 0\n" ; if mod is zero, result = 0 - mov eax, [dst] + mov eax, [dest] mov dword[eax], 0 ret .exp_zero: DEBUGF 3, "modexp with exponent 0\n" ; if exponent is zero, result = 1 - mov eax, [dst] + mov eax, [dest] mov dword[eax], 1 mov byte[eax+4], 1 ret @@ -1254,3 +1272,4 @@ endl endp +