Source file src/crypto/subtle/constant_time_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package subtle
     6  
     7  import (
     8  	"testing"
     9  	"testing/quick"
    10  )
    11  
    12  type TestConstantTimeCompareStruct struct {
    13  	a, b []byte
    14  	out  int
    15  }
    16  
    17  var testConstantTimeCompareData = []TestConstantTimeCompareStruct{
    18  	{[]byte{}, []byte{}, 1},
    19  	{[]byte{0x11}, []byte{0x11}, 1},
    20  	{[]byte{0x12}, []byte{0x11}, 0},
    21  	{[]byte{0x11}, []byte{0x11, 0x12}, 0},
    22  	{[]byte{0x11, 0x12}, []byte{0x11}, 0},
    23  }
    24  
    25  func TestConstantTimeCompare(t *testing.T) {
    26  	for i, test := range testConstantTimeCompareData {
    27  		if r := ConstantTimeCompare(test.a, test.b); r != test.out {
    28  			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
    29  		}
    30  	}
    31  }
    32  
    33  type TestConstantTimeByteEqStruct struct {
    34  	a, b uint8
    35  	out  int
    36  }
    37  
    38  var testConstandTimeByteEqData = []TestConstantTimeByteEqStruct{
    39  	{0, 0, 1},
    40  	{0, 1, 0},
    41  	{1, 0, 0},
    42  	{0xff, 0xff, 1},
    43  	{0xff, 0xfe, 0},
    44  }
    45  
    46  func byteEq(a, b uint8) int {
    47  	if a == b {
    48  		return 1
    49  	}
    50  	return 0
    51  }
    52  
    53  func TestConstantTimeByteEq(t *testing.T) {
    54  	for i, test := range testConstandTimeByteEqData {
    55  		if r := ConstantTimeByteEq(test.a, test.b); r != test.out {
    56  			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
    57  		}
    58  	}
    59  	err := quick.CheckEqual(ConstantTimeByteEq, byteEq, nil)
    60  	if err != nil {
    61  		t.Error(err)
    62  	}
    63  }
    64  
    65  func eq(a, b int32) int {
    66  	if a == b {
    67  		return 1
    68  	}
    69  	return 0
    70  }
    71  
    72  func TestConstantTimeEq(t *testing.T) {
    73  	err := quick.CheckEqual(ConstantTimeEq, eq, nil)
    74  	if err != nil {
    75  		t.Error(err)
    76  	}
    77  }
    78  
    79  func makeCopy(v int, x, y []byte) []byte {
    80  	if len(x) > len(y) {
    81  		x = x[:len(y)]
    82  	} else {
    83  		y = y[:len(x)]
    84  	}
    85  	if v == 1 {
    86  		copy(x, y)
    87  	}
    88  	return x
    89  }
    90  
    91  func constantTimeCopyWrapper(v int, x, y []byte) []byte {
    92  	if len(x) > len(y) {
    93  		x = x[:len(y)]
    94  	} else {
    95  		y = y[:len(x)]
    96  	}
    97  	v &= 1
    98  	ConstantTimeCopy(v, x, y)
    99  	return x
   100  }
   101  
   102  func TestConstantTimeCopy(t *testing.T) {
   103  	err := quick.CheckEqual(constantTimeCopyWrapper, makeCopy, nil)
   104  	if err != nil {
   105  		t.Error(err)
   106  	}
   107  }
   108  
   109  var lessOrEqTests = []struct {
   110  	x, y, result int
   111  }{
   112  	{0, 0, 1},
   113  	{1, 0, 0},
   114  	{0, 1, 1},
   115  	{10, 20, 1},
   116  	{20, 10, 0},
   117  	{10, 10, 1},
   118  }
   119  
   120  func TestConstantTimeLessOrEq(t *testing.T) {
   121  	for i, test := range lessOrEqTests {
   122  		result := ConstantTimeLessOrEq(test.x, test.y)
   123  		if result != test.result {
   124  			t.Errorf("#%d: %d <= %d gave %d, expected %d", i, test.x, test.y, result, test.result)
   125  		}
   126  	}
   127  }
   128  
   129  var benchmarkGlobal uint8
   130  
   131  func BenchmarkConstantTimeByteEq(b *testing.B) {
   132  	var x, y uint8
   133  
   134  	for i := 0; i < b.N; i++ {
   135  		x, y = uint8(ConstantTimeByteEq(x, y)), x
   136  	}
   137  
   138  	benchmarkGlobal = x
   139  }
   140  
   141  func BenchmarkConstantTimeEq(b *testing.B) {
   142  	var x, y int
   143  
   144  	for i := 0; i < b.N; i++ {
   145  		x, y = ConstantTimeEq(int32(x), int32(y)), x
   146  	}
   147  
   148  	benchmarkGlobal = uint8(x)
   149  }
   150  
   151  func BenchmarkConstantTimeLessOrEq(b *testing.B) {
   152  	var x, y int
   153  
   154  	for i := 0; i < b.N; i++ {
   155  		x, y = ConstantTimeLessOrEq(x, y), x
   156  	}
   157  
   158  	benchmarkGlobal = uint8(x)
   159  }
   160  

View as plain text