Source file src/simd/archsimd/internal/simd_test/shift_amd64_test.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build goexperiment.simd && amd64
     6  
     7  package simd_test
     8  
     9  import (
    10  	"simd/archsimd"
    11  	"testing"
    12  )
    13  
    14  func TestRotateAllLeftAMD64(t *testing.T) {
    15  	x := uint8(0x81)
    16  	if y := rotl(x, 1); y != 3 {
    17  		t.Errorf("Expected 3, got 0x%x", y)
    18  	}
    19  	if y := rotl(x, 7); y != 0xc0 {
    20  		t.Errorf("Expected 0xc0, got 0x%x", y)
    21  	}
    22  	if y := rotr(x, 4); y != 0x18 {
    23  		t.Errorf("Expected 0x18, got 0x%x", y)
    24  	}
    25  
    26  	for i := uint64(0); i < 65; i++ {
    27  		testUint64x4Unary(t, curry2(archsimd.Uint64x4.RotateAllLeft, i), rotlOfSlice[uint64](i))
    28  		testUint32x8Unary(t, curry2(archsimd.Uint32x8.RotateAllLeft, i), rotlOfSlice[uint32](i))
    29  		//		testUint16x16Unary(t, curry2(archsimd.Uint16x16.RotateAllLeft, i), rotlOfSlice[uint16](i))
    30  		//		testUint8x32Unary(t, curry2(archsimd.Uint8x32.RotateAllLeft, i), rotlOfSlice[uint8](i))
    31  	}
    32  
    33  }
    34  
    35  func TestRotateAllRightAMD64(t *testing.T) {
    36  	x := uint8(0x81)
    37  	if y := rotr(x, 1); y != 0xc0 {
    38  		t.Errorf("Expected 0xc0, got 0x%x", y)
    39  	}
    40  	if y := rotr(x, 7); y != 3 {
    41  		t.Errorf("Expected 3, got 0x%x", y)
    42  	}
    43  	if y := rotr(x, 4); y != 0x18 {
    44  		t.Errorf("Expected 0x18, got 0x%x", y)
    45  	}
    46  
    47  	for i := uint64(0); i < 65; i++ {
    48  		testUint64x4Unary(t, curry2(archsimd.Uint64x4.RotateAllRight, i), rotrOfSlice[uint64](i))
    49  		testUint32x8Unary(t, curry2(archsimd.Uint32x8.RotateAllRight, i), rotrOfSlice[uint32](i))
    50  		//		testUint16x16Unary(t, curry2(archsimd.Uint16x16.RotateAllLeft, i), rotlOfSlice[uint16](i))
    51  		//		testUint8x32Unary(t, curry2(archsimd.Uint8x32.RotateAllLeft, i), rotlOfSlice[uint8](i))
    52  	}
    53  }
    54  
    55  func TestShift(t *testing.T) {
    56  	if !archsimd.X86.AVX2() {
    57  		t.Skip("requires AVX2")
    58  	}
    59  
    60  	testInt32x4Binary(t,
    61  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftLeft(y.AsUint32x4()) },
    62  		map2(func(x, y int32) int32 { return x << uint32(y) }))
    63  	testInt32x4Binary(t,
    64  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftRight(y.AsUint32x4()) },
    65  		map2(func(x, y int32) int32 { return x >> uint32(y) }))
    66  	testUint32x4Binary(t,
    67  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftRight(y) },
    68  		map2(func(x, y uint32) uint32 { return x >> y }))
    69  }
    70  
    71  func concatInt32s(x, y int32) int64 {
    72  	return (int64(x) << 32) | int64(uint32(y))
    73  }
    74  
    75  func concatUint32s(x, y uint32) uint64 {
    76  	return (uint64(x) << 32) | uint64(y)
    77  }
    78  
    79  func TestShiftAllConcat(t *testing.T) {
    80  	if !archsimd.X86.AVX512VBMI2() {
    81  		t.Skip("requires AVX512-VBMI2")
    82  	}
    83  
    84  	// Note that unlike their non-Concat counterparts, these wrap the shift count.
    85  
    86  	hide := hideConst[uint64]
    87  
    88  	// ShiftAllLeftConcat
    89  	salc := func(shift uint64) func(x, y int32) int32 {
    90  		return func(x, y int32) int32 {
    91  			return int32(concatInt32s(x, y) >> (32 - shift%32))
    92  		}
    93  	}
    94  
    95  	testInt32x4Binary(t,
    96  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, 2) },
    97  		map2(salc(2)))
    98  	testInt32x4Binary(t,
    99  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, hide(2)) },
   100  		map2(salc(hide(2))))
   101  
   102  	testInt32x4Binary(t,
   103  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, 128) },
   104  		map2(salc(128)))
   105  	testInt32x4Binary(t,
   106  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllLeftConcatMod32(y, hide(128)) },
   107  		map2(salc(hide(128))))
   108  
   109  	// Signed ShiftAllRightConcat
   110  	sarc := func(shift uint64) func(x, y int32) int32 {
   111  		return func(x, y int32) int32 {
   112  			return int32(concatInt32s(y, x) >> (shift % 32))
   113  		}
   114  	}
   115  
   116  	testInt32x4Binary(t,
   117  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, 2) },
   118  		map2(sarc(2)))
   119  	testInt32x4Binary(t,
   120  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, hide(2)) },
   121  		map2(sarc(hide(2))))
   122  
   123  	testInt32x4Binary(t,
   124  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, 128) },
   125  		map2(sarc(128)))
   126  	testInt32x4Binary(t,
   127  		func(x, y archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftAllRightConcatMod32(y, hide(128)) },
   128  		map2(sarc(hide(128))))
   129  
   130  	// Unsigned ShiftAllRightConcat
   131  	usarc := func(shift uint64) func(x, y uint32) uint32 {
   132  		return func(x, y uint32) uint32 {
   133  			return uint32(concatUint32s(y, x) >> (shift % 32))
   134  		}
   135  	}
   136  
   137  	testUint32x4Binary(t,
   138  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, 2) },
   139  		map2(usarc(2)))
   140  	testUint32x4Binary(t,
   141  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, hide(2)) },
   142  		map2(usarc(hide(2))))
   143  
   144  	testUint32x4Binary(t,
   145  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, 128) },
   146  		map2(usarc(128)))
   147  	testUint32x4Binary(t,
   148  		func(x, y archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftAllRightConcatMod32(y, hide(128)) },
   149  		map2(usarc(hide(128))))
   150  }
   151  
   152  func TestShiftConcat(t *testing.T) {
   153  	if !archsimd.X86.AVX512VBMI2() {
   154  		t.Skip("requires AVX512-VBMI2")
   155  	}
   156  
   157  	// Note that unlike their non-Concat counterparts, these wrap the shift count.
   158  
   159  	testInt32x4Ternary(t,
   160  		func(x, y, z archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftLeftConcatMod32(y, z.AsUint32x4()) },
   161  		map3(func(x, y, z int32) int32 {
   162  			return int32(concatInt32s(x, y) >> (32 - uint32(z)%32))
   163  		}))
   164  
   165  	testInt32x4Ternary(t,
   166  		func(x, y, z archsimd.Int32x4) archsimd.Int32x4 { return x.ShiftRightConcatMod32(y, z.AsUint32x4()) },
   167  		map3(func(x, y, z int32) int32 {
   168  			return int32(concatInt32s(y, x) >> (uint32(z) % 32))
   169  		}))
   170  
   171  	testUint32x4Ternary(t,
   172  		func(x, y, z archsimd.Uint32x4) archsimd.Uint32x4 { return x.ShiftRightConcatMod32(y, z) },
   173  		map3(func(x, y, z uint32) uint32 {
   174  			return uint32(concatUint32s(y, x) >> (z % 32))
   175  		}))
   176  }
   177  
   178  func TestConcatShiftBytesRight(t *testing.T) {
   179  	hide := hideConst[uint64]
   180  
   181  	csbr := func(shift uint64) func(x, y []uint8) []uint8 {
   182  		return func(x, y []uint8) []uint8 {
   183  			z := make([]uint8, len(x))
   184  			for i := range z {
   185  				target := i + int(shift)
   186  				if target < 16 {
   187  					z[i] = y[target]
   188  				} else if target < 32 {
   189  					z[i] = x[(target - 16)]
   190  				}
   191  			}
   192  			return z
   193  		}
   194  	}
   195  
   196  	t.Run("Uint8x16", func(t *testing.T) {
   197  		if !archsimd.X86.AVX() {
   198  			t.Skip("requires AVX")
   199  		}
   200  		for _, shift := range []uint64{0, 2, 16, 20, 32, 128} {
   201  			t.Log("shift", shift)
   202  			testUint8x16Binary(t,
   203  				func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(y, shift) },
   204  				csbr(shift))
   205  			testUint8x16Binary(t,
   206  				func(x, y archsimd.Uint8x16) archsimd.Uint8x16 { return x.ConcatShiftBytesRight(y, hide(shift)) },
   207  				csbr(hide(shift)))
   208  		}
   209  	})
   210  
   211  	t.Run("Uint8x32", func(t *testing.T) {
   212  		if !archsimd.X86.AVX2() {
   213  			t.Skip("requires AVX2")
   214  		}
   215  		for _, shift := range []uint64{0, 2, 16, 20, 32, 128} {
   216  			t.Log("shift", shift)
   217  			testUint8x32Binary(t,
   218  				func(x, y archsimd.Uint8x32) archsimd.Uint8x32 { return x.ConcatShiftBytesRightGrouped(y, shift) },
   219  				grouped2(csbr(shift)))
   220  			testUint8x32Binary(t,
   221  				func(x, y archsimd.Uint8x32) archsimd.Uint8x32 { return x.ConcatShiftBytesRightGrouped(y, hide(shift)) },
   222  				grouped2(csbr(hide(shift))))
   223  		}
   224  	})
   225  
   226  	t.Run("Uint8x64", func(t *testing.T) {
   227  		if !archsimd.X86.AVX512() {
   228  			t.Skip("requires AVX512")
   229  		}
   230  		for _, shift := range []uint64{0, 2, 16, 20, 32, 128} {
   231  			t.Log("shift", shift)
   232  			testUint8x64Binary(t,
   233  				func(x, y archsimd.Uint8x64) archsimd.Uint8x64 { return x.ConcatShiftBytesRightGrouped(y, shift) },
   234  				grouped2(csbr(shift)))
   235  			testUint8x64Binary(t,
   236  				func(x, y archsimd.Uint8x64) archsimd.Uint8x64 { return x.ConcatShiftBytesRightGrouped(y, hide(shift)) },
   237  				grouped2(csbr(hide(shift))))
   238  		}
   239  	})
   240  }
   241  
   242  func TestShiftAllAMD64(t *testing.T) {
   243  	if archsimd.X86.AVX2() {
   244  		// ShiftAllLeft
   245  		testInt16x16ShiftAll(t, archsimd.Int16x16.ShiftAllLeft, shiftAllLeftSlice[int16])
   246  		testInt32x8ShiftAll(t, archsimd.Int32x8.ShiftAllLeft, shiftAllLeftSlice[int32])
   247  		testInt64x4ShiftAll(t, archsimd.Int64x4.ShiftAllLeft, shiftAllLeftSlice[int64])
   248  		testUint16x16ShiftAll(t, archsimd.Uint16x16.ShiftAllLeft, shiftAllLeftSlice[uint16])
   249  		testUint32x8ShiftAll(t, archsimd.Uint32x8.ShiftAllLeft, shiftAllLeftSlice[uint32])
   250  		testUint64x4ShiftAll(t, archsimd.Uint64x4.ShiftAllLeft, shiftAllLeftSlice[uint64])
   251  
   252  		// ShiftAllRight signed
   253  		testInt16x16ShiftAll(t, archsimd.Int16x16.ShiftAllRight, shiftAllRightSlice[int16])
   254  		testInt32x8ShiftAll(t, archsimd.Int32x8.ShiftAllRight, shiftAllRightSlice[int32])
   255  		// Int64x4 ShiftAllRight requires AVX-512
   256  
   257  		// ShiftAllRight unsigned
   258  		testUint16x16ShiftAll(t, archsimd.Uint16x16.ShiftAllRight, shiftAllRightSlice[uint16])
   259  		testUint32x8ShiftAll(t, archsimd.Uint32x8.ShiftAllRight, shiftAllRightSlice[uint32])
   260  		testUint64x4ShiftAll(t, archsimd.Uint64x4.ShiftAllRight, shiftAllRightSlice[uint64])
   261  	}
   262  
   263  	if archsimd.X86.AVX512() {
   264  		// 512-bit vectors (AVX512)
   265  		// ShiftAllLeft
   266  		testInt16x32ShiftAll(t, archsimd.Int16x32.ShiftAllLeft, shiftAllLeftSlice[int16])
   267  		testInt32x16ShiftAll(t, archsimd.Int32x16.ShiftAllLeft, shiftAllLeftSlice[int32])
   268  		testInt64x8ShiftAll(t, archsimd.Int64x8.ShiftAllLeft, shiftAllLeftSlice[int64])
   269  		testUint16x32ShiftAll(t, archsimd.Uint16x32.ShiftAllLeft, shiftAllLeftSlice[uint16])
   270  		testUint32x16ShiftAll(t, archsimd.Uint32x16.ShiftAllLeft, shiftAllLeftSlice[uint32])
   271  		testUint64x8ShiftAll(t, archsimd.Uint64x8.ShiftAllLeft, shiftAllLeftSlice[uint64])
   272  
   273  		// ShiftAllRight signed
   274  		testInt16x32ShiftAll(t, archsimd.Int16x32.ShiftAllRight, shiftAllRightSlice[int16])
   275  		testInt32x16ShiftAll(t, archsimd.Int32x16.ShiftAllRight, shiftAllRightSlice[int32])
   276  		testInt64x8ShiftAll(t, archsimd.Int64x8.ShiftAllRight, shiftAllRightSlice[int64])
   277  		// 256-bit Int64x4 ShiftAllRight (requires AVX-512)
   278  		testInt64x4ShiftAll(t, archsimd.Int64x4.ShiftAllRight, shiftAllRightSlice[int64])
   279  
   280  		// ShiftAllRight unsigned
   281  		testUint16x32ShiftAll(t, archsimd.Uint16x32.ShiftAllRight, shiftAllRightSlice[uint16])
   282  		testUint32x16ShiftAll(t, archsimd.Uint32x16.ShiftAllRight, shiftAllRightSlice[uint32])
   283  		testUint64x8ShiftAll(t, archsimd.Uint64x8.ShiftAllRight, shiftAllRightSlice[uint64])
   284  	}
   285  }
   286  

View as plain text