1
2
3
4
5
6 package cookiejar
7
8 import (
9 "cmp"
10 "errors"
11 "fmt"
12 "net"
13 "net/http"
14 "net/http/internal/ascii"
15 "net/netip"
16 "net/url"
17 "slices"
18 "strings"
19 "sync"
20 "time"
21 )
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37 type PublicSuffixList interface {
38
39
40
41
42
43 PublicSuffix(domain string) string
44
45
46
47
48 String() string
49 }
50
51
52 type Options struct {
53
54
55
56
57
58
59 PublicSuffixList PublicSuffixList
60 }
61
62
63 type Jar struct {
64 psList PublicSuffixList
65
66
67 mu sync.Mutex
68
69
70
71 entries map[string]map[string]entry
72
73
74
75 nextSeqNum uint64
76 }
77
78
79
80 func New(o *Options) (*Jar, error) {
81 jar := &Jar{
82 entries: make(map[string]map[string]entry),
83 }
84 if o != nil {
85 jar.psList = o.PublicSuffixList
86 }
87 return jar, nil
88 }
89
90
91
92
93
94 type entry struct {
95 Name string
96 Value string
97 Quoted bool
98 Domain string
99 Path string
100 SameSite string
101 Secure bool
102 HttpOnly bool
103 Persistent bool
104 HostOnly bool
105 Expires time.Time
106 Creation time.Time
107 LastAccess time.Time
108
109
110
111
112 seqNum uint64
113 }
114
115
116 func (e *entry) id() string {
117 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
118 }
119
120
121
122
123 func (e *entry) shouldSend(https bool, host, path string) bool {
124 return e.domainMatch(host) && e.pathMatch(path) && e.secureMatch(https)
125 }
126
127
128
129
130 func (e *entry) domainMatch(host string) bool {
131 if e.Domain == host {
132 return true
133 }
134 return !e.HostOnly && hasDotSuffix(host, e.Domain)
135 }
136
137
138 func (e *entry) pathMatch(requestPath string) bool {
139 if requestPath == e.Path {
140 return true
141 }
142 if strings.HasPrefix(requestPath, e.Path) {
143 if e.Path[len(e.Path)-1] == '/' {
144 return true
145 } else if requestPath[len(e.Path)] == '/' {
146 return true
147 }
148 }
149 return false
150 }
151
152
153
154
155 func (e *entry) secureMatch(https bool) bool {
156 if !e.Secure {
157
158 return true
159 }
160
161 if https {
162
163 return true
164 }
165
166 if isLocalhost(e.Domain) {
167 return true
168 }
169 ip, err := netip.ParseAddr(e.Domain)
170 if err == nil && ip.IsLoopback() {
171 return true
172 }
173 return false
174 }
175
176 func isLocalhost(host string) bool {
177 host = strings.TrimSuffix(host, ".")
178 if idx := strings.LastIndex(host, "."); idx >= 0 {
179 host = host[idx+1:]
180 }
181 return ascii.EqualFold(host, "localhost")
182 }
183
184
185 func hasDotSuffix(s, suffix string) bool {
186 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
187 }
188
189
190
191
192 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
193 return j.cookies(u, time.Now())
194 }
195
196
197 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
198 if u.Scheme != "http" && u.Scheme != "https" {
199 return cookies
200 }
201 host, err := canonicalHost(u.Host)
202 if err != nil {
203 return cookies
204 }
205 key := jarKey(host, j.psList)
206
207 j.mu.Lock()
208 defer j.mu.Unlock()
209
210 submap := j.entries[key]
211 if submap == nil {
212 return cookies
213 }
214
215 https := u.Scheme == "https"
216 path := u.Path
217 if path == "" {
218 path = "/"
219 }
220
221 modified := false
222 var selected []entry
223 for id, e := range submap {
224 if e.Persistent && !e.Expires.After(now) {
225 delete(submap, id)
226 modified = true
227 continue
228 }
229 if !e.shouldSend(https, host, path) {
230 continue
231 }
232 e.LastAccess = now
233 submap[id] = e
234 selected = append(selected, e)
235 modified = true
236 }
237 if modified {
238 if len(submap) == 0 {
239 delete(j.entries, key)
240 } else {
241 j.entries[key] = submap
242 }
243 }
244
245
246
247 slices.SortFunc(selected, func(a, b entry) int {
248 if r := cmp.Compare(b.Path, a.Path); r != 0 {
249 return r
250 }
251 if r := a.Creation.Compare(b.Creation); r != 0 {
252 return r
253 }
254 return cmp.Compare(a.seqNum, b.seqNum)
255 })
256 for _, e := range selected {
257 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value, Quoted: e.Quoted})
258 }
259
260 return cookies
261 }
262
263
264
265
266 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
267 j.setCookies(u, cookies, time.Now())
268 }
269
270
271 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
272 if len(cookies) == 0 {
273 return
274 }
275 if u.Scheme != "http" && u.Scheme != "https" {
276 return
277 }
278 host, err := canonicalHost(u.Host)
279 if err != nil {
280 return
281 }
282 key := jarKey(host, j.psList)
283 defPath := defaultPath(u.Path)
284
285 j.mu.Lock()
286 defer j.mu.Unlock()
287
288 submap := j.entries[key]
289
290 modified := false
291 for _, cookie := range cookies {
292 e, remove, err := j.newEntry(cookie, now, defPath, host)
293 if err != nil {
294 continue
295 }
296 id := e.id()
297 if remove {
298 if submap != nil {
299 if _, ok := submap[id]; ok {
300 delete(submap, id)
301 modified = true
302 }
303 }
304 continue
305 }
306 if submap == nil {
307 submap = make(map[string]entry)
308 }
309
310 if old, ok := submap[id]; ok {
311 e.Creation = old.Creation
312 e.seqNum = old.seqNum
313 } else {
314 e.Creation = now
315 e.seqNum = j.nextSeqNum
316 j.nextSeqNum++
317 }
318 e.LastAccess = now
319 submap[id] = e
320 modified = true
321 }
322
323 if modified {
324 if len(submap) == 0 {
325 delete(j.entries, key)
326 } else {
327 j.entries[key] = submap
328 }
329 }
330 }
331
332
333
334 func canonicalHost(host string) (string, error) {
335 var err error
336 if hasPort(host) {
337 host, _, err = net.SplitHostPort(host)
338 if err != nil {
339 return "", err
340 }
341 }
342
343 host = strings.TrimSuffix(host, ".")
344 encoded, err := toASCII(host)
345 if err != nil {
346 return "", err
347 }
348
349 lower, _ := ascii.ToLower(encoded)
350 return lower, nil
351 }
352
353
354
355 func hasPort(host string) bool {
356 colons := strings.Count(host, ":")
357 if colons == 0 {
358 return false
359 }
360 if colons == 1 {
361 return true
362 }
363 return host[0] == '[' && strings.Contains(host, "]:")
364 }
365
366
367 func jarKey(host string, psl PublicSuffixList) string {
368 if isIP(host) {
369 return host
370 }
371
372 var i int
373 if psl == nil {
374 i = strings.LastIndex(host, ".")
375 if i <= 0 {
376 return host
377 }
378 } else {
379 suffix := psl.PublicSuffix(host)
380 if suffix == host {
381 return host
382 }
383 i = len(host) - len(suffix)
384 if i <= 0 || host[i-1] != '.' {
385
386
387 return host
388 }
389
390
391
392 }
393 prevDot := strings.LastIndex(host[:i-1], ".")
394 return host[prevDot+1:]
395 }
396
397
398 func isIP(host string) bool {
399 if strings.ContainsAny(host, ":%") {
400
401
402
403
404 return true
405 }
406 return net.ParseIP(host) != nil
407 }
408
409
410
411 func defaultPath(path string) string {
412 if len(path) == 0 || path[0] != '/' {
413 return "/"
414 }
415
416 i := strings.LastIndex(path, "/")
417 if i == 0 {
418 return "/"
419 }
420 return path[:i]
421 }
422
423
424
425
426
427
428
429
430
431
432 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
433 e.Name = c.Name
434
435 if c.Path == "" || c.Path[0] != '/' {
436 e.Path = defPath
437 } else {
438 e.Path = c.Path
439 }
440
441 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
442 if err != nil {
443 return e, false, err
444 }
445
446
447 if c.MaxAge < 0 {
448 return e, true, nil
449 } else if c.MaxAge > 0 {
450 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
451 e.Persistent = true
452 } else {
453 if c.Expires.IsZero() {
454 e.Expires = endOfTime
455 e.Persistent = false
456 } else {
457 if !c.Expires.After(now) {
458 return e, true, nil
459 }
460 e.Expires = c.Expires
461 e.Persistent = true
462 }
463 }
464
465 e.Value = c.Value
466 e.Quoted = c.Quoted
467 e.Secure = c.Secure
468 e.HttpOnly = c.HttpOnly
469
470 switch c.SameSite {
471 case http.SameSiteDefaultMode:
472 e.SameSite = "SameSite"
473 case http.SameSiteStrictMode:
474 e.SameSite = "SameSite=Strict"
475 case http.SameSiteLaxMode:
476 e.SameSite = "SameSite=Lax"
477 }
478
479 return e, false, nil
480 }
481
482 var (
483 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
484 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
485 )
486
487
488
489
490 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
491
492
493 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
494 if domain == "" {
495
496
497 return host, true, nil
498 }
499
500 if isIP(host) {
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517 if host != domain {
518 return "", false, errIllegalDomain
519 }
520
521
522
523
524
525
526
527
528
529
530 return host, true, nil
531 }
532
533
534
535
536 domain = strings.TrimPrefix(domain, ".")
537
538 if len(domain) == 0 || domain[0] == '.' {
539
540
541 return "", false, errMalformedDomain
542 }
543
544 domain, isASCII := ascii.ToLower(domain)
545 if !isASCII {
546
547 return "", false, errMalformedDomain
548 }
549
550 if domain[len(domain)-1] == '.' {
551
552
553
554
555
556
557 return "", false, errMalformedDomain
558 }
559
560
561 if j.psList != nil {
562 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
563 if host == domain {
564
565
566 return host, true, nil
567 }
568 return "", false, errIllegalDomain
569 }
570 }
571
572
573
574 if host != domain && !hasDotSuffix(host, domain) {
575 return "", false, errIllegalDomain
576 }
577
578 return domain, false, nil
579 }
580
View as plain text