1
2
3
4
5
6
7 package tls
8
9
10
11
12
13
14 import (
15 "bytes"
16 "context"
17 "crypto"
18 "crypto/ecdsa"
19 "crypto/ed25519"
20 "crypto/rsa"
21 "crypto/x509"
22 "encoding/pem"
23 "errors"
24 "fmt"
25 "internal/godebug"
26 "net"
27 "os"
28 "strings"
29 )
30
31
32
33
34
35 func Server(conn net.Conn, config *Config) *Conn {
36 c := &Conn{
37 conn: conn,
38 config: config,
39 }
40 c.handshakeFn = c.serverHandshake
41 return c
42 }
43
44
45
46
47
48 func Client(conn net.Conn, config *Config) *Conn {
49 c := &Conn{
50 conn: conn,
51 config: config,
52 isClient: true,
53 }
54 c.handshakeFn = c.clientHandshake
55 return c
56 }
57
58
59 type listener struct {
60 net.Listener
61 config *Config
62 }
63
64
65
66 func (l *listener) Accept() (net.Conn, error) {
67 c, err := l.Listener.Accept()
68 if err != nil {
69 return nil, err
70 }
71 return Server(c, l.config), nil
72 }
73
74
75
76
77
78 func NewListener(inner net.Listener, config *Config) net.Listener {
79 l := new(listener)
80 l.Listener = inner
81 l.config = config
82 return l
83 }
84
85
86
87
88
89 func Listen(network, laddr string, config *Config) (net.Listener, error) {
90
91 if config == nil || len(config.Certificates) == 0 &&
92 config.GetCertificate == nil && config.GetConfigForClient == nil {
93 return nil, errors.New("tls: neither Certificates, GetCertificate, nor GetConfigForClient set in Config")
94 }
95 l, err := net.Listen(network, laddr)
96 if err != nil {
97 return nil, err
98 }
99 return NewListener(l, config), nil
100 }
101
102 type timeoutError struct{}
103
104 func (timeoutError) Error() string { return "tls: DialWithDialer timed out" }
105 func (timeoutError) Timeout() bool { return true }
106 func (timeoutError) Temporary() bool { return true }
107
108
109
110
111
112
113
114
115
116
117
118 func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
119 return dial(context.Background(), dialer, network, addr, config)
120 }
121
122 func dial(ctx context.Context, netDialer *net.Dialer, network, addr string, config *Config) (*Conn, error) {
123 if netDialer.Timeout != 0 {
124 var cancel context.CancelFunc
125 ctx, cancel = context.WithTimeout(ctx, netDialer.Timeout)
126 defer cancel()
127 }
128
129 if !netDialer.Deadline.IsZero() {
130 var cancel context.CancelFunc
131 ctx, cancel = context.WithDeadline(ctx, netDialer.Deadline)
132 defer cancel()
133 }
134
135 rawConn, err := netDialer.DialContext(ctx, network, addr)
136 if err != nil {
137 return nil, err
138 }
139
140 colonPos := strings.LastIndex(addr, ":")
141 if colonPos == -1 {
142 colonPos = len(addr)
143 }
144 hostname := addr[:colonPos]
145
146 if config == nil {
147 config = defaultConfig()
148 }
149
150
151 if config.ServerName == "" {
152
153 c := config.Clone()
154 c.ServerName = hostname
155 config = c
156 }
157
158 conn := Client(rawConn, config)
159 if err := conn.HandshakeContext(ctx); err != nil {
160 rawConn.Close()
161 return nil, err
162 }
163 return conn, nil
164 }
165
166
167
168
169
170
171
172 func Dial(network, addr string, config *Config) (*Conn, error) {
173 return DialWithDialer(new(net.Dialer), network, addr, config)
174 }
175
176
177
178 type Dialer struct {
179
180
181
182 NetDialer *net.Dialer
183
184
185
186
187
188 Config *Config
189 }
190
191
192
193
194
195
196
197
198 func (d *Dialer) Dial(network, addr string) (net.Conn, error) {
199 return d.DialContext(context.Background(), network, addr)
200 }
201
202 func (d *Dialer) netDialer() *net.Dialer {
203 if d.NetDialer != nil {
204 return d.NetDialer
205 }
206 return new(net.Dialer)
207 }
208
209
210
211
212
213
214
215
216
217
218 func (d *Dialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
219 c, err := dial(ctx, d.netDialer(), network, addr, d.Config)
220 if err != nil {
221
222 return nil, err
223 }
224 return c, nil
225 }
226
227
228
229
230
231
232
233
234
235 func LoadX509KeyPair(certFile, keyFile string) (Certificate, error) {
236 certPEMBlock, err := os.ReadFile(certFile)
237 if err != nil {
238 return Certificate{}, err
239 }
240 keyPEMBlock, err := os.ReadFile(keyFile)
241 if err != nil {
242 return Certificate{}, err
243 }
244 return X509KeyPair(certPEMBlock, keyPEMBlock)
245 }
246
247 var x509keypairleaf = godebug.New("x509keypairleaf")
248
249
250
251
252
253
254
255 func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (Certificate, error) {
256 fail := func(err error) (Certificate, error) { return Certificate{}, err }
257
258 var cert Certificate
259 var skippedBlockTypes []string
260 for {
261 var certDERBlock *pem.Block
262 certDERBlock, certPEMBlock = pem.Decode(certPEMBlock)
263 if certDERBlock == nil {
264 break
265 }
266 if certDERBlock.Type == "CERTIFICATE" {
267 cert.Certificate = append(cert.Certificate, certDERBlock.Bytes)
268 } else {
269 skippedBlockTypes = append(skippedBlockTypes, certDERBlock.Type)
270 }
271 }
272
273 if len(cert.Certificate) == 0 {
274 if len(skippedBlockTypes) == 0 {
275 return fail(errors.New("tls: failed to find any PEM data in certificate input"))
276 }
277 if len(skippedBlockTypes) == 1 && strings.HasSuffix(skippedBlockTypes[0], "PRIVATE KEY") {
278 return fail(errors.New("tls: failed to find certificate PEM data in certificate input, but did find a private key; PEM inputs may have been switched"))
279 }
280 return fail(fmt.Errorf("tls: failed to find \"CERTIFICATE\" PEM block in certificate input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
281 }
282
283 skippedBlockTypes = skippedBlockTypes[:0]
284 var keyDERBlock *pem.Block
285 for {
286 keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock)
287 if keyDERBlock == nil {
288 if len(skippedBlockTypes) == 0 {
289 return fail(errors.New("tls: failed to find any PEM data in key input"))
290 }
291 if len(skippedBlockTypes) == 1 && skippedBlockTypes[0] == "CERTIFICATE" {
292 return fail(errors.New("tls: found a certificate rather than a key in the PEM for the private key"))
293 }
294 return fail(fmt.Errorf("tls: failed to find PEM block with type ending in \"PRIVATE KEY\" in key input after skipping PEM blocks of the following types: %v", skippedBlockTypes))
295 }
296 if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") {
297 break
298 }
299 skippedBlockTypes = append(skippedBlockTypes, keyDERBlock.Type)
300 }
301
302
303
304 x509Cert, err := x509.ParseCertificate(cert.Certificate[0])
305 if err != nil {
306 return fail(err)
307 }
308
309 if x509keypairleaf.Value() != "0" {
310 cert.Leaf = x509Cert
311 } else {
312 x509keypairleaf.IncNonDefault()
313 }
314
315 cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes)
316 if err != nil {
317 return fail(err)
318 }
319
320 switch pub := x509Cert.PublicKey.(type) {
321 case *rsa.PublicKey:
322 priv, ok := cert.PrivateKey.(*rsa.PrivateKey)
323 if !ok {
324 return fail(errors.New("tls: private key type does not match public key type"))
325 }
326 if pub.N.Cmp(priv.N) != 0 {
327 return fail(errors.New("tls: private key does not match public key"))
328 }
329 case *ecdsa.PublicKey:
330 priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey)
331 if !ok {
332 return fail(errors.New("tls: private key type does not match public key type"))
333 }
334 if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 {
335 return fail(errors.New("tls: private key does not match public key"))
336 }
337 case ed25519.PublicKey:
338 priv, ok := cert.PrivateKey.(ed25519.PrivateKey)
339 if !ok {
340 return fail(errors.New("tls: private key type does not match public key type"))
341 }
342 if !bytes.Equal(priv.Public().(ed25519.PublicKey), pub) {
343 return fail(errors.New("tls: private key does not match public key"))
344 }
345 default:
346 return fail(errors.New("tls: unknown public key algorithm"))
347 }
348
349 return cert, nil
350 }
351
352
353
354
355 func parsePrivateKey(der []byte) (crypto.PrivateKey, error) {
356 if key, err := x509.ParsePKCS1PrivateKey(der); err == nil {
357 return key, nil
358 }
359 if key, err := x509.ParsePKCS8PrivateKey(der); err == nil {
360 switch key := key.(type) {
361 case *rsa.PrivateKey, *ecdsa.PrivateKey, ed25519.PrivateKey:
362 return key, nil
363 default:
364 return nil, errors.New("tls: found unknown private key type in PKCS#8 wrapping")
365 }
366 }
367 if key, err := x509.ParseECPrivateKey(der); err == nil {
368 return key, nil
369 }
370
371 return nil, errors.New("tls: failed to parse private key")
372 }
373
View as plain text