diff --git a/upstream/resolver.go b/upstream/resolver.go index 451416072..2dbe8f887 100644 --- a/upstream/resolver.go +++ b/upstream/resolver.go @@ -6,6 +6,7 @@ import ( "math" "net/netip" "net/url" + "slices" "strings" "sync" "time" @@ -265,13 +266,13 @@ type CachingResolver struct { // resolver is the underlying resolver to use for lookups. resolver *UpstreamResolver - // mu protects cached and it's elements. + // mu protects cache and it's elements. mu *sync.RWMutex - // cached is the set of cached results sorted by [resolveResult.name]. + // cache is the set of resolved hostnames mapped to cached addresses. // // TODO(e.burkov): Use expiration cache. - cached map[string]*ipResult + cache map[string]*ipResult } // NewCachingResolver creates a new caching resolver that uses r for lookups. @@ -279,7 +280,7 @@ func NewCachingResolver(r *UpstreamResolver) (cr *CachingResolver) { return &CachingResolver{ resolver: r, mu: &sync.RWMutex{}, - cached: map[string]*ipResult{}, + cache: map[string]*ipResult{}, } } @@ -300,32 +301,38 @@ func (r *CachingResolver) LookupNetIP( addrs = r.findCached(host, now) if addrs != nil { - return addrs, nil + return slices.Clone(addrs), nil } - newRes, err := r.resolver.lookupNetIP(ctx, network, host) + res, err := r.resolver.lookupNetIP(ctx, network, host) if err != nil { return []netip.Addr{}, err } - r.mu.Lock() - defer r.mu.Unlock() + r.setCached(host, res) - r.cached[host] = newRes - - return newRes.addrs, nil + return slices.Clone(res.addrs), nil } // findCached returns the cached addresses for host if it's not expired yet, and -// the corresponding cached result, if any. +// the corresponding cached result, if any. It's safe for concurrent use. func (r *CachingResolver) findCached(host string, now time.Time) (addrs []netip.Addr) { r.mu.RLock() defer r.mu.RUnlock() - res, ok := r.cached[host] + res, ok := r.cache[host] if !ok || res.expire.Before(now) { return nil } return res.addrs } + +// setCached sets the result into the address cache for host. It's safe for +// concurrent use. +func (r *CachingResolver) setCached(host string, res *ipResult) { + r.mu.Lock() + defer r.mu.Unlock() + + r.cache[host] = res +}