Skip to content

Commit

Permalink
all: add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
schzhn committed Jan 16, 2025
1 parent 4f686ad commit 966cabf
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 22 deletions.
4 changes: 0 additions & 4 deletions proxy/dnscontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ func (p *Proxy) newDNSContext(proto Proto, req *dns.Msg, addr netip.AddrPort) (d
// QueryStatistics returns the DNS query statistics for both the upstream and
// fallback DNS servers.
func (dctx *DNSContext) QueryStatistics() (s *QueryStatistics) {
if dctx == nil {
return nil
}

return dctx.queryStatistics
}

Expand Down
4 changes: 4 additions & 0 deletions proxy/exchange.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ func (p *Proxy) exchangeUpstreams(
if len(ups) == 1 {
u = ups[0]
resp, _, err = p.exchange(u, req, p.time)
if err != nil {
return nil, nil, err
}

// TODO(e.burkov): p.updateRTT(u.Address(), elapsed)

return resp, u, err
Expand Down
34 changes: 19 additions & 15 deletions proxy/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ type upstreamWithStats struct {
}

// newUpstreamWithStats returns a new initialized *upstreamWithStats.
func newUpstreamWithStats(
upstream upstream.Upstream,
isFallback bool,
) (u *upstreamWithStats) {
func newUpstreamWithStats(upstream upstream.Upstream, isFallback bool) (u *upstreamWithStats) {
return &upstreamWithStats{
upstream: upstream,
mu: &sync.Mutex{},
Expand All @@ -55,12 +52,13 @@ var _ upstream.Upstream = (*upstreamWithStats)(nil)
func (u *upstreamWithStats) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
start := time.Now()
resp, err = u.upstream.Exchange(req)
dur := time.Since(start)

u.mu.Lock()
defer u.mu.Unlock()

u.err = err
u.queryDuration = time.Since(start)
u.queryDuration = dur

return resp, err
}
Expand Down Expand Up @@ -119,13 +117,13 @@ func (s *QueryStatistics) Fallback() (us []*UpstreamStatistics) {

// collectQueryStats gathers the statistics from the wrapped upstreams,
// considering the upstream mode. resolver is an upstream DNS resolver that
// successfully resolved the request, and it will be unwrapped. Provided
// upstreams must be of type *upstreamWithStats.
//
// If the DNS query was not resolved (i.e., if resolver is nil) or upstream mode
// is [UpstreamModeFastestAddr], the function returns the gathered statistics
// for both the upstream and fallback DNS servers. Otherwise, it returns the
// query statistics specifically for resolver.
// successfully resolved the request, and it will be unwrapped. If resolver is
// nil (i.e. the DNS query was not resolved) or upstream mode is
// [UpstreamModeFastestAddr], the function returns the gathered statistics for
// both the upstream and fallback DNS servers. If resolver is fallback, it also
// gathers the statistics for the upstreams. Otherwise, it returns the query
// statistics specifically for upstream resolver. Provided upstreams must be of
// type *upstreamWithStats.
func collectQueryStats(
mode UpstreamMode,
resolver upstream.Upstream,
Expand All @@ -152,12 +150,17 @@ func collectQueryStats(
}
}

return unwrapped, collectResolverQueryStats(wrapped)
return unwrapped, collectResolverQueryStats(upstreams, wrapped)
}

// collectResolverQueryStats gathers the statistics from an upstream DNS
// resolver that successfully resolved the request. resolver must be not nil.
func collectResolverQueryStats(resolver *upstreamWithStats) (stats *QueryStatistics) {
// resolver that successfully resolved the request. If resolver is the fallback
// DNS resolver, it also gathers the statistics for the upstream DNS resolvers.
// resolver must be not nil.
func collectResolverQueryStats(
upstreams []upstream.Upstream,
resolver *upstreamWithStats,
) (stats *QueryStatistics) {
dur, err := resolver.stats()
s := &UpstreamStatistics{
Address: resolver.upstream.Address(),
Expand All @@ -167,6 +170,7 @@ func collectResolverQueryStats(resolver *upstreamWithStats) (stats *QueryStatist

if resolver.isFallback {
return &QueryStatistics{
main: collectUpstreamStats(upstreams),
fallback: []*UpstreamStatistics{s},
}
}
Expand Down
276 changes: 276 additions & 0 deletions proxy/stats_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
package proxy_test

import (
"net"
"net/netip"
"testing"

"github.com/AdguardTeam/dnsproxy/internal/dnsproxytest"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCollectQueryStats(t *testing.T) {
const (
listenIP = "127.0.0.1"
)

var (
testReq = &dns.Msg{
Question: []dns.Question{{
Name: "test.",
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}},
}

defaultTrustedProxies netutil.SubnetSet = netutil.SliceSubnetSet{
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::0/0"),
}

localhostAnyPort = netip.MustParseAddrPort(netutil.JoinHostPort(listenIP, 0))
)

ups := &dnsproxytest.FakeUpstream{
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return (&dns.Msg{}).SetReply(req), nil
},
OnAddress: func() (addr string) { return "upstream" },
OnClose: func() (err error) { return nil },
}

failUps := &dnsproxytest.FakeUpstream{
OnExchange: func(req *dns.Msg) (resp *dns.Msg, err error) {
return nil, errors.Error("exchange error")
},
OnAddress: func() (addr string) { return "fail.upstream" },
OnClose: func() (err error) { return nil },
}

conf := &proxy.Config{
Logger: slogutil.NewDiscardLogger(),
UDPListenAddr: []*net.UDPAddr{net.UDPAddrFromAddrPort(localhostAnyPort)},
TCPListenAddr: []*net.TCPAddr{net.TCPAddrFromAddrPort(localhostAnyPort)},
TrustedProxies: defaultTrustedProxies,
RatelimitSubnetLenIPv4: 24,
RatelimitSubnetLenIPv6: 64,
}

testCases := []struct {
isExchangeErr assert.BoolAssertionFunc
config *proxy.UpstreamConfig
fallbackConfig *proxy.UpstreamConfig
name string
mode proxy.UpstreamMode
mainCount int
fallbackCount int
isMainErr bool
isFallbackErr bool
}{{
isExchangeErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "load_balance_success",
mode: proxy.UpstreamModeLoadBalance,
mainCount: 1,
fallbackCount: 0,
isMainErr: false,
isFallbackErr: false,
}, {
isExchangeErr: assert.True,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
name: "load_balance_bad",
mode: proxy.UpstreamModeLoadBalance,
mainCount: 1,
fallbackCount: 2,
isMainErr: true,
isFallbackErr: true,
}, {
isExchangeErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "parallel_success",
mode: proxy.UpstreamModeParallel,
mainCount: 1,
fallbackCount: 0,
isMainErr: false,
isFallbackErr: false,
}, {
isExchangeErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "parallel_bad_fallback_success",
mode: proxy.UpstreamModeParallel,
mainCount: 1,
fallbackCount: 1,
isMainErr: true,
isFallbackErr: false,
}, {
isExchangeErr: assert.True,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps, failUps},
},
name: "parallel_bad",
mode: proxy.UpstreamModeParallel,
mainCount: 2,
fallbackCount: 3,
isMainErr: true,
isFallbackErr: true,
}, {
isExchangeErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_single_success",
mode: proxy.UpstreamModeFastestAddr,
mainCount: 1,
fallbackCount: 0,
isMainErr: false,
isFallbackErr: false,
}, {
isExchangeErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups, ups},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_multiple_success",
mode: proxy.UpstreamModeFastestAddr,
mainCount: 2,
fallbackCount: 0,
isMainErr: false,
isFallbackErr: false,
}, {
isExchangeErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_mixed_success",
mode: proxy.UpstreamModeFastestAddr,
mainCount: 2,
fallbackCount: 0,
isMainErr: true,
isFallbackErr: false,
}, {
isExchangeErr: assert.True,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps, failUps},
},
name: "fastest_multiple_bad",
mode: proxy.UpstreamModeFastestAddr,
mainCount: 2,
fallbackCount: 3,
isMainErr: true,
isFallbackErr: true,
}, {
isExchangeErr: assert.False,
config: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{failUps, failUps},
},
fallbackConfig: &proxy.UpstreamConfig{
Upstreams: []upstream.Upstream{ups},
},
name: "fastest_bad_fallback_success",
mode: proxy.UpstreamModeFastestAddr,
mainCount: 2,
fallbackCount: 1,
isMainErr: true,
isFallbackErr: false,
}}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
conf.UpstreamConfig = tc.config
conf.Fallbacks = tc.fallbackConfig
conf.UpstreamMode = tc.mode

p, err := proxy.New(conf)
require.NoError(t, err)

d := &proxy.DNSContext{Req: testReq}

err = p.Resolve(d)
tc.isExchangeErr(t, err != nil)

stats := d.QueryStatistics()
assertQueryStats(
t,
stats,
tc.mainCount,
tc.isMainErr,
tc.fallbackCount,
tc.isFallbackErr,
)
})
}
}

// assertQueryStats asserts the statistics using the provided parameters.
func assertQueryStats(
t *testing.T,
stats *proxy.QueryStatistics,
mainCount int,
isMainErr bool,
fallbackCount int,
isFallbackErr bool,
) {
t.Helper()

main := stats.Main()
assert.Equal(t, mainCount, len(main), "main stats count")

fallback := stats.Fallback()
assert.Equal(t, fallbackCount, len(fallback), "fallback stats count")

assert.Equal(t, isMainErr, isErrorInStats(main), "main err")
assert.Equal(t, isFallbackErr, isErrorInStats(fallback), "fallback err")
}

// isErrorInStats is a helper function for tests that returns true if the
// upstream statistics contain an DNS lookup error.
func isErrorInStats(stats []*proxy.UpstreamStatistics) (ok bool) {
for _, u := range stats {
if u.Error != nil {
return true
}
}

return false
}
Loading

0 comments on commit 966cabf

Please sign in to comment.