Source file src/cmd/go/internal/vcweb/vcstest/vcstest.go

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package vcstest serves the repository scripts in cmd/go/testdata/vcstest
     6  // using the [vcweb] script engine.
     7  package vcstest
     8  
     9  import (
    10  	"cmd/go/internal/vcs"
    11  	"cmd/go/internal/vcweb"
    12  	"cmd/go/internal/web"
    13  	"crypto/tls"
    14  	"crypto/x509"
    15  	"encoding/pem"
    16  	"fmt"
    17  	"internal/testenv"
    18  	"io"
    19  	"log"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"net/url"
    23  	"os"
    24  	"path/filepath"
    25  	"testing"
    26  )
    27  
    28  var Hosts = []string{
    29  	"vcs-test.golang.org",
    30  }
    31  
    32  type Server struct {
    33  	vcweb   *vcweb.Server
    34  	workDir string
    35  	HTTP    *httptest.Server
    36  	HTTPS   *httptest.Server
    37  }
    38  
    39  // NewServer returns a new test-local vcweb server that serves VCS requests
    40  // for modules with paths that begin with "vcs-test.golang.org" using the
    41  // scripts in cmd/go/testdata/vcstest.
    42  func NewServer() (srv *Server, err error) {
    43  	if vcs.VCSTestRepoURL != "" {
    44  		panic("vcs URL hooks already set")
    45  	}
    46  
    47  	scriptDir := filepath.Join(testenv.GOROOT(nil), "src/cmd/go/testdata/vcstest")
    48  
    49  	workDir, err := os.MkdirTemp("", "vcstest")
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	defer func() {
    54  		if err != nil {
    55  			os.RemoveAll(workDir)
    56  		}
    57  	}()
    58  
    59  	logger := log.Default()
    60  	if !testing.Verbose() {
    61  		logger = log.New(io.Discard, "", log.LstdFlags)
    62  	}
    63  	handler, err := vcweb.NewServer(scriptDir, workDir, logger)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	defer func() {
    68  		if err != nil {
    69  			handler.Close()
    70  		}
    71  	}()
    72  
    73  	srvHTTP := httptest.NewServer(handler)
    74  	httpURL, err := url.Parse(srvHTTP.URL)
    75  	if err != nil {
    76  		return nil, err
    77  	}
    78  	defer func() {
    79  		if err != nil {
    80  			srvHTTP.Close()
    81  		}
    82  	}()
    83  
    84  	srvHTTPS := httptest.NewTLSServer(handler)
    85  	httpsURL, err := url.Parse(srvHTTPS.URL)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	defer func() {
    90  		if err != nil {
    91  			srvHTTPS.Close()
    92  		}
    93  	}()
    94  
    95  	srv = &Server{
    96  		vcweb:   handler,
    97  		workDir: workDir,
    98  		HTTP:    srvHTTP,
    99  		HTTPS:   srvHTTPS,
   100  	}
   101  	vcs.VCSTestRepoURL = srv.HTTP.URL
   102  	vcs.VCSTestHosts = Hosts
   103  
   104  	var interceptors []web.Interceptor
   105  	for _, host := range Hosts {
   106  		interceptors = append(interceptors,
   107  			web.Interceptor{Scheme: "http", FromHost: host, ToHost: httpURL.Host, Client: srv.HTTP.Client()},
   108  			web.Interceptor{Scheme: "https", FromHost: host, ToHost: httpsURL.Host, Client: srv.HTTPS.Client()})
   109  	}
   110  	web.EnableTestHooks(interceptors)
   111  
   112  	fmt.Fprintln(os.Stderr, "vcs-test.golang.org rerouted to "+srv.HTTP.URL)
   113  	fmt.Fprintln(os.Stderr, "https://vcs-test.golang.org rerouted to "+srv.HTTPS.URL)
   114  
   115  	return srv, nil
   116  }
   117  
   118  func (srv *Server) Close() error {
   119  	if vcs.VCSTestRepoURL != srv.HTTP.URL {
   120  		panic("vcs URL hooks modified before Close")
   121  	}
   122  	vcs.VCSTestRepoURL = ""
   123  	vcs.VCSTestHosts = nil
   124  	web.DisableTestHooks()
   125  
   126  	srv.HTTP.Close()
   127  	srv.HTTPS.Close()
   128  	err := srv.vcweb.Close()
   129  	if rmErr := os.RemoveAll(srv.workDir); err == nil {
   130  		err = rmErr
   131  	}
   132  	return err
   133  }
   134  
   135  func (srv *Server) WriteCertificateFile() (string, error) {
   136  	b := pem.EncodeToMemory(&pem.Block{
   137  		Type:  "CERTIFICATE",
   138  		Bytes: srv.HTTPS.Certificate().Raw,
   139  	})
   140  
   141  	filename := filepath.Join(srv.workDir, "cert.pem")
   142  	if err := os.WriteFile(filename, b, 0644); err != nil {
   143  		return "", err
   144  	}
   145  	return filename, nil
   146  }
   147  
   148  // TLSClient returns an http.Client that can talk to the httptest.Server
   149  // whose certificate is written to the given file path.
   150  func TLSClient(certFile string) (*http.Client, error) {
   151  	client := &http.Client{
   152  		Transport: http.DefaultTransport.(*http.Transport).Clone(),
   153  	}
   154  
   155  	pemBytes, err := os.ReadFile(certFile)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  
   160  	certpool := x509.NewCertPool()
   161  	if !certpool.AppendCertsFromPEM(pemBytes) {
   162  		return nil, fmt.Errorf("no certificates found in %s", certFile)
   163  	}
   164  	client.Transport.(*http.Transport).TLSClientConfig = &tls.Config{
   165  		RootCAs: certpool,
   166  	}
   167  
   168  	return client, nil
   169  }
   170  

View as plain text