Skip to content

Commit

Permalink
proxy: query statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Jan 16, 2025
1 parent 429c98c commit 4f686ad
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 19 deletions.
25 changes: 16 additions & 9 deletions proxy/dnscontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"net"
"net/http"
"net/netip"
"time"

"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/ameshkov/dnscrypt/v2"
Expand Down Expand Up @@ -47,17 +46,19 @@ type DNSContext struct {
// servers if it's not nil.
CustomUpstreamConfig *CustomUpstreamConfig

// queryStatistics contains the DNS query statistics for both the upstream
// and fallback DNS servers.
queryStatistics *QueryStatistics

// Req is the request message.
Req *dns.Msg

// Res is the response message.
Res *dns.Msg

// Proto is the DNS protocol of the query.
Proto Proto

// CachedUpstreamAddr is the address of the upstream which the answer was
// cached with. It's empty for responses resolved by the upstream server.
CachedUpstreamAddr string

// RequestedPrivateRDNS is the subnet extracted from the ARPA domain of
// request's question if it's a PTR, SOA, or NS query for a private IP
// address. It can be a single-address subnet as well as a zero-length one.
Expand All @@ -69,10 +70,6 @@ type DNSContext struct {
// Addr is the address of the client.
Addr netip.AddrPort

// QueryDuration is the duration of a successful query to an upstream
// server or, if the upstream server is unavailable, to a fallback server.
QueryDuration time.Duration

// DoQVersion is the DoQ protocol version. It can (and should) be read from
// ALPN, but in the current version we also use the way DNS messages are
// encoded as a signal.
Expand Down Expand Up @@ -115,6 +112,16 @@ func (p *Proxy) newDNSContext(proto Proto, req *dns.Msg, addr netip.AddrPort) (d
}
}

// QueryStatistics returns the DNS query statistics for both the upstream and
// fallback DNS servers.
func (dctx *DNSContext) QueryStatistics() (s *QueryStatistics) {
if dctx == nil {
return nil
}

return dctx.queryStatistics
}

// calcFlagsAndSize lazily calculates some values required for Resolve method.
func (dctx *DNSContext) calcFlagsAndSize() {
if dctx.udpSize != 0 || dctx.Req == nil {
Expand Down
20 changes: 11 additions & 9 deletions proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -554,42 +554,44 @@ func (p *Proxy) replyFromUpstream(d *DNSContext) (ok bool, err error) {
p.recDetector.add(d.Req)
}

start := time.Now()
src := "upstream"
wrapped := upstreamsWithStats(upstreams, false)

// Perform the DNS request.
resp, u, err := p.exchangeUpstreams(req, upstreams)
if dns64Ups := p.performDNS64(req, resp, upstreams); dns64Ups != nil {
resp, u, err := p.exchangeUpstreams(req, wrapped)
if dns64Ups := p.performDNS64(req, resp, wrapped); dns64Ups != nil {
u = dns64Ups
} else if p.isBogusNXDomain(resp) {
p.logger.Debug("response contains bogus-nxdomain ip")
resp = p.messages.NewMsgNXDOMAIN(req)
}

var wrappedFallbacks []upstream.Upstream
if err != nil && !isPrivate && p.Fallbacks != nil {
p.logger.Debug("using fallback", slogutil.KeyError, err)

// Reset the timer.
start = time.Now()
src = "fallback"

// upstreams mustn't appear empty since they have been validated when
// creating proxy.
upstreams = p.Fallbacks.getUpstreamsForDomain(req.Question[0].Name)

resp, u, err = upstream.ExchangeParallel(upstreams, req)
wrappedFallbacks = upstreamsWithStats(upstreams, true)
resp, u, err = upstream.ExchangeParallel(wrappedFallbacks, req)
}

if err != nil {
p.logger.Debug("resolving err", "src", src, slogutil.KeyError, err)
}

if resp != nil {
d.QueryDuration = time.Since(start)
p.logger.Debug("resolved", "src", src, "rtt", d.QueryDuration)
p.logger.Debug("resolved", "src", src)
}

p.handleExchangeResult(d, req, resp, u)
unwrapped, stats := collectQueryStats(p.UpstreamMode, u, wrapped, wrappedFallbacks)
d.queryStatistics = stats

p.handleExchangeResult(d, req, resp, unwrapped)

return resp != nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion proxy/proxycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func (p *Proxy) replyFromCache(d *DNSContext) (hit bool) {
}

d.Res = ci.m
d.CachedUpstreamAddr = ci.u
d.queryStatistics = cachedQueryStatistics(ci.u)

p.logger.Debug(
"replying from cache",
Expand Down
219 changes: 219 additions & 0 deletions proxy/stats.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
package proxy

import (
"fmt"
"sync"
"time"

"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/miekg/dns"
)

// upstreamWithStats is a wrapper around the [upstream.Upstream] interface that
// gathers statistics.
type upstreamWithStats struct {
// upstream is the upstream DNS resolver.
upstream upstream.Upstream

// mu protects err and queryDuration.
mu *sync.Mutex

// err is the DNS lookup error, if any.
err error

// queryDuration is the duration of the successful DNS lookup.
queryDuration time.Duration

// isFallback indicates whether the upstream is a fallback upstream.
isFallback bool
}

// newUpstreamWithStats returns a new initialized *upstreamWithStats.
func newUpstreamWithStats(
upstream upstream.Upstream,
isFallback bool,
) (u *upstreamWithStats) {
return &upstreamWithStats{
upstream: upstream,
mu: &sync.Mutex{},
isFallback: isFallback,
}
}

// stats returns the stored statistics.
func (u *upstreamWithStats) stats() (dur time.Duration, err error) {
u.mu.Lock()
defer u.mu.Unlock()

return u.queryDuration, u.err
}

// type check
var _ upstream.Upstream = (*upstreamWithStats)(nil)

// Exchange implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
start := time.Now()
resp, err = u.upstream.Exchange(req)

u.mu.Lock()
defer u.mu.Unlock()

u.err = err
u.queryDuration = time.Since(start)

return resp, err
}

// Address implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Address() (addr string) {
return u.upstream.Address()
}

// Close implements the [upstream.Upstream] for *upstreamWithStats.
func (u *upstreamWithStats) Close() (err error) {
return u.upstream.Close()
}

// upstreamsWithStats takes a list of upstreams, wraps each upstream with
// [upstreamWithStats] to gather statistics, and returns the wrapped upstreams.
func upstreamsWithStats(
upstreams []upstream.Upstream,
isFallback bool,
) (wrapped []upstream.Upstream) {
wrapped = make([]upstream.Upstream, 0, len(upstreams))
for _, u := range upstreams {
w := newUpstreamWithStats(u, isFallback)
wrapped = append(wrapped, w)
}

return wrapped
}

// QueryStatistics contains the DNS query statistics for both the upstream and
// fallback DNS servers.
type QueryStatistics struct {
main []*UpstreamStatistics
fallback []*UpstreamStatistics
}

// cachedQueryStatistics returns the DNS query statistics for cached queries.
func cachedQueryStatistics(addr string) (s *QueryStatistics) {
return &QueryStatistics{
main: []*UpstreamStatistics{{
Address: addr,
IsCached: true,
}},
}
}

// Main returns the DNS query statistics for the upstream DNS servers.
func (s *QueryStatistics) Main() (us []*UpstreamStatistics) {
return s.main
}

// Fallback returns the DNS query statistics for the fallback DNS servers.
func (s *QueryStatistics) Fallback() (us []*UpstreamStatistics) {
return s.fallback
}

// collectQueryStats gathers the statistics from the wrapped upstreams,
// considering the upstream mode. resolver is an upstream DNS resolver that
// successfully resolved the request, and it will be unwrapped. Provided
// upstreams must be of type *upstreamWithStats.
//
// If the DNS query was not resolved (i.e., if resolver is nil) or upstream mode
// is [UpstreamModeFastestAddr], the function returns the gathered statistics
// for both the upstream and fallback DNS servers. Otherwise, it returns the
// query statistics specifically for resolver.
func collectQueryStats(
mode UpstreamMode,
resolver upstream.Upstream,
upstreams []upstream.Upstream,
fallbacks []upstream.Upstream,
) (unwrapped upstream.Upstream, stats *QueryStatistics) {
var wrapped *upstreamWithStats
if resolver != nil {
var ok bool
wrapped, ok = resolver.(*upstreamWithStats)
if !ok {
// Should never happen.
err := fmt.Errorf("unexpected type %T", resolver)
panic(err)
}

unwrapped = wrapped.upstream
}

if wrapped == nil || mode == UpstreamModeFastestAddr {
return unwrapped, &QueryStatistics{
main: collectUpstreamStats(upstreams),
fallback: collectUpstreamStats(fallbacks),
}
}

return unwrapped, collectResolverQueryStats(wrapped)
}

// collectResolverQueryStats gathers the statistics from an upstream DNS
// resolver that successfully resolved the request. resolver must be not nil.
func collectResolverQueryStats(resolver *upstreamWithStats) (stats *QueryStatistics) {
dur, err := resolver.stats()
s := &UpstreamStatistics{
Address: resolver.upstream.Address(),
Error: err,
QueryDuration: dur,
}

if resolver.isFallback {
return &QueryStatistics{
fallback: []*UpstreamStatistics{s},
}
}

return &QueryStatistics{
main: []*UpstreamStatistics{s},
}
}

// UpstreamStatistics contains the DNS query statistics.
type UpstreamStatistics struct {
// Error is the DNS lookup error, if any.
Error error

// Address is the address of the upstream DNS resolver.
//
// TODO(s.chzhen): Use [upstream.Upstream] when [cacheItem] starts to
// contain one.
Address string

// QueryDuration is the duration of the successful DNS lookup.
QueryDuration time.Duration

// IsCached indicates whether the response was served from a cache.
IsCached bool
}

// collectUpstreamStats gathers the upstream statistics from the list of wrapped
// upstreams. upstreams must be of type *upstreamWithStats.
func collectUpstreamStats(upstreams []upstream.Upstream) (stats []*UpstreamStatistics) {
stats = make([]*UpstreamStatistics, 0, len(upstreams))

for _, u := range upstreams {
w, ok := u.(*upstreamWithStats)
if !ok {
// Should never happen.
err := fmt.Errorf("unexpected type %T", u)
panic(err)
}

dur, err := w.stats()
stats = append(stats, &UpstreamStatistics{
Error: err,
Address: w.Address(),
QueryDuration: dur,
})
}

return stats
}

0 comments on commit 4f686ad

Please sign in to comment.