mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:59:58 +00:00
TUN-1961: Create EdgeConnectionManager to maintain outbound connections to the edge
This commit is contained in:
@@ -2,15 +2,14 @@ package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
@@ -18,7 +17,6 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
dialTimeout = 5 * time.Second
|
||||
openStreamTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
@@ -30,123 +28,54 @@ func (e dialError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
type muxerShutdownError struct{}
|
||||
|
||||
func (e muxerShutdownError) Error() string {
|
||||
return "muxer shutdown"
|
||||
type Connection struct {
|
||||
id uuid.UUID
|
||||
muxer *h2mux.Muxer
|
||||
}
|
||||
|
||||
type ConnectionConfig struct {
|
||||
TLSConfig *tls.Config
|
||||
HeartbeatInterval time.Duration
|
||||
MaxHeartbeats uint64
|
||||
Logger *logrus.Entry
|
||||
}
|
||||
|
||||
type connectionHandler interface {
|
||||
serve(ctx context.Context) error
|
||||
connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error)
|
||||
shutdown()
|
||||
}
|
||||
|
||||
type h2muxHandler struct {
|
||||
muxer *h2mux.Muxer
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func (h *h2muxHandler) serve(ctx context.Context) error {
|
||||
// Serve doesn't return until h2mux is shutdown
|
||||
if err := h.muxer.Serve(ctx); err != nil {
|
||||
return err
|
||||
func newConnection(muxer *h2mux.Muxer, edgeIP *net.TCPAddr) (*Connection, error) {
|
||||
id, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return muxerShutdownError{}
|
||||
return &Connection{
|
||||
id: id,
|
||||
muxer: muxer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Connection) Serve(ctx context.Context) error {
|
||||
// Serve doesn't return until h2mux is shutdown
|
||||
return c.muxer.Serve(ctx)
|
||||
}
|
||||
|
||||
// Connect is used to establish connections with cloudflare's edge network
|
||||
func (h *h2muxHandler) connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) {
|
||||
func (c *Connection) Connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters, logger *logrus.Entry) (*pogs.ConnectResult, error) {
|
||||
openStreamCtx, cancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||
defer cancel()
|
||||
conn, err := h.newRPConn(openStreamCtx)
|
||||
|
||||
rpcConn, err := c.newRPConn(openStreamCtx, logger)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to create new RPC connection")
|
||||
return nil, errors.Wrap(err, "cannot create new RPC connection")
|
||||
}
|
||||
defer conn.Close()
|
||||
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
|
||||
defer rpcConn.Close()
|
||||
|
||||
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: rpcConn.Bootstrap(ctx)}
|
||||
|
||||
return tsClient.Connect(ctx, parameters)
|
||||
}
|
||||
|
||||
func (h *h2muxHandler) shutdown() {
|
||||
h.muxer.Shutdown()
|
||||
func (c *Connection) Shutdown() {
|
||||
c.muxer.Shutdown()
|
||||
}
|
||||
|
||||
func (h *h2muxHandler) newRPConn(ctx context.Context) (*rpc.Conn, error) {
|
||||
stream, err := h.muxer.OpenRPCStream(ctx)
|
||||
func (c *Connection) newRPConn(ctx context.Context, logger *logrus.Entry) (*rpc.Conn, error) {
|
||||
stream, err := c.muxer.OpenRPCStream(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rpc.NewConn(
|
||||
tunnelrpc.NewTransportLogger(h.logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(h.logger.WithField("subsystem", "rpc-transport")),
|
||||
tunnelrpc.NewTransportLogger(logger.WithField("rpc", "connect"), rpc.StreamTransport(stream)),
|
||||
tunnelrpc.ConnLog(logger.WithField("rpc", "connect")),
|
||||
), nil
|
||||
}
|
||||
|
||||
// NewConnectionHandler returns a connectionHandler, wrapping h2mux to make RPC calls
|
||||
func newH2MuxHandler(ctx context.Context,
|
||||
streamHandler *streamhandler.StreamHandler,
|
||||
config *ConnectionConfig,
|
||||
edgeIP *net.TCPAddr,
|
||||
) (connectionHandler, error) {
|
||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
|
||||
defer dialCancel()
|
||||
dialer := net.Dialer{DualStack: true}
|
||||
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
|
||||
if err != nil {
|
||||
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
|
||||
}
|
||||
edgeConn := tls.Client(plaintextEdgeConn, config.TLSConfig)
|
||||
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
|
||||
err = edgeConn.Handshake()
|
||||
if err != nil {
|
||||
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
edgeConn.SetDeadline(time.Time{})
|
||||
// Establish a muxed connection with the edge
|
||||
// Client mux handshake with agent server
|
||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||
Timeout: dialTimeout,
|
||||
Handler: streamHandler,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: config.HeartbeatInterval,
|
||||
MaxHeartbeats: config.MaxHeartbeats,
|
||||
Logger: config.Logger,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &h2muxHandler{
|
||||
muxer: muxer,
|
||||
logger: config.Logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// connectionPool is a pool of connection handlers
|
||||
type connectionPool struct {
|
||||
sync.Mutex
|
||||
connectionHandlers []connectionHandler
|
||||
}
|
||||
|
||||
func (cp *connectionPool) put(h connectionHandler) {
|
||||
cp.Lock()
|
||||
defer cp.Unlock()
|
||||
cp.connectionHandlers = append(cp.connectionHandlers, h)
|
||||
}
|
||||
|
||||
func (cp *connectionPool) close() {
|
||||
cp.Lock()
|
||||
defer cp.Unlock()
|
||||
for _, h := range cp.connectionHandlers {
|
||||
h.shutdown()
|
||||
}
|
||||
}
|
||||
|
@@ -5,10 +5,11 @@ import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -22,6 +23,9 @@ const (
|
||||
dotServerName = "cloudflare-dns.com"
|
||||
dotServerAddr = "1.1.1.1:853"
|
||||
dotTimeout = time.Duration(15 * time.Second)
|
||||
|
||||
// SRV record resolution TTL
|
||||
resolveEdgeAddrTTL = 1 * time.Hour
|
||||
)
|
||||
|
||||
var friendlyDNSErrorLines = []string{
|
||||
@@ -34,20 +38,65 @@ var friendlyDNSErrorLines = []string{
|
||||
` https://developers.cloudflare.com/1.1.1.1/setting-up-1.1.1.1/`,
|
||||
}
|
||||
|
||||
func ResolveEdgeIPs(logger *log.Logger, addresses []string) ([]*net.TCPAddr, error) {
|
||||
if len(addresses) > 0 {
|
||||
var tcpAddrs []*net.TCPAddr
|
||||
for _, address := range addresses {
|
||||
// Addresses specified (for testing, usually)
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpAddrs = append(tcpAddrs, tcpAddr)
|
||||
}
|
||||
return tcpAddrs, nil
|
||||
// EdgeServiceDiscoverer is an interface for looking up Cloudflare's edge network addresses
|
||||
type EdgeServiceDiscoverer interface {
|
||||
// Addr returns an address to connect to cloudflare's edge network
|
||||
Addr() *net.TCPAddr
|
||||
// AvailableAddrs returns the number of unique addresses
|
||||
AvailableAddrs() uint8
|
||||
// Refresh rediscover Cloudflare's edge network addresses
|
||||
Refresh() error
|
||||
}
|
||||
|
||||
// EdgeAddrResolver discovers the addresses of Cloudflare's edge network through SRV record.
|
||||
// It implements EdgeServiceDiscoverer interface
|
||||
type EdgeAddrResolver struct {
|
||||
sync.Mutex
|
||||
// Addrs to connect to cloudflare's edge network
|
||||
addrs []*net.TCPAddr
|
||||
// index of the next element to use in addrs
|
||||
nextAddrIndex int
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
func NewEdgeAddrResolver(logger *logrus.Logger) (EdgeServiceDiscoverer, error) {
|
||||
r := &EdgeAddrResolver{
|
||||
logger: logger.WithField("subsystem", " edgeAddrResolver"),
|
||||
}
|
||||
// HA service discovery lookup
|
||||
if err := r.Refresh(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *EdgeAddrResolver) Addr() *net.TCPAddr {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
addr := r.addrs[r.nextAddrIndex]
|
||||
r.nextAddrIndex = (r.nextAddrIndex + 1) % len(r.addrs)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (r *EdgeAddrResolver) AvailableAddrs() uint8 {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return uint8(len(r.addrs))
|
||||
}
|
||||
|
||||
func (r *EdgeAddrResolver) Refresh() error {
|
||||
newAddrs, err := EdgeDiscovery(r.logger)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.addrs = newAddrs
|
||||
r.nextAddrIndex = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// HA service discovery lookup
|
||||
func EdgeDiscovery(logger *logrus.Entry) ([]*net.TCPAddr, error) {
|
||||
_, addrs, err := net.LookupSRV(srvService, srvProto, srvName)
|
||||
if err != nil {
|
||||
// Try to fall back to DoT from Cloudflare directly.
|
||||
@@ -78,7 +127,7 @@ func ResolveEdgeIPs(logger *log.Logger, addresses []string) ([]*net.TCPAddr, err
|
||||
var resolvedIPsPerCNAME [][]*net.TCPAddr
|
||||
var lookupErr error
|
||||
for _, addr := range addrs {
|
||||
ips, err := ResolveSRVToTCP(addr)
|
||||
ips, err := resolveSRVToTCP(addr)
|
||||
if err != nil || len(ips) == 0 {
|
||||
// don't return early, we might be able to resolve other addresses
|
||||
lookupErr = err
|
||||
@@ -86,14 +135,14 @@ func ResolveEdgeIPs(logger *log.Logger, addresses []string) ([]*net.TCPAddr, err
|
||||
}
|
||||
resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips)
|
||||
}
|
||||
ips := FlattenServiceIPs(resolvedIPsPerCNAME)
|
||||
ips := flattenServiceIPs(resolvedIPsPerCNAME)
|
||||
if lookupErr == nil && len(ips) == 0 {
|
||||
return nil, fmt.Errorf("Unknown service discovery error")
|
||||
}
|
||||
return ips, lookupErr
|
||||
}
|
||||
|
||||
func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
|
||||
func resolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
|
||||
ips, err := net.LookupIP(srv.Target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -107,7 +156,7 @@ func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
|
||||
|
||||
// FlattenServiceIPs transposes and flattens the input slices such that the
|
||||
// first element of the n inner slices are the first n elements of the result.
|
||||
func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
|
||||
func flattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
|
||||
var result []*net.TCPAddr
|
||||
for len(ipsByService) > 0 {
|
||||
filtered := ipsByService[:0]
|
||||
@@ -141,3 +190,65 @@ func fallbackResolver(serverName, serverAddress string) *net.Resolver {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// EdgeHostnameResolver discovers the addresses of Cloudflare's edge network via a list of server hostnames.
|
||||
// It implements EdgeServiceDiscoverer interface, and is used mainly for testing connectivity.
|
||||
type EdgeHostnameResolver struct {
|
||||
sync.Mutex
|
||||
// hostnames of edge servers
|
||||
hostnames []string
|
||||
// Addrs to connect to cloudflare's edge network
|
||||
addrs []*net.TCPAddr
|
||||
// index of the next element to use in addrs
|
||||
nextAddrIndex int
|
||||
}
|
||||
|
||||
func NewEdgeHostnameResolver(edgeHostnames []string) (EdgeServiceDiscoverer, error) {
|
||||
r := &EdgeHostnameResolver{
|
||||
hostnames: edgeHostnames,
|
||||
}
|
||||
if err := r.Refresh(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *EdgeHostnameResolver) Addr() *net.TCPAddr {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
addr := r.addrs[r.nextAddrIndex]
|
||||
r.nextAddrIndex = (r.nextAddrIndex + 1) % len(r.addrs)
|
||||
return addr
|
||||
}
|
||||
|
||||
func (r *EdgeHostnameResolver) AvailableAddrs() uint8 {
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
return uint8(len(r.addrs))
|
||||
}
|
||||
|
||||
func (r *EdgeHostnameResolver) Refresh() error {
|
||||
newAddrs, err := ResolveAddrs(r.hostnames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.Lock()
|
||||
defer r.Unlock()
|
||||
r.addrs = newAddrs
|
||||
r.nextAddrIndex = 0
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resolve TCP address given a list of addresses. Address can be a hostname, however, it will return at most one
|
||||
// of the hostname's IP addresses
|
||||
func ResolveAddrs(addrs []string) ([]*net.TCPAddr, error) {
|
||||
var tcpAddrs []*net.TCPAddr
|
||||
for _, addr := range addrs {
|
||||
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcpAddrs = append(tcpAddrs, tcpAddr)
|
||||
}
|
||||
return tcpAddrs, nil
|
||||
}
|
||||
|
@@ -7,8 +7,26 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type mockEdgeServiceDiscoverer struct {
|
||||
}
|
||||
|
||||
func (mr *mockEdgeServiceDiscoverer) Addr() *net.TCPAddr {
|
||||
return &net.TCPAddr{
|
||||
IP: net.ParseIP("127.0.0.1"),
|
||||
Port: 63102,
|
||||
}
|
||||
}
|
||||
|
||||
func (mr *mockEdgeServiceDiscoverer) AvailableAddrs() uint8 {
|
||||
return 1
|
||||
}
|
||||
|
||||
func (mr *mockEdgeServiceDiscoverer) Refresh() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestFlattenServiceIPs(t *testing.T) {
|
||||
result := FlattenServiceIPs([][]*net.TCPAddr{
|
||||
result := flattenServiceIPs([][]*net.TCPAddr{
|
||||
[]*net.TCPAddr{
|
||||
&net.TCPAddr{Port: 1},
|
||||
&net.TCPAddr{Port: 2},
|
||||
|
281
connection/manager.go
Normal file
281
connection/manager.go
Normal file
@@ -0,0 +1,281 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
quickStartLink = "https://developers.cloudflare.com/argo-tunnel/quickstart/"
|
||||
faqLink = "https://developers.cloudflare.com/argo-tunnel/faq/"
|
||||
)
|
||||
|
||||
// EdgeManager manages connections with the edge
|
||||
type EdgeManager struct {
|
||||
// streamHandler handles stream opened by the edge
|
||||
streamHandler h2mux.MuxedStreamHandler
|
||||
// TLSConfig is the TLS configuration to connect with edge
|
||||
tlsConfig *tls.Config
|
||||
// cloudflaredConfig is the cloudflared configuration that is determined when the process first starts
|
||||
cloudflaredConfig *CloudflaredConfig
|
||||
// serviceDiscoverer returns the next edge addr to connect to
|
||||
serviceDiscoverer EdgeServiceDiscoverer
|
||||
// state is attributes of ConnectionManager that can change during runtime.
|
||||
state *edgeManagerState
|
||||
|
||||
logger *logrus.Entry
|
||||
}
|
||||
|
||||
// EdgeConnectionManagerConfigurable is the configurable attributes of a EdgeConnectionManager
|
||||
type EdgeManagerConfigurable struct {
|
||||
TunnelHostnames []h2mux.TunnelHostname
|
||||
*pogs.EdgeConnectionConfig
|
||||
}
|
||||
|
||||
type CloudflaredConfig struct {
|
||||
CloudflaredID uuid.UUID
|
||||
Tags []pogs.Tag
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
}
|
||||
|
||||
func NewEdgeManager(
|
||||
streamHandler h2mux.MuxedStreamHandler,
|
||||
edgeConnMgrConfigurable *EdgeManagerConfigurable,
|
||||
userCredential []byte,
|
||||
tlsConfig *tls.Config,
|
||||
serviceDiscoverer EdgeServiceDiscoverer,
|
||||
cloudflaredConfig *CloudflaredConfig,
|
||||
logger *logrus.Logger,
|
||||
) *EdgeManager {
|
||||
return &EdgeManager{
|
||||
streamHandler: streamHandler,
|
||||
tlsConfig: tlsConfig,
|
||||
cloudflaredConfig: cloudflaredConfig,
|
||||
serviceDiscoverer: serviceDiscoverer,
|
||||
state: newEdgeConnectionManagerState(edgeConnMgrConfigurable, userCredential),
|
||||
logger: logger.WithField("subsystem", "connectionManager"),
|
||||
}
|
||||
}
|
||||
|
||||
func (em *EdgeManager) Run(ctx context.Context) error {
|
||||
defer em.shutdown()
|
||||
|
||||
resolveEdgeIPTicker := time.Tick(resolveEdgeAddrTTL)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.Wrap(ctx.Err(), "EdgeConnectionManager terminated")
|
||||
case <-resolveEdgeIPTicker:
|
||||
if err := em.serviceDiscoverer.Refresh(); err != nil {
|
||||
em.logger.WithError(err).Warn("Cannot refresh Cloudflare edge addresses")
|
||||
}
|
||||
default:
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
// Create/delete connection one at a time, so we don't need to adjust for connections that are being created/deleted
|
||||
// in shouldCreateConnection or shouldReduceConnection calculation
|
||||
if em.state.shouldCreateConnection(em.serviceDiscoverer.AvailableAddrs()) {
|
||||
if err := em.newConnection(ctx); err != nil {
|
||||
em.logger.WithError(err).Error("cannot create new connection")
|
||||
}
|
||||
} else if em.state.shouldReduceConnection() {
|
||||
if err := em.closeConnection(ctx); err != nil {
|
||||
em.logger.WithError(err).Error("cannot close connection")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (em *EdgeManager) UpdateConfigurable(newConfigurable *EdgeManagerConfigurable) {
|
||||
em.logger.Infof("New edge connection manager configuration %+v", newConfigurable)
|
||||
em.state.updateConfigurable(newConfigurable)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) newConnection(ctx context.Context) error {
|
||||
edgeIP := em.serviceDiscoverer.Addr()
|
||||
edgeConn, err := em.dialEdge(ctx, edgeIP)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "dial edge error")
|
||||
}
|
||||
configurable := em.state.getConfigurable()
|
||||
// Establish a muxed connection with the edge
|
||||
// Client mux handshake with agent server
|
||||
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
|
||||
Timeout: configurable.Timeout,
|
||||
Handler: em.streamHandler,
|
||||
IsClient: true,
|
||||
HeartbeatInterval: configurable.HeartbeatInterval,
|
||||
MaxHeartbeats: configurable.MaxFailedHeartbeats,
|
||||
Logger: em.logger.WithField("subsystem", "muxer"),
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "handshake with edge error")
|
||||
}
|
||||
|
||||
h2muxConn, err := newConnection(muxer, edgeIP)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create h2mux connection error")
|
||||
}
|
||||
|
||||
go em.serveConn(ctx, h2muxConn)
|
||||
|
||||
connResult, err := h2muxConn.Connect(ctx, &pogs.ConnectParameters{
|
||||
OriginCert: em.state.getUserCredential(),
|
||||
CloudflaredID: em.cloudflaredConfig.CloudflaredID,
|
||||
NumPreviousAttempts: 0,
|
||||
CloudflaredVersion: em.cloudflaredConfig.BuildInfo.CloudflaredVersion,
|
||||
}, em.logger)
|
||||
if err != nil {
|
||||
h2muxConn.Shutdown()
|
||||
return errors.Wrap(err, "connect with edge error")
|
||||
}
|
||||
|
||||
if connErr := connResult.Err; connErr != nil {
|
||||
if !connErr.ShouldRetry {
|
||||
return errors.Wrap(connErr, em.noRetryMessage())
|
||||
}
|
||||
return errors.Wrapf(connErr, "server respond with retry at %v", connErr.RetryAfter)
|
||||
}
|
||||
|
||||
em.state.newConnection(h2muxConn)
|
||||
em.logger.Infof("connected to %s", connResult.ServerInfo.LocationName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (em *EdgeManager) closeConnection(ctx context.Context) error {
|
||||
conn := em.state.getFirstConnection()
|
||||
if conn == nil {
|
||||
return fmt.Errorf("no connection to close")
|
||||
}
|
||||
conn.Shutdown()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (em *EdgeManager) serveConn(ctx context.Context, conn *Connection) {
|
||||
err := conn.Serve(ctx)
|
||||
em.logger.WithError(err).Warn("Connection closed")
|
||||
em.state.closeConnection(conn)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) dialEdge(ctx context.Context, edgeIP *net.TCPAddr) (*tls.Conn, error) {
|
||||
timeout := em.state.getConfigurable().Timeout
|
||||
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
|
||||
dialCtx, dialCancel := context.WithTimeout(ctx, timeout)
|
||||
defer dialCancel()
|
||||
|
||||
dialer := net.Dialer{DualStack: true}
|
||||
edgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
|
||||
if err != nil {
|
||||
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
|
||||
}
|
||||
tlsEdgeConn := tls.Client(edgeConn, em.tlsConfig)
|
||||
tlsEdgeConn.SetDeadline(time.Now().Add(timeout))
|
||||
|
||||
if err = tlsEdgeConn.Handshake(); err != nil {
|
||||
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
|
||||
}
|
||||
// clear the deadline on the conn; h2mux has its own timeouts
|
||||
tlsEdgeConn.SetDeadline(time.Time{})
|
||||
return tlsEdgeConn, nil
|
||||
}
|
||||
|
||||
func (em *EdgeManager) noRetryMessage() string {
|
||||
messageTemplate := "cloudflared could not register an Argo Tunnel on your account. Please confirm the following before trying again:" +
|
||||
"1. You have Argo Smart Routing enabled in your account, See Enable Argo section of %s." +
|
||||
"2. Your credential at %s is still valid. See %s."
|
||||
return fmt.Sprintf(messageTemplate, quickStartLink, em.state.getConfigurable().UserCredentialPath, faqLink)
|
||||
}
|
||||
|
||||
func (em *EdgeManager) shutdown() {
|
||||
em.state.shutdown()
|
||||
}
|
||||
|
||||
type edgeManagerState struct {
|
||||
sync.RWMutex
|
||||
configurable *EdgeManagerConfigurable
|
||||
userCredential []byte
|
||||
conns map[uuid.UUID]*Connection
|
||||
}
|
||||
|
||||
func newEdgeConnectionManagerState(configurable *EdgeManagerConfigurable, userCredential []byte) *edgeManagerState {
|
||||
return &edgeManagerState{
|
||||
configurable: configurable,
|
||||
userCredential: userCredential,
|
||||
conns: make(map[uuid.UUID]*Connection),
|
||||
}
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) shouldCreateConnection(availableEdgeAddrs uint8) bool {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
expectedHAConns := ems.configurable.NumHAConnections
|
||||
if availableEdgeAddrs < expectedHAConns {
|
||||
expectedHAConns = availableEdgeAddrs
|
||||
}
|
||||
return uint8(len(ems.conns)) < expectedHAConns
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) shouldReduceConnection() bool {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
return uint8(len(ems.conns)) > ems.configurable.NumHAConnections
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) newConnection(conn *Connection) {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
ems.conns[conn.id] = conn
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) closeConnection(conn *Connection) {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
delete(ems.conns, conn.id)
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) getFirstConnection() *Connection {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
|
||||
for _, conn := range ems.conns {
|
||||
return conn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) shutdown() {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
for _, conn := range ems.conns {
|
||||
conn.Shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) getConfigurable() *EdgeManagerConfigurable {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
return ems.configurable
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) updateConfigurable(newConfigurable *EdgeManagerConfigurable) {
|
||||
ems.Lock()
|
||||
defer ems.Unlock()
|
||||
ems.configurable = newConfigurable
|
||||
}
|
||||
|
||||
func (ems *edgeManagerState) getUserCredential() []byte {
|
||||
ems.RLock()
|
||||
defer ems.RUnlock()
|
||||
return ems.userCredential
|
||||
}
|
77
connection/manager_test.go
Normal file
77
connection/manager_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cmd/cloudflared/buildinfo"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/google/uuid"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
configurable = &EdgeManagerConfigurable{
|
||||
[]h2mux.TunnelHostname{
|
||||
"http.example.com",
|
||||
"ws.example.com",
|
||||
"hello.example.com",
|
||||
},
|
||||
&pogs.EdgeConnectionConfig{
|
||||
NumHAConnections: 1,
|
||||
HeartbeatInterval: 1 * time.Second,
|
||||
Timeout: 5 * time.Second,
|
||||
MaxFailedHeartbeats: 3,
|
||||
UserCredentialPath: "/etc/cloudflared/cert.pem",
|
||||
},
|
||||
}
|
||||
cloudflaredConfig = &CloudflaredConfig{
|
||||
CloudflaredID: uuid.New(),
|
||||
Tags: []pogs.Tag{
|
||||
{Name: "pool", Value: "east-6"},
|
||||
},
|
||||
BuildInfo: &buildinfo.BuildInfo{
|
||||
GoOS: "linux",
|
||||
GoVersion: "1.12",
|
||||
GoArch: "amd64",
|
||||
CloudflaredVersion: "2019.6.0",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type mockStreamHandler struct {
|
||||
}
|
||||
|
||||
func (msh *mockStreamHandler) ServeStream(*h2mux.MuxedStream) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func mockEdgeManager() *EdgeManager {
|
||||
return NewEdgeManager(
|
||||
&mockStreamHandler{},
|
||||
configurable,
|
||||
[]byte{},
|
||||
nil,
|
||||
&mockEdgeServiceDiscoverer{},
|
||||
cloudflaredConfig,
|
||||
logrus.New(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestUpdateConfigurable(t *testing.T) {
|
||||
m := mockEdgeManager()
|
||||
newConfigurable := &EdgeManagerConfigurable{
|
||||
[]h2mux.TunnelHostname{
|
||||
"second.example.com",
|
||||
},
|
||||
&pogs.EdgeConnectionConfig{
|
||||
NumHAConnections: 2,
|
||||
},
|
||||
}
|
||||
m.UpdateConfigurable(newConfigurable)
|
||||
|
||||
assert.Equal(t, newConfigurable, m.state.getConfigurable())
|
||||
}
|
@@ -1,158 +0,0 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/streamhandler"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// Waiting time before retrying a failed tunnel connection
|
||||
reconnectDuration = time.Second * 10
|
||||
// SRV record resolution TTL
|
||||
resolveTTL = time.Hour
|
||||
// Interval between establishing new connection
|
||||
connectionInterval = time.Second
|
||||
)
|
||||
|
||||
type CloudflaredConfig struct {
|
||||
ConnectionConfig *ConnectionConfig
|
||||
OriginCert []byte
|
||||
Tags []tunnelpogs.Tag
|
||||
EdgeAddrs []string
|
||||
HAConnections uint
|
||||
Logger *logrus.Logger
|
||||
CloudflaredVersion string
|
||||
}
|
||||
|
||||
// Supervisor is a stateful object that manages connections with the edge
|
||||
type Supervisor struct {
|
||||
streamHandler *streamhandler.StreamHandler
|
||||
newConfigChan chan<- *pogs.ClientConfig
|
||||
useConfigResultChan <-chan *pogs.UseConfigurationResult
|
||||
config *CloudflaredConfig
|
||||
state *supervisorState
|
||||
connErrors chan error
|
||||
}
|
||||
|
||||
type supervisorState struct {
|
||||
// IPs to connect to cloudflare's edge network
|
||||
edgeIPs []*net.TCPAddr
|
||||
// index of the next element to use in edgeIPs
|
||||
nextEdgeIPIndex int
|
||||
// last time edgeIPs were refreshed
|
||||
lastResolveTime time.Time
|
||||
// ID of this cloudflared instance
|
||||
cloudflaredID uuid.UUID
|
||||
// connectionPool is a pool of connectionHandlers that can be used to make RPCs
|
||||
connectionPool *connectionPool
|
||||
}
|
||||
|
||||
func (s *supervisorState) getNextEdgeIP() *net.TCPAddr {
|
||||
ip := s.edgeIPs[s.nextEdgeIPIndex%len(s.edgeIPs)]
|
||||
s.nextEdgeIPIndex++
|
||||
return ip
|
||||
}
|
||||
|
||||
func NewSupervisor(config *CloudflaredConfig) *Supervisor {
|
||||
newConfigChan := make(chan *pogs.ClientConfig)
|
||||
useConfigResultChan := make(chan *pogs.UseConfigurationResult)
|
||||
return &Supervisor{
|
||||
streamHandler: streamhandler.NewStreamHandler(newConfigChan, useConfigResultChan, config.Logger),
|
||||
newConfigChan: newConfigChan,
|
||||
useConfigResultChan: useConfigResultChan,
|
||||
config: config,
|
||||
state: &supervisorState{
|
||||
connectionPool: &connectionPool{},
|
||||
},
|
||||
connErrors: make(chan error),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) Run(ctx context.Context) error {
|
||||
logger := s.config.Logger
|
||||
if err := s.initialize(); err != nil {
|
||||
logger.WithError(err).Error("Failed to get edge IPs")
|
||||
return err
|
||||
}
|
||||
defer s.state.connectionPool.close()
|
||||
|
||||
var currentConnectionCount uint
|
||||
expectedConnectionCount := s.config.HAConnections
|
||||
if uint(len(s.state.edgeIPs)) < s.config.HAConnections {
|
||||
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(s.state.edgeIPs))
|
||||
expectedConnectionCount = uint(len(s.state.edgeIPs))
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case connErr := <-s.connErrors:
|
||||
logger.WithError(connErr).Warnf("Connection dropped unexpectedly")
|
||||
currentConnectionCount--
|
||||
default:
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
if currentConnectionCount < expectedConnectionCount {
|
||||
h, err := newH2MuxHandler(ctx, s.streamHandler, s.config.ConnectionConfig, s.state.getNextEdgeIP())
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to create new connection handler")
|
||||
continue
|
||||
}
|
||||
go func() {
|
||||
s.connErrors <- h.serve(ctx)
|
||||
}()
|
||||
connResult, err := s.connect(ctx, s.config, s.state.cloudflaredID, h)
|
||||
if err != nil {
|
||||
logger.WithError(err).Errorf("Failed to connect to cloudflared's edge network")
|
||||
h.shutdown()
|
||||
continue
|
||||
}
|
||||
if connErr := connResult.Err; connErr != nil && !connErr.ShouldRetry {
|
||||
logger.WithError(connErr).Errorf("Server respond with don't retry to connect")
|
||||
h.shutdown()
|
||||
return err
|
||||
}
|
||||
logger.Infof("Connected to %s", connResult.ServerInfo.LocationName)
|
||||
s.state.connectionPool.put(h)
|
||||
currentConnectionCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Supervisor) initialize() error {
|
||||
edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "Failed to resolve cloudflare edge network address")
|
||||
}
|
||||
s.state.edgeIPs = edgeIPs
|
||||
s.state.lastResolveTime = time.Now()
|
||||
cloudflaredID, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Failed to generate cloudflared ID")
|
||||
}
|
||||
s.state.cloudflaredID = cloudflaredID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Supervisor) connect(ctx context.Context,
|
||||
config *CloudflaredConfig,
|
||||
cloudflaredID uuid.UUID,
|
||||
h connectionHandler,
|
||||
) (*tunnelpogs.ConnectResult, error) {
|
||||
connectParameters := &tunnelpogs.ConnectParameters{
|
||||
OriginCert: config.OriginCert,
|
||||
CloudflaredID: cloudflaredID,
|
||||
NumPreviousAttempts: 0,
|
||||
CloudflaredVersion: config.CloudflaredVersion,
|
||||
}
|
||||
return h.connect(ctx, connectParameters)
|
||||
}
|
Reference in New Issue
Block a user