Source file src/crypto/internal/edwards25519/field/_asm/fe_amd64_asm.go

     1  // Copyright (c) 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"fmt"
     9  
    10  	. "github.com/mmcloughlin/avo/build"
    11  	. "github.com/mmcloughlin/avo/gotypes"
    12  	. "github.com/mmcloughlin/avo/operand"
    13  	. "github.com/mmcloughlin/avo/reg"
    14  )
    15  
    16  //go:generate go run . -out ../fe_amd64.s -stubs ../fe_amd64.go -pkg field
    17  
    18  func main() {
    19  	Package("crypto/internal/edwards25519/field")
    20  	ConstraintExpr("!purego")
    21  	feMul()
    22  	feSquare()
    23  	Generate()
    24  }
    25  
    26  type namedComponent struct {
    27  	Component
    28  	name string
    29  }
    30  
    31  func (c namedComponent) String() string { return c.name }
    32  
    33  type uint128 struct {
    34  	name   string
    35  	hi, lo GPVirtual
    36  }
    37  
    38  func (c uint128) String() string { return c.name }
    39  
    40  func feSquare() {
    41  	TEXT("feSquare", NOSPLIT, "func(out, a *Element)")
    42  	Doc("feSquare sets out = a * a. It works like feSquareGeneric.")
    43  	Pragma("noescape")
    44  
    45  	a := Dereference(Param("a"))
    46  	l0 := namedComponent{a.Field("l0"), "l0"}
    47  	l1 := namedComponent{a.Field("l1"), "l1"}
    48  	l2 := namedComponent{a.Field("l2"), "l2"}
    49  	l3 := namedComponent{a.Field("l3"), "l3"}
    50  	l4 := namedComponent{a.Field("l4"), "l4"}
    51  
    52  	// r0 = l0×l0 + 19×2×(l1×l4 + l2×l3)
    53  	r0 := uint128{"r0", GP64(), GP64()}
    54  	mul64(r0, 1, l0, l0)
    55  	addMul64(r0, 38, l1, l4)
    56  	addMul64(r0, 38, l2, l3)
    57  
    58  	// r1 = 2×l0×l1 + 19×2×l2×l4 + 19×l3×l3
    59  	r1 := uint128{"r1", GP64(), GP64()}
    60  	mul64(r1, 2, l0, l1)
    61  	addMul64(r1, 38, l2, l4)
    62  	addMul64(r1, 19, l3, l3)
    63  
    64  	// r2 = = 2×l0×l2 + l1×l1 + 19×2×l3×l4
    65  	r2 := uint128{"r2", GP64(), GP64()}
    66  	mul64(r2, 2, l0, l2)
    67  	addMul64(r2, 1, l1, l1)
    68  	addMul64(r2, 38, l3, l4)
    69  
    70  	// r3 = = 2×l0×l3 + 2×l1×l2 + 19×l4×l4
    71  	r3 := uint128{"r3", GP64(), GP64()}
    72  	mul64(r3, 2, l0, l3)
    73  	addMul64(r3, 2, l1, l2)
    74  	addMul64(r3, 19, l4, l4)
    75  
    76  	// r4 = = 2×l0×l4 + 2×l1×l3 + l2×l2
    77  	r4 := uint128{"r4", GP64(), GP64()}
    78  	mul64(r4, 2, l0, l4)
    79  	addMul64(r4, 2, l1, l3)
    80  	addMul64(r4, 1, l2, l2)
    81  
    82  	Comment("First reduction chain")
    83  	maskLow51Bits := GP64()
    84  	MOVQ(Imm((1<<51)-1), maskLow51Bits)
    85  	c0, r0lo := shiftRightBy51(&r0)
    86  	c1, r1lo := shiftRightBy51(&r1)
    87  	c2, r2lo := shiftRightBy51(&r2)
    88  	c3, r3lo := shiftRightBy51(&r3)
    89  	c4, r4lo := shiftRightBy51(&r4)
    90  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
    91  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
    92  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
    93  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
    94  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
    95  
    96  	Comment("Second reduction chain (carryPropagate)")
    97  	// c0 = r0 >> 51
    98  	MOVQ(r0lo, c0)
    99  	SHRQ(Imm(51), c0)
   100  	// c1 = r1 >> 51
   101  	MOVQ(r1lo, c1)
   102  	SHRQ(Imm(51), c1)
   103  	// c2 = r2 >> 51
   104  	MOVQ(r2lo, c2)
   105  	SHRQ(Imm(51), c2)
   106  	// c3 = r3 >> 51
   107  	MOVQ(r3lo, c3)
   108  	SHRQ(Imm(51), c3)
   109  	// c4 = r4 >> 51
   110  	MOVQ(r4lo, c4)
   111  	SHRQ(Imm(51), c4)
   112  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
   113  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
   114  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
   115  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
   116  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
   117  
   118  	Comment("Store output")
   119  	out := Dereference(Param("out"))
   120  	Store(r0lo, out.Field("l0"))
   121  	Store(r1lo, out.Field("l1"))
   122  	Store(r2lo, out.Field("l2"))
   123  	Store(r3lo, out.Field("l3"))
   124  	Store(r4lo, out.Field("l4"))
   125  
   126  	RET()
   127  }
   128  
   129  func feMul() {
   130  	TEXT("feMul", NOSPLIT, "func(out, a, b *Element)")
   131  	Doc("feMul sets out = a * b. It works like feMulGeneric.")
   132  	Pragma("noescape")
   133  
   134  	a := Dereference(Param("a"))
   135  	a0 := namedComponent{a.Field("l0"), "a0"}
   136  	a1 := namedComponent{a.Field("l1"), "a1"}
   137  	a2 := namedComponent{a.Field("l2"), "a2"}
   138  	a3 := namedComponent{a.Field("l3"), "a3"}
   139  	a4 := namedComponent{a.Field("l4"), "a4"}
   140  
   141  	b := Dereference(Param("b"))
   142  	b0 := namedComponent{b.Field("l0"), "b0"}
   143  	b1 := namedComponent{b.Field("l1"), "b1"}
   144  	b2 := namedComponent{b.Field("l2"), "b2"}
   145  	b3 := namedComponent{b.Field("l3"), "b3"}
   146  	b4 := namedComponent{b.Field("l4"), "b4"}
   147  
   148  	// r0 = a0×b0 + 19×(a1×b4 + a2×b3 + a3×b2 + a4×b1)
   149  	r0 := uint128{"r0", GP64(), GP64()}
   150  	mul64(r0, 1, a0, b0)
   151  	addMul64(r0, 19, a1, b4)
   152  	addMul64(r0, 19, a2, b3)
   153  	addMul64(r0, 19, a3, b2)
   154  	addMul64(r0, 19, a4, b1)
   155  
   156  	// r1 = a0×b1 + a1×b0 + 19×(a2×b4 + a3×b3 + a4×b2)
   157  	r1 := uint128{"r1", GP64(), GP64()}
   158  	mul64(r1, 1, a0, b1)
   159  	addMul64(r1, 1, a1, b0)
   160  	addMul64(r1, 19, a2, b4)
   161  	addMul64(r1, 19, a3, b3)
   162  	addMul64(r1, 19, a4, b2)
   163  
   164  	// r2 = a0×b2 + a1×b1 + a2×b0 + 19×(a3×b4 + a4×b3)
   165  	r2 := uint128{"r2", GP64(), GP64()}
   166  	mul64(r2, 1, a0, b2)
   167  	addMul64(r2, 1, a1, b1)
   168  	addMul64(r2, 1, a2, b0)
   169  	addMul64(r2, 19, a3, b4)
   170  	addMul64(r2, 19, a4, b3)
   171  
   172  	// r3 = a0×b3 + a1×b2 + a2×b1 + a3×b0 + 19×a4×b4
   173  	r3 := uint128{"r3", GP64(), GP64()}
   174  	mul64(r3, 1, a0, b3)
   175  	addMul64(r3, 1, a1, b2)
   176  	addMul64(r3, 1, a2, b1)
   177  	addMul64(r3, 1, a3, b0)
   178  	addMul64(r3, 19, a4, b4)
   179  
   180  	// r4 = a0×b4 + a1×b3 + a2×b2 + a3×b1 + a4×b0
   181  	r4 := uint128{"r4", GP64(), GP64()}
   182  	mul64(r4, 1, a0, b4)
   183  	addMul64(r4, 1, a1, b3)
   184  	addMul64(r4, 1, a2, b2)
   185  	addMul64(r4, 1, a3, b1)
   186  	addMul64(r4, 1, a4, b0)
   187  
   188  	Comment("First reduction chain")
   189  	maskLow51Bits := GP64()
   190  	MOVQ(Imm((1<<51)-1), maskLow51Bits)
   191  	c0, r0lo := shiftRightBy51(&r0)
   192  	c1, r1lo := shiftRightBy51(&r1)
   193  	c2, r2lo := shiftRightBy51(&r2)
   194  	c3, r3lo := shiftRightBy51(&r3)
   195  	c4, r4lo := shiftRightBy51(&r4)
   196  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
   197  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
   198  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
   199  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
   200  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
   201  
   202  	Comment("Second reduction chain (carryPropagate)")
   203  	// c0 = r0 >> 51
   204  	MOVQ(r0lo, c0)
   205  	SHRQ(Imm(51), c0)
   206  	// c1 = r1 >> 51
   207  	MOVQ(r1lo, c1)
   208  	SHRQ(Imm(51), c1)
   209  	// c2 = r2 >> 51
   210  	MOVQ(r2lo, c2)
   211  	SHRQ(Imm(51), c2)
   212  	// c3 = r3 >> 51
   213  	MOVQ(r3lo, c3)
   214  	SHRQ(Imm(51), c3)
   215  	// c4 = r4 >> 51
   216  	MOVQ(r4lo, c4)
   217  	SHRQ(Imm(51), c4)
   218  	maskAndAdd(r0lo, maskLow51Bits, c4, 19)
   219  	maskAndAdd(r1lo, maskLow51Bits, c0, 1)
   220  	maskAndAdd(r2lo, maskLow51Bits, c1, 1)
   221  	maskAndAdd(r3lo, maskLow51Bits, c2, 1)
   222  	maskAndAdd(r4lo, maskLow51Bits, c3, 1)
   223  
   224  	Comment("Store output")
   225  	out := Dereference(Param("out"))
   226  	Store(r0lo, out.Field("l0"))
   227  	Store(r1lo, out.Field("l1"))
   228  	Store(r2lo, out.Field("l2"))
   229  	Store(r3lo, out.Field("l3"))
   230  	Store(r4lo, out.Field("l4"))
   231  
   232  	RET()
   233  }
   234  
   235  // mul64 sets r to i * aX * bX.
   236  func mul64(r uint128, i int, aX, bX namedComponent) {
   237  	switch i {
   238  	case 1:
   239  		Comment(fmt.Sprintf("%s = %s×%s", r, aX, bX))
   240  		Load(aX, RAX)
   241  	case 2:
   242  		Comment(fmt.Sprintf("%s = 2×%s×%s", r, aX, bX))
   243  		Load(aX, RAX)
   244  		SHLQ(Imm(1), RAX)
   245  	default:
   246  		panic("unsupported i value")
   247  	}
   248  	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
   249  	MOVQ(RAX, r.lo)
   250  	MOVQ(RDX, r.hi)
   251  }
   252  
   253  // addMul64 sets r to r + i * aX * bX.
   254  func addMul64(r uint128, i uint64, aX, bX namedComponent) {
   255  	switch i {
   256  	case 1:
   257  		Comment(fmt.Sprintf("%s += %s×%s", r, aX, bX))
   258  		Load(aX, RAX)
   259  	default:
   260  		Comment(fmt.Sprintf("%s += %d×%s×%s", r, i, aX, bX))
   261  		IMUL3Q(Imm(i), Load(aX, GP64()), RAX)
   262  	}
   263  	MULQ(mustAddr(bX)) // RDX, RAX = RAX * bX
   264  	ADDQ(RAX, r.lo)
   265  	ADCQ(RDX, r.hi)
   266  }
   267  
   268  // shiftRightBy51 returns r >> 51 and r.lo.
   269  //
   270  // After this function is called, the uint128 may not be used anymore.
   271  func shiftRightBy51(r *uint128) (out, lo GPVirtual) {
   272  	out = r.hi
   273  	lo = r.lo
   274  	SHLQ(Imm(64-51), r.lo, r.hi)
   275  	r.lo, r.hi = nil, nil // make sure the uint128 is unusable
   276  	return
   277  }
   278  
   279  // maskAndAdd sets r = r&mask + c*i.
   280  func maskAndAdd(r, mask, c GPVirtual, i uint64) {
   281  	ANDQ(mask, r)
   282  	if i != 1 {
   283  		IMUL3Q(Imm(i), c, c)
   284  	}
   285  	ADDQ(c, r)
   286  }
   287  
   288  func mustAddr(c Component) Op {
   289  	b, err := c.Resolve()
   290  	if err != nil {
   291  		panic(err)
   292  	}
   293  	return b.Addr
   294  }
   295  

View as plain text