Source file src/crypto/internal/fips140/bigmod/nat.go
1 // Copyright 2021 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package bigmod 6 7 import ( 8 _ "crypto/internal/fips140/check" 9 "crypto/internal/fips140deps/byteorder" 10 "errors" 11 "math/bits" 12 ) 13 14 const ( 15 // _W is the size in bits of our limbs. 16 _W = bits.UintSize 17 // _S is the size in bytes of our limbs. 18 _S = _W / 8 19 ) 20 21 // Note: These functions make many loops over all the words in a Nat. 22 // These loops used to be in assembly, invisible to -race, -asan, and -msan, 23 // but now they are in Go and incur significant overhead in those modes. 24 // To bring the old performance back, we mark all functions that loop 25 // over Nat words with //go:norace. Because //go:norace does not 26 // propagate across inlining, we must also mark functions that inline 27 // //go:norace functions - specifically, those that inline add, addMulVVW, 28 // assign, cmpGeq, rshift1, and sub. 29 30 // choice represents a constant-time boolean. The value of choice is always 31 // either 1 or 0. We use an int instead of bool in order to make decisions in 32 // constant time by turning it into a mask. 33 type choice uint 34 35 func not(c choice) choice { return 1 ^ c } 36 37 const yes = choice(1) 38 const no = choice(0) 39 40 // ctMask is all 1s if on is yes, and all 0s otherwise. 41 func ctMask(on choice) uint { return -uint(on) } 42 43 // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this 44 // function does not depend on its inputs. 45 func ctEq(x, y uint) choice { 46 // If x != y, then either x - y or y - x will generate a carry. 47 _, c1 := bits.Sub(x, y, 0) 48 _, c2 := bits.Sub(y, x, 0) 49 return not(choice(c1 | c2)) 50 } 51 52 // Nat represents an arbitrary natural number 53 // 54 // Each Nat has an announced length, which is the number of limbs it has stored. 55 // Operations on this number are allowed to leak this length, but will not leak 56 // any information about the values contained in those limbs. 57 type Nat struct { 58 // limbs is little-endian in base 2^W with W = bits.UintSize. 59 limbs []uint 60 } 61 62 // preallocTarget is the size in bits of the numbers used to implement the most 63 // common and most performant RSA key size. It's also enough to cover some of 64 // the operations of key sizes up to 4096. 65 const preallocTarget = 2048 66 const preallocLimbs = (preallocTarget + _W - 1) / _W 67 68 // NewNat returns a new nat with a size of zero, just like new(Nat), but with 69 // the preallocated capacity to hold a number of up to preallocTarget bits. 70 // NewNat inlines, so the allocation can live on the stack. 71 func NewNat() *Nat { 72 limbs := make([]uint, 0, preallocLimbs) 73 return &Nat{limbs} 74 } 75 76 // expand expands x to n limbs, leaving its value unchanged. 77 func (x *Nat) expand(n int) *Nat { 78 if len(x.limbs) > n { 79 panic("bigmod: internal error: shrinking nat") 80 } 81 if cap(x.limbs) < n { 82 newLimbs := make([]uint, n) 83 copy(newLimbs, x.limbs) 84 x.limbs = newLimbs 85 return x 86 } 87 extraLimbs := x.limbs[len(x.limbs):n] 88 clear(extraLimbs) 89 x.limbs = x.limbs[:n] 90 return x 91 } 92 93 // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). 94 func (x *Nat) reset(n int) *Nat { 95 if cap(x.limbs) < n { 96 x.limbs = make([]uint, n) 97 return x 98 } 99 clear(x.limbs) 100 x.limbs = x.limbs[:n] 101 return x 102 } 103 104 // resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing 105 // n to the appropriate size. 106 // 107 // The announced length of x is set based on the actual bit size of the input, 108 // ignoring leading zeroes. 109 func (x *Nat) resetToBytes(b []byte) *Nat { 110 x.reset((len(b) + _S - 1) / _S) 111 if err := x.setBytes(b); err != nil { 112 panic("bigmod: internal error: bad arithmetic") 113 } 114 return x.trim() 115 } 116 117 // trim reduces the size of x to match its value. 118 func (x *Nat) trim() *Nat { 119 // Trim most significant (trailing in little-endian) zero limbs. 120 // We assume comparison with zero (but not the branch) is constant time. 121 for i := len(x.limbs) - 1; i >= 0; i-- { 122 if x.limbs[i] != 0 { 123 break 124 } 125 x.limbs = x.limbs[:i] 126 } 127 return x 128 } 129 130 // set assigns x = y, optionally resizing x to the appropriate size. 131 func (x *Nat) set(y *Nat) *Nat { 132 x.reset(len(y.limbs)) 133 copy(x.limbs, y.limbs) 134 return x 135 } 136 137 // Bits returns x as a little-endian slice of uint. The length of the slice 138 // matches the announced length of x. The result and x share the same underlying 139 // array. 140 func (x *Nat) Bits() []uint { 141 return x.limbs 142 } 143 144 // Bytes returns x as a zero-extended big-endian byte slice. The size of the 145 // slice will match the size of m. 146 // 147 // x must have the same size as m and it must be less than or equal to m. 148 func (x *Nat) Bytes(m *Modulus) []byte { 149 i := m.Size() 150 bytes := make([]byte, i) 151 for _, limb := range x.limbs { 152 for j := 0; j < _S; j++ { 153 i-- 154 if i < 0 { 155 if limb == 0 { 156 break 157 } 158 panic("bigmod: modulus is smaller than nat") 159 } 160 bytes[i] = byte(limb) 161 limb >>= 8 162 } 163 } 164 return bytes 165 } 166 167 // SetBytes assigns x = b, where b is a slice of big-endian bytes. 168 // SetBytes returns an error if b >= m. 169 // 170 // The output will be resized to the size of m and overwritten. 171 // 172 //go:norace 173 func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) { 174 x.resetFor(m) 175 if err := x.setBytes(b); err != nil { 176 return nil, err 177 } 178 if x.cmpGeq(m.nat) == yes { 179 return nil, errors.New("input overflows the modulus") 180 } 181 return x, nil 182 } 183 184 // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. 185 // SetOverflowingBytes returns an error if b has a longer bit length than m, but 186 // reduces overflowing values up to 2^⌈log2(m)⌉ - 1. 187 // 188 // The output will be resized to the size of m and overwritten. 189 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) { 190 x.resetFor(m) 191 if err := x.setBytes(b); err != nil { 192 return nil, err 193 } 194 // setBytes would have returned an error if the input overflowed the limb 195 // size of the modulus, so now we only need to check if the most significant 196 // limb of x has more bits than the most significant limb of the modulus. 197 if bitLen(x.limbs[len(x.limbs)-1]) > bitLen(m.nat.limbs[len(m.nat.limbs)-1]) { 198 return nil, errors.New("input overflows the modulus size") 199 } 200 x.maybeSubtractModulus(no, m) 201 return x, nil 202 } 203 204 // bigEndianUint returns the contents of buf interpreted as a 205 // big-endian encoded uint value. 206 func bigEndianUint(buf []byte) uint { 207 if _W == 64 { 208 return uint(byteorder.BEUint64(buf)) 209 } 210 return uint(byteorder.BEUint32(buf)) 211 } 212 213 func (x *Nat) setBytes(b []byte) error { 214 i, k := len(b), 0 215 for k < len(x.limbs) && i >= _S { 216 x.limbs[k] = bigEndianUint(b[i-_S : i]) 217 i -= _S 218 k++ 219 } 220 for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 { 221 x.limbs[k] |= uint(b[i-1]) << s 222 i-- 223 } 224 if i > 0 { 225 return errors.New("input overflows the modulus size") 226 } 227 return nil 228 } 229 230 // SetUint assigns x = y. 231 // 232 // The output will be resized to a single limb and overwritten. 233 func (x *Nat) SetUint(y uint) *Nat { 234 x.reset(1) 235 x.limbs[0] = y 236 return x 237 } 238 239 // Equal returns 1 if x == y, and 0 otherwise. 240 // 241 // Both operands must have the same announced length. 242 // 243 //go:norace 244 func (x *Nat) Equal(y *Nat) choice { 245 // Eliminate bounds checks in the loop. 246 size := len(x.limbs) 247 xLimbs := x.limbs[:size] 248 yLimbs := y.limbs[:size] 249 250 equal := yes 251 for i := 0; i < size; i++ { 252 equal &= ctEq(xLimbs[i], yLimbs[i]) 253 } 254 return equal 255 } 256 257 // IsZero returns 1 if x == 0, and 0 otherwise. 258 // 259 //go:norace 260 func (x *Nat) IsZero() choice { 261 // Eliminate bounds checks in the loop. 262 size := len(x.limbs) 263 xLimbs := x.limbs[:size] 264 265 zero := yes 266 for i := 0; i < size; i++ { 267 zero &= ctEq(xLimbs[i], 0) 268 } 269 return zero 270 } 271 272 // IsOne returns 1 if x == 1, and 0 otherwise. 273 // 274 //go:norace 275 func (x *Nat) IsOne() choice { 276 // Eliminate bounds checks in the loop. 277 size := len(x.limbs) 278 xLimbs := x.limbs[:size] 279 280 if len(xLimbs) == 0 { 281 return no 282 } 283 284 one := ctEq(xLimbs[0], 1) 285 for i := 1; i < size; i++ { 286 one &= ctEq(xLimbs[i], 0) 287 } 288 return one 289 } 290 291 // IsMinusOne returns 1 if x == -1 mod m, and 0 otherwise. 292 // 293 // The length of x must be the same as the modulus. x must already be reduced 294 // modulo m. 295 // 296 //go:norace 297 func (x *Nat) IsMinusOne(m *Modulus) choice { 298 minusOne := m.Nat() 299 minusOne.SubOne(m) 300 return x.Equal(minusOne) 301 } 302 303 // IsOdd returns 1 if x is odd, and 0 otherwise. 304 func (x *Nat) IsOdd() choice { 305 if len(x.limbs) == 0 { 306 return no 307 } 308 return choice(x.limbs[0] & 1) 309 } 310 311 // TrailingZeroBitsVarTime returns the number of trailing zero bits in x. 312 func (x *Nat) TrailingZeroBitsVarTime() uint { 313 var t uint 314 limbs := x.limbs 315 for _, l := range limbs { 316 if l == 0 { 317 t += _W 318 continue 319 } 320 t += uint(bits.TrailingZeros(l)) 321 break 322 } 323 return t 324 } 325 326 // cmpGeq returns 1 if x >= y, and 0 otherwise. 327 // 328 // Both operands must have the same announced length. 329 // 330 //go:norace 331 func (x *Nat) cmpGeq(y *Nat) choice { 332 // Eliminate bounds checks in the loop. 333 size := len(x.limbs) 334 xLimbs := x.limbs[:size] 335 yLimbs := y.limbs[:size] 336 337 var c uint 338 for i := 0; i < size; i++ { 339 _, c = bits.Sub(xLimbs[i], yLimbs[i], c) 340 } 341 // If there was a carry, then subtracting y underflowed, so 342 // x is not greater than or equal to y. 343 return not(choice(c)) 344 } 345 346 // assign sets x <- y if on == 1, and does nothing otherwise. 347 // 348 // Both operands must have the same announced length. 349 // 350 //go:norace 351 func (x *Nat) assign(on choice, y *Nat) *Nat { 352 // Eliminate bounds checks in the loop. 353 size := len(x.limbs) 354 xLimbs := x.limbs[:size] 355 yLimbs := y.limbs[:size] 356 357 mask := ctMask(on) 358 for i := 0; i < size; i++ { 359 xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i]) 360 } 361 return x 362 } 363 364 // add computes x += y and returns the carry. 365 // 366 // Both operands must have the same announced length. 367 // 368 //go:norace 369 func (x *Nat) add(y *Nat) (c uint) { 370 // Eliminate bounds checks in the loop. 371 size := len(x.limbs) 372 xLimbs := x.limbs[:size] 373 yLimbs := y.limbs[:size] 374 375 for i := 0; i < size; i++ { 376 xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c) 377 } 378 return 379 } 380 381 // sub computes x -= y. It returns the borrow of the subtraction. 382 // 383 // Both operands must have the same announced length. 384 // 385 //go:norace 386 func (x *Nat) sub(y *Nat) (c uint) { 387 // Eliminate bounds checks in the loop. 388 size := len(x.limbs) 389 xLimbs := x.limbs[:size] 390 yLimbs := y.limbs[:size] 391 392 for i := 0; i < size; i++ { 393 xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c) 394 } 395 return 396 } 397 398 // ShiftRightVarTime sets x = x >> n. 399 // 400 // The announced length of x is unchanged. 401 // 402 //go:norace 403 func (x *Nat) ShiftRightVarTime(n uint) *Nat { 404 // Eliminate bounds checks in the loop. 405 size := len(x.limbs) 406 xLimbs := x.limbs[:size] 407 408 shift := int(n % _W) 409 shiftLimbs := int(n / _W) 410 411 var shiftedLimbs []uint 412 if shiftLimbs < size { 413 shiftedLimbs = xLimbs[shiftLimbs:] 414 } 415 416 for i := range xLimbs { 417 if i >= len(shiftedLimbs) { 418 xLimbs[i] = 0 419 continue 420 } 421 422 xLimbs[i] = shiftedLimbs[i] >> shift 423 if i+1 < len(shiftedLimbs) { 424 xLimbs[i] |= shiftedLimbs[i+1] << (_W - shift) 425 } 426 } 427 428 return x 429 } 430 431 // BitLenVarTime returns the actual size of x in bits. 432 // 433 // The actual size of x (but nothing more) leaks through timing side-channels. 434 // Note that this is ordinarily secret, as opposed to the announced size of x. 435 func (x *Nat) BitLenVarTime() int { 436 // Eliminate bounds checks in the loop. 437 size := len(x.limbs) 438 xLimbs := x.limbs[:size] 439 440 for i := size - 1; i >= 0; i-- { 441 if xLimbs[i] != 0 { 442 return i*_W + bitLen(xLimbs[i]) 443 } 444 } 445 return 0 446 } 447 448 // bitLen is a version of bits.Len that only leaks the bit length of n, but not 449 // its value. bits.Len and bits.LeadingZeros use a lookup table for the 450 // low-order bits on some architectures. 451 func bitLen(n uint) int { 452 len := 0 453 // We assume, here and elsewhere, that comparison to zero is constant time 454 // with respect to different non-zero values. 455 for n != 0 { 456 len++ 457 n >>= 1 458 } 459 return len 460 } 461 462 // Modulus is used for modular arithmetic, precomputing relevant constants. 463 // 464 // A Modulus can leak the exact number of bits needed to store its value 465 // and is stored without padding. Its actual value is still kept secret. 466 type Modulus struct { 467 // The underlying natural number for this modulus. 468 // 469 // This will be stored without any padding, and shouldn't alias with any 470 // other natural number being used. 471 nat *Nat 472 473 // If m is even, the following fields are not set. 474 odd bool 475 m0inv uint // -nat.limbs[0]⁻¹ mod _W 476 rr *Nat // R*R for montgomeryRepresentation 477 } 478 479 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). 480 func rr(m *Modulus) *Nat { 481 rr := NewNat().ExpandFor(m) 482 n := uint(len(rr.limbs)) 483 mLen := uint(m.BitLen()) 484 logR := _W * n 485 486 // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to 487 // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce. 488 rr.limbs[n-1] = 1 << ((mLen - 1) % _W) 489 // Then we double until we reach 2^(_W * n). 490 for i := mLen - 1; i < logR; i++ { 491 rr.Add(rr, m) 492 } 493 494 // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in 495 // the Montgomery domain, meaning we can use Montgomery multiplication now). 496 // We could do that by doubling _W * n times, or with a square-and-double 497 // chain log2(_W * n) long. Turns out the fastest thing is to start out with 498 // doublings, and switch to square-and-double once the exponent is large 499 // enough to justify the cost of the multiplications. 500 501 // The threshold is selected experimentally as a linear function of n. 502 threshold := n / 4 503 504 // We calculate how many of the most-significant bits of the exponent we can 505 // compute before crossing the threshold, and we do it with doublings. 506 i := bits.UintSize 507 for logR>>i <= threshold { 508 i-- 509 } 510 for k := uint(0); k < logR>>i; k++ { 511 rr.Add(rr, m) 512 } 513 514 // Then we process the remaining bits of the exponent with a 515 // square-and-double chain. 516 for i > 0 { 517 rr.montgomeryMul(rr, rr, m) 518 i-- 519 if logR>>i&1 != 0 { 520 rr.Add(rr, m) 521 } 522 } 523 524 return rr 525 } 526 527 // minusInverseModW computes -x⁻¹ mod _W with x odd. 528 // 529 // This operation is used to precompute a constant involved in Montgomery 530 // multiplication. 531 func minusInverseModW(x uint) uint { 532 // Every iteration of this loop doubles the least-significant bits of 533 // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, 534 // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough 535 // for 64 bits (and wastes only one iteration for 32 bits). 536 // 537 // See https://crypto.stackexchange.com/a/47496. 538 y := x 539 for i := 0; i < 5; i++ { 540 y = y * (2 - x*y) 541 } 542 return -y 543 } 544 545 // NewModulus creates a new Modulus from a slice of big-endian bytes. The 546 // modulus must be greater than one. 547 // 548 // The number of significant bits and whether the modulus is even is leaked 549 // through timing side-channels. 550 func NewModulus(b []byte) (*Modulus, error) { 551 n := NewNat().resetToBytes(b) 552 return newModulus(n) 553 } 554 555 // NewModulusProduct creates a new Modulus from the product of two numbers 556 // represented as big-endian byte slices. The result must be greater than one. 557 // 558 //go:norace 559 func NewModulusProduct(a, b []byte) (*Modulus, error) { 560 x := NewNat().resetToBytes(a) 561 y := NewNat().resetToBytes(b) 562 n := NewNat().reset(len(x.limbs) + len(y.limbs)) 563 for i := range y.limbs { 564 n.limbs[i+len(x.limbs)] = addMulVVW(n.limbs[i:i+len(x.limbs)], x.limbs, y.limbs[i]) 565 } 566 return newModulus(n.trim()) 567 } 568 569 func newModulus(n *Nat) (*Modulus, error) { 570 m := &Modulus{nat: n} 571 if m.nat.IsZero() == yes || m.nat.IsOne() == yes { 572 return nil, errors.New("modulus must be > 1") 573 } 574 if m.nat.IsOdd() == 1 { 575 m.odd = true 576 m.m0inv = minusInverseModW(m.nat.limbs[0]) 577 m.rr = rr(m) 578 } 579 return m, nil 580 } 581 582 // Size returns the size of m in bytes. 583 func (m *Modulus) Size() int { 584 return (m.BitLen() + 7) / 8 585 } 586 587 // BitLen returns the size of m in bits. 588 func (m *Modulus) BitLen() int { 589 return m.nat.BitLenVarTime() 590 } 591 592 // Nat returns m as a Nat. 593 func (m *Modulus) Nat() *Nat { 594 // Make a copy so that the caller can't modify m.nat or alias it with 595 // another Nat in a modulus operation. 596 n := NewNat() 597 n.set(m.nat) 598 return n 599 } 600 601 // shiftIn calculates x = x << _W + y mod m. 602 // 603 // This assumes that x is already reduced mod m. 604 // 605 //go:norace 606 func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { 607 d := NewNat().resetFor(m) 608 609 // Eliminate bounds checks in the loop. 610 size := len(m.nat.limbs) 611 xLimbs := x.limbs[:size] 612 dLimbs := d.limbs[:size] 613 mLimbs := m.nat.limbs[:size] 614 615 // Each iteration of this loop computes x = 2x + b mod m, where b is a bit 616 // from y. Effectively, it left-shifts x and adds y one bit at a time, 617 // reducing it every time. 618 // 619 // To do the reduction, each iteration computes both 2x + b and 2x + b - m. 620 // The next iteration (and finally the return line) will use either result 621 // based on whether 2x + b overflows m. 622 needSubtraction := no 623 for i := _W - 1; i >= 0; i-- { 624 carry := (y >> i) & 1 625 var borrow uint 626 mask := ctMask(needSubtraction) 627 for i := 0; i < size; i++ { 628 l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i])) 629 xLimbs[i], carry = bits.Add(l, l, carry) 630 dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow) 631 } 632 // Like in maybeSubtractModulus, we need the subtraction if either it 633 // didn't underflow (meaning 2x + b > m) or if computing 2x + b 634 // overflowed (meaning 2x + b > 2^_W*n > m). 635 needSubtraction = not(choice(borrow)) | choice(carry) 636 } 637 return x.assign(needSubtraction, d) 638 } 639 640 // Mod calculates out = x mod m. 641 // 642 // This works regardless how large the value of x is. 643 // 644 // The output will be resized to the size of m and overwritten. 645 // 646 //go:norace 647 func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { 648 out.resetFor(m) 649 // Working our way from the most significant to the least significant limb, 650 // we can insert each limb at the least significant position, shifting all 651 // previous limbs left by _W. This way each limb will get shifted by the 652 // correct number of bits. We can insert at least N - 1 limbs without 653 // overflowing m. After that, we need to reduce every time we shift. 654 i := len(x.limbs) - 1 655 // For the first N - 1 limbs we can skip the actual shifting and position 656 // them at the shifted position, which starts at min(N - 2, i). 657 start := len(m.nat.limbs) - 2 658 if i < start { 659 start = i 660 } 661 for j := start; j >= 0; j-- { 662 out.limbs[j] = x.limbs[i] 663 i-- 664 } 665 // We shift in the remaining limbs, reducing modulo m each time. 666 for i >= 0 { 667 out.shiftIn(x.limbs[i], m) 668 i-- 669 } 670 return out 671 } 672 673 // ExpandFor ensures x has the right size to work with operations modulo m. 674 // 675 // The announced size of x must be smaller than or equal to that of m. 676 func (x *Nat) ExpandFor(m *Modulus) *Nat { 677 return x.expand(len(m.nat.limbs)) 678 } 679 680 // resetFor ensures out has the right size to work with operations modulo m. 681 // 682 // out is zeroed and may start at any size. 683 func (out *Nat) resetFor(m *Modulus) *Nat { 684 return out.reset(len(m.nat.limbs)) 685 } 686 687 // maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes. 688 // 689 // It can be used to reduce modulo m a value up to 2m - 1, which is a common 690 // range for results computed by higher level operations. 691 // 692 // always is usually a carry that indicates that the operation that produced x 693 // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m. 694 // 695 // x and m operands must have the same announced length. 696 // 697 //go:norace 698 func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { 699 t := NewNat().set(x) 700 underflow := t.sub(m.nat) 701 // We keep the result if x - m didn't underflow (meaning x >= m) 702 // or if always was set. 703 keep := not(choice(underflow)) | choice(always) 704 x.assign(keep, t) 705 } 706 707 // Sub computes x = x - y mod m. 708 // 709 // The length of both operands must be the same as the modulus. Both operands 710 // must already be reduced modulo m. 711 // 712 //go:norace 713 func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { 714 underflow := x.sub(y) 715 // If the subtraction underflowed, add m. 716 t := NewNat().set(x) 717 t.add(m.nat) 718 x.assign(choice(underflow), t) 719 return x 720 } 721 722 // SubOne computes x = x - 1 mod m. 723 // 724 // The length of x must be the same as the modulus. 725 func (x *Nat) SubOne(m *Modulus) *Nat { 726 one := NewNat().ExpandFor(m) 727 one.limbs[0] = 1 728 // Sub asks for x to be reduced modulo m, while SubOne doesn't, but when 729 // y = 1, it works, and this is an internal use. 730 return x.Sub(one, m) 731 } 732 733 // Add computes x = x + y mod m. 734 // 735 // The length of both operands must be the same as the modulus. Both operands 736 // must already be reduced modulo m. 737 // 738 //go:norace 739 func (x *Nat) Add(y *Nat, m *Modulus) *Nat { 740 overflow := x.add(y) 741 x.maybeSubtractModulus(choice(overflow), m) 742 return x 743 } 744 745 // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and 746 // n = len(m.nat.limbs). 747 // 748 // Faster Montgomery multiplication replaces standard modular multiplication for 749 // numbers in this representation. 750 // 751 // This assumes that x is already reduced mod m. 752 func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat { 753 // A Montgomery multiplication (which computes a * b / R) by R * R works out 754 // to a multiplication by R, which takes the value out of the Montgomery domain. 755 return x.montgomeryMul(x, m.rr, m) 756 } 757 758 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and 759 // n = len(m.nat.limbs). 760 // 761 // This assumes that x is already reduced mod m. 762 func (x *Nat) montgomeryReduction(m *Modulus) *Nat { 763 // By Montgomery multiplying with 1 not in Montgomery representation, we 764 // convert out back from Montgomery representation, because it works out to 765 // dividing by R. 766 one := NewNat().ExpandFor(m) 767 one.limbs[0] = 1 768 return x.montgomeryMul(x, one, m) 769 } 770 771 // montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and 772 // n = len(m.nat.limbs), also known as a Montgomery multiplication. 773 // 774 // All inputs should be the same length and already reduced modulo m. 775 // x will be resized to the size of m and overwritten. 776 // 777 //go:norace 778 func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { 779 n := len(m.nat.limbs) 780 mLimbs := m.nat.limbs[:n] 781 aLimbs := a.limbs[:n] 782 bLimbs := b.limbs[:n] 783 784 switch n { 785 default: 786 // Attempt to use a stack-allocated backing array. 787 T := make([]uint, 0, preallocLimbs*2) 788 if cap(T) < n*2 { 789 T = make([]uint, 0, n*2) 790 } 791 T = T[:n*2] 792 793 // This loop implements Word-by-Word Montgomery Multiplication, as 794 // described in Algorithm 4 (Fig. 3) of "Efficient Software 795 // Implementations of Modular Exponentiation" by Shay Gueron 796 // [https://eprint.iacr.org/2011/239.pdf]. 797 var c uint 798 for i := 0; i < n; i++ { 799 _ = T[n+i] // bounds check elimination hint 800 801 // Step 1 (T = a × b) is computed as a large pen-and-paper column 802 // multiplication of two numbers with n base-2^_W digits. If we just 803 // wanted to produce 2n-wide T, we would do 804 // 805 // for i := 0; i < n; i++ { 806 // d := bLimbs[i] 807 // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d) 808 // } 809 // 810 // where d is a digit of the multiplier, T[i:n+i] is the shifted 811 // position of the product of that digit, and T[n+i] is the final carry. 812 // Note that T[i] isn't modified after processing the i-th digit. 813 // 814 // Instead of running two loops, one for Step 1 and one for Steps 2–6, 815 // the result of Step 1 is computed during the next loop. This is 816 // possible because each iteration only uses T[i] in Step 2 and then 817 // discards it in Step 6. 818 d := bLimbs[i] 819 c1 := addMulVVW(T[i:n+i], aLimbs, d) 820 821 // Step 6 is replaced by shifting the virtual window we operate 822 // over: T of the algorithm is T[i:] for us. That means that T1 in 823 // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv. 824 Y := T[i] * m.m0inv 825 826 // Step 4 and 5 add Y × m to T, which as mentioned above is stored 827 // at T[i:]. The two carries (from a × d and Y × m) are added up in 828 // the next word T[n+i], and the carry bit from that addition is 829 // brought forward to the next iteration. 830 c2 := addMulVVW(T[i:n+i], mLimbs, Y) 831 T[n+i], c = bits.Add(c1, c2, c) 832 } 833 834 // Finally for Step 7 we copy the final T window into x, and subtract m 835 // if necessary (which as explained in maybeSubtractModulus can be the 836 // case both if x >= m, or if x overflowed). 837 // 838 // The paper suggests in Section 4 that we can do an "Almost Montgomery 839 // Multiplication" by subtracting only in the overflow case, but the 840 // cost is very similar since the constant time subtraction tells us if 841 // x >= m as a side effect, and taking care of the broken invariant is 842 // highly undesirable (see https://go.dev/issue/13907). 843 copy(x.reset(n).limbs, T[n:]) 844 x.maybeSubtractModulus(choice(c), m) 845 846 // The following specialized cases follow the exact same algorithm, but 847 // optimized for the sizes most used in RSA. addMulVVW is implemented in 848 // assembly with loop unrolling depending on the architecture and bounds 849 // checks are removed by the compiler thanks to the constant size. 850 case 1024 / _W: 851 const n = 1024 / _W // compiler hint 852 T := make([]uint, n*2) 853 var c uint 854 for i := 0; i < n; i++ { 855 d := bLimbs[i] 856 c1 := addMulVVW1024(&T[i], &aLimbs[0], d) 857 Y := T[i] * m.m0inv 858 c2 := addMulVVW1024(&T[i], &mLimbs[0], Y) 859 T[n+i], c = bits.Add(c1, c2, c) 860 } 861 copy(x.reset(n).limbs, T[n:]) 862 x.maybeSubtractModulus(choice(c), m) 863 864 case 1536 / _W: 865 const n = 1536 / _W // compiler hint 866 T := make([]uint, n*2) 867 var c uint 868 for i := 0; i < n; i++ { 869 d := bLimbs[i] 870 c1 := addMulVVW1536(&T[i], &aLimbs[0], d) 871 Y := T[i] * m.m0inv 872 c2 := addMulVVW1536(&T[i], &mLimbs[0], Y) 873 T[n+i], c = bits.Add(c1, c2, c) 874 } 875 copy(x.reset(n).limbs, T[n:]) 876 x.maybeSubtractModulus(choice(c), m) 877 878 case 2048 / _W: 879 const n = 2048 / _W // compiler hint 880 T := make([]uint, n*2) 881 var c uint 882 for i := 0; i < n; i++ { 883 d := bLimbs[i] 884 c1 := addMulVVW2048(&T[i], &aLimbs[0], d) 885 Y := T[i] * m.m0inv 886 c2 := addMulVVW2048(&T[i], &mLimbs[0], Y) 887 T[n+i], c = bits.Add(c1, c2, c) 888 } 889 copy(x.reset(n).limbs, T[n:]) 890 x.maybeSubtractModulus(choice(c), m) 891 } 892 893 return x 894 } 895 896 // addMulVVW multiplies the multi-word value x by the single-word value y, 897 // adding the result to the multi-word value z and returning the final carry. 898 // It can be thought of as one row of a pen-and-paper column multiplication. 899 // 900 //go:norace 901 func addMulVVW(z, x []uint, y uint) (carry uint) { 902 _ = x[len(z)-1] // bounds check elimination hint 903 for i := range z { 904 hi, lo := bits.Mul(x[i], y) 905 lo, c := bits.Add(lo, z[i], 0) 906 // We use bits.Add with zero to get an add-with-carry instruction that 907 // absorbs the carry from the previous bits.Add. 908 hi, _ = bits.Add(hi, 0, c) 909 lo, c = bits.Add(lo, carry, 0) 910 hi, _ = bits.Add(hi, 0, c) 911 carry = hi 912 z[i] = lo 913 } 914 return carry 915 } 916 917 // Mul calculates x = x * y mod m. 918 // 919 // The length of both operands must be the same as the modulus. Both operands 920 // must already be reduced modulo m. 921 // 922 //go:norace 923 func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { 924 if m.odd { 925 // A Montgomery multiplication by a value out of the Montgomery domain 926 // takes the result out of Montgomery representation. 927 xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m 928 return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m 929 } 930 931 n := len(m.nat.limbs) 932 xLimbs := x.limbs[:n] 933 yLimbs := y.limbs[:n] 934 935 switch n { 936 default: 937 // Attempt to use a stack-allocated backing array. 938 T := make([]uint, 0, preallocLimbs*2) 939 if cap(T) < n*2 { 940 T = make([]uint, 0, n*2) 941 } 942 T = T[:n*2] 943 944 // T = x * y 945 for i := 0; i < n; i++ { 946 T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i]) 947 } 948 949 // x = T mod m 950 return x.Mod(&Nat{limbs: T}, m) 951 952 // The following specialized cases follow the exact same algorithm, but 953 // optimized for the sizes most used in RSA. See montgomeryMul for details. 954 case 1024 / _W: 955 const n = 1024 / _W // compiler hint 956 T := make([]uint, n*2) 957 for i := 0; i < n; i++ { 958 T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i]) 959 } 960 return x.Mod(&Nat{limbs: T}, m) 961 case 1536 / _W: 962 const n = 1536 / _W // compiler hint 963 T := make([]uint, n*2) 964 for i := 0; i < n; i++ { 965 T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i]) 966 } 967 return x.Mod(&Nat{limbs: T}, m) 968 case 2048 / _W: 969 const n = 2048 / _W // compiler hint 970 T := make([]uint, n*2) 971 for i := 0; i < n; i++ { 972 T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i]) 973 } 974 return x.Mod(&Nat{limbs: T}, m) 975 } 976 } 977 978 // Exp calculates out = x^e mod m. 979 // 980 // The exponent e is represented in big-endian order. The output will be resized 981 // to the size of m and overwritten. x must already be reduced modulo m. 982 // 983 // m must be odd, or Exp will panic. 984 // 985 //go:norace 986 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { 987 if !m.odd { 988 panic("bigmod: modulus for Exp must be odd") 989 } 990 991 // We use a 4 bit window. For our RSA workload, 4 bit windows are faster 992 // than 2 bit windows, but use an extra 12 nats worth of scratch space. 993 // Using bit sizes that don't divide 8 are more complex to implement, but 994 // are likely to be more efficient if necessary. 995 996 table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1) 997 // newNat calls are unrolled so they are allocated on the stack. 998 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 999 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 1000 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), 1001 } 1002 table[0].set(x).montgomeryRepresentation(m) 1003 for i := 1; i < len(table); i++ { 1004 table[i].montgomeryMul(table[i-1], table[0], m) 1005 } 1006 1007 out.resetFor(m) 1008 out.limbs[0] = 1 1009 out.montgomeryRepresentation(m) 1010 tmp := NewNat().ExpandFor(m) 1011 for _, b := range e { 1012 for _, j := range []int{4, 0} { 1013 // Square four times. Optimization note: this can be implemented 1014 // more efficiently than with generic Montgomery multiplication. 1015 out.montgomeryMul(out, out, m) 1016 out.montgomeryMul(out, out, m) 1017 out.montgomeryMul(out, out, m) 1018 out.montgomeryMul(out, out, m) 1019 1020 // Select x^k in constant time from the table. 1021 k := uint((b >> j) & 0b1111) 1022 for i := range table { 1023 tmp.assign(ctEq(k, uint(i+1)), table[i]) 1024 } 1025 1026 // Multiply by x^k, discarding the result if k = 0. 1027 tmp.montgomeryMul(out, tmp, m) 1028 out.assign(not(ctEq(k, 0)), tmp) 1029 } 1030 } 1031 1032 return out.montgomeryReduction(m) 1033 } 1034 1035 // ExpShortVarTime calculates out = x^e mod m. 1036 // 1037 // The output will be resized to the size of m and overwritten. x must already 1038 // be reduced modulo m. This leaks the exponent through timing side-channels. 1039 // 1040 // m must be odd, or ExpShortVarTime will panic. 1041 func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { 1042 if !m.odd { 1043 panic("bigmod: modulus for ExpShortVarTime must be odd") 1044 } 1045 // For short exponents, precomputing a table and using a window like in Exp 1046 // doesn't pay off. Instead, we do a simple conditional square-and-multiply 1047 // chain, skipping the initial run of zeroes. 1048 xR := NewNat().set(x).montgomeryRepresentation(m) 1049 out.set(xR) 1050 for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ { 1051 out.montgomeryMul(out, out, m) 1052 if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { 1053 out.montgomeryMul(out, xR, m) 1054 } 1055 } 1056 return out.montgomeryReduction(m) 1057 } 1058 1059 // InverseVarTime calculates x = a⁻¹ mod m and returns (x, true) if a is 1060 // invertible. Otherwise, InverseVarTime returns (x, false) and x is not 1061 // modified. 1062 // 1063 // a must be reduced modulo m, but doesn't need to have the same size. The 1064 // output will be resized to the size of m and overwritten. 1065 // 1066 //go:norace 1067 func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) { 1068 u, A, err := extendedGCD(a, m.nat) 1069 if err != nil { 1070 return x, false 1071 } 1072 if u.IsOne() == no { 1073 return x, false 1074 } 1075 return x.set(A), true 1076 } 1077 1078 // GCDVarTime calculates x = GCD(a, b) where at least one of a or b is odd, and 1079 // both are non-zero. If GCDVarTime returns an error, x is not modified. 1080 // 1081 // The output will be resized to the size of the larger of a and b. 1082 func (x *Nat) GCDVarTime(a, b *Nat) (*Nat, error) { 1083 u, _, err := extendedGCD(a, b) 1084 if err != nil { 1085 return nil, err 1086 } 1087 return x.set(u), nil 1088 } 1089 1090 // extendedGCD computes u and A such that a = GCD(a, m) and u = A*a - B*m. 1091 // 1092 // u will have the size of the larger of a and m, and A will have the size of m. 1093 // 1094 // It is an error if either a or m is zero, or if they are both even. 1095 func extendedGCD(a, m *Nat) (u, A *Nat, err error) { 1096 // This is the extended binary GCD algorithm described in the Handbook of 1097 // Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound 1098 // coefficients and avoid negative numbers. For more details and proof of 1099 // correctness, see https://github.com/mit-plv/fiat-crypto/pull/333/files. 1100 // 1101 // Following the proof linked in the PR above, the changes are: 1102 // 1103 // 1. Negate [B] and [C] so they are positive. The invariant now involves a 1104 // subtraction. 1105 // 2. If step 2 (both [x] and [y] are even) runs, abort immediately. This 1106 // case needs to be handled by the caller. 1107 // 3. Subtract copies of [x] and [y] as needed in step 6 (both [u] and [v] 1108 // are odd) so coefficients stay in bounds. 1109 // 4. Replace the [u >= v] check with [u > v]. This changes the end 1110 // condition to [v = 0] rather than [u = 0]. This saves an extra 1111 // subtraction due to which coefficients were negated. 1112 // 5. Rename x and y to a and n, to capture that one is a modulus. 1113 // 6. Rearrange steps 4 through 6 slightly. Merge the loops in steps 4 and 1114 // 5 into the main loop (step 7's goto), and move step 6 to the start of 1115 // the loop iteration, ensuring each loop iteration halves at least one 1116 // value. 1117 // 1118 // Note this algorithm does not handle either input being zero. 1119 1120 if a.IsZero() == yes || m.IsZero() == yes { 1121 return nil, nil, errors.New("extendedGCD: a or m is zero") 1122 } 1123 if a.IsOdd() == no && m.IsOdd() == no { 1124 return nil, nil, errors.New("extendedGCD: both a and m are even") 1125 } 1126 1127 size := max(len(a.limbs), len(m.limbs)) 1128 u = NewNat().set(a).expand(size) 1129 v := NewNat().set(m).expand(size) 1130 1131 A = NewNat().reset(len(m.limbs)) 1132 A.limbs[0] = 1 1133 B := NewNat().reset(len(a.limbs)) 1134 C := NewNat().reset(len(m.limbs)) 1135 D := NewNat().reset(len(a.limbs)) 1136 D.limbs[0] = 1 1137 1138 // Before and after each loop iteration, the following hold: 1139 // 1140 // u = A*a - B*m 1141 // v = D*m - C*a 1142 // 0 < u <= a 1143 // 0 <= v <= m 1144 // 0 <= A < m 1145 // 0 <= B <= a 1146 // 0 <= C < m 1147 // 0 <= D <= a 1148 // 1149 // After each loop iteration, u and v only get smaller, and at least one of 1150 // them shrinks by at least a factor of two. 1151 for { 1152 // If both u and v are odd, subtract the smaller from the larger. 1153 // If u = v, we need to subtract from v to hit the modified exit condition. 1154 if u.IsOdd() == yes && v.IsOdd() == yes { 1155 if v.cmpGeq(u) == no { 1156 u.sub(v) 1157 A.Add(C, &Modulus{nat: m}) 1158 B.Add(D, &Modulus{nat: a}) 1159 } else { 1160 v.sub(u) 1161 C.Add(A, &Modulus{nat: m}) 1162 D.Add(B, &Modulus{nat: a}) 1163 } 1164 } 1165 1166 // Exactly one of u and v is now even. 1167 if u.IsOdd() == v.IsOdd() { 1168 panic("bigmod: internal error: u and v are not in the expected state") 1169 } 1170 1171 // Halve the even one and adjust the corresponding coefficient. 1172 if u.IsOdd() == no { 1173 rshift1(u, 0) 1174 if A.IsOdd() == yes || B.IsOdd() == yes { 1175 rshift1(A, A.add(m)) 1176 rshift1(B, B.add(a)) 1177 } else { 1178 rshift1(A, 0) 1179 rshift1(B, 0) 1180 } 1181 } else { // v.IsOdd() == no 1182 rshift1(v, 0) 1183 if C.IsOdd() == yes || D.IsOdd() == yes { 1184 rshift1(C, C.add(m)) 1185 rshift1(D, D.add(a)) 1186 } else { 1187 rshift1(C, 0) 1188 rshift1(D, 0) 1189 } 1190 } 1191 1192 if v.IsZero() == yes { 1193 return u, A, nil 1194 } 1195 } 1196 } 1197 1198 //go:norace 1199 func rshift1(a *Nat, carry uint) { 1200 size := len(a.limbs) 1201 aLimbs := a.limbs[:size] 1202 1203 for i := range size { 1204 aLimbs[i] >>= 1 1205 if i+1 < size { 1206 aLimbs[i] |= aLimbs[i+1] << (_W - 1) 1207 } else { 1208 aLimbs[i] |= carry << (_W - 1) 1209 } 1210 } 1211 } 1212 1213 // DivShortVarTime calculates x = x / y and returns the remainder. 1214 // 1215 // It panics if y is zero. 1216 // 1217 //go:norace 1218 func (x *Nat) DivShortVarTime(y uint) uint { 1219 if y == 0 { 1220 panic("bigmod: division by zero") 1221 } 1222 1223 var r uint 1224 for i := len(x.limbs) - 1; i >= 0; i-- { 1225 x.limbs[i], r = bits.Div(r, x.limbs[i], y) 1226 } 1227 return r 1228 } 1229