; mpint.inc - Multi precision integer procedures ; ; Copyright (C) 2015-2017 Jeffrey Amelynck ; ; This program is free software: you can redistribute it and/or modify ; it under the terms of the GNU General Public License as published by ; the Free Software Foundation, either version 3 of the License, or ; (at your option) any later version. ; ; This program is distributed in the hope that it will be useful, ; but WITHOUT ANY WARRANTY; without even the implied warranty of ; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ; GNU General Public License for more details. ; ; You should have received a copy of the GNU General Public License ; along with this program. If not, see <http://www.gnu.org/licenses/>. ; Notes: ; ; These procedures work only with positive integers. ; For compatibility reasons, the highest bit must always be 0. ; However, leading 0 bytes MUST at all other times be omitted. ; ; You have been warned! MPINT_MAX_LEN = MAX_BITS/8 ;;===========================================================================;; proc mpint_to_little_endian uses esi edi ecx ;///////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Convert big endian MPINT to little endian MPINT. ;; ;;---------------------------------------------------------------------------;; ;> esi = pointer to big endian MPINT ;; ;> edi = pointer to buffer for little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< eax = MPINT number length ;; ;;===========================================================================;; ; Load length dword lodsd ; Convert to little endian bswap eax stosd test eax, eax jz .zero ; Copy data, convert to little endian meanwhile push eax add esi, eax push esi dec esi mov ecx, eax std @@: lodsb mov byte[edi], al inc edi dec ecx jnz @r cld pop esi eax .zero: ret endp ;;===========================================================================;; proc mpint_to_big_endian uses esi edi ecx ;//////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Convert little endian MPINT to big endian MPINT. ;; ;;---------------------------------------------------------------------------;; ;> esi = pointer to little endian MPINT ;; ;> edi = pointer to buffer for big endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< eax = MPINT number length ;; ;;===========================================================================;; ; Load length dword lodsd test eax, eax jz .zero mov ecx, eax add esi, eax dec esi push eax ; we'll return length to the caller later bswap eax stosd ; Copy data, convert to big endian meanwhile std @@: lodsb mov byte[edi], al inc edi dec ecx jnz @r cld pop eax ret .zero: stosd ; Number 0 has 0 data bytes ret endp ;;===========================================================================;; proc mpint_print uses ecx esi eax, src ;/////////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Print MPINT to the debug board. ;; ;;---------------------------------------------------------------------------;; ;> src = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< - ;; ;;===========================================================================;; DEBUGF 1, "0x" mov esi, [src] mov ecx, [esi] test ecx, ecx jz .zero lea esi, [esi + ecx + 4 - 1] pushf std .loop: lodsb DEBUGF 1, "%x", eax:2 dec ecx jnz .loop DEBUGF 1, "\n" popf ret .zero: DEBUGF 1, "00\n" ret endp ;;===========================================================================;; proc mpint_hob uses edi ecx eax, dst ;///////////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Return an index number giving the position of the highest order bit. ;; ;;---------------------------------------------------------------------------;; ;> src = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< eax = highest order bit number ;; ;;===========================================================================;; mov edi, [dst] lodsd dec eax ; total length minus one mov cl, [edi+eax] ; load the highest order byte shl eax, 3 ; multiply eax by 8 to get nr of bits ; Now shift bits of the highest order byte right, until the byte reaches zero, counting bits meanwhile test cl, cl jz .end @@: inc eax shr cl, 1 jnz @r .end: ret endp ;;===========================================================================;; proc mpint_cmp uses esi edi ecx eax, dst, src ;//////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Compare two mpints. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to little endian MPINT ;; ;> src = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< flags are set as for single precision CMP instruction ;; ;;===========================================================================;; ; First, check if number of significant bytes is the same ; If not, number with more bytes is bigger mov esi, [src] mov edi, [dst] mov ecx, [esi] cmp ecx, [edi] jne .got_answer ; Numbers have equal amount of bytes, compare starting from the high order byte add edi, ecx add esi, ecx std .do_byte: test ecx, 11b jz .do_dword dec esi dec edi cmpsb jne .got_answer dec ecx jmp .do_byte .do_dword: shr ecx, 2 jz .got_answer sub esi, 4 sub edi, 4 repe cmpsd .got_answer: cld ret endp ;;===========================================================================;; proc mpint_mov uses esi edi ecx, dst, src ;//////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Copy mpint. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to buffer for little endian MPINT ;; ;> src = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = src ;; ;;===========================================================================;; mov esi, [src] mov edi, [dst] mov ecx, [esi] push ecx shr ecx, 2 inc ecx ; for length dword rep movsd pop ecx and ecx, 11b jz @f rep movsb @@: ret endp ;;===========================================================================;; proc mpint_shl1 uses esi ecx, dst ;//////////////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Shift little endian MPINT one bit to the left. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = dst SHL 1 ;; ;;===========================================================================;; mov esi, [dst] mov ecx, [esi] test ecx, ecx jz .done ; Test if high order byte will overflow ; Remember: highest bit must never be set for positive numbers! test byte[esi+ecx+3], 11000000b jz @f ; We must grow a byte in size! ; TODO: check for overflow inc ecx mov [esi], ecx mov byte[esi+ecx+3], 0 ; Add the new MSB @@: add esi, 4 ; Do the lowest order byte first shl byte[esi], 1 dec ecx jz .done ; And the remaining bytes @@: inc esi rcl byte[esi], 1 dec ecx jnz @r .done: ret endp ;;===========================================================================;; proc mpint_shr1 uses edi ecx, dst ;//////////////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Shift little endian MPINT one bit to the right. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = dst SHR 1 ;; ;;===========================================================================;; mov edi, [dst] mov ecx, [edi] test ecx, ecx jz .done ; Do the highest order byte first dec ecx shr byte[edi+ecx+3], 1 ; Was it 0? If so, we must decrement total length jnz @f jc @f mov [edi], ecx @@: test ecx, ecx jz .done ; Now do the trailing bytes add edi, 4 add edi, ecx @@: dec edi rcr byte[edi], 1 dec ecx ; does not affect carry flag, hooray! jnz @r .done: ret endp ;;===========================================================================;; proc mpint_shl uses eax ebx ecx edx esi edi, dst, shift ;////////////////////;; ;;---------------------------------------------------------------------------;; ;? Left shift little endian MPINT by x bits. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to little endian MPINT ;; ;> shift = number of bits to shift the MPINT ;; ;;---------------------------------------------------------------------------;; ;< - ;; ;;===========================================================================;; mov ecx, [shift] shr ecx, 3 ; 8 bits in one byte cmp ecx, MPINT_MAX_LEN jge .zero mov esi, [dst] add esi, MPINT_MAX_LEN+4-4 mov edi, esi and ecx, not 11b sub esi, ecx mov edx, MPINT_MAX_LEN/4-1 shr ecx, 2 ; 4 bytes in one dword push ecx sub edx, ecx mov ecx, [shift] and ecx, 11111b std .loop: lodsd mov ebx, [esi] shld eax, ebx, cl stosd dec edx jnz .loop lodsd shl eax, cl stosd ; fill the lsb bytes with zeros pop ecx test ecx, ecx jz @f xor eax, eax rep stosd @@: cld ret .zero: mov eax, [dst] mov dword[eax], 0 ret endp ;;===========================================================================;; proc mpint_shlmov uses eax ebx ecx edx esi edi, dst, src, shift ;////////////;; ;;---------------------------------------------------------------------------;; ;? Left shift by x bits and copy little endian MPINT. ;; ;;---------------------------------------------------------------------------;; ;> src = pointer to little endian MPINT ;; ;> dst = pointer to little endian MPINT ;; ;> shift = number of bits to shift the MPINT to the left ;; ;;---------------------------------------------------------------------------;; ;< dst = src SHL shift ;; ;;===========================================================================;; mov ecx, [shift] shr ecx, 3 ; 8 bits in one byte cmp ecx, MPINT_MAX_LEN jge .zero mov esi, [src] add esi, MPINT_MAX_LEN+4-4 mov edi, [dst] add edi, MPINT_MAX_LEN+4-4 and ecx, not 11b sub esi, ecx mov edx, MPINT_MAX_LEN/4-1 shr ecx, 2 ; 4 bytes in one dword push ecx sub edx, ecx mov ecx, [shift] and ecx, 11111b std .loop: lodsd mov ebx, [esi] shld eax, ebx, cl stosd dec edx jnz .loop lodsd shl eax, cl stosd ; fill the lsb bytes with zeros pop ecx test ecx, ecx jz @f xor eax, eax rep stosd @@: cld ret .zero: mov eax, [dst] mov dword[eax], 0 ret endp ;;===========================================================================;; proc mpint_add uses esi edi ecx eax, dst, src ;//////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Add a little endian MPINT to another little endian MPINT. ;; ;;---------------------------------------------------------------------------;; ;> src = pointer to little endian MPINT ;; ;> dst = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = dst + src ;; ;;===========================================================================;; mov esi, [src] mov edi, [dst] mov ecx, [esi] ; source number length sub ecx, [dst] jbe .length_ok ; Length of the destination is currently smaller then the source, pad with 0 bytes add edi, [edi] add edi, 4 mov al, 0 rep stosb .length_ok: mov ecx, [esi] mov edi, [dst] add esi, 4 add edi, 4 ; Add the first byte lodsb add byte[edi], al dec ecx jz .done ; Add the other bytes @@: inc edi lodsb adc byte[edi], al dec ecx jnz @r .done: ; check if highest bit OR carry flag is set ; if so, add a byte if we have the buffer space ; TODO: check if we have the buffer space jc .carry cmp byte[edi], 0x80 jnz .high_bit_set ret .carry: inc edi mov byte[edi], 1 mov eax, [dst] inc dword[eax] ret .high_bit_set: inc edi mov byte[edi], 0 mov eax, [dst] inc dword[eax] ret endp ;;===========================================================================;; proc mpint_sub uses eax esi edi ecx, dst, src ;//////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Subtract a little endian MPINT to another little endian MPINT. ;; ;;---------------------------------------------------------------------------;; ;> src = pointer to little endian MPINT ;; ;> dst = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = dst - src ;; ;;===========================================================================;; mov esi, [src] mov edi, [dst] mov ecx, [esi] ; destination number length cmp ecx, [edi] ja .overflow add esi, 4 add edi, 4 ; Subtract the first byte lodsb sub byte[edi], al dec ecx jz .done ; Subtract the other bytes @@: inc edi lodsb sbb byte[edi], al dec ecx jnz @r .done: stdcall mpint_shrink, [dst] ret .overflow: mov dword[edi], 0 stc ret endp ;;===========================================================================;; proc mpint_shrink uses eax edi ecx, dst ;////////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Get rid of leading zeroes on a little endian MPINT. ;; ;;---------------------------------------------------------------------------;; ;> src = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< ;; ;;===========================================================================;; mov edi, [dst] lodsd std mov ecx, eax dec eax ; total length minus one add edi, eax xor al, al repe cmpsb inc ecx mov edi, [dst] mov [edi], ecx cld ret endp ;;===========================================================================;; proc mpint_mul uses esi edi ecx ebx eax, dst, A, B ;/////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Multiply to little endian MPINTS and store them in a new one. ;; ;;---------------------------------------------------------------------------;; ;> A = pointer to little endian MPINT ;; ;> B = pointer to little endian MPINT ;; ;> dst = pointer to buffer for little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = A * B ;; ;;===========================================================================;; ; Set result to zero mov eax, [dst] mov dword[eax], 0 ; first, find the byte in A containing the highest order bit mov edi, [A] mov eax, [edi] test eax, eax jz .zero add edi, eax mov al, [edi+1] mov esi, edi mov bl, 8 @@: shl al, 1 jc .first_hit dec bl jnz @r ; Then, starting from this byte, iterate through the bits in A, ; starting from the highest order bit down to the lowest order bit. .next_byte: mov al, [edi] dec edi mov bl, 8 .next_bit: stdcall mpint_shl1, [dst] shl al, 1 jnc .zero_bit .first_hit: stdcall mpint_add, [dst], [B] .zero_bit: dec bl jnz .next_bit dec ecx jnz .next_byte .zero: ret endp ;;===========================================================================;; proc mpint_mod uses eax ebx ecx, dst, mod ;//////////////////////////////////;; ;;---------------------------------------------------------------------------;; ;? Find the modulo (remainder after division) of dst by mod. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to little endian MPINT ;; ;> mod = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = dst MOD mod ;; ;;===========================================================================;; locals mpint_tmp rb MPINT_MAX_LEN+4 endl ; if mod is zero, return mov eax, [mod] cmp dword[eax], 0 je .zero stdcall mpint_cmp, eax, [dst] jb .done ; if dst < mod, dst = dst je .zero ; if dst == mod, dst = 0 lea ebx, [mpint_tmp] ; left shift mod until the high order bits of mod and dst are aligned stdcall mpint_hob, [dst] mov ecx, eax stdcall mpint_hob, [mod] sub ecx, eax stdcall mpint_shlmov, ebx, [mod], ecx inc ecx ; For every bit in dst (starting from the high order bit): .loop: ; determine if dst is bigger than mpint_tmp stdcall mpint_cmp, [dst], ebx ja @f ; if so, subtract mpint_tmp from dst stdcall mpint_sub, [dst], ebx @@: dec ecx jz .done ; shift mpint_tmp right by 1 stdcall mpint_shr1, ebx jmp .loop .zero: mov eax, [dst] mov dword[eax], 0 .done: ret endp ;;===========================================================================;; proc mpint_modexp uses edi eax ebx ecx edx, dst, base, exp, mod ;////////////;; ;;---------------------------------------------------------------------------;; ;? Find the modulo (remainder after division) of dst by mod. ;; ;;---------------------------------------------------------------------------;; ;> dst = pointer to buffer for little endian MPINT ;; ;> base = pointer to little endian MPINT ;; ;> exp = pointer to little endian MPINT ;; ;> mod = pointer to little endian MPINT ;; ;;---------------------------------------------------------------------------;; ;< dst = base ** exp MOD mod ;; ;;===========================================================================;; locals mpint_tmp rb MPINT_MAX_LEN+4 endl ; If mod is zero, return mov eax, [mod] cmp dword[eax], 0 je .mod_zero ; Find the highest order byte in exponent mov edi, [exp] mov ecx, [edi] lea edi, [edi + 4 + ecx - 1] ; Find the highest order bit in this byte mov al, [edi] test al, al jz .invalid mov bl, 9 @@: dec bl shl al, 1 jnc @r lea edx, [mpint_tmp] ; Initialise result to base, to take care of the highest order bit stdcall mpint_mov, [dst], [base] dec bl jz .next_byte .bit_loop: ; For each bit, square result stdcall mpint_mov, edx, [dst] stdcall mpint_mul, [dst], edx, edx stdcall mpint_mod, [dst], [mod] ; If the bit is set, multiply result by the base shl al, 1 jnc .next_bit stdcall mpint_mov, edx, [dst] stdcall mpint_mul, [dst], [base], edx stdcall mpint_mod, [dst], [mod] .next_bit: dec bl jnz .bit_loop .next_byte: dec ecx jz .done dec edi mov al, [edi] mov bl, 8 jmp .bit_loop .done: ret .mod_zero: DEBUGF 3, "modexp with modulo 0\n" ; if mod is zero, result = 0 mov eax, [dst] mov dword[eax], 0 ret .exp_zero: DEBUGF 3, "modexp with exponent 0\n" ; if exponent is zero, result = 1 mov eax, [dst] mov dword[eax], 1 mov byte[eax+4], 1 ret .invalid: DEBUGF 3, "modexp: Invalid input!\n" ret endp