AUTH-2588 add DoH to service mode

This commit is contained in:
Dalton
2020-05-01 10:30:50 -05:00
committed by Dalton Cherry
parent 2c878c47ed
commit 2b7fbbb7b7
11 changed files with 406 additions and 81 deletions

View File

@@ -0,0 +1,48 @@
package main
import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
)
// ForwardServiceType is used to identify what kind of overwatch service this is
const ForwardServiceType = "forward"
// ForwarderService is used to wrap the access package websocket forwarders
// into a service model for the overwatch package.
// it also holds a reference to the config object that represents its state
type ForwarderService struct {
forwarder config.Forwarder
shutdown chan struct{}
}
// NewForwardService creates a new forwarder service
func NewForwardService(f config.Forwarder) *ForwarderService {
return &ForwarderService{forwarder: f, shutdown: make(chan struct{}, 1)}
}
// Name is used to figure out this service is related to the others (normally the addr it binds to)
// e.g. localhost:78641 or 127.0.0.1:2222 since this is a websocket forwarder
func (s *ForwarderService) Name() string {
return s.forwarder.Listener
}
// Type is used to identify what kind of overwatch service this is
func (s *ForwarderService) Type() string {
return ForwardServiceType
}
// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates
func (s *ForwarderService) Hash() string {
return s.forwarder.Hash()
}
// Shutdown stops the websocket listener
func (s *ForwarderService) Shutdown() {
s.shutdown <- struct{}{}
}
// Run is the run loop that is started by the overwatch service
func (s *ForwarderService) Run() error {
return access.StartForwarder(s.forwarder, s.shutdown)
}

View File

