1
2
3
4
5
6 package cookiejar
7
8 import (
9 "cmp"
10 "errors"
11 "fmt"
12 "net"
13 "net/http"
14 "net/http/internal/ascii"
15 "net/url"
16 "slices"
17 "strings"
18 "sync"
19 "time"
20 )
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36 type PublicSuffixList interface {
37
38
39
40
41
42 PublicSuffix(domain string) string
43
44
45
46
47 String() string
48 }
49
50
51 type Options struct {
52
53
54
55
56
57
58 PublicSuffixList PublicSuffixList
59 }
60
61
62 type Jar struct {
63 psList PublicSuffixList
64
65
66 mu sync.Mutex
67
68
69
70 entries map[string]map[string]entry
71
72
73
74 nextSeqNum uint64
75 }
76
77
78
79 func New(o *Options) (*Jar, error) {
80 jar := &Jar{
81 entries: make(map[string]map[string]entry),
82 }
83 if o != nil {
84 jar.psList = o.PublicSuffixList
85 }
86 return jar, nil
87 }
88
89
90
91
92
93 type entry struct {
94 Name string
95 Value string
96 Quoted bool
97 Domain string
98 Path string
99 SameSite string
100 Secure bool
101 HttpOnly bool
102 Persistent bool
103 HostOnly bool
104 Expires time.Time
105 Creation time.Time
106 LastAccess time.Time
107
108
109
110
111 seqNum uint64
112 }
113
114
115 func (e *entry) id() string {
116 return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name)
117 }
118
119
120
121
122 func (e *entry) shouldSend(https bool, host, path string) bool {
123 return e.domainMatch(host) && e.pathMatch(path) && (https || !e.Secure)
124 }
125
126
127
128
129 func (e *entry) domainMatch(host string) bool {
130 if e.Domain == host {
131 return true
132 }
133 return !e.HostOnly && hasDotSuffix(host, e.Domain)
134 }
135
136
137 func (e *entry) pathMatch(requestPath string) bool {
138 if requestPath == e.Path {
139 return true
140 }
141 if strings.HasPrefix(requestPath, e.Path) {
142 if e.Path[len(e.Path)-1] == '/' {
143 return true
144 } else if requestPath[len(e.Path)] == '/' {
145 return true
146 }
147 }
148 return false
149 }
150
151
152 func hasDotSuffix(s, suffix string) bool {
153 return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix
154 }
155
156
157
158
159 func (j *Jar) Cookies(u *url.URL) (cookies []*http.Cookie) {
160 return j.cookies(u, time.Now())
161 }
162
163
164 func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) {
165 if u.Scheme != "http" && u.Scheme != "https" {
166 return cookies
167 }
168 host, err := canonicalHost(u.Host)
169 if err != nil {
170 return cookies
171 }
172 key := jarKey(host, j.psList)
173
174 j.mu.Lock()
175 defer j.mu.Unlock()
176
177 submap := j.entries[key]
178 if submap == nil {
179 return cookies
180 }
181
182 https := u.Scheme == "https"
183 path := u.Path
184 if path == "" {
185 path = "/"
186 }
187
188 modified := false
189 var selected []entry
190 for id, e := range submap {
191 if e.Persistent && !e.Expires.After(now) {
192 delete(submap, id)
193 modified = true
194 continue
195 }
196 if !e.shouldSend(https, host, path) {
197 continue
198 }
199 e.LastAccess = now
200 submap[id] = e
201 selected = append(selected, e)
202 modified = true
203 }
204 if modified {
205 if len(submap) == 0 {
206 delete(j.entries, key)
207 } else {
208 j.entries[key] = submap
209 }
210 }
211
212
213
214 slices.SortFunc(selected, func(a, b entry) int {
215 if r := cmp.Compare(b.Path, a.Path); r != 0 {
216 return r
217 }
218 if r := a.Creation.Compare(b.Creation); r != 0 {
219 return r
220 }
221 return cmp.Compare(a.seqNum, b.seqNum)
222 })
223 for _, e := range selected {
224 cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value, Quoted: e.Quoted})
225 }
226
227 return cookies
228 }
229
230
231
232
233 func (j *Jar) SetCookies(u *url.URL, cookies []*http.Cookie) {
234 j.setCookies(u, cookies, time.Now())
235 }
236
237
238 func (j *Jar) setCookies(u *url.URL, cookies []*http.Cookie, now time.Time) {
239 if len(cookies) == 0 {
240 return
241 }
242 if u.Scheme != "http" && u.Scheme != "https" {
243 return
244 }
245 host, err := canonicalHost(u.Host)
246 if err != nil {
247 return
248 }
249 key := jarKey(host, j.psList)
250 defPath := defaultPath(u.Path)
251
252 j.mu.Lock()
253 defer j.mu.Unlock()
254
255 submap := j.entries[key]
256
257 modified := false
258 for _, cookie := range cookies {
259 e, remove, err := j.newEntry(cookie, now, defPath, host)
260 if err != nil {
261 continue
262 }
263 id := e.id()
264 if remove {
265 if submap != nil {
266 if _, ok := submap[id]; ok {
267 delete(submap, id)
268 modified = true
269 }
270 }
271 continue
272 }
273 if submap == nil {
274 submap = make(map[string]entry)
275 }
276
277 if old, ok := submap[id]; ok {
278 e.Creation = old.Creation
279 e.seqNum = old.seqNum
280 } else {
281 e.Creation = now
282 e.seqNum = j.nextSeqNum
283 j.nextSeqNum++
284 }
285 e.LastAccess = now
286 submap[id] = e
287 modified = true
288 }
289
290 if modified {
291 if len(submap) == 0 {
292 delete(j.entries, key)
293 } else {
294 j.entries[key] = submap
295 }
296 }
297 }
298
299
300
301 func canonicalHost(host string) (string, error) {
302 var err error
303 if hasPort(host) {
304 host, _, err = net.SplitHostPort(host)
305 if err != nil {
306 return "", err
307 }
308 }
309
310 host = strings.TrimSuffix(host, ".")
311 encoded, err := toASCII(host)
312 if err != nil {
313 return "", err
314 }
315
316 lower, _ := ascii.ToLower(encoded)
317 return lower, nil
318 }
319
320
321
322 func hasPort(host string) bool {
323 colons := strings.Count(host, ":")
324 if colons == 0 {
325 return false
326 }
327 if colons == 1 {
328 return true
329 }
330 return host[0] == '[' && strings.Contains(host, "]:")
331 }
332
333
334 func jarKey(host string, psl PublicSuffixList) string {
335 if isIP(host) {
336 return host
337 }
338
339 var i int
340 if psl == nil {
341 i = strings.LastIndex(host, ".")
342 if i <= 0 {
343 return host
344 }
345 } else {
346 suffix := psl.PublicSuffix(host)
347 if suffix == host {
348 return host
349 }
350 i = len(host) - len(suffix)
351 if i <= 0 || host[i-1] != '.' {
352
353
354 return host
355 }
356
357
358
359 }
360 prevDot := strings.LastIndex(host[:i-1], ".")
361 return host[prevDot+1:]
362 }
363
364
365 func isIP(host string) bool {
366 if strings.ContainsAny(host, ":%") {
367
368
369
370
371 return true
372 }
373 return net.ParseIP(host) != nil
374 }
375
376
377
378 func defaultPath(path string) string {
379 if len(path) == 0 || path[0] != '/' {
380 return "/"
381 }
382
383 i := strings.LastIndex(path, "/")
384 if i == 0 {
385 return "/"
386 }
387 return path[:i]
388 }
389
390
391
392
393
394
395
396
397
398
399 func (j *Jar) newEntry(c *http.Cookie, now time.Time, defPath, host string) (e entry, remove bool, err error) {
400 e.Name = c.Name
401
402 if c.Path == "" || c.Path[0] != '/' {
403 e.Path = defPath
404 } else {
405 e.Path = c.Path
406 }
407
408 e.Domain, e.HostOnly, err = j.domainAndType(host, c.Domain)
409 if err != nil {
410 return e, false, err
411 }
412
413
414 if c.MaxAge < 0 {
415 return e, true, nil
416 } else if c.MaxAge > 0 {
417 e.Expires = now.Add(time.Duration(c.MaxAge) * time.Second)
418 e.Persistent = true
419 } else {
420 if c.Expires.IsZero() {
421 e.Expires = endOfTime
422 e.Persistent = false
423 } else {
424 if !c.Expires.After(now) {
425 return e, true, nil
426 }
427 e.Expires = c.Expires
428 e.Persistent = true
429 }
430 }
431
432 e.Value = c.Value
433 e.Quoted = c.Quoted
434 e.Secure = c.Secure
435 e.HttpOnly = c.HttpOnly
436
437 switch c.SameSite {
438 case http.SameSiteDefaultMode:
439 e.SameSite = "SameSite"
440 case http.SameSiteStrictMode:
441 e.SameSite = "SameSite=Strict"
442 case http.SameSiteLaxMode:
443 e.SameSite = "SameSite=Lax"
444 }
445
446 return e, false, nil
447 }
448
449 var (
450 errIllegalDomain = errors.New("cookiejar: illegal cookie domain attribute")
451 errMalformedDomain = errors.New("cookiejar: malformed cookie domain attribute")
452 )
453
454
455
456
457 var endOfTime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC)
458
459
460 func (j *Jar) domainAndType(host, domain string) (string, bool, error) {
461 if domain == "" {
462
463
464 return host, true, nil
465 }
466
467 if isIP(host) {
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484 if host != domain {
485 return "", false, errIllegalDomain
486 }
487
488
489
490
491
492
493
494
495
496
497 return host, true, nil
498 }
499
500
501
502
503 if domain[0] == '.' {
504 domain = domain[1:]
505 }
506
507 if len(domain) == 0 || domain[0] == '.' {
508
509
510 return "", false, errMalformedDomain
511 }
512
513 domain, isASCII := ascii.ToLower(domain)
514 if !isASCII {
515
516 return "", false, errMalformedDomain
517 }
518
519 if domain[len(domain)-1] == '.' {
520
521
522
523
524
525
526 return "", false, errMalformedDomain
527 }
528
529
530 if j.psList != nil {
531 if ps := j.psList.PublicSuffix(domain); ps != "" && !hasDotSuffix(domain, ps) {
532 if host == domain {
533
534
535 return host, true, nil
536 }
537 return "", false, errIllegalDomain
538 }
539 }
540
541
542
543 if host != domain && !hasDotSuffix(host, domain) {
544 return "", false, errIllegalDomain
545 }
546
547 return domain, false, nil
548 }
549
View as plain text