From ffdfe97ebd6a401c7f21dcc02f991213fc79fcbb Mon Sep 17 00:00:00 2001 From: hidnplayr Date: Mon, 26 Jul 2021 09:44:08 +0000 Subject: [PATCH] MPINT: Less fiddling with bits and bytes, more performance. git-svn-id: svn://kolibrios.org@9090 a494cfbc-eb01-0410-851d-a64ba20cac60 --- programs/network/ssh/mpint.inc | 766 +++++++++++++++++++-------- programs/network/ssh/test/modexp.asm | 6 +- programs/network/ssh/test/mpint.asm | 8 +- 3 files changed, 560 insertions(+), 220 deletions(-) diff --git a/programs/network/ssh/mpint.inc b/programs/network/ssh/mpint.inc index c69c6ab0db..c56168094d 100644 --- a/programs/network/ssh/mpint.inc +++ b/programs/network/ssh/mpint.inc @@ -15,9 +15,9 @@ ; You should have received a copy of the GNU General Public License ; along with this program. If not, see . -; Notes: +; Note: ; -; These procedures work only with positive integers. +; These procedures have been designed to work with unsigned integers. ; For compatibility reasons, the highest bit must always be 0. ; ; You have been warned! @@ -181,7 +181,6 @@ proc mpint_bits uses esi ecx, dst ;//////////////////////////////////////////;; endp - ;;===========================================================================;; proc mpint_bytes uses esi, dst ;/////////////////////////////////////////////;; ;;---------------------------------------------------------------------------;; @@ -212,7 +211,7 @@ proc mpint_bytes uses esi, dst ;/////////////////////////////////////////////;; endp ;;===========================================================================;; -proc mpint_cmp uses esi edi ecx eax, src, dst ;//////////////////////////////;; +proc mpint_cmp uses esi edi edx ecx ebx eax, src, dst ;//////////////////////;; ;;---------------------------------------------------------------------------;; ;? Compare two MPINTS. ;; ;;---------------------------------------------------------------------------;; @@ -238,36 +237,44 @@ proc mpint_cmp uses esi edi ecx eax, src, dst ;//////////////////////////////;; ; Start comparing from the MSB towards the LSB mov esi, [src] mov edi, [dst] - add esi, ecx - add edi, ecx - add esi, 4 - add edi, 4 - std + lea esi, [esi + 4 + ecx] + lea edi, [edi + 4 + ecx] ; If remaining bytes is not divisible by 4, compare only one byte at a time - .do_byte: - test ecx, 1b - jz .do_dword + .loop_1: + test ecx, 111b + jz .done_1 dec esi dec edi mov al, byte[esi] cmp al, byte[edi] jne .got_answer dec ecx -; Remaining bytes is divisable by 4, compare dwords - .do_dword: - shr ecx, 2 + jmp .loop_1 +; Remaining bytes is divisable by 8, compare dwords + .done_1: + shr ecx, 3 jz .got_answer - sub esi, 4 - sub edi, 4 - repe cmpsd + +align 8 + .loop_8: + lea esi, [esi-8] + lea edi, [edi-8] + mov eax, [esi+04] + mov ebx, [esi+00] + cmp eax, [edi+04] + jne .got_answer + cmp ebx, [edi+00] + jne .got_answer + dec ecx + jnz .loop_8 + .got_answer: - cld ret endp ;;===========================================================================;; -proc mpint_mov uses esi edi ecx, dst, src ;//////////////////////////////////;; +proc mpint_mov uses esi edi edx ecx ebx eax, dst, src ;//////////////////////;; ;;---------------------------------------------------------------------------;; ;? Copy MPINT. ;; ;;---------------------------------------------------------------------------;; @@ -281,23 +288,84 @@ proc mpint_mov uses esi edi ecx, dst, src ;//////////////////////////////////;; mov esi, [src] mov edi, [dst] - mov ecx, [esi] - push ecx - shr ecx, 2 - inc ecx ; for length dword - rep movsd - pop ecx - and ecx, 11b - jz @f - rep movsb - @@: + mov ecx, [esi] ; Get dword count + 1 + add ecx, 7 ; + shr ecx, 2 ; + mov edx, ecx + + shr ecx, 3 + test ecx, ecx + jz .no32 +align 8 + .loop32: + mov eax, [esi+00] + mov ebx, [esi+04] + mov [edi+00], eax + mov [edi+04], ebx + + mov eax, [esi+08] + mov ebx, [esi+12] + mov [edi+08], eax + mov [edi+12], ebx + + mov eax, [esi+16] + mov ebx, [esi+20] + mov [edi+16], eax + mov [edi+20], ebx + + mov eax, [esi+24] + mov ebx, [esi+28] + mov [edi+24], eax + mov [edi+28], ebx + + lea esi, [esi+32] + lea edi, [edi+32] + dec ecx + jnz .loop32 + .no32: + + test edx, 100b + jz .no16 + + mov eax, [esi+00] + mov ebx, [esi+04] + mov [edi+00], eax + mov [edi+04], ebx + + mov eax, [esi+08] + mov ebx, [esi+12] + mov [edi+08], eax + mov [edi+12], ebx + + lea esi, [esi+16] + lea edi, [edi+16] + .no16: + + test edx, 010b + jz .no8 + + mov eax, [esi+00] + mov ebx, [esi+04] + mov [edi+00], eax + mov [edi+04], ebx + + lea esi, [esi+08] + lea edi, [edi+08] + .no8: + + test edx, 001b + jz .no4 + + mov eax, [esi+00] + mov [edi+00], eax + .no4: ret endp ;;===========================================================================;; -proc mpint_shl1 uses esi ecx, dst ;//////////////////////////////////////////;; +proc mpint_shl1 uses edi ecx, dst ;//////////////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Shift little endian MPINT one bit to the left. ;; ;;---------------------------------------------------------------------------;; @@ -308,35 +376,59 @@ proc mpint_shl1 uses esi ecx, dst ;//////////////////////////////////////////;; DEBUGF 1, "mpint_shl1(0x%x)\n", [dst] - mov esi, [dst] - mov ecx, [esi] - test ecx, ecx + mov edi, [dst] + mov ecx, [edi] + test ecx, 11b + jnz .adjust_needed + shr ecx, 2 jz .done -; Test if high order byte will overflow -; Remember: highest bit must never be set for positive numbers! - test byte[esi+ecx+3], 11000000b - jz @f -; We must grow a byte in size! -; TODO: check for overflow - inc ecx - mov [esi], ecx - mov byte[esi+ecx+3], 0 ; Add the new MSB - @@: - add esi, 4 -; Do the lowest order byte first - shl byte[esi], 1 + .length_ok: + add edi, 4 +; Do the lowest order dword first + shl dword[edi], 1 + lea edi, [edi+4] dec ecx jz .done -; And the remaining bytes - @@: - inc esi - rcl byte[esi], 1 +; And the remaining dwords + .loop: + rcl dword[edi], 1 + lea edi, [edi+4] dec ecx - jnz @r + jnz .loop .done: + jc .carry + test dword[edi-4], 0x10000000 + jnz .add0 ret + .carry: + mov ecx, [dst] + cmp dword[ecx], MPINT_MAX_LEN + je .ovf + add dword[ecx], 4 + mov dword[edi], 1 + ret + + .add0: + mov ecx, [dst] + cmp dword[ecx], MPINT_MAX_LEN + je .ovf + add dword[ecx], 4 + mov dword[edi], 0 + ret + + .ovf: + int3 +;;;; + ret + + .adjust_needed: + add ecx, 3 + and ecx, not 3 + stdcall mpint_grow, edi, ecx + jmp .length_ok + endp ;;===========================================================================;; @@ -353,24 +445,33 @@ proc mpint_shr1 uses edi ecx, dst ;//////////////////////////////////////////;; mov edi, [dst] mov ecx, [edi] - test ecx, ecx - jz .done + test ecx, 11b + jnz .adjust_needed -; Do the highest order byte first - add edi, 4-1 - add edi, ecx - shr byte[edi], 1 + .length_ok: + shr ecx, 2 + jz .done + lea edi, [edi+ecx*4] +; Do the highest order dword first + shr dword[edi], 1 + lea edi, [edi-4] dec ecx jz .done -; Now do the trailing bytes - @@: - dec edi - rcr byte[edi], 1 - dec ecx ; does not affect carry flag, hooray! - jnz @r +; And the remaining dwords + .loop: + rcr dword[edi], 1 + lea edi, [edi-4] + dec ecx + jnz .loop .done: ret + .adjust_needed: + add ecx, 3 + and ecx, not 3 + stdcall mpint_grow, edi, ecx + jmp .length_ok + endp ;;===========================================================================;; @@ -465,7 +566,7 @@ proc mpint_shlmov uses eax ebx ecx edx esi edi, dst, src, shift ;////////////;; mov [edi], eax cmp eax, MPINT_MAX_LEN - jae .overflow ;;;; + jae .overflow mov esi, [src] add esi, MPINT_MAX_LEN+4-4 @@ -514,7 +615,7 @@ proc mpint_shlmov uses eax ebx ecx edx esi edi, dst, src, shift ;////////////;; endp ;;===========================================================================;; -proc mpint_add uses esi edi ecx eax, dst, src ;//////////////////////////////;; +proc mpint_add uses esi edi edx ecx ebx eax, dst, src ;//////////////////////;; ;;---------------------------------------------------------------------------;; ;? Add a little endian MPINT to another little endian MPINT. ;; ;;---------------------------------------------------------------------------;; @@ -524,78 +625,131 @@ proc mpint_add uses esi edi ecx eax, dst, src ;//////////////////////////////;; ;< dst = dst + src ;; ;;===========================================================================;; +locals + dd_cnt dd ? +endl + DEBUGF 1, "mpint_add(0x%x, 0x%x)\n", [dst], [src] +; Grow both numbers to same 4-byte boundary, if not already the case mov esi, [src] mov edi, [dst] - stdcall mpint_bytes, esi - mov ecx, eax - stdcall mpint_bytes, edi - cmp ecx, eax - jb .grow_src - ja .grow_dst - test ecx, ecx - jz .done + mov ecx, [esi] + test ecx, 11b + jnz .adjust_needed + cmp ecx, [edi] + jne .adjust_needed +; Do the additions .length_ok: - push ecx add esi, 4 add edi, 4 -; Add the first byte - lodsb - add byte[edi], al - dec ecx + shr ecx, 2 jz .done -; Add the other bytes - @@: - inc edi - lodsb - adc byte[edi], al + mov eax, ecx + and eax, 111b + mov [dd_cnt], eax + shr ecx, 3 + test ecx, ecx ; Clear carry flag + jz .no32 + + .loop32: + mov eax, [esi+00] + mov edx, [esi+04] + adc [edi+00], eax + adc [edi+04], edx + + mov eax, [esi+08] + mov edx, [esi+12] + adc [edi+08], eax + adc [edi+12], edx + + mov eax, [esi+16] + mov edx, [esi+20] + adc [edi+16], eax + adc [edi+20], edx + + mov eax, [esi+24] + mov edx, [esi+28] + adc [edi+24], eax + adc [edi+28], edx + + lea esi, [esi + 32] + lea edi, [edi + 32] dec ecx - jnz @r + jnz .loop32 + + .no32: + mov ecx, [dd_cnt] + dec ecx + js .check_ovf + inc ecx + .dword_loop: + mov eax, [esi+0] + adc [edi+0], eax + lea esi, [esi + 4] + lea edi, [edi + 4] + dec ecx + jnz .dword_loop + + .check_ovf: + jc .add_1 ; Carry + test byte[edi-1], 0x80 + jnz .add_0 ; Currently highest bit set .done: - -; check if highest bit OR carry flag is set -; if so, add a byte if we have the buffer space -; TODO: check if we have the buffer space - pop ecx - jc .carry - cmp byte[edi], 0x80 - jnz .high_bit_set - ret - .carry: - mov eax, [dst] - cmp [eax], ecx - ja @f - inc dword[eax] - @@: - mov byte[edi+1], 1 +; Highest bit was set, add a 0 byte as MSB if possible + .add_0: + mov ecx, [dst] + cmp dword[ecx], MPINT_MAX_LEN + jae .ovf_0 + mov byte[edi], 0 + inc dword[ecx] ret - .high_bit_set: - mov eax, [dst] - cmp [eax], ecx - ja @f - inc dword[eax] - @@: - mov byte[edi+1], 0 + .ovf_0: + int3 + clc +; TODO: set overflow flag? ret - .grow_dst: - stdcall mpint_grow, edi, ecx - jmp .length_ok +; Carry bit was set, add a 1 byte as MSB if possible + .add_1: + mov ecx, [dst] + cmp dword[ecx], MPINT_MAX_LEN + jae .ovf_1 + mov byte[edi], 1 + inc dword[ecx] + ret - .grow_src: - mov ecx, eax + .ovf_1: + int3 + stc +; TODO: set overflow flag? + ret + + .adjust_needed: +; mov ecx, [esi] + mov eax, [edi] +; find the maximum of the two in ecx + mov edx, ecx + sub edx, eax + sbb ebx, ebx + and ebx, edx + sub ecx, ebx +; align to 4 byte boundary + add ecx, 3 + and ecx, not 3 +; adjust both mpints stdcall mpint_grow, esi, ecx + stdcall mpint_grow, edi, ecx jmp .length_ok endp ;;===========================================================================;; -proc mpint_sub uses eax esi edi ecx, dst, src ;//////////////////////////////;; +proc mpint_sub uses esi edi edx ecx ebx eax, dst, src ;//////////////////////;; ;;---------------------------------------------------------------------------;; ;? Subtract a little endian MPINT to another little endian MPINT. ;; ;;---------------------------------------------------------------------------;; @@ -605,49 +759,91 @@ proc mpint_sub uses eax esi edi ecx, dst, src ;//////////////////////////////;; ;< dst = dst - src ;; ;;===========================================================================;; +locals + dd_cnt dd ? +endl + DEBUGF 1, "mpint_sub(0x%x, 0x%x)\n", [dst], [src] +; Grow both numbers to same 4-byte boundary, if not already the case mov esi, [src] mov edi, [dst] - stdcall mpint_bytes, esi - mov ecx, eax - stdcall mpint_bytes, edi - cmp ecx, eax - jb .grow_src - ja .grow_dst - test ecx, ecx - jz .done + mov ecx, [esi] + test ecx, 11b + jnz .adjust_needed + cmp ecx, [edi] + jne .adjust_needed +; Do the subtractions .length_ok: add esi, 4 add edi, 4 -; Subtract the first byte - lodsb - sub byte[edi], al - dec ecx + shr ecx, 2 jz .done -; Subtract the other bytes - @@: - inc edi - lodsb - sbb byte[edi], al + mov eax, ecx + and eax, 111b + mov [dd_cnt], eax + shr ecx, 3 + test ecx, ecx ; Clear carry flag + jz .no32 + + .loop32: + mov eax, [esi+00] + mov edx, [esi+04] + sbb [edi+00], eax + sbb [edi+04], edx + + mov eax, [esi+08] + mov edx, [esi+12] + sbb [edi+08], eax + sbb [edi+12], edx + + mov eax, [esi+16] + mov edx, [esi+20] + sbb [edi+16], eax + sbb [edi+20], edx + + mov eax, [esi+24] + mov edx, [esi+28] + sbb [edi+24], eax + sbb [edi+28], edx + + lea esi, [esi + 32] + lea edi, [edi + 32] dec ecx - jnz @r + jnz .loop32 + + .no32: + mov ecx, [dd_cnt] + dec ecx + js .done + inc ecx + .dword_loop: + mov eax, [esi+0] + sbb [edi+0], eax + lea esi, [esi + 4] + lea edi, [edi + 4] + dec ecx + jnz .dword_loop + .done: ret - .overflow: - mov dword[edi], 0 - stc - ret - - .grow_dst: - stdcall mpint_grow, edi, ecx - jmp .length_ok - - .grow_src: - mov ecx, eax + .adjust_needed: +; mov ecx, [esi] + mov eax, [edi] +; find the maximum of the two in ecx + mov edx, ecx + sub edx, eax + sbb ebx, ebx + and ebx, edx + sub ecx, ebx +; align to 4 byte boundary + add ecx, 3 + and ecx, not 3 +; adjust both mpints stdcall mpint_grow, esi, ecx + stdcall mpint_grow, edi, ecx jmp .length_ok endp @@ -665,19 +861,6 @@ proc mpint_shrink uses eax edi, dst ;////////////////////////////////////////;; DEBUGF 1, "mpint_shrink(0x%x)\n", [dst] -; mov edi, [dst] -; lodsd -; std -; mov ecx, eax -; dec eax ; total length minus one -; add edi, eax -; xor al, al -; repe cmpsb -; inc ecx -; mov edi, [dst] -; mov [edi], ecx -; cld - stdcall mpint_bits, [dst] shr eax, 3 inc eax @@ -688,7 +871,6 @@ proc mpint_shrink uses eax edi, dst ;////////////////////////////////////////;; endp - ;;===========================================================================;; proc mpint_grow uses eax edi ecx, dst, length ;//////////////////////////////;; ;;---------------------------------------------------------------------------;; @@ -723,88 +905,224 @@ proc mpint_grow uses eax edi ecx, dst, length ;//////////////////////////////;; endp ;;===========================================================================;; -proc mpint_mul uses esi edi ecx ebx eax, dst, A, B ;/////////////////////////;; +proc mpint_mul uses eax ebx ecx edx esi edi, dst, a, b ;///////////////////////;; ;;---------------------------------------------------------------------------;; -;? Multiply two little endian MPINTS and store them in a third one. ;; +;? Multiply a little endian MPINT with another little endian MPINT and store ;; +;? in a third one. ;; ;;---------------------------------------------------------------------------;; -;> A = pointer to little endian MPINT ;; -;> B = pointer to little endian MPINT ;; -;> dst = pointer to buffer for little endian MPINT ;; +;> dst = pointer to little endian MPINT ;; +;> a = pointer to little endian MPINT ;; +;> b = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; -;< dst = A * B ;; +;< dst = a * b ;; ;;===========================================================================;; - DEBUGF 1, "mpint_mul(0x%x, 0x%x, 0x%x)\n", [dst], [A], [B] +locals + asize dd ? + bsize dd ? + counter dd ? + esp_ dd ? +endl - ; Set result to zero + DEBUGF 1, "mpint_mul(0x%x, 0x%x, 0x%x)\n", [dst], [a], [b] + +; Grow both numbers to individual 4-byte boundary, if not already the case + mov esi, [a] + mov edx, [b] + mov ecx, [esi] + mov ebx, [edx] + test ecx, 11b + jnz .adjust_needed + test ebx, 11b + jnz .adjust_needed + .length_ok: + +; Must have Asize >= Bsize. + cmp ebx, ecx + ja .swap_a_b + .conditions_ok: + +; D size will be A size + B size + lea eax, [ebx + ecx] + cmp eax, MPINT_MAX_LEN + ja .ovf + +; [Asize] = number of dwords in x + shr ecx, 2 + jz .zero + mov [asize], ecx +; esi = x ptr + add esi, 4 + +; [Bsize] = number of dwords in y + shr ebx, 2 + jz .zero + mov [bsize], ebx +; edx = y ptr (temporarily) + add edx, 4 + +; store D size + mov edi, [dst] + mov [edi], eax +; edi = D ptr + add edi, 4 + +; Use esp as frame pointer instead of ebp +; ! Use the stack with extreme caution from here on ! + mov [esp_], esp + mov esp, ebp + +; ebp = B ptr + mov ebp, edx + +; Do the first multiplication + mov eax, [esi] ; load A[0] + mul dword[ebp] ; multiply by B[0] + mov [edi], eax ; store to D[0] +; mov ecx, [Asize] ; Asize + dec ecx ; if Asize = 1, Bsize = 1 too + jz .done + +; Prepare to enter loop1 + mov eax, [asize-ebp+esp] + + mov ebx, edx + lea esi, [esi + eax * 4] ; make A ptr point at end + lea edi, [edi + eax * 4] ; offset D ptr by Asize + neg ecx ; negate j size/index for inner loop + xor eax, eax ; clear carry + +align 8 + .loop1: + adc ebx, 0 + mov eax, [esi + ecx * 4] ; load next dword at A[j] + mul dword[ebp] + add eax, ebx + mov [edi + ecx * 4], eax + inc ecx + mov ebx, edx + jnz .loop1 + + adc ebx, 0 + mov eax, [bsize-ebp+esp] + mov [edi], ebx ; most significant dword of the product + add edi, 4 ; increment dst + dec eax + jz .skip + mov [counter-ebp+esp], eax ; set index i to Bsize + + .outer: + add ebp, 4 ; make ebp point to next B dword + mov ecx, [asize-ebp+esp] + neg ecx + xor ebx, ebx + + .loop2: + adc ebx, 0 + mov eax, [esi + ecx * 4] + mul dword[ebp] + add eax, ebx + mov ebx, [edi + ecx * 4] + adc edx, 0 + add ebx, eax + mov [edi + ecx * 4], ebx + inc ecx + mov ebx, edx + jnz .loop2 + + adc ebx, 0 + + mov [edi], ebx + add edi, 4 + mov eax, [counter-ebp+esp] + dec eax + mov [counter-ebp+esp], eax + jnz .outer + + .skip: +; restore esp, ebp + mov ebp, esp + mov esp, [esp_] + + ret + + .done: + mov [edi+4], edx ; store to D[1] +; restore esp, ebp + mov ebp, esp + mov esp, [esp_] + + ret + + .ovf: + int3 + + .zero: mov eax, [dst] mov dword[eax], 0 - mov edi, [A] - stdcall mpint_bytes, edi - test eax, eax - jz .zero - add edi, 4-1 - add edi, eax - mov ecx, eax -; Iterate through the bits in A, -; starting from the highest order bit down to the lowest order bit. - .next_byte: - mov al, [edi] - dec edi - mov bl, 8 - .next_bit: - stdcall mpint_shl1, [dst] - shl al, 1 - jnc .zero_bit - stdcall mpint_add, [dst], [B] - .zero_bit: - dec bl - jnz .next_bit - dec ecx - jnz .next_byte - .zero: ret + .adjust_needed: +; align to 4 byte boundary + add ecx, 3 + and ecx, not 3 + add ebx, 3 + and ebx, not 3 +; adjust both mpints + stdcall mpint_grow, esi, ecx + stdcall mpint_grow, edx, ebx + jmp .length_ok + + .swap_a_b: + mov eax, esi + mov esi, edx + mov edx, eax + + mov eax, ebx + mov ebx, ecx + mov ecx, eax + jmp .conditions_ok + endp ;;===========================================================================;; -proc mpint_mod uses eax ebx ecx, dst, mod ;//////////////////////////////////;; +proc mpint_mod uses eax ebx ecx, dst, m ;////////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Find the modulo (remainder after division) of dst by mod. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to little endian MPINT ;; ;> mod = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; -;< dst = dst MOD mod ;; +;< dst = dst MOD m ;; ;;===========================================================================;; - DEBUGF 1, "mpint_mod(0x%x, 0x%x)\n", [dst], [mod] - locals mpint_tmp rb MPINT_MAX_LEN+4 endl - stdcall mpint_cmp, [mod], [dst] - ja .done ; if mod > dst, dst = dst ;;;;;;; + DEBUGF 1, "mpint_mod(0x%x, 0x%x)\n", [dst], [m] + + stdcall mpint_cmp, [m], [dst] + ja .done ; if mod > dst, dst = dst je .zero ; if mod == dst, dst = 0 ; left shift mod until the high order bits of mod and dst are aligned stdcall mpint_bits, [dst] mov ecx, eax - stdcall mpint_bits, [mod] + stdcall mpint_bits, [m] test eax, eax jz .zero ; if mod is zero, return sub ecx, eax lea ebx, [mpint_tmp] - stdcall mpint_shlmov, ebx, [mod], ecx + stdcall mpint_shlmov, ebx, [m], ecx inc ecx ; For every bit in dst (starting from the high order bit): .bitloop: stdcall mpint_cmp, [dst], ebx ; If dst > mpint_tmp - jb @f ;;;;;;;; + jb @f stdcall mpint_sub, [dst], ebx ; dst = dst - mpint_tmp @@: dec ecx @@ -813,16 +1131,27 @@ endl stdcall mpint_shr1, ebx ; mpint = mpint >> 1 jmp .bitloop - .zero: - mov eax, [dst] - mov dword[eax], 0 .done: +; adjust size of dst so it is no larger than mod + mov ebx, [dst] + mov ecx, [ebx] ; current size + mov eax, [m] + mov eax, [eax] ; size of mod + cmp ecx, eax + jb .ret + mov [ebx], eax + .ret: + ret + + .zero: + mov ebx, [dst] + mov dword[ebx], 0 ret endp ;;===========================================================================;; -proc mpint_modexp uses edi eax ebx ecx edx, dst, base, exp, mod ;////////////;; +proc mpint_modexp uses edi eax ebx ecx edx, dst, b, e, m ;///////////////////;; ;;---------------------------------------------------------------------------;; ;? Find the modulo (remainder after division) of dst by mod. ;; ;;---------------------------------------------------------------------------;; @@ -831,26 +1160,29 @@ proc mpint_modexp uses edi eax ebx ecx edx, dst, base, exp, mod ;////////////;; ;> exp = pointer to little endian MPINT ;; ;> mod = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; -;< dst = base ** exp MOD mod ;; +;< dst = b ** e MOD m ;; ;;===========================================================================;; - ;DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dst], [base], [exp], [mod] - locals mpint_tmp rb MPINT_MAX_LEN+4 endl + DEBUGF 1, "mpint_modexp(0x%x, 0x%x, 0x%x, 0x%x)\n", [dst], [b], [e], [m] + ; If mod is zero, return - stdcall mpint_bits, [mod] + stdcall mpint_bytes, [m] test eax, eax jz .mod_zero + test eax, 3 + jnz .grow_mod + .modsize_ok: ; Find highest order byte in exponent - stdcall mpint_bytes, [exp] + stdcall mpint_bytes, [e] test eax, eax jz .exp_zero mov ecx, eax - mov edi, [exp] + mov edi, [e] lea edi, [edi + 4 + ecx - 1] ; Find the highest order bit in this byte @@ -867,21 +1199,21 @@ endl lea edx, [mpint_tmp] ; Initialise result to base, to take care of the highest order bit - stdcall mpint_mov, [dst], [base] + stdcall mpint_mov, [dst], [b] dec bl jz .next_byte .bit_loop: ; For each bit, square result stdcall mpint_mov, edx, [dst] stdcall mpint_mul, [dst], edx, edx - stdcall mpint_mod, [dst], [mod] + stdcall mpint_mod, [dst], [m] ; If the bit is set, multiply result by the base shl al, 1 jnc .next_bit stdcall mpint_mov, edx, [dst] - stdcall mpint_mul, [dst], [base], edx - stdcall mpint_mod, [dst], [mod] + stdcall mpint_mul, [dst], [b], edx + stdcall mpint_mod, [dst], [m] .next_bit: dec bl jnz .bit_loop @@ -893,7 +1225,6 @@ endl mov bl, 8 jmp .bit_loop .done: - ;stdcall mpint_print, [dst] ret .mod_zero: @@ -915,4 +1246,11 @@ endl DEBUGF 3, "modexp: Invalid input!\n" ret -endp \ No newline at end of file + .grow_mod: + add eax, 3 + and eax, not 3 + stdcall mpint_grow, [m], eax + jmp .modsize_ok + +endp + diff --git a/programs/network/ssh/test/modexp.asm b/programs/network/ssh/test/modexp.asm index 34474f1360..ef1383f19a 100644 --- a/programs/network/ssh/test/modexp.asm +++ b/programs/network/ssh/test/modexp.asm @@ -127,7 +127,7 @@ start: mov dword[mpint_B+4], 497 stdcall mpint_cmp, mpint_A, mpint_B stdcall mpint_mod, mpint_A, mpint_B - DEBUGF 1, "mpint_mod(936, 497)\n" + DEBUGF 1, "mpint_mod(1936, 497)\n" stdcall mpint_print, mpint_A mov dword[mpint_A+00], 32 @@ -155,8 +155,8 @@ start: stdcall mpint_mul, mpint_C, mpint_B, mpint_A stdcall mpint_print, mpint_C - stdcall mpint_hob, mpint_C - DEBUGF 1, "mpint_hob(C): %u\n", eax + stdcall mpint_bits, mpint_C + DEBUGF 1, "mpint_bits(C): %u\n", eax mov dword[mpint_A+0], 1 mov dword[mpint_A+4], 3 diff --git a/programs/network/ssh/test/mpint.asm b/programs/network/ssh/test/mpint.asm index 4df4ab7f15..ca608b04f3 100644 --- a/programs/network/ssh/test/mpint.asm +++ b/programs/network/ssh/test/mpint.asm @@ -18,7 +18,7 @@ format binary as "" __DEBUG__ = 1 -__DEBUG_LEVEL__ = 1 +__DEBUG_LEVEL__ = 2 MAX_BITS = 4096 @@ -68,7 +68,7 @@ cmptestctr = cmptestctr + 1 start: - DEBUGF 1, "MPINT Test suite\n" + DEBUGF 3, "MPINT Test suite\n" ; First, do some checks on the compare routine cmptesteq mpint_0_0, mpint_0_0 @@ -216,7 +216,7 @@ endg include "tests.inc" - DEBUGF 1, "All tests completed\n" + DEBUGF 3, "All tests completed\n" mcall -1 @@ -224,6 +224,8 @@ IncludeIGlobals i_end: +starttime dq ? + mpint_tmp rb MPINT_MAX_LEN+4 include_debug_strings