Source file src/errors/wrap_test.go

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package errors_test
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io/fs"
    11  	"os"
    12  	"reflect"
    13  	"testing"
    14  )
    15  
    16  func TestIs(t *testing.T) {
    17  	err1 := errors.New("1")
    18  	erra := wrapped{"wrap 2", err1}
    19  	errb := wrapped{"wrap 3", erra}
    20  
    21  	err3 := errors.New("3")
    22  
    23  	poser := &poser{"either 1 or 3", func(err error) bool {
    24  		return err == err1 || err == err3
    25  	}}
    26  
    27  	testCases := []struct {
    28  		err    error
    29  		target error
    30  		match  bool
    31  	}{
    32  		{nil, nil, true},
    33  		{nil, err1, false},
    34  		{err1, nil, false},
    35  		{err1, err1, true},
    36  		{erra, err1, true},
    37  		{errb, err1, true},
    38  		{err1, err3, false},
    39  		{erra, err3, false},
    40  		{errb, err3, false},
    41  		{poser, err1, true},
    42  		{poser, err3, true},
    43  		{poser, erra, false},
    44  		{poser, errb, false},
    45  		{errorUncomparable{}, errorUncomparable{}, true},
    46  		{errorUncomparable{}, &errorUncomparable{}, false},
    47  		{&errorUncomparable{}, errorUncomparable{}, true},
    48  		{&errorUncomparable{}, &errorUncomparable{}, false},
    49  		{errorUncomparable{}, err1, false},
    50  		{&errorUncomparable{}, err1, false},
    51  		{multiErr{}, err1, false},
    52  		{multiErr{err1, err3}, err1, true},
    53  		{multiErr{err3, err1}, err1, true},
    54  		{multiErr{err1, err3}, errors.New("x"), false},
    55  		{multiErr{err3, errb}, errb, true},
    56  		{multiErr{err3, errb}, erra, true},
    57  		{multiErr{err3, errb}, err1, true},
    58  		{multiErr{errb, err3}, err1, true},
    59  		{multiErr{poser}, err1, true},
    60  		{multiErr{poser}, err3, true},
    61  		{multiErr{nil}, nil, false},
    62  	}
    63  	for _, tc := range testCases {
    64  		t.Run("", func(t *testing.T) {
    65  			if got := errors.Is(tc.err, tc.target); got != tc.match {
    66  				t.Errorf("Is(%v, %v) = %v, want %v", tc.err, tc.target, got, tc.match)
    67  			}
    68  		})
    69  	}
    70  }
    71  
    72  type poser struct {
    73  	msg string
    74  	f   func(error) bool
    75  }
    76  
    77  var poserPathErr = &fs.PathError{Op: "poser"}
    78  
    79  func (p *poser) Error() string     { return p.msg }
    80  func (p *poser) Is(err error) bool { return p.f(err) }
    81  func (p *poser) As(err any) bool {
    82  	switch x := err.(type) {
    83  	case **poser:
    84  		*x = p
    85  	case *errorT:
    86  		*x = errorT{"poser"}
    87  	case **fs.PathError:
    88  		*x = poserPathErr
    89  	default:
    90  		return false
    91  	}
    92  	return true
    93  }
    94  
    95  func TestAs(t *testing.T) {
    96  	var errT errorT
    97  	var errP *fs.PathError
    98  	var timeout interface{ Timeout() bool }
    99  	var p *poser
   100  	_, errF := os.Open("non-existing")
   101  	poserErr := &poser{"oh no", nil}
   102  
   103  	testCases := []struct {
   104  		err    error
   105  		target any
   106  		match  bool
   107  		want   any // value of target on match
   108  	}{{
   109  		nil,
   110  		&errP,
   111  		false,
   112  		nil,
   113  	}, {
   114  		wrapped{"pitied the fool", errorT{"T"}},
   115  		&errT,
   116  		true,
   117  		errorT{"T"},
   118  	}, {
   119  		errF,
   120  		&errP,
   121  		true,
   122  		errF,
   123  	}, {
   124  		errorT{},
   125  		&errP,
   126  		false,
   127  		nil,
   128  	}, {
   129  		wrapped{"wrapped", nil},
   130  		&errT,
   131  		false,
   132  		nil,
   133  	}, {
   134  		&poser{"error", nil},
   135  		&errT,
   136  		true,
   137  		errorT{"poser"},
   138  	}, {
   139  		&poser{"path", nil},
   140  		&errP,
   141  		true,
   142  		poserPathErr,
   143  	}, {
   144  		poserErr,
   145  		&p,
   146  		true,
   147  		poserErr,
   148  	}, {
   149  		errors.New("err"),
   150  		&timeout,
   151  		false,
   152  		nil,
   153  	}, {
   154  		errF,
   155  		&timeout,
   156  		true,
   157  		errF,
   158  	}, {
   159  		wrapped{"path error", errF},
   160  		&timeout,
   161  		true,
   162  		errF,
   163  	}, {
   164  		multiErr{},
   165  		&errT,
   166  		false,
   167  		nil,
   168  	}, {
   169  		multiErr{errors.New("a"), errorT{"T"}},
   170  		&errT,
   171  		true,
   172  		errorT{"T"},
   173  	}, {
   174  		multiErr{errorT{"T"}, errors.New("a")},
   175  		&errT,
   176  		true,
   177  		errorT{"T"},
   178  	}, {
   179  		multiErr{errorT{"a"}, errorT{"b"}},
   180  		&errT,
   181  		true,
   182  		errorT{"a"},
   183  	}, {
   184  		multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}},
   185  		&errT,
   186  		true,
   187  		errorT{"a"},
   188  	}, {
   189  		multiErr{wrapped{"path error", errF}},
   190  		&timeout,
   191  		true,
   192  		errF,
   193  	}, {
   194  		multiErr{nil},
   195  		&errT,
   196  		false,
   197  		nil,
   198  	}}
   199  	for i, tc := range testCases {
   200  		name := fmt.Sprintf("%d:As(Errorf(..., %v), %v)", i, tc.err, tc.target)
   201  		// Clear the target pointer, in case it was set in a previous test.
   202  		rtarget := reflect.ValueOf(tc.target)
   203  		rtarget.Elem().Set(reflect.Zero(reflect.TypeOf(tc.target).Elem()))
   204  		t.Run(name, func(t *testing.T) {
   205  			match := errors.As(tc.err, tc.target)
   206  			if match != tc.match {
   207  				t.Fatalf("match: got %v; want %v", match, tc.match)
   208  			}
   209  			if !match {
   210  				return
   211  			}
   212  			if got := rtarget.Elem().Interface(); got != tc.want {
   213  				t.Fatalf("got %#v, want %#v", got, tc.want)
   214  			}
   215  		})
   216  	}
   217  }
   218  
   219  func TestAsValidation(t *testing.T) {
   220  	var s string
   221  	testCases := []any{
   222  		nil,
   223  		(*int)(nil),
   224  		"error",
   225  		&s,
   226  	}
   227  	err := errors.New("error")
   228  	for _, tc := range testCases {
   229  		t.Run(fmt.Sprintf("%T(%v)", tc, tc), func(t *testing.T) {
   230  			defer func() {
   231  				recover()
   232  			}()
   233  			if errors.As(err, tc) {
   234  				t.Errorf("As(err, %T(%v)) = true, want false", tc, tc)
   235  				return
   236  			}
   237  			t.Errorf("As(err, %T(%v)) did not panic", tc, tc)
   238  		})
   239  	}
   240  }
   241  
   242  func BenchmarkIs(b *testing.B) {
   243  	err1 := errors.New("1")
   244  	err2 := multiErr{multiErr{multiErr{err1, errorT{"a"}}, errorT{"b"}}}
   245  
   246  	for i := 0; i < b.N; i++ {
   247  		if !errors.Is(err2, err1) {
   248  			b.Fatal("Is failed")
   249  		}
   250  	}
   251  }
   252  
   253  func BenchmarkAs(b *testing.B) {
   254  	err := multiErr{multiErr{multiErr{errors.New("a"), errorT{"a"}}, errorT{"b"}}}
   255  	for i := 0; i < b.N; i++ {
   256  		var target errorT
   257  		if !errors.As(err, &target) {
   258  			b.Fatal("As failed")
   259  		}
   260  	}
   261  }
   262  
   263  func TestUnwrap(t *testing.T) {
   264  	err1 := errors.New("1")
   265  	erra := wrapped{"wrap 2", err1}
   266  
   267  	testCases := []struct {
   268  		err  error
   269  		want error
   270  	}{
   271  		{nil, nil},
   272  		{wrapped{"wrapped", nil}, nil},
   273  		{err1, nil},
   274  		{erra, err1},
   275  		{wrapped{"wrap 3", erra}, erra},
   276  	}
   277  	for _, tc := range testCases {
   278  		if got := errors.Unwrap(tc.err); got != tc.want {
   279  			t.Errorf("Unwrap(%v) = %v, want %v", tc.err, got, tc.want)
   280  		}
   281  	}
   282  }
   283  
   284  type errorT struct{ s string }
   285  
   286  func (e errorT) Error() string { return fmt.Sprintf("errorT(%s)", e.s) }
   287  
   288  type wrapped struct {
   289  	msg string
   290  	err error
   291  }
   292  
   293  func (e wrapped) Error() string { return e.msg }
   294  func (e wrapped) Unwrap() error { return e.err }
   295  
   296  type multiErr []error
   297  
   298  func (m multiErr) Error() string   { return "multiError" }
   299  func (m multiErr) Unwrap() []error { return []error(m) }
   300  
   301  type errorUncomparable struct {
   302  	f []string
   303  }
   304  
   305  func (errorUncomparable) Error() string {
   306  	return "uncomparable error"
   307  }
   308  
   309  func (errorUncomparable) Is(target error) bool {
   310  	_, ok := target.(errorUncomparable)
   311  	return ok
   312  }
   313  

View as plain text