@@ -0,0 +1,73 @@
package main
import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/tunneldns"
"github.com/sirupsen/logrus"
)
// ResolverServiceType is used to identify what kind of overwatch service this is
const ResolverServiceType = "resolver"
// ResolverService is used to wrap the tunneldns package's DNS over HTTP
// into a service model for the overwatch package.
// it also holds a reference to the config object that represents its state
type ResolverService struct {
resolver config.DNSResolver
shutdown chan struct{}
logger *logrus.Logger
}
// NewResolverService creates a new resolver service
func NewResolverService(r config.DNSResolver, logger *logrus.Logger) *ResolverService {
return &ResolverService{resolver: r,
shutdown: make(chan struct{}),
logger: logger,
}
}
// Name is used to figure out this service is related to the others (normally the addr it binds to)
// this is just "resolver" since there can only be one DNS resolver running
func (s *ResolverService) Name() string {
return ResolverServiceType
}
// Type is used to identify what kind of overwatch service this is
func (s *ResolverService) Type() string {
return ResolverServiceType
}
// Hash is used to figure out if this forwarder is the unchanged or not from the config file updates
func (s *ResolverService) Hash() string {
return s.resolver.Hash()
}
// Shutdown stops the tunneldns listener
func (s *ResolverService) Shutdown() {
s.shutdown <- struct{}{}
}
// Run is the run loop that is started by the overwatch service
func (s *ResolverService) Run() error {
// create a listener
l, err := tunneldns.CreateListener(s.resolver.AddressOrDefault(), s.resolver.PortOrDefault(),
s.resolver.UpstreamsOrDefault(), s.resolver.BootstrapsOrDefault())
if err != nil {
return err
}
// start the listener.
readySignal := make(chan struct{})
err = l.Start(readySignal)
if err != nil {
l.Stop()
return err
}
<-readySignal
s.logger.Infof("start resolver on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault())
// wait for shutdown signal
<-s.shutdown
s.logger.Infof("shutdown on: %s:%d", s.resolver.AddressOrDefault(), s.resolver.PortOrDefault())
return l.Stop()
}

View File

@@ -1,36 +1,27 @@
package main
import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/access"
"github.com/cloudflare/cloudflared/cmd/cloudflared/config"
"github.com/cloudflare/cloudflared/overwatch"
"github.com/sirupsen/logrus"
)
type forwarderState struct {
forwarder config.Forwarder
shutdown chan struct{}
}
func (s *forwarderState) Shutdown() {
s.shutdown <- struct{}{}
}
// AppService is the main service that runs when no command lines flags are passed to cloudflared
// it manages all the running services such as tunnels, forwarders, DNS resolver, etc
type AppService struct {
configManager config.Manager
serviceManager overwatch.Manager
shutdownC chan struct{}
forwarders map[string]forwarderState
configUpdateChan chan config.Root
logger *logrus.Logger
}
// NewAppService creates a new AppService with needed supporting services
func NewAppService(configManager config.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService {
func NewAppService(configManager config.Manager, serviceManager overwatch.Manager, shutdownC chan struct{}, logger *logrus.Logger) *AppService {
return &AppService{
configManager: configManager,
serviceManager: serviceManager,
shutdownC: shutdownC,
forwarders: make(map[string]forwarderState),
configUpdateChan: make(chan config.Root),
logger: logger,
}
@@ -45,6 +36,7 @@ func (s *AppService) Run() error {
// Shutdown kills all the running services
func (s *AppService) Shutdown() error {
s.configManager.Shutdown()
s.shutdownC <- struct{}{}
return nil
}
@@ -62,8 +54,8 @@ func (s *AppService) actionLoop() {
case c := <-s.configUpdateChan:
s.handleConfigUpdate(c)
case <-s.shutdownC:
for _, state := range s.forwarders {
state.Shutdown()
for _, service := range s.serviceManager.Services() {
service.Shutdown()
}
return
}
@@ -72,41 +64,27 @@ func (s *AppService) actionLoop() {
func (s *AppService) handleConfigUpdate(c config.Root) {
// handle the client forward listeners
activeListeners := map[string]struct{}{}
activeServices := map[string]struct{}{}
for _, f := range c.Forwarders {
s.handleForwarderUpdate(f)
activeListeners[f.Listener] = struct{}{}
service := NewForwardService(f)
s.serviceManager.Add(service)
activeServices[service.Name()] = struct{}{}
}
// remove any listeners that are no longer active
for key, state := range s.forwarders {
if _, ok := activeListeners[key]; !ok {
state.Shutdown()
delete(s.forwarders, key)
// handle resolver changes
if c.Resolver.Enabled {
service := NewResolverService(c.Resolver, s.logger)
s.serviceManager.Add(service)
activeServices[service.Name()] = struct{}{}
}
// TODO: TUN-1451 - tunnels
// remove any services that are no longer active
for _, service := range s.serviceManager.Services() {
if _, ok := activeServices[service.Name()]; !ok {
s.serviceManager.Remove(service.Name())
}
}
// TODO: AUTH-2588, TUN-1451 - tunnels and dns proxy
}
// handle managing a forwarder service
func (s *AppService) handleForwarderUpdate(f config.Forwarder) {
// check if we need to start a new listener or stop an old one
if state, ok := s.forwarders[f.Listener]; ok {
if state.forwarder.Hash() == f.Hash() {
return // the exact same listener, no changes, so move along
}
state.Shutdown() //shutdown the listener since a new one is starting
}
// add a new forwarder to the list
state := forwarderState{forwarder: f, shutdown: make(chan struct{}, 1)}
s.forwarders[f.Listener] = state
// start the forwarder
go func(f forwarderState) {
err := access.StartForwarder(f.forwarder, f.shutdown)
if err != nil {
s.logger.WithError(err).Errorf("Forwarder at address: %s", f.forwarder)
}
}(state)
}

View File

@@ -28,14 +28,16 @@ type FileManager struct {
notifier Notifier
configPath string
logger *logrus.Logger
ReadConfig func(string) (Root, error)
}
// NewFileManager creates a config manager
func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (Manager, error) {
func NewFileManager(watcher watcher.Notifier, configPath string, logger *logrus.Logger) (*FileManager, error) {
m := &FileManager{
watcher: watcher,
configPath: configPath,
logger: logger,
ReadConfig: readConfigFromPath,
}
err := watcher.Add(configPath)
return m, err
@@ -58,11 +60,20 @@ func (m *FileManager) Start(notifier Notifier) error {
// GetConfig reads the yaml file from the disk
func (m *FileManager) GetConfig() (Root, error) {
if m.configPath == "" {
return m.ReadConfig(m.configPath)
}
// Shutdown stops the watcher
func (m *FileManager) Shutdown() {
m.watcher.Shutdown()
}
func readConfigFromPath(configPath string) (Root, error) {
if configPath == "" {
return Root{}, errors.New("unable to find config file")
}
file, err := os.Open(m.configPath)
file, err := os.Open(configPath)
if err != nil {
return Root{}, err
}
@@ -76,11 +87,6 @@ func (m *FileManager) GetConfig() (Root, error) {
return config, nil
}
// Shutdown stops the watcher
func (m *FileManager) Shutdown() {
m.watcher.Shutdown()
}
// File change notifications from the watcher
// WatcherItemDidChange triggers when the yaml config is updated

View File

@@ -1,15 +1,12 @@
package config
import (
"bufio"
"os"
"testing"
"time"
"github.com/cloudflare/cloudflared/log"
"github.com/cloudflare/cloudflared/watcher"
"github.com/stretchr/testify/assert"
"gopkg.in/yaml.v2"
)
type mockNotifier struct {
@@ -20,17 +17,27 @@ func (n *mockNotifier) ConfigDidUpdate(c Root) {
n.configs = append(n.configs, c)
}
func writeConfig(t *testing.T, f *os.File, c *Root) {
f.Sync()
b, err := yaml.Marshal(c)
assert.NoError(t, err)
type mockFileWatcher struct {
path string
notifier watcher.Notification
ready chan struct{}
}
w := bufio.NewWriter(f)
_, err = w.Write(b)
assert.NoError(t, err)
func (w *mockFileWatcher) Start(n watcher.Notification) {
w.notifier = n
w.ready <- struct{}{}
}
err = w.Flush()
assert.NoError(t, err)
func (w *mockFileWatcher) Add(string) error {
return nil
}
func (w *mockFileWatcher) Shutdown() {
}
func (w *mockFileWatcher) TriggerChange() {
w.notifier.WatcherItemDidChange(w.path)
}
func TestConfigChanged(t *testing.T) {
@@ -52,22 +59,24 @@ func TestConfigChanged(t *testing.T) {
},
},
}
writeConfig(t, f, c)
configRead := func(configPath string) (Root, error) {
return *c, nil
}
wait := make(chan struct{})
w := &mockFileWatcher{path: filePath, ready: wait}
w, err := watcher.NewFile()
assert.NoError(t, err)
logger := log.CreateLogger()
service, err := NewFileManager(w, filePath, logger)
service.ReadConfig = configRead
assert.NoError(t, err)
n := &mockNotifier{}
go service.Start(n)
<-wait
c.Forwarders = append(c.Forwarders, Forwarder{URL: "add.daltoniam.com", Listener: "127.0.0.1:8081"})
writeConfig(t, f, c)
w.TriggerChange()
// give it time to trigger
time.Sleep(10 * time.Millisecond)
service.Shutdown()
assert.Len(t, n.configs, 2, "did not get 2 config updates as expected")

View File

@@ -4,6 +4,7 @@ import (
"crypto/md5"
"fmt"
"io"
"strings"
)
// Forwarder represents a client side listener to forward traffic to the edge
@@ -19,6 +20,15 @@ type Tunnel struct {
ProtocolType string `json:"type"`
}
// DNSResolver represents a client side DNS resolver
type DNSResolver struct {
Enabled bool `json:"enabled"`
Address string `json:"address"`
Port uint16 `json:"port"`
Upstreams []string `json:"upstreams"`
Bootstraps []string `json:"bootstraps"`
}
// Root is the base options to configure the service
type Root struct {
OrgKey string `json:"org_key"`
@@ -26,6 +36,7 @@ type Root struct {
CheckinInterval int `json:"checkin_interval"`
Forwarders []Forwarder `json:"forwarders,omitempty"`
Tunnels []Tunnel `json:"tunnels,omitempty"`
Resolver DNSResolver `json:"resolver"`
}
// Hash returns the computed values to see if the forwarder values change
@@ -35,3 +46,51 @@ func (f *Forwarder) Hash() string {
io.WriteString(h, f.Listener)
return fmt.Sprintf("%x", h.Sum(nil))
}
// Hash returns the computed values to see if the forwarder values change
func (r *DNSResolver) Hash() string {
h := md5.New()
io.WriteString(h, r.Address)
io.WriteString(h, strings.Join(r.Bootstraps, ","))
io.WriteString(h, strings.Join(r.Upstreams, ","))
io.WriteString(h, fmt.Sprintf("%d", r.Port))
io.WriteString(h, fmt.Sprintf("%v", r.Enabled))
return fmt.Sprintf("%x", h.Sum(nil))
}
// EnabledOrDefault returns the enabled property
func (r *DNSResolver) EnabledOrDefault() bool {
return r.Enabled
}
// AddressOrDefault returns the address or returns the default if empty
func (r *DNSResolver) AddressOrDefault() string {
if r.Address != "" {
return r.Address
}
return "localhost"
}
// PortOrDefault return the port or returns the default if 0
func (r *DNSResolver) PortOrDefault() uint16 {
if r.Port > 0 {
return r.Port
}
return 53
}
// UpstreamsOrDefault returns the upstreams or returns the default if empty
func (r *DNSResolver) UpstreamsOrDefault() []string {
if len(r.Upstreams) > 0 {
return r.Upstreams
}
return []string{"https://1.1.1.1/dns-query", "https://1.0.0.1/dns-query"}
}
// BootstrapsOrDefault returns the bootstraps or returns the default if empty
func (r *DNSResolver) BootstrapsOrDefault() []string {
if len(r.Bootstraps) > 0 {
return r.Bootstraps
}
return []string{"https://162.159.36.1/dns-query", "https://162.159.46.1/dns-query", "https://[2606:4700:4700::1111]/dns-query", "https://[2606:4700:4700::1001]/dns-query"}
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/cloudflare/cloudflared/cmd/cloudflared/updater"
"github.com/cloudflare/cloudflared/log"
"github.com/cloudflare/cloudflared/metrics"
"github.com/cloudflare/cloudflared/overwatch"
"github.com/cloudflare/cloudflared/watcher"
raven "github.com/getsentry/raven-go"
@@ -180,7 +181,9 @@ func handleServiceMode(shutdownC chan struct{}) error {
return err
}
appService := NewAppService(configManager, shutdownC, logger)
serviceManager := overwatch.NewAppManager(nil)
appService := NewAppService(configManager, serviceManager, shutdownC, logger)
if err := appService.Run(); err != nil {
logger.WithError(err).Error("Failed to start app service")
return err