TUN-1961: Create EdgeConnectionManager to maintain outbound connections to the edge

This commit is contained in:
Chung-Ting Huang
2019-06-17 16:18:47 -05:00
parent d26a8c5d44
commit 80a15547e3
15 changed files with 856 additions and 569 deletions

View File

@@ -2,15 +2,14 @@ package connection
import (
"context"
"crypto/tls"
"net"
"sync"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
@@ -18,7 +17,6 @@ import (
)
const (
dialTimeout = 5 * time.Second
openStreamTimeout = 30 * time.Second
)
@@ -30,123 +28,54 @@ func (e dialError) Error() string {
return e.cause.Error()
}
type muxerShutdownError struct{}
func (e muxerShutdownError) Error() string {
return "muxer shutdown"
type Connection struct {
id uuid.UUID
muxer *h2mux.Muxer
}
type ConnectionConfig struct {
TLSConfig *tls.Config
HeartbeatInterval time.Duration
MaxHeartbeats uint64
Logger *logrus.Entry
}
type connectionHandler interface {
serve(ctx context.Context) error
connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error)
shutdown()
}
type h2muxHandler struct {
muxer *h2mux.Muxer
logger *logrus.Entry
}
func (h *h2muxHandler) serve(ctx context.Context) error {
// Serve doesn't return until h2mux is shutdown
if err := h.muxer.Serve(ctx); err != nil {
return err
func newConnection(muxer *h2mux.Muxer, edgeIP *net.TCPAddr) (*Connection, error) {
id, err := uuid.NewRandom()
if err != nil {
return nil, err
}
return muxerShutdownError{}
return &Connection{
id: id,
muxer: muxer,
}, nil
}
func (c *Connection) Serve(ctx context.Context) error {
// Serve doesn't return until h2mux is shutdown
return c.muxer.Serve(ctx)
}
// Connect is used to establish connections with cloudflare's edge network
func (h *h2muxHandler) connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) {
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (*pogs.ConnectResult, error) {
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
defer cancel()
conn, err := h.newRPConn(openStreamCtx)
rpcConn, err := c.newRPConn(openStreamCtx, logger)
if err != nil {
return nil, errors.Wrap(err, "Failed to create new RPC connection")
return nil, errors.Wrap(err, "cannot create new RPC connection")
}
defer conn.Close()
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
defer rpcConn.Close()
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
return tsClient.Connect(ctx, parameters)
}
func (h *h2muxHandler) shutdown() {
h.muxer.Shutdown()
func (c *Connection) Shutdown() {
c.muxer.Shutdown()
}
func (h *h2muxHandler) newRPConn(ctx context.Context) (*rpc.Conn, error) {
stream, err := h.muxer.OpenRPCStream(ctx)
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
stream, err := c.muxer.OpenRPCStream(ctx)
if err != nil {
return nil, err
}
return rpc.NewConn(
tunnelrpc.NewTransportLogger(h.logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(h.logger.WithField("subsystem", "rpc-transport")),
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
), nil
}
// NewConnectionHandler returns a connectionHandler, wrapping h2mux to make RPC calls
func newH2MuxHandler(ctx context.Context,
streamHandler *streamhandler.StreamHandler,
config *ConnectionConfig,
edgeIP *net.TCPAddr,
) (connectionHandler, error) {
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
defer dialCancel()
dialer := net.Dialer{DualStack: true}
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
if err != nil {
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
}
edgeConn := tls.Client(plaintextEdgeConn, config.TLSConfig)
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
err = edgeConn.Handshake()
if err != nil {
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
}
// clear the deadline on the conn; h2mux has its own timeouts
edgeConn.SetDeadline(time.Time{})
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: dialTimeout,
Handler: streamHandler,
IsClient: true,
HeartbeatInterval: config.HeartbeatInterval,
MaxHeartbeats: config.MaxHeartbeats,
Logger: config.Logger,
})
if err != nil {
return nil, err
}
return &h2muxHandler{
muxer: muxer,
logger: config.Logger,
}, nil
}
// connectionPool is a pool of connection handlers
type connectionPool struct {
sync.Mutex
connectionHandlers []connectionHandler
}
func (cp *connectionPool) put(h connectionHandler) {
cp.Lock()
defer cp.Unlock()
cp.connectionHandlers = append(cp.connectionHandlers, h)
}
func (cp *connectionPool) close() {
cp.Lock()
defer cp.Unlock()
for _, h := range cp.connectionHandlers {
h.shutdown()
}
}

View File

@@ -5,10 +5,11 @@ import (
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus"
)
const (
@@ -22,6 +23,9 @@ const (
dotServerName = "cloudflare-dns.com"
dotServerAddr = "1.1.1.1:853"
dotTimeout = time.Duration(15 * time.Second)
// SRV record resolution TTL
resolveEdgeAddrTTL = 1 * time.Hour
)
var friendlyDNSErrorLines = []string{
@@ -34,20 +38,65 @@ var friendlyDNSErrorLines = []string{
` https://developers.cloudflare.com/1.1.1.1/setting-up-1.1.1.1/`,
}
func ResolveEdgeIPs(logger *log.Logger, addresses []string) ([]*net.TCPAddr, error) {
if len(addresses) > 0 {
var tcpAddrs []*net.TCPAddr
for _, address := range addresses {
// Addresses specified (for testing, usually)
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return nil, err
}
tcpAddrs = append(tcpAddrs, tcpAddr)
}
return tcpAddrs, nil
// EdgeServiceDiscoverer is an interface for looking up Cloudflare's edge network addresses
type EdgeServiceDiscoverer interface {
// Addr returns an address to connect to cloudflare's edge network
Addr() *net.TCPAddr
// AvailableAddrs returns the number of unique addresses
AvailableAddrs() uint8
// Refresh rediscover Cloudflare's edge network addresses
Refresh() error
}
// EdgeAddrResolver discovers the addresses of Cloudflare's edge network through SRV record.
// It implements EdgeServiceDiscoverer interface
type EdgeAddrResolver struct {
sync.Mutex
// Addrs to connect to cloudflare's edge network
addrs []*net.TCPAddr
// index of the next element to use in addrs
nextAddrIndex int
logger *logrus.Entry
}
func NewEdgeAddrResolver(logger *logrus.Logger) (EdgeServiceDiscoverer, error) {
r := &EdgeAddrResolver{
logger: logger.WithField("subsystem", " edgeAddrResolver"),
}
// HA service discovery lookup
if err := r.Refresh(); err != nil {
return nil, err
}
return r, nil
}
func (r *EdgeAddrResolver) Addr() *net.TCPAddr {
r.Lock()
defer r.Unlock()
addr := r.addrs[r.nextAddrIndex]
r.nextAddrIndex = (r.nextAddrIndex + 1) % len(r.addrs)
return addr
}
func (r *EdgeAddrResolver) AvailableAddrs() uint8 {
r.Lock()
defer r.Unlock()
return uint8(len(r.addrs))
}
func (r *EdgeAddrResolver) Refresh() error {
newAddrs, err := EdgeDiscovery(r.logger)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.addrs = newAddrs
r.nextAddrIndex = 0
return nil
}
// HA service discovery lookup
func EdgeDiscovery(logger *logrus.Entry) ([]*net.TCPAddr, error) {
_, addrs, err := net.LookupSRV(srvService, srvProto, srvName)
if err != nil {
// Try to fall back to DoT from Cloudflare directly.
@@ -78,7 +127,7 @@ func ResolveEdgeIPs(logger *log.Logger, addresses []string) ([]*net.TCPAddr, err
var resolvedIPsPerCNAME [][]*net.TCPAddr
var lookupErr error
for _, addr := range addrs {
ips, err := ResolveSRVToTCP(addr)
ips, err := resolveSRVToTCP(addr)
if err != nil || len(ips) == 0 {
// don't return early, we might be able to resolve other addresses
lookupErr = err
@@ -86,14 +135,14 @@ func ResolveEdgeIPs(logger *log.Logger, addresses []string) ([]*net.TCPAddr, err
}
resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips)
}
ips := FlattenServiceIPs(resolvedIPsPerCNAME)
ips := flattenServiceIPs(resolvedIPsPerCNAME)
if lookupErr == nil && len(ips) == 0 {
return nil, fmt.Errorf("Unknown service discovery error")
}
return ips, lookupErr
}
func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
func resolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
ips, err := net.LookupIP(srv.Target)
if err != nil {
return nil, err
@@ -107,7 +156,7 @@ func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
// FlattenServiceIPs transposes and flattens the input slices such that the
// first element of the n inner slices are the first n elements of the result.
func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
func flattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
var result []*net.TCPAddr
for len(ipsByService) > 0 {
filtered := ipsByService[:0]
@@ -141,3 +190,65 @@ func fallbackResolver(serverName, serverAddress string) *net.Resolver {
},
}
}
// EdgeHostnameResolver discovers the addresses of Cloudflare's edge network via a list of server hostnames.
// It implements EdgeServiceDiscoverer interface, and is used mainly for testing connectivity.
type EdgeHostnameResolver struct {
sync.Mutex
// hostnames of edge servers
hostnames []string
// Addrs to connect to cloudflare's edge network
addrs []*net.TCPAddr
// index of the next element to use in addrs
nextAddrIndex int
}
func NewEdgeHostnameResolver(edgeHostnames []string) (EdgeServiceDiscoverer, error) {
r := &EdgeHostnameResolver{
hostnames: edgeHostnames,
}
if err := r.Refresh(); err != nil {
return nil, err
}
return r, nil
}
func (r *EdgeHostnameResolver) Addr() *net.TCPAddr {
r.Lock()
defer r.Unlock()
addr := r.addrs[r.nextAddrIndex]
r.nextAddrIndex = (r.nextAddrIndex + 1) % len(r.addrs)
return addr
}
func (r *EdgeHostnameResolver) AvailableAddrs() uint8 {
r.Lock()
defer r.Unlock()
return uint8(len(r.addrs))
}
func (r *EdgeHostnameResolver) Refresh() error {
newAddrs, err := ResolveAddrs(r.hostnames)
if err != nil {
return err
}
r.Lock()
defer r.Unlock()
r.addrs = newAddrs
r.nextAddrIndex = 0
return nil
}
// Resolve TCP address given a list of addresses. Address can be a hostname, however, it will return at most one
// of the hostname's IP addresses
func ResolveAddrs(addrs []string) ([]*net.TCPAddr, error) {
var tcpAddrs []*net.TCPAddr
for _, addr := range addrs {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return nil, err
}
tcpAddrs = append(tcpAddrs, tcpAddr)
}
return tcpAddrs, nil
}

View File

@@ -7,8 +7,26 @@ import (
"github.com/stretchr/testify/assert"
)
type mockEdgeServiceDiscoverer struct {
}
func (mr *mockEdgeServiceDiscoverer) Addr() *net.TCPAddr {
return &net.TCPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 63102,
}
}
func (mr *mockEdgeServiceDiscoverer) AvailableAddrs() uint8 {
return 1
}
func (mr *mockEdgeServiceDiscoverer) Refresh() error {
return nil
}
func TestFlattenServiceIPs(t *testing.T) {
result := FlattenServiceIPs([][]*net.TCPAddr{
result := flattenServiceIPs([][]*net.TCPAddr{
[]*net.TCPAddr{
&net.TCPAddr{Port: 1},
&net.TCPAddr{Port: 2},

281
connection/manager.go Normal file
View File

@@ -0,0 +1,281 @@
package connection
import (
"context"
"crypto/tls"
"fmt"
"net"
"sync"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
quickStartLink = "https://developers.cloudflare.com/argo-tunnel/quickstart/"
faqLink = "https://developers.cloudflare.com/argo-tunnel/faq/"
)
// EdgeManager manages connections with the edge
type EdgeManager struct {
// streamHandler handles stream opened by the edge
streamHandler h2mux.MuxedStreamHandler
// TLSConfig is the TLS configuration to connect with edge
tlsConfig *tls.Config
// cloudflaredConfig is the cloudflared configuration that is determined when the process first starts
cloudflaredConfig *CloudflaredConfig
// serviceDiscoverer returns the next edge addr to connect to
serviceDiscoverer EdgeServiceDiscoverer
// state is attributes of ConnectionManager that can change during runtime.
state *edgeManagerState
logger *logrus.Entry
}
// EdgeConnectionManagerConfigurable is the configurable attributes of a EdgeConnectionManager
type EdgeManagerConfigurable struct {
TunnelHostnames []h2mux.TunnelHostname
*pogs.EdgeConnectionConfig
}
type CloudflaredConfig struct {
CloudflaredID uuid.UUID
Tags []pogs.Tag
BuildInfo *buildinfo.BuildInfo
}
func NewEdgeManager(
streamHandler h2mux.MuxedStreamHandler,
edgeConnMgrConfigurable *EdgeManagerConfigurable,
userCredential []byte,
tlsConfig *tls.Config,
serviceDiscoverer EdgeServiceDiscoverer,
cloudflaredConfig *CloudflaredConfig,
logger *logrus.Logger,
) *EdgeManager {
return &EdgeManager{
streamHandler: streamHandler,
tlsConfig: tlsConfig,
cloudflaredConfig: cloudflaredConfig,
serviceDiscoverer: serviceDiscoverer,
state: newEdgeConnectionManagerState(edgeConnMgrConfigurable, userCredential),
logger: logger.WithField("subsystem", "connectionManager"),
}
}
func (em *EdgeManager) Run(ctx context.Context) error {
defer em.shutdown()
resolveEdgeIPTicker := time.Tick(resolveEdgeAddrTTL)
for {
select {
case <-ctx.Done():
return errors.Wrap(ctx.Err(), "EdgeConnectionManager terminated")
case <-resolveEdgeIPTicker:
if err := em.serviceDiscoverer.Refresh(); err != nil {
em.logger.WithError(err).Warn("Cannot refresh Cloudflare edge addresses")
}
default:
time.Sleep(1 * time.Second)
}
// Create/delete connection one at a time, so we don't need to adjust for connections that are being created/deleted
// in shouldCreateConnection or shouldReduceConnection calculation
if em.state.shouldCreateConnection(em.serviceDiscoverer.AvailableAddrs()) {
if err := em.newConnection(ctx); err != nil {
em.logger.WithError(err).Error("cannot create new connection")
}
} else if em.state.shouldReduceConnection() {
if err := em.closeConnection(ctx); err != nil {
em.logger.WithError(err).Error("cannot close connection")
}
}
}
}
func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurable) {
em.logger.Infof("New edge connection manager configuration %+v", newConfigurable)
em.state.updateConfigurable(newConfigurable)
}
func (em *EdgeManager) newConnection(ctx context.Context) error {
edgeIP := em.serviceDiscoverer.Addr()
edgeConn, err := em.dialEdge(ctx, edgeIP)
if err != nil {
return errors.Wrap(err, "dial edge error")
}
configurable := em.state.getConfigurable()
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: configurable.Timeout,
Handler: em.streamHandler,
IsClient: true,
HeartbeatInterval: configurable.HeartbeatInterval,
MaxHeartbeats: configurable.MaxFailedHeartbeats,
Logger: em.logger.WithField("subsystem", "muxer"),
})
if err != nil {
return errors.Wrap(err, "handshake with edge error")
}
h2muxConn, err := newConnection(muxer, edgeIP)
if err != nil {
return errors.Wrap(err, "create h2mux connection error")
}
go em.serveConn(ctx, h2muxConn)
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
OriginCert: em.state.getUserCredential(),
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
NumPreviousAttempts: 0,
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
}, em.logger)
if err != nil {
h2muxConn.Shutdown()
return errors.Wrap(err, "connect with edge error")
}
if connErr := connResult.Err; connErr != nil {
if !connErr.ShouldRetry {
return errors.Wrap(connErr, em.noRetryMessage())
}
return errors.Wrapf(connErr, "server respond with retry at %v", connErr.RetryAfter)
}
em.state.newConnection(h2muxConn)
em.logger.Infof("connected to %s", connResult.ServerInfo.LocationName)
return nil
}
func (em *EdgeManager) closeConnection(ctx context.Context) error {
conn := em.state.getFirstConnection()
if conn == nil {
return fmt.Errorf("no connection to close")
}
conn.Shutdown()
return nil
}
func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
err := conn.Serve(ctx)
em.logger.WithError(err).Warn("Connection closed")
em.state.closeConnection(conn)
}
func (em *EdgeManager) dialEdge(ctx context.Context, edgeIP *net.TCPAddr) (*tls.Conn, error) {
timeout := em.state.getConfigurable().Timeout
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
defer dialCancel()
dialer := net.Dialer{DualStack: true}
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
if err != nil {
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
}
tlsEdgeConn := tls.Client(edgeConn, em.tlsConfig)
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
if err = tlsEdgeConn.Handshake(); err != nil {
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
}
// clear the deadline on the conn; h2mux has its own timeouts
tlsEdgeConn.SetDeadline(time.Time{})
return tlsEdgeConn, nil
}
func (em *EdgeManager) noRetryMessage() string {
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
"2. Your credential at %s is still valid. See %s."
return fmt.Sprintf(messageTemplate, quickStartLink, em.state.getConfigurable().UserCredentialPath, faqLink)
}
func (em *EdgeManager) shutdown() {
em.state.shutdown()
}
type edgeManagerState struct {
sync.RWMutex
configurable *EdgeManagerConfigurable
userCredential []byte
conns map[uuid.UUID]*Connection
}
func newEdgeConnectionManagerState(configurable *EdgeManagerConfigurable, userCredential []byte) *edgeManagerState {
return &edgeManagerState{
configurable: configurable,
userCredential: userCredential,
conns: make(map[uuid.UUID]*Connection),
}
}
func (ems *edgeManagerState) shouldCreateConnection(availableEdgeAddrs uint8) bool {
ems.RLock()
defer ems.RUnlock()
expectedHAConns := ems.configurable.NumHAConnections
if availableEdgeAddrs < expectedHAConns {
expectedHAConns = availableEdgeAddrs
}
return uint8(len(ems.conns)) < expectedHAConns
}
func (ems *edgeManagerState) shouldReduceConnection() bool {
ems.RLock()
defer ems.RUnlock()
return uint8(len(ems.conns)) > ems.configurable.NumHAConnections
}
func (ems *edgeManagerState) newConnection(conn *Connection) {
ems.Lock()
defer ems.Unlock()
ems.conns[conn.id] = conn
}
func (ems *edgeManagerState) closeConnection(conn *Connection) {
ems.Lock()
defer ems.Unlock()
delete(ems.conns, conn.id)
}
func (ems *edgeManagerState) getFirstConnection() *Connection {
ems.RLock()
defer ems.RUnlock()
for _, conn := range ems.conns {
return conn
}
return nil
}
func (ems *edgeManagerState) shutdown() {
ems.Lock()
defer ems.Unlock()
for _, conn := range ems.conns {
conn.Shutdown()
}
}
func (ems *edgeManagerState) getConfigurable() *EdgeManagerConfigurable {
ems.Lock()
defer ems.Unlock()
return ems.configurable
}
func (ems *edgeManagerState) updateConfigurable(newConfigurable *EdgeManagerConfigurable) {
ems.Lock()
defer ems.Unlock()
ems.configurable = newConfigurable
}
func (ems *edgeManagerState) getUserCredential() []byte {
ems.RLock()
defer ems.RUnlock()
return ems.userCredential
}

View File

@@ -0,0 +1,77 @@
package connection
import (
"testing"
"time"
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/sirupsen/logrus"
)
var (
configurable = &EdgeManagerConfigurable{
[]h2mux.TunnelHostname{
"http.example.com",
"ws.example.com",
"hello.example.com",
},
&pogs.EdgeConnectionConfig{
NumHAConnections: 1,
HeartbeatInterval: 1 * time.Second,
Timeout: 5 * time.Second,
MaxFailedHeartbeats: 3,
UserCredentialPath: "/etc/cloudflared/cert.pem",
},
}
cloudflaredConfig = &CloudflaredConfig{
CloudflaredID: uuid.New(),
Tags: []pogs.Tag{
{Name: "pool", Value: "east-6"},
},
BuildInfo: &buildinfo.BuildInfo{
GoOS: "linux",
GoVersion: "1.12",
GoArch: "amd64",
CloudflaredVersion: "2019.6.0",
},
}
)
type mockStreamHandler struct {
}
func (msh *mockStreamHandler) ServeStream(*h2mux.MuxedStream) error {
return nil
}
func mockEdgeManager() *EdgeManager {
return NewEdgeManager(
&mockStreamHandler{},
configurable,
[]byte{},
nil,
&mockEdgeServiceDiscoverer{},
cloudflaredConfig,
logrus.New(),
)
}
func TestUpdateConfigurable(t *testing.T) {
m := mockEdgeManager()
newConfigurable := &EdgeManagerConfigurable{
[]h2mux.TunnelHostname{
"second.example.com",
},
&pogs.EdgeConnectionConfig{
NumHAConnections: 2,
},
}
m.UpdateConfigurable(newConfigurable)
assert.Equal(t, newConfigurable, m.state.getConfigurable())
}

View File

@@ -1,158 +0,0 @@
package connection
import (
"context"
"net"
"time"
"github.com/cloudflare/cloudflared/streamhandler"
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
// Waiting time before retrying a failed tunnel connection
reconnectDuration = time.Second * 10
// SRV record resolution TTL
resolveTTL = time.Hour
// Interval between establishing new connection
connectionInterval = time.Second
)
type CloudflaredConfig struct {
ConnectionConfig *ConnectionConfig
OriginCert []byte
Tags []tunnelpogs.Tag
EdgeAddrs []string
HAConnections uint
Logger *logrus.Logger
CloudflaredVersion string
}
// Supervisor is a stateful object that manages connections with the edge
type Supervisor struct {
streamHandler *streamhandler.StreamHandler
newConfigChan chan<- *pogs.ClientConfig
useConfigResultChan <-chan *pogs.UseConfigurationResult
config *CloudflaredConfig
state *supervisorState
connErrors chan error
}
type supervisorState struct {
// IPs to connect to cloudflare's edge network
edgeIPs []*net.TCPAddr
// index of the next element to use in edgeIPs
nextEdgeIPIndex int
// last time edgeIPs were refreshed
lastResolveTime time.Time
// ID of this cloudflared instance
cloudflaredID uuid.UUID
// connectionPool is a pool of connectionHandlers that can be used to make RPCs
connectionPool *connectionPool
}
func (s *supervisorState) getNextEdgeIP() *net.TCPAddr {
ip := s.edgeIPs[s.nextEdgeIPIndex%len(s.edgeIPs)]
s.nextEdgeIPIndex++
return ip
}
func NewSupervisor(config *CloudflaredConfig) *Supervisor {
newConfigChan := make(chan *pogs.ClientConfig)
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
return &Supervisor{
streamHandler: streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, config.Logger),
newConfigChan: newConfigChan,
useConfigResultChan: useConfigResultChan,
config: config,
state: &supervisorState{
connectionPool: &connectionPool{},
},
connErrors: make(chan error),
}
}
func (s *Supervisor) Run(ctx context.Context) error {
logger := s.config.Logger
if err := s.initialize(); err != nil {
logger.WithError(err).Error("Failed to get edge IPs")
return err
}
defer s.state.connectionPool.close()
var currentConnectionCount uint
expectedConnectionCount := s.config.HAConnections
if uint(len(s.state.edgeIPs)) < s.config.HAConnections {
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(s.state.edgeIPs))
expectedConnectionCount = uint(len(s.state.edgeIPs))
}
for {
select {
case <-ctx.Done():
return nil
case connErr := <-s.connErrors:
logger.WithError(connErr).Warnf("Connection dropped unexpectedly")
currentConnectionCount--
default:
time.Sleep(5 * time.Second)
}
if currentConnectionCount < expectedConnectionCount {
h, err := newH2MuxHandler(ctx, s.streamHandler, s.config.ConnectionConfig, s.state.getNextEdgeIP())
if err != nil {
logger.WithError(err).Error("Failed to create new connection handler")
continue
}
go func() {
s.connErrors <- h.serve(ctx)
}()
connResult, err := s.connect(ctx, s.config, s.state.cloudflaredID, h)
if err != nil {
logger.WithError(err).Errorf("Failed to connect to cloudflared's edge network")
h.shutdown()
continue
}
if connErr := connResult.Err; connErr != nil && !connErr.ShouldRetry {
logger.WithError(connErr).Errorf("Server respond with don't retry to connect")
h.shutdown()
return err
}
logger.Infof("Connected to %s", connResult.ServerInfo.LocationName)
s.state.connectionPool.put(h)
currentConnectionCount++
}
}
}
func (s *Supervisor) initialize() error {
edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs)
if err != nil {
return errors.Wrapf(err, "Failed to resolve cloudflare edge network address")
}
s.state.edgeIPs = edgeIPs
s.state.lastResolveTime = time.Now()
cloudflaredID, err := uuid.NewRandom()
if err != nil {
return errors.Wrap(err, "Failed to generate cloudflared ID")
}
s.state.cloudflaredID = cloudflaredID
return nil
}
func (s *Supervisor) connect(ctx context.Context,
config *CloudflaredConfig,
cloudflaredID uuid.UUID,
h connectionHandler,
) (*tunnelpogs.ConnectResult, error) {
connectParameters := &tunnelpogs.ConnectParameters{
OriginCert: config.OriginCert,
CloudflaredID: cloudflaredID,
NumPreviousAttempts: 0,
CloudflaredVersion: config.CloudflaredVersion,
}
return h.connect(ctx, connectParameters)
}