1
2
3
4
5
6
7
8 package cookiejar
9
10 import (
11 "cmp"
12 "errors"
13 "fmt"
14 "net"
15 "net/http"
16 "net/http/internal/ascii"
17 "net/netip"
18 "net/url"
19 "slices"
20 "strings"
21 "sync"
22 "time"
23 )
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39 type PublicSuffixList interface {
40
41
42
43
44
45 PublicSuffix(domain string) string
46
47
48
49
50 String() string
51 }
52
53
54 type Options struct {
55
56
57
58
59
60
61 PublicSuffixList PublicSuffixList
62 }
63
64
65 type Jar struct {
66 psList PublicSuffixList
67
68
69 mu sync.Mutex
70
71
72
73 entries map[string]map[string]entry
74
75
76
77 nextSeqNum uint64
78 }
79
80
81
82 func New(o *Options) (*Jar, error) {
83 jar := &Jar{
84 entries: make(map[string]map[string]entry),
85 }
86 if o != nil {
87 jar.psList = o.PublicSuffixList
88 }
89 return jar, nil
90 }
91
92
93
94
95
96 type entry struct {
97 Name string
98 Value string
99 Quoted bool
100 Domain string
101 Path string
102 SameSite string
103 Secure bool
104 HttpOnly bool
105 Persistent bool
106 HostOnly bool
107 Expires time.Time
108 Creation time.Time
109 LastAccess time.Time
110
111
112
113
114 seqNum uint64
115 }
116
117
118 func (e *entry) id() string {
119 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
120 }
121
122
123
124
125 func (e *entry) shouldSend(https bool, host, path string) bool {
126 return e.domainMatch(host) && e.pathMatch(path) && e.secureMatch(https)
127 }
128
129
130
131
132 func (e *entry) domainMatch(host string) bool {
133 if e.Domain == host {
134 return true
135 }
136 return !e.HostOnly && hasDotSuffix(host, e.Domain)
137 }
138
139
140 func (e *entry) pathMatch(requestPath string) bool {
141 if requestPath == e.Path {
142 return true
143 }
144 if strings.HasPrefix(requestPath, e.Path) {
145 if e.Path[len(e.Path)-1] == '/' {
146 return true
147 } else if requestPath[len(e.Path)] == '/' {
148 return true
149 }
150 }
151 return false
152 }
153
154
155
156
157 func (e *entry) secureMatch(https bool) bool {
158 if !e.Secure {
159
160 return true
161 }
162
163 if https {
164
165 return true
166 }
167
168 if isLocalhost(e.Domain) {
169 return true
170 }
171 ip, err := netip.ParseAddr(e.Domain)
172 if err == nil && ip.IsLoopback() {
173 return true
174 }
175 return false
176 }
177
178 func isLocalhost(host string) bool {
179 host = strings.TrimSuffix(host, ".")
180 if idx := strings.LastIndex(host, "."); idx >= 0 {
181 host = host[idx+1:]
182 }
183 return ascii.EqualFold(host, "localhost")
184 }
185
186
187 func hasDotSuffix(s, suffix string) bool {
188 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
189 }
190
191
192
193
194 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
195 return j.cookies(u, time.Now())
196 }
197
198
199 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
200 if u.Scheme != "http" && u.Scheme != "https" {
201 return cookies
202 }
203 host, err := canonicalHost(u.Host)
204 if err != nil {
205 return cookies
206 }
207 key := jarKey(host, j.psList)
208
209 j.mu.Lock()
210 defer j.mu.Unlock()
211
212 submap := j.entries[key]
213 if submap == nil {
214 return cookies
215 }
216
217 https := u.Scheme == "https"
218 path := u.Path
219 if path == "" {
220 path = "/"
221 }
222
223 modified := false
224 var selected []entry
225 for id, e := range submap {
226 if e.Persistent && !e.Expires.After(now) {
227 delete(submap, id)
228 modified = true
229 continue
230 }
231 if !e.shouldSend(https, host, path) {
232 continue
233 }
234 e.LastAccess = now
235 submap[id] = e
236 selected = append(selected, e)
237 modified = true
238 }
239 if modified {
240 if len(submap) == 0 {
241 delete(j.entries, key)
242 } else {
243 j.entries[key] = submap
244 }
245 }
246
247
248
249 slices.SortFunc(selected, func(a, b entry) int {
250 if r := cmp.Compare(b.Path, a.Path); r != 0 {
251 return r
252 }
253 if r := a.Creation.Compare(b.Creation); r != 0 {
254 return r
255 }
256 return cmp.Compare(a.seqNum, b.seqNum)
257 })
258 for _, e := range selected {
259 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value, Quoted: e.Quoted})
260 }
261
262 return cookies
263 }
264
265
266
267
268 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
269 j.setCookies(u, cookies, time.Now())
270 }
271
272
273 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
274 if len(cookies) == 0 {
275 return
276 }
277 if u.Scheme != "http" && u.Scheme != "https" {
278 return
279 }
280 host, err := canonicalHost(u.Host)
281 if err != nil {
282 return
283 }
284 key := jarKey(host, j.psList)
285 defPath := defaultPath(u.Path)
286
287 j.mu.Lock()
288 defer j.mu.Unlock()
289
290 submap := j.entries[key]
291
292 modified := false
293 for _, cookie := range cookies {
294 e, remove, err := j.newEntry(cookie, now, defPath, host)
295 if err != nil {
296 continue
297 }
298 id := e.id()
299 if remove {
300 if submap != nil {
301 if _, ok := submap[id]; ok {
302 delete(submap, id)
303 modified = true
304 }
305 }
306 continue
307 }
308 if submap == nil {
309 submap = make(map[string]entry)
310 }
311
312 if old, ok := submap[id]; ok {
313 e.Creation = old.Creation
314 e.seqNum = old.seqNum
315 } else {
316 e.Creation = now
317 e.seqNum = j.nextSeqNum
318 j.nextSeqNum++
319 }
320 e.LastAccess = now
321 submap[id] = e
322 modified = true
323 }
324
325 if modified {
326 if len(submap) == 0 {
327 delete(j.entries, key)
328 } else {
329 j.entries[key] = submap
330 }
331 }
332 }
333
334
335
336 func canonicalHost(host string) (string, error) {
337 var err error
338 if hasPort(host) {
339 host, _, err = net.SplitHostPort(host)
340 if err != nil {
341 return "", err
342 }
343 }
344
345 host = strings.TrimSuffix(host, ".")
346 encoded, err := toASCII(host)
347 if err != nil {
348 return "", err
349 }
350
351 lower, _ := ascii.ToLower(encoded)
352 return lower, nil
353 }
354
355
356
357 func hasPort(host string) bool {
358 colons := strings.Count(host, ":")
359 if colons == 0 {
360 return false
361 }
362 if colons == 1 {
363 return true
364 }
365 return host[0] == '[' && strings.Contains(host, "]:")
366 }
367
368
369 func jarKey(host string, psl PublicSuffixList) string {
370 if isIP(host) {
371 return host
372 }
373
374 var i int
375 if psl == nil {
376 i = strings.LastIndex(host, ".")
377 if i <= 0 {
378 return host
379 }
380 } else {
381 suffix := psl.PublicSuffix(host)
382 if suffix == host {
383 return host
384 }
385 i = len(host) - len(suffix)
386 if i <= 0 || host[i-1] != '.' {
387
388
389 return host
390 }
391
392
393
394 }
395 prevDot := strings.LastIndex(host[:i-1], ".")
396 return host[prevDot+1:]
397 }
398
399
400 func isIP(host string) bool {
401 if strings.ContainsAny(host, ":%") {
402
403
404
405
406 return true
407 }
408 return net.ParseIP(host) != nil
409 }
410
411
412
413 func defaultPath(path string) string {
414 if len(path) == 0 || path[0] != '/' {
415 return "/"
416 }
417
418 i := strings.LastIndex(path, "/")
419 if i == 0 {
420 return "/"
421 }
422 return path[:i]
423 }
424
425
426
427
428
429
430
431
432
433
434 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
435 e.Name = c.Name
436
437 if c.Path == "" || c.Path[0] != '/' {
438 e.Path = defPath
439 } else {
440 e.Path = c.Path
441 }
442
443 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
444 if err != nil {
445 return e, false, err
446 }
447
448
449 if c.MaxAge < 0 {
450 return e, true, nil
451 } else if c.MaxAge > 0 {
452 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
453 e.Persistent = true
454 } else {
455 if c.Expires.IsZero() {
456 e.Expires = endOfTime
457 e.Persistent = false
458 } else {
459 if !c.Expires.After(now) {
460 return e, true, nil
461 }
462 e.Expires = c.Expires
463 e.Persistent = true
464 }
465 }
466
467 e.Value = c.Value
468 e.Quoted = c.Quoted
469 e.Secure = c.Secure
470 e.HttpOnly = c.HttpOnly
471
472 switch c.SameSite {
473 case http.SameSiteDefaultMode:
474 e.SameSite = "SameSite"
475 case http.SameSiteStrictMode:
476 e.SameSite = "SameSite=Strict"
477 case http.SameSiteLaxMode:
478 e.SameSite = "SameSite=Lax"
479 }
480
481 return e, false, nil
482 }
483
484 var (
485 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
486 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
487 )
488
489
490
491
492 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
493
494
495 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
496 if domain == "" {
497
498
499 return host, true, nil
500 }
501
502 if isIP(host) {
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519 if host != domain {
520 return "", false, errIllegalDomain
521 }
522
523
524
525
526
527
528
529
530
531
532 return host, true, nil
533 }
534
535
536
537
538 domain = strings.TrimPrefix(domain, ".")
539
540 if len(domain) == 0 || domain[0] == '.' {
541
542
543 return "", false, errMalformedDomain
544 }
545
546 domain, isASCII := ascii.ToLower(domain)
547 if !isASCII {
548
549 return "", false, errMalformedDomain
550 }
551
552 if domain[len(domain)-1] == '.' {
553
554
555
556
557
558
559 return "", false, errMalformedDomain
560 }
561
562
563 if j.psList != nil {
564 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
565 if host == domain {
566
567
568 return host, true, nil
569 }
570 return "", false, errIllegalDomain
571 }
572 }
573
574
575
576 if host != domain && !hasDotSuffix(host, domain) {
577 return "", false, errIllegalDomain
578 }
579
580 return domain, false, nil
581 }
582
View as plain text