mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 19:59:58 +00:00

To help support users with environments that don't work well with the DNS local resolver's automatic resolution process for local resolver addresses, we introduce a flag to provide them statically to the runtime. When providing the resolver addresses, cloudflared will no longer lookup the DNS resolver addresses and use the user input directly. When provided with a list of DNS resolvers larger than one, the resolver service will randomly select one at random for each incoming request. Closes TUN-9473
196 lines
5.4 KiB
Go
196 lines
5.4 KiB
Go
package origins
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net"
|
|
"net/netip"
|
|
"slices"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/rs/zerolog"
|
|
)
|
|
|
|
func TestDNSResolver_DefaultResolver(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
service := NewDNSResolverService(NewDNSDialer(), &log)
|
|
mockResolver := &mockPeekResolver{
|
|
address: "127.0.0.2:53",
|
|
}
|
|
service.resolver = mockResolver
|
|
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
|
|
}
|
|
|
|
func TestStaticDNSResolver_DefaultResolver(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")}
|
|
service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log)
|
|
mockResolver := &mockPeekResolver{
|
|
address: "127.0.0.2:53",
|
|
}
|
|
service.resolver = mockResolver
|
|
validateAddrs(t, addresses, service.addresses)
|
|
}
|
|
|
|
func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
service := NewDNSResolverService(NewDNSDialer(), &log)
|
|
|
|
mockResolver := &mockPeekResolver{}
|
|
service.resolver = mockResolver
|
|
|
|
tests := []struct {
|
|
addr string
|
|
expected netip.AddrPort
|
|
}{
|
|
{"127.0.0.2:53", netip.MustParseAddrPort("127.0.0.2:53")},
|
|
// missing port should be added (even though this is unlikely to happen)
|
|
{"127.0.0.3", netip.MustParseAddrPort("127.0.0.3:53")},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
mockResolver.address = test.addr
|
|
// Update the resolver address
|
|
err := service.update(t.Context())
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
// Validate expected
|
|
validateAddrs(t, []netip.AddrPort{test.expected}, service.addresses)
|
|
}
|
|
}
|
|
|
|
func TestStaticDNSResolver_RefreshLoopExits(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")}
|
|
service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log)
|
|
|
|
mockResolver := &mockPeekResolver{
|
|
address: "127.0.0.2:53",
|
|
}
|
|
service.resolver = mockResolver
|
|
|
|
ctx, cancel := context.WithCancel(t.Context())
|
|
defer cancel()
|
|
|
|
go service.StartRefreshLoop(ctx)
|
|
|
|
// Wait for the refresh loop to end _and_ not update the addresses
|
|
time.Sleep(10 * time.Millisecond)
|
|
|
|
// Validate expected
|
|
validateAddrs(t, addresses, service.addresses)
|
|
}
|
|
|
|
func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
service := NewDNSResolverService(NewDNSDialer(), &log)
|
|
mockResolver := &mockPeekResolver{}
|
|
service.resolver = mockResolver
|
|
|
|
invalidAddresses := []string{
|
|
"999.999.999.999",
|
|
"localhost",
|
|
"255.255.255",
|
|
}
|
|
|
|
for _, addr := range invalidAddresses {
|
|
mockResolver.address = addr
|
|
// Update the resolver address should not update for these invalid addresses
|
|
err := service.update(t.Context())
|
|
if err == nil {
|
|
t.Error("service update should throw an error")
|
|
}
|
|
// Validate expected
|
|
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
|
|
}
|
|
}
|
|
|
|
func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
service := NewDNSResolverService(NewDNSDialer(), &log)
|
|
resolverErr := errors.New("test resolver error")
|
|
mockResolver := &mockPeekResolver{err: resolverErr}
|
|
service.resolver = mockResolver
|
|
|
|
// Update the resolver address should not update when the resolver cannot complete the lookup
|
|
err := service.update(t.Context())
|
|
if err != resolverErr {
|
|
t.Error("service update should throw an error")
|
|
}
|
|
// Validate expected
|
|
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
|
|
}
|
|
|
|
func TestDNSResolver_DialUDPUsesResolvedAddress(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
mockDialer := &mockDialer{expected: defaultResolverAddr}
|
|
service := NewDNSResolverService(mockDialer, &log)
|
|
mockResolver := &mockPeekResolver{}
|
|
service.resolver = mockResolver
|
|
|
|
// Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53
|
|
_, err := service.DialUDP(netip.MustParseAddrPort("127.0.0.2:53"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
func TestDNSResolver_DialTCPUsesResolvedAddress(t *testing.T) {
|
|
log := zerolog.Nop()
|
|
mockDialer := &mockDialer{expected: defaultResolverAddr}
|
|
service := NewDNSResolverService(mockDialer, &log)
|
|
mockResolver := &mockPeekResolver{}
|
|
service.resolver = mockResolver
|
|
|
|
// Attempt a dial to 127.0.0.2:53 which should be ignored and instead resolve to 127.0.0.1:53
|
|
_, err := service.DialTCP(t.Context(), netip.MustParseAddrPort("127.0.0.2:53"))
|
|
if err != nil {
|
|
t.Error(err)
|
|
}
|
|
}
|
|
|
|
type mockPeekResolver struct {
|
|
err error
|
|
address string
|
|
}
|
|
|
|
func (r *mockPeekResolver) addr() (network, address string) {
|
|
return "udp", r.address
|
|
}
|
|
|
|
func (r *mockPeekResolver) lookupNetIP(ctx context.Context, host string) ([]netip.Addr, error) {
|
|
// We can return an empty result as it doesn't matter as long as the lookup doesn't fail
|
|
return []netip.Addr{}, r.err
|
|
}
|
|
|
|
type mockDialer struct {
|
|
expected netip.AddrPort
|
|
}
|
|
|
|
func (d *mockDialer) DialTCP(ctx context.Context, addr netip.AddrPort) (net.Conn, error) {
|
|
if d.expected != addr {
|
|
return nil, errors.New("unexpected address dialed")
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) {
|
|
if d.expected != addr {
|
|
return nil, errors.New("unexpected address dialed")
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func validateAddrs(t *testing.T, expected []netip.AddrPort, actual []netip.AddrPort) {
|
|
if len(actual) != len(expected) {
|
|
t.Errorf("addresses should only contain one element: %s", actual)
|
|
}
|
|
for _, e := range expected {
|
|
if !slices.Contains(actual, e) {
|
|
t.Errorf("missing address: %s in %s", e, actual)
|
|
}
|
|
}
|
|
}
|