mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 13:09:58 +00:00
TUN-9473: Add --dns-resolver-addrs flag
To help support users with environments that don't work well with the DNS local resolver's automatic resolution process for local resolver addresses, we introduce a flag to provide them statically to the runtime. When providing the resolver addresses, cloudflared will no longer lookup the DNS resolver addresses and use the user input directly. When provided with a list of DNS resolvers larger than one, the resolver service will randomly select one at random for each incoming request. Closes TUN-9473
This commit is contained in:
@@ -157,4 +157,7 @@ const (
|
||||
|
||||
// ApiURL is the command line flag used to define the base URL of the API
|
||||
ApiURL = "api-url"
|
||||
|
||||
// Virtual DNS resolver service resolver addresses to use instead of dynamically fetching them from the OS.
|
||||
VirtualDNSServiceResolverAddresses = "dns-resolver-addrs"
|
||||
)
|
||||
|
@@ -227,7 +227,17 @@ func prepareTunnelConfig(
|
||||
DefaultDialer: ingress.NewDialer(warpRoutingConfig),
|
||||
TCPWriteTimeout: c.Duration(flags.WriteStreamTimeout),
|
||||
}, log)
|
||||
|
||||
// Setup DNS Resolver Service
|
||||
dnsResolverAddrs := c.StringSlice(flags.VirtualDNSServiceResolverAddresses)
|
||||
dnsService := origins.NewDNSResolverService(origins.NewDNSDialer(), log)
|
||||
if len(dnsResolverAddrs) > 0 {
|
||||
addrs, err := parseResolverAddrPorts(dnsResolverAddrs)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid %s provided: %w", flags.VirtualDNSServiceResolverAddresses, err)
|
||||
}
|
||||
dnsService = origins.NewStaticDNSResolverService(addrs, origins.NewDNSDialer(), log)
|
||||
}
|
||||
originDialerService.AddReservedService(dnsService, []netip.AddrPort{origins.VirtualDNSServiceAddr})
|
||||
|
||||
tunnelConfig := &supervisor.TunnelConfig{
|
||||
@@ -507,3 +517,19 @@ func findLocalAddr(dst net.IP, port int) (netip.Addr, error) {
|
||||
localAddr := localAddrPort.Addr()
|
||||
return localAddr, nil
|
||||
}
|
||||
|
||||
func parseResolverAddrPorts(input []string) ([]netip.AddrPort, error) {
|
||||
// We don't allow more than 10 resolvers to be provided statically for the resolver service.
|
||||
if len(input) > 10 {
|
||||
return nil, errors.New("too many addresses provided, max: 10")
|
||||
}
|
||||
addrs := make([]netip.AddrPort, 0, len(input))
|
||||
for _, val := range input {
|
||||
addr, err := netip.ParseAddrPort(val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
@@ -241,6 +241,11 @@ var (
|
||||
Usage: "Overrides the remote configuration for max active private network flows (TCP/UDP) that this cloudflared instance supports",
|
||||
EnvVars: []string{"TUNNEL_MAX_ACTIVE_FLOWS"},
|
||||
}
|
||||
dnsResolverAddrsFlag = &cli.StringSliceFlag{
|
||||
Name: flags.VirtualDNSServiceResolverAddresses,
|
||||
Usage: "Overrides the dynamic DNS resolver resolution to use these address:port's instead.",
|
||||
EnvVars: []string{"TUNNEL_DNS_RESOLVER_ADDRS"},
|
||||
}
|
||||
)
|
||||
|
||||
func buildCreateCommand() *cli.Command {
|
||||
@@ -718,6 +723,7 @@ func buildRunCommand() *cli.Command {
|
||||
icmpv4SrcFlag,
|
||||
icmpv6SrcFlag,
|
||||
maxActiveFlowsFlag,
|
||||
dnsResolverAddrsFlag,
|
||||
}
|
||||
flags = append(flags, configureProxyFlags(false)...)
|
||||
return &cli.Command{
|
||||
|
@@ -2,8 +2,11 @@ package origins
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -42,42 +45,50 @@ type netDial func(network string, address string) (net.Conn, error)
|
||||
|
||||
// DNSResolverService will make DNS requests to the local DNS resolver via the Dial method.
|
||||
type DNSResolverService struct {
|
||||
address netip.AddrPort
|
||||
addressM sync.RWMutex
|
||||
|
||||
dialer ingress.OriginDialer
|
||||
resolver peekResolver
|
||||
logger *zerolog.Logger
|
||||
addresses []netip.AddrPort
|
||||
addressesM sync.RWMutex
|
||||
static bool
|
||||
dialer ingress.OriginDialer
|
||||
resolver peekResolver
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func NewDNSResolverService(dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService {
|
||||
return &DNSResolverService{
|
||||
address: defaultResolverAddr,
|
||||
dialer: dialer,
|
||||
resolver: &resolver{dialFunc: net.Dial},
|
||||
logger: logger,
|
||||
addresses: []netip.AddrPort{defaultResolverAddr},
|
||||
dialer: dialer,
|
||||
resolver: &resolver{dialFunc: net.Dial},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func NewStaticDNSResolverService(resolverAddrs []netip.AddrPort, dialer ingress.OriginDialer, logger *zerolog.Logger) *DNSResolverService {
|
||||
s := NewDNSResolverService(dialer, logger)
|
||||
s.addresses = resolverAddrs
|
||||
s.static = true
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *DNSResolverService) DialTCP(ctx context.Context, _ netip.AddrPort) (net.Conn, error) {
|
||||
s.addressM.RLock()
|
||||
dest := s.address
|
||||
s.addressM.RUnlock()
|
||||
dest := s.getAddress()
|
||||
// The dialer ignores the provided address because the request will instead go to the local DNS resolver.
|
||||
return s.dialer.DialTCP(ctx, dest)
|
||||
}
|
||||
|
||||
func (s *DNSResolverService) DialUDP(_ netip.AddrPort) (net.Conn, error) {
|
||||
s.addressM.RLock()
|
||||
dest := s.address
|
||||
s.addressM.RUnlock()
|
||||
dest := s.getAddress()
|
||||
// The dialer ignores the provided address because the request will instead go to the local DNS resolver.
|
||||
return s.dialer.DialUDP(dest)
|
||||
}
|
||||
|
||||
// StartRefreshLoop is a routine that is expected to run in the background to update the DNS local resolver if
|
||||
// adjusted while the cloudflared process is running.
|
||||
// Does not run when the resolver was provided with external resolver addresses via CLI.
|
||||
func (s *DNSResolverService) StartRefreshLoop(ctx context.Context) {
|
||||
if s.static {
|
||||
s.logger.Debug().Msgf("Canceled DNS local resolver refresh loop because static resolver addresses were provided: %s", s.addresses)
|
||||
return
|
||||
}
|
||||
// Call update first to load an address before handling traffic
|
||||
err := s.update(ctx)
|
||||
if err != nil {
|
||||
@@ -122,14 +133,38 @@ func (s *DNSResolverService) update(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// returns the address from the peekResolver or from the static addresses if provided.
|
||||
// If multiple addresses are provided in the static addresses pick one randomly.
|
||||
func (s *DNSResolverService) getAddress() netip.AddrPort {
|
||||
s.addressesM.RLock()
|
||||
defer s.addressesM.RUnlock()
|
||||
l := len(s.addresses)
|
||||
if l <= 0 {
|
||||
return defaultResolverAddr
|
||||
}
|
||||
if l == 1 {
|
||||
return s.addresses[0]
|
||||
}
|
||||
// Only initialize the random selection if there is more than one element in the list.
|
||||
var i int64 = 0
|
||||
r, err := rand.Int(rand.Reader, big.NewInt(int64(l)))
|
||||
// We ignore errors from crypto rand and use index 0; this should be extremely unlikely and the
|
||||
// list index doesn't need to be cryptographically secure, but linters insist.
|
||||
if err == nil {
|
||||
i = r.Int64()
|
||||
}
|
||||
return s.addresses[i]
|
||||
}
|
||||
|
||||
// lock and update the address used for the local DNS resolver
|
||||
func (s *DNSResolverService) setAddress(addr netip.AddrPort) {
|
||||
s.addressM.Lock()
|
||||
defer s.addressM.Unlock()
|
||||
if s.address != addr {
|
||||
s.addressesM.Lock()
|
||||
defer s.addressesM.Unlock()
|
||||
if !slices.Contains(s.addresses, addr) {
|
||||
s.logger.Debug().Msgf("Updating DNS local resolver: %s", addr)
|
||||
}
|
||||
s.address = addr
|
||||
// We only store one address when reading the peekResolver, so we just replace the whole list.
|
||||
s.addresses = []netip.AddrPort{addr}
|
||||
}
|
||||
|
||||
type peekResolver interface {
|
||||
|
@@ -5,7 +5,9 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
@@ -17,9 +19,18 @@ func TestDNSResolver_DefaultResolver(t *testing.T) {
|
||||
address: "127.0.0.2:53",
|
||||
}
|
||||
service.resolver = mockResolver
|
||||
if service.address != defaultResolverAddr {
|
||||
t.Errorf("resolver address should be the default: %s, was: %s", defaultResolverAddr, service.address)
|
||||
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
|
||||
}
|
||||
|
||||
func TestStaticDNSResolver_DefaultResolver(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")}
|
||||
service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log)
|
||||
mockResolver := &mockPeekResolver{
|
||||
address: "127.0.0.2:53",
|
||||
}
|
||||
service.resolver = mockResolver
|
||||
validateAddrs(t, addresses, service.addresses)
|
||||
}
|
||||
|
||||
func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
|
||||
@@ -29,26 +40,49 @@ func TestDNSResolver_UpdateResolverAddress(t *testing.T) {
|
||||
mockResolver := &mockPeekResolver{}
|
||||
service.resolver = mockResolver
|
||||
|
||||
expectedAddr := netip.MustParseAddrPort("127.0.0.2:53")
|
||||
addresses := []string{
|
||||
"127.0.0.2:53",
|
||||
"127.0.0.2", // missing port should be added (even though this is unlikely to happen)
|
||||
tests := []struct {
|
||||
addr string
|
||||
expected netip.AddrPort
|
||||
}{
|
||||
{"127.0.0.2:53", netip.MustParseAddrPort("127.0.0.2:53")},
|
||||
// missing port should be added (even though this is unlikely to happen)
|
||||
{"127.0.0.3", netip.MustParseAddrPort("127.0.0.3:53")},
|
||||
}
|
||||
|
||||
for _, addr := range addresses {
|
||||
mockResolver.address = addr
|
||||
for _, test := range tests {
|
||||
mockResolver.address = test.addr
|
||||
// Update the resolver address
|
||||
err := service.update(t.Context())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
// Validate expected
|
||||
if service.address != expectedAddr {
|
||||
t.Errorf("resolver address should be: %s, was: %s", expectedAddr, service.address)
|
||||
}
|
||||
validateAddrs(t, []netip.AddrPort{test.expected}, service.addresses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticDNSResolver_RefreshLoopExits(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
addresses := []netip.AddrPort{netip.MustParseAddrPort("1.1.1.1:53"), netip.MustParseAddrPort("1.0.0.1:53")}
|
||||
service := NewStaticDNSResolverService(addresses, NewDNSDialer(), &log)
|
||||
|
||||
mockResolver := &mockPeekResolver{
|
||||
address: "127.0.0.2:53",
|
||||
}
|
||||
service.resolver = mockResolver
|
||||
|
||||
ctx, cancel := context.WithCancel(t.Context())
|
||||
defer cancel()
|
||||
|
||||
go service.StartRefreshLoop(ctx)
|
||||
|
||||
// Wait for the refresh loop to end _and_ not update the addresses
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
|
||||
// Validate expected
|
||||
validateAddrs(t, addresses, service.addresses)
|
||||
}
|
||||
|
||||
func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
|
||||
log := zerolog.Nop()
|
||||
service := NewDNSResolverService(NewDNSDialer(), &log)
|
||||
@@ -69,9 +103,7 @@ func TestDNSResolver_UpdateResolverAddressInvalid(t *testing.T) {
|
||||
t.Error("service update should throw an error")
|
||||
}
|
||||
// Validate expected
|
||||
if service.address != defaultResolverAddr {
|
||||
t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address)
|
||||
}
|
||||
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,9 +120,7 @@ func TestDNSResolver_UpdateResolverErrorIgnored(t *testing.T) {
|
||||
t.Error("service update should throw an error")
|
||||
}
|
||||
// Validate expected
|
||||
if service.address != defaultResolverAddr {
|
||||
t.Errorf("resolver address should not be updated from default: %s, was: %s", defaultResolverAddr, service.address)
|
||||
}
|
||||
validateAddrs(t, []netip.AddrPort{defaultResolverAddr}, service.addresses)
|
||||
}
|
||||
|
||||
func TestDNSResolver_DialUDPUsesResolvedAddress(t *testing.T) {
|
||||
@@ -152,3 +182,14 @@ func (d *mockDialer) DialUDP(addr netip.AddrPort) (net.Conn, error) {
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func validateAddrs(t *testing.T, expected []netip.AddrPort, actual []netip.AddrPort) {
|
||||
if len(actual) != len(expected) {
|
||||
t.Errorf("addresses should only contain one element: %s", actual)
|
||||
}
|
||||
for _, e := range expected {
|
||||
if !slices.Contains(actual, e) {
|
||||
t.Errorf("missing address: %s in %s", e, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user