diff --git a/cmd/cloudflared/flags/flags.go b/cmd/cloudflared/flags/flags.go index a7bf1b7e..975ee401 100644 --- a/cmd/cloudflared/flags/flags.go +++ b/cmd/cloudflared/flags/flags.go @@ -157,4 +157,7 @@ const ( // ApiURL is the command line flag used to define the base URL of the API ApiURL = "api-url" + + // Virtual DNS resolver service resolver addresses to use instead of dynamically fetching them from the OS. + VirtualDNSServiceResolverAddresses = "dns-resolver-addrs" ) diff --git a/cmd/cloudflared/tunnel/configuration.go b/cmd/cloudflared/tunnel/configuration.go index 3960d0db..de41184f 100644 --- a/cmd/cloudflared/tunnel/configuration.go +++ b/cmd/cloudflared/tunnel/configuration.go @@ -227,7 +227,17 @@ func prepareTunnelConfig( DefaultDialer: ingress.NewDialer(warpRoutingConfig), TCPWriteTimeout: c.Duration(flags.WriteStreamTimeout), }, log) + + // Setup DNS Resolver Service + dnsResolverAddrs := c.StringSlice(flags.VirtualDNSServiceResolverAddresses) dnsService := origins.NewDNSResolverService(origins.NewDNSDialer(), log) + if len(dnsResolverAddrs) > 0 { + addrs, err := parseResolverAddrPorts(dnsResolverAddrs) + if err != nil { + return nil, nil, fmt.Errorf("invalid %s provided: %w", flags.VirtualDNSServiceResolverAddresses, err) + } + dnsService = origins.NewStaticDNSResolverService(addrs, origins.NewDNSDialer(), log) + } originDialerService.AddReservedService(dnsService, []netip.AddrPort{origins.VirtualDNSServiceAddr}) tunnelConfig := &supervisor.TunnelConfig{ @@ -507,3 +517,19 @@ func findLocalAddr(dst net.IP, port int) (netip.Addr, error) { localAddr := localAddrPort.Addr() return localAddr, nil } + +func parseResolverAddrPorts(input []string) ([]netip.AddrPort, error) { + // We don't allow more than 10 resolvers to be provided statically for the resolver service. + if len(input) > 10 { + return nil, errors.New("too many addresses provided, max: 10") + } + addrs := make([]netip.AddrPort, 0, len(input)) + for _, val := range input { + addr, err := netip.ParseAddrPort(val) + if err != nil { + return nil, err + } + addrs = append(addrs, addr) + } + return addrs, nil +} diff --git a/cmd/cloudflared/tunnel/subcommands.go b/cmd/cloudflared/tunnel/subcommands.go index 4be655a0..f89e05c1 100644 --- a/cmd/cloudflared/tunnel/subcommands.go +++ b/cmd/cloudflared/tunnel/subcommands.go @@ -241,6 +241,11 @@ var ( Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports", EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"}, } + dnsResolverAddrsFlag = &cli.StringSliceFlag{ + Name: flags.VirtualDNSServiceResolverAddresses, + Usage: "Overrides the dynamic DNS resolver resolution to use these address:port's instead.", + EnvVars: []string{"TUNNEL_DNS_RESOLVER_ADDRS"}, + } ) func buildCreateCommand() *cli.Command { @@ -718,6 +723,7 @@ func buildRunCommand() *cli.Command { icmpv4SrcFlag, icmpv6SrcFlag, maxActiveFlowsFlag, + dnsResolverAddrsFlag, } flags = append(flags, configureProxyFlags(false)...) return &cli.Command{ diff --git a/ingress/origins/dns.go b/ingress/origins/dns.go index a3936b9f..b5599f89 100644 --- a/ingress/origins/dns.go +++ b/ingress/origins/dns.go @@ -2,8 +2,11 @@ package origins import ( "context" + "crypto/rand" + "math/big" "net" "net/netip" + "slices" "sync" "time" @@ -42,42 +45,50 @@ type netDial func(network string, address string) (net.Conn, error) // DNSResolverService will make DNS requests to the local DNS resolver via the Dial method. type DNSResolverService struct { - address netip.AddrPort - addressM sync.RWMutex - - dialer ingress.OriginDialer - resolver peekResolver - logger *zerolog.Logger + addresses []netip.AddrPort + addressesM sync.RWMutex + static bool + dialer ingress.OriginDialer + resolver peekResolver + logger *zerolog.Logger } func NewDNSResolverService(dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService { return &DNSResolverService{ - address: defaultResolverAddr, - dialer: dialer, - resolver: &resolver{dialFunc: net.Dial}, - logger: logger, + addresses: []netip.AddrPort{defaultResolverAddr}, + dialer: dialer, + resolver: &resolver{dialFunc: net.Dial}, + logger: logger, } } +func NewStaticDNSResolverService(resolverAddrs []netip.AddrPort, dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService { + s := NewDNSResolverService(dialer, logger) + s.addresses = resolverAddrs + s.static = true + return s +} + func (s *DNSResolverService) DialTCP(ctx context.Context, _ netip.AddrPort) (net.Conn, error) { - s.addressM.RLock() - dest := s.address - s.addressM.RUnlock() + dest := s.getAddress() // The dialer ignores the provided address because the request will instead go to the local DNS resolver. return s.dialer.DialTCP(ctx, dest) } func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) { - s.addressM.RLock() - dest := s.address - s.addressM.RUnlock() + dest := s.getAddress() // The dialer ignores the provided address because the request will instead go to the local DNS resolver. return s.dialer.DialUDP(dest) } // StartRefreshLoop is a routine that is expected to run in the background to update the DNS local resolver if // adjusted while the cloudflared process is running. +// Does not run when the resolver was provided with external resolver addresses via CLI. func (s *DNSResolverService) StartRefreshLoop(ctx context.Context) { + if s.static { + s.logger.Debug().Msgf("Canceled DNS local resolver refresh loop because static resolver addresses were provided: %s", s.addresses) + return + } // Call update first to load an address before handling traffic err := s.update(ctx) if err != nil { @@ -122,14 +133,38 @@ func (s *DNSResolverService) update(ctx context.Context) error { return nil } +// returns the address from the peekResolver or from the static addresses if provided. +// If multiple addresses are provided in the static addresses pick one randomly. +func (s *DNSResolverService) getAddress() netip.AddrPort { + s.addressesM.RLock() + defer s.addressesM.RUnlock() + l := len(s.addresses) + if l <= 0 { + return defaultResolverAddr + } + if l == 1 { + return s.addresses[0] + } + // Only initialize the random selection if there is more than one element in the list. + var i int64 = 0 + r, err := rand.Int(rand.Reader, big.NewInt(int64(l))) + // We ignore errors from crypto rand and use index 0; this should be extremely unlikely and the + // list index doesn't need to be cryptographically secure, but linters insist. + if err == nil { + i = r.Int64() + } + return s.addresses[i] +} + // lock and update the address used for the local DNS resolver func (s *DNSResolverService) setAddress(addr netip.AddrPort) { - s.addressM.Lock() - defer s.addressM.Unlock() - if s.address != addr { + s.addressesM.Lock() + defer s.addressesM.Unlock() + if !slices.Contains(s.addresses, addr) { s.logger.Debug().Msgf("Updating DNS local resolver: %s", addr) } - s.address = addr + // We only store one address when reading the peekResolver, so we just replace the whole list. + s.addresses = []netip.AddrPort{addr} } type peekResolver interface { diff --git a/ingress/origins/dns_test.go b/ingress/origins/dns_test.go index a137c814..3ea24510 100644 --- a/ingress/origins/dns_test.go +++ b/ingress/origins/dns_test.go @@ -5,7 +5,9 @@ import ( "errors" "net" "net/netip" + "slices" "testing" + "time" "github.com/rs/zerolog" ) @@ -17,9 +19,18 @@ func TestDNSResolver_DefaultResolver(t *testing.T) { address: "127.0.0.2:53", } service.resolver = mockResolver - if service.address != defaultResolverAddr { - t.Errorf("resolver address should be the default: %s, was: %s", defaultResolverAddr, service.address) + validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses) +} + +func TestStaticDNSResolver_DefaultResolver(t *testing.T) { + log := zerolog.Nop() + addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")} + service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log) + mockResolver := &mockPeekResolver{ + address: "127.0.0.2:53", } + service.resolver = mockResolver + validateAddrs(t, addresses, service.addresses) } func TestDNSResolver_UpdateResolverAddress(t *testing.T) { @@ -29,26 +40,49 @@ func TestDNSResolver_UpdateResolverAddress(t *testing.T) { mockResolver := &mockPeekResolver{} service.resolver = mockResolver - expectedAddr := netip.MustParseAddrPort("127.0.0.2:53") - addresses := []string{ - "127.0.0.2:53", - "127.0.0.2", // missing port should be added (even though this is unlikely to happen) + tests := []struct { + addr string + expected netip.AddrPort + }{ + {"127.0.0.2:53", netip.MustParseAddrPort("127.0.0.2:53")}, + // missing port should be added (even though this is unlikely to happen) + {"127.0.0.3", netip.MustParseAddrPort("127.0.0.3:53")}, } - for _, addr := range addresses { - mockResolver.address = addr + for _, test := range tests { + mockResolver.address = test.addr // Update the resolver address err := service.update(t.Context()) if err != nil { t.Error(err) } // Validate expected - if service.address != expectedAddr { - t.Errorf("resolver address should be: %s, was: %s", expectedAddr, service.address) - } + validateAddrs(t, []netip.AddrPort{test.expected}, service.addresses) } } +func TestStaticDNSResolver_RefreshLoopExits(t *testing.T) { + log := zerolog.Nop() + addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")} + service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log) + + mockResolver := &mockPeekResolver{ + address: "127.0.0.2:53", + } + service.resolver = mockResolver + + ctx, cancel := context.WithCancel(t.Context()) + defer cancel() + + go service.StartRefreshLoop(ctx) + + // Wait for the refresh loop to end _and_ not update the addresses + time.Sleep(10 * time.Millisecond) + + // Validate expected + validateAddrs(t, addresses, service.addresses) +} + func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) { log := zerolog.Nop() service := NewDNSResolverService(NewDNSDialer(), &log) @@ -69,9 +103,7 @@ func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) { t.Error("service update should throw an error") } // Validate expected - if service.address != defaultResolverAddr { - t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address) - } + validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses) } } @@ -88,9 +120,7 @@ func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) { t.Error("service update should throw an error") } // Validate expected - if service.address != defaultResolverAddr { - t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address) - } + validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses) } func TestDNSResolver_DialUDPUsesResolvedAddress(t *testing.T) { @@ -152,3 +182,14 @@ func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) { } return nil, nil } + +func validateAddrs(t *testing.T, expected []netip.AddrPort, actual []netip.AddrPort) { + if len(actual) != len(expected) { + t.Errorf("addresses should only contain one element: %s", actual) + } + for _, e := range expected { + if !slices.Contains(actual, e) { + t.Errorf("missing address: %s in %s", e, actual) + } + } +}