Source file src/net/http/transport_internal_test.go

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // White-box tests for transport.go (in package http instead of http_test).
     6  
     7  package http
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"errors"
    14  	"io"
    15  	"net"
    16  	"net/http/internal/testcert"
    17  	"strings"
    18  	"testing"
    19  )
    20  
    21  // Issue 15446: incorrect wrapping of errors when server closes an idle connection.
    22  func TestTransportPersistConnReadLoopEOF(t *testing.T) {
    23  	ln := newLocalListener(t)
    24  	defer ln.Close()
    25  
    26  	connc := make(chan net.Conn, 1)
    27  	go func() {
    28  		defer close(connc)
    29  		c, err := ln.Accept()
    30  		if err != nil {
    31  			t.Error(err)
    32  			return
    33  		}
    34  		connc <- c
    35  	}()
    36  
    37  	tr := new(Transport)
    38  	req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
    39  	req = req.WithT(t)
    40  	ctx, cancel := context.WithCancelCause(context.Background())
    41  	treq := &transportRequest{Request: req, ctx: ctx, cancel: cancel}
    42  	cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
    43  	pc, err := tr.getConn(treq, cm)
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	defer pc.close(errors.New("test over"))
    48  
    49  	conn := <-connc
    50  	if conn == nil {
    51  		// Already called t.Error in the accept goroutine.
    52  		return
    53  	}
    54  	conn.Close() // simulate the server hanging up on the client
    55  
    56  	_, err = pc.roundTrip(treq)
    57  	if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
    58  		t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle, transportReadFromServerError, or nothingWrittenError", err, err)
    59  	}
    60  
    61  	<-pc.closech
    62  	err = pc.closed
    63  	if !isNothingWrittenError(err) && !isTransportReadFromServerError(err) && err != errServerClosedIdle {
    64  		t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError, or nothingWrittenError", err, err)
    65  	}
    66  }
    67  
    68  func isNothingWrittenError(err error) bool {
    69  	_, ok := err.(nothingWrittenError)
    70  	return ok
    71  }
    72  
    73  func isTransportReadFromServerError(err error) bool {
    74  	_, ok := err.(transportReadFromServerError)
    75  	return ok
    76  }
    77  
    78  func newLocalListener(t *testing.T) net.Listener {
    79  	ln, err := net.Listen("tcp", "127.0.0.1:0")
    80  	if err != nil {
    81  		ln, err = net.Listen("tcp6", "[::1]:0")
    82  	}
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  	return ln
    87  }
    88  
    89  func dummyRequest(method string) *Request {
    90  	req, err := NewRequest(method, "http://fake.tld/", nil)
    91  	if err != nil {
    92  		panic(err)
    93  	}
    94  	return req
    95  }
    96  func dummyRequestWithBody(method string) *Request {
    97  	req, err := NewRequest(method, "http://fake.tld/", strings.NewReader("foo"))
    98  	if err != nil {
    99  		panic(err)
   100  	}
   101  	return req
   102  }
   103  
   104  func dummyRequestWithBodyNoGetBody(method string) *Request {
   105  	req := dummyRequestWithBody(method)
   106  	req.GetBody = nil
   107  	return req
   108  }
   109  
   110  // issue22091Error acts like a golang.org/x/net/http2.ErrNoCachedConn.
   111  type issue22091Error struct{}
   112  
   113  func (issue22091Error) IsHTTP2NoCachedConnError() {}
   114  func (issue22091Error) Error() string             { return "issue22091Error" }
   115  
   116  func TestTransportShouldRetryRequest(t *testing.T) {
   117  	tests := []struct {
   118  		pc  *persistConn
   119  		req *Request
   120  
   121  		err  error
   122  		want bool
   123  	}{
   124  		0: {
   125  			pc:   &persistConn{reused: false},
   126  			req:  dummyRequest("POST"),
   127  			err:  nothingWrittenError{},
   128  			want: false,
   129  		},
   130  		1: {
   131  			pc:   &persistConn{reused: true},
   132  			req:  dummyRequest("POST"),
   133  			err:  nothingWrittenError{},
   134  			want: true,
   135  		},
   136  		2: {
   137  			pc:   &persistConn{reused: true},
   138  			req:  dummyRequest("POST"),
   139  			err:  http2ErrNoCachedConn,
   140  			want: true,
   141  		},
   142  		3: {
   143  			pc:   nil,
   144  			req:  nil,
   145  			err:  issue22091Error{}, // like an external http2ErrNoCachedConn
   146  			want: true,
   147  		},
   148  		4: {
   149  			pc:   &persistConn{reused: true},
   150  			req:  dummyRequest("POST"),
   151  			err:  errMissingHost,
   152  			want: false,
   153  		},
   154  		5: {
   155  			pc:   &persistConn{reused: true},
   156  			req:  dummyRequest("POST"),
   157  			err:  transportReadFromServerError{},
   158  			want: false,
   159  		},
   160  		6: {
   161  			pc:   &persistConn{reused: true},
   162  			req:  dummyRequest("GET"),
   163  			err:  transportReadFromServerError{},
   164  			want: true,
   165  		},
   166  		7: {
   167  			pc:   &persistConn{reused: true},
   168  			req:  dummyRequest("GET"),
   169  			err:  errServerClosedIdle,
   170  			want: true,
   171  		},
   172  		8: {
   173  			pc:   &persistConn{reused: true},
   174  			req:  dummyRequestWithBody("POST"),
   175  			err:  nothingWrittenError{},
   176  			want: true,
   177  		},
   178  		9: {
   179  			pc:   &persistConn{reused: true},
   180  			req:  dummyRequestWithBodyNoGetBody("POST"),
   181  			err:  nothingWrittenError{},
   182  			want: false,
   183  		},
   184  	}
   185  	for i, tt := range tests {
   186  		got := tt.pc.shouldRetryRequest(tt.req, tt.err)
   187  		if got != tt.want {
   188  			t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want)
   189  		}
   190  	}
   191  }
   192  
   193  type roundTripFunc func(r *Request) (*Response, error)
   194  
   195  func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
   196  	return f(r)
   197  }
   198  
   199  // Issue 25009
   200  func TestTransportBodyAltRewind(t *testing.T) {
   201  	cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
   202  	if err != nil {
   203  		t.Fatal(err)
   204  	}
   205  	ln := newLocalListener(t)
   206  	defer ln.Close()
   207  
   208  	go func() {
   209  		tln := tls.NewListener(ln, &tls.Config{
   210  			NextProtos:   []string{"foo"},
   211  			Certificates: []tls.Certificate{cert},
   212  		})
   213  		for i := 0; i < 2; i++ {
   214  			sc, err := tln.Accept()
   215  			if err != nil {
   216  				t.Error(err)
   217  				return
   218  			}
   219  			if err := sc.(*tls.Conn).Handshake(); err != nil {
   220  				t.Error(err)
   221  				return
   222  			}
   223  			sc.Close()
   224  		}
   225  	}()
   226  
   227  	addr := ln.Addr().String()
   228  	req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
   229  	roundTripped := false
   230  	tr := &Transport{
   231  		DisableKeepAlives: true,
   232  		TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
   233  			"foo": func(authority string, c *tls.Conn) RoundTripper {
   234  				return roundTripFunc(func(r *Request) (*Response, error) {
   235  					n, _ := io.Copy(io.Discard, r.Body)
   236  					if n == 0 {
   237  						t.Error("body length is zero")
   238  					}
   239  					if roundTripped {
   240  						return &Response{
   241  							Body:       NoBody,
   242  							StatusCode: 200,
   243  						}, nil
   244  					}
   245  					roundTripped = true
   246  					return nil, http2noCachedConnError{}
   247  				})
   248  			},
   249  		},
   250  		DialTLS: func(_, _ string) (net.Conn, error) {
   251  			tc, err := tls.Dial("tcp", addr, &tls.Config{
   252  				InsecureSkipVerify: true,
   253  				NextProtos:         []string{"foo"},
   254  			})
   255  			if err != nil {
   256  				return nil, err
   257  			}
   258  			if err := tc.Handshake(); err != nil {
   259  				return nil, err
   260  			}
   261  			return tc, nil
   262  		},
   263  	}
   264  	c := &Client{Transport: tr}
   265  	_, err = c.Do(req)
   266  	if err != nil {
   267  		t.Error(err)
   268  	}
   269  }
   270  

View as plain text