mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:29:58 +00:00
TUN-3007: Implement named tunnel connection registration and unregistration.
Removed flag for using quick reconnect, this logic is now always enabled.
This commit is contained in:
@@ -68,8 +68,6 @@ type Supervisor struct {
|
||||
connDigest map[uint8][]byte
|
||||
|
||||
bufferPool *buffer.Pool
|
||||
|
||||
namedTunnel *NamedTunnelConfig
|
||||
}
|
||||
|
||||
type resolveResult struct {
|
||||
@@ -82,7 +80,7 @@ type tunnelError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel *NamedTunnelConfig) (*Supervisor, error) {
|
||||
func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID) (*Supervisor, error) {
|
||||
var (
|
||||
edgeIPs *edgediscovery.Edge
|
||||
err error
|
||||
@@ -95,6 +93,7 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Supervisor{
|
||||
cloudflaredUUID: cloudflaredUUID,
|
||||
config: config,
|
||||
@@ -104,7 +103,6 @@ func NewSupervisor(config *TunnelConfig, cloudflaredUUID uuid.UUID, namedTunnel
|
||||
logger: config.Logger,
|
||||
connDigest: make(map[uint8][]byte),
|
||||
bufferPool: buffer.NewPool(512 * 1024),
|
||||
namedTunnel: namedTunnel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -229,17 +227,17 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
addr *net.TCPAddr
|
||||
err error
|
||||
)
|
||||
const thisConnID = 0
|
||||
const firstConnIndex = 0
|
||||
defer func() {
|
||||
s.tunnelErrors <- tunnelError{index: thisConnID, addr: addr, err: err}
|
||||
s.tunnelErrors <- tunnelError{index: firstConnIndex, addr: addr, err: err}
|
||||
}()
|
||||
|
||||
addr, err = s.edgeIPs.GetAddr(thisConnID)
|
||||
addr, err = s.edgeIPs.GetAddr(firstConnIndex)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
// If the first tunnel disconnects, keep restarting it.
|
||||
edgeErrors := 0
|
||||
for s.unusedIPs() {
|
||||
@@ -257,12 +255,12 @@ func (s *Supervisor) startFirstTunnel(ctx context.Context, connectedSignal *sign
|
||||
return
|
||||
}
|
||||
if edgeErrors >= 2 {
|
||||
addr, err = s.edgeIPs.GetDifferentAddr(thisConnID)
|
||||
addr, err = s.edgeIPs.GetDifferentAddr(firstConnIndex)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, thisConnID, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
err = ServeTunnelLoop(ctx, s, s.config, addr, firstConnIndex, connectedSignal, s.cloudflaredUUID, s.bufferPool, reconnectCh)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -48,7 +48,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
@@ -92,7 +92,7 @@ func TestRefreshAuthSuccess(t *testing.T) {
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
@@ -120,7 +120,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
|
||||
return time.After(d)
|
||||
}
|
||||
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
@@ -142,7 +142,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
|
||||
func TestRefreshAuthFail(t *testing.T) {
|
||||
logger := logger.NewOutputWriter(logger.NewMockWriteManager())
|
||||
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New(), nil)
|
||||
s, err := NewSupervisor(testConfig(logger), uuid.New())
|
||||
if !assert.NoError(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
227
origin/tunnel.go
227
origin/tunnel.go
@@ -48,47 +48,46 @@ type registerRPCName string
|
||||
const (
|
||||
register registerRPCName = "register"
|
||||
reconnect registerRPCName = "reconnect"
|
||||
unknown registerRPCName = "unknown"
|
||||
)
|
||||
|
||||
type TunnelConfig struct {
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
ClientID string
|
||||
ClientTlsConfig *tls.Config
|
||||
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
|
||||
CompressionQuality uint64
|
||||
EdgeAddrs []string
|
||||
GracePeriod time.Duration
|
||||
HAConnections int
|
||||
HTTPTransport http.RoundTripper
|
||||
HeartbeatInterval time.Duration
|
||||
Hostname string
|
||||
HTTPHostHeader string
|
||||
IncidentLookup IncidentLookup
|
||||
IsAutoupdated bool
|
||||
IsFreeTunnel bool
|
||||
LBPool string
|
||||
Logger logger.Service
|
||||
TransportLogger logger.Service
|
||||
MaxHeartbeats uint64
|
||||
Metrics *TunnelMetrics
|
||||
MetricsUpdateFreq time.Duration
|
||||
NoChunkedEncoding bool
|
||||
OriginCert []byte
|
||||
ReportedVersion string
|
||||
Retries uint
|
||||
RunFromTerminal bool
|
||||
Tags []tunnelpogs.Tag
|
||||
TlsConfig *tls.Config
|
||||
UseDeclarativeTunnel bool
|
||||
WSGI bool
|
||||
BuildInfo *buildinfo.BuildInfo
|
||||
ClientID string
|
||||
ClientTlsConfig *tls.Config
|
||||
CloseConnOnce *sync.Once // Used to close connectedSignal no more than once
|
||||
CompressionQuality uint64
|
||||
EdgeAddrs []string
|
||||
GracePeriod time.Duration
|
||||
HAConnections int
|
||||
HTTPTransport http.RoundTripper
|
||||
HeartbeatInterval time.Duration
|
||||
Hostname string
|
||||
HTTPHostHeader string
|
||||
IncidentLookup IncidentLookup
|
||||
IsAutoupdated bool
|
||||
IsFreeTunnel bool
|
||||
LBPool string
|
||||
Logger logger.Service
|
||||
TransportLogger logger.Service
|
||||
MaxHeartbeats uint64
|
||||
Metrics *TunnelMetrics
|
||||
MetricsUpdateFreq time.Duration
|
||||
NoChunkedEncoding bool
|
||||
OriginCert []byte
|
||||
ReportedVersion string
|
||||
Retries uint
|
||||
RunFromTerminal bool
|
||||
Tags []tunnelpogs.Tag
|
||||
TlsConfig *tls.Config
|
||||
WSGI bool
|
||||
// OriginUrl may not be used if a user specifies a unix socket.
|
||||
OriginUrl string
|
||||
|
||||
// feature-flag to use new edge reconnect tokens
|
||||
UseReconnectToken bool
|
||||
// feature-flag for using ConnectionDigest
|
||||
UseQuickReconnects bool
|
||||
|
||||
NamedTunnel *NamedTunnelConfig
|
||||
ReplaceExisting bool
|
||||
}
|
||||
|
||||
// ReconnectTunnelCredentialManager is invoked by functions in this file to
|
||||
@@ -103,6 +102,8 @@ type ReconnectTunnelCredentialManager interface {
|
||||
|
||||
type dupConnRegisterTunnelError struct{}
|
||||
|
||||
var errDuplicationConnection = &dupConnRegisterTunnelError{}
|
||||
|
||||
func (e dupConnRegisterTunnelError) Error() string {
|
||||
return "already connected to this server"
|
||||
}
|
||||
@@ -171,21 +172,35 @@ func (c *TunnelConfig) RegistrationOptions(connectionID uint8, OriginLocalIP str
|
||||
}
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) SupportedFeatures() []string {
|
||||
basic := []string{FeatureSerializedHeaders}
|
||||
if c.UseQuickReconnects {
|
||||
basic = append(basic, FeatureQuickReconnects)
|
||||
func (c *TunnelConfig) ConnectionOptions(originLocalAddr string) *tunnelpogs.ConnectionOptions {
|
||||
// attempt to parse out origin IP, but don't fail since it's informational field
|
||||
host, _, _ := net.SplitHostPort(originLocalAddr)
|
||||
originIP := net.ParseIP(host)
|
||||
|
||||
return &tunnelpogs.ConnectionOptions{
|
||||
Client: c.NamedTunnel.Client,
|
||||
OriginLocalIP: originIP,
|
||||
ReplaceExisting: c.ReplaceExisting,
|
||||
CompressionQuality: uint8(c.CompressionQuality),
|
||||
}
|
||||
return basic
|
||||
}
|
||||
|
||||
func (c *TunnelConfig) SupportedFeatures() []string {
|
||||
features := []string{FeatureSerializedHeaders}
|
||||
if c.NamedTunnel == nil {
|
||||
features = append(features, FeatureQuickReconnects)
|
||||
}
|
||||
return features
|
||||
}
|
||||
|
||||
type NamedTunnelConfig struct {
|
||||
Auth pogs.TunnelAuth
|
||||
ID string
|
||||
Auth pogs.TunnelAuth
|
||||
ID uuid.UUID
|
||||
Client pogs.ClientInfo
|
||||
}
|
||||
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal, namedTunnel *NamedTunnelConfig) error {
|
||||
s, err := NewSupervisor(config, cloudflaredID, namedTunnel)
|
||||
func StartTunnelDaemon(ctx context.Context, config *TunnelConfig, connectedSignal *signal.Signal, cloudflaredID uuid.UUID, reconnectCh chan ReconnectSignal) error {
|
||||
s, err := NewSupervisor(config, cloudflaredID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -196,7 +211,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
credentialManager ReconnectTunnelCredentialManager,
|
||||
config *TunnelConfig,
|
||||
addr *net.TCPAddr,
|
||||
connectionID uint8,
|
||||
connectionIndex uint8,
|
||||
connectedSignal *signal.Signal,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
bufferPool *buffer.Pool,
|
||||
@@ -219,7 +234,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
credentialManager,
|
||||
config,
|
||||
config.Logger,
|
||||
addr, connectionID,
|
||||
addr, connectionIndex,
|
||||
connectedFuse,
|
||||
&backoff,
|
||||
cloudflaredUUID,
|
||||
@@ -228,7 +243,7 @@ func ServeTunnelLoop(ctx context.Context,
|
||||
)
|
||||
if recoverable {
|
||||
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
|
||||
config.Logger.Infof("Retrying in %s seconds: connectionID: %d", duration, connectionID)
|
||||
config.Logger.Infof("Retrying connection %d in %s seconds", connectionIndex, duration)
|
||||
backoff.Backoff(ctx)
|
||||
continue
|
||||
}
|
||||
@@ -243,7 +258,7 @@ func ServeTunnel(
|
||||
config *TunnelConfig,
|
||||
logger logger.Service,
|
||||
addr *net.TCPAddr,
|
||||
connectionID uint8,
|
||||
connectionIndex uint8,
|
||||
connectedFuse *h2mux.BooleanFuse,
|
||||
backoff *BackoffHandler,
|
||||
cloudflaredUUID uuid.UUID,
|
||||
@@ -262,22 +277,18 @@ func ServeTunnel(
|
||||
}
|
||||
}()
|
||||
|
||||
connectionTag := uint8ToString(connectionID)
|
||||
|
||||
// additional tags to send other than hostname which is set in cloudflared main package
|
||||
tags := make(map[string]string)
|
||||
tags["ha"] = connectionTag
|
||||
connectionTag := uint8ToString(connectionIndex)
|
||||
|
||||
// Returns error from parsing the origin URL or handshake errors
|
||||
handler, originLocalIP, err := NewTunnelHandler(ctx, config, addr, connectionID, bufferPool)
|
||||
handler, originLocalAddr, err := NewTunnelHandler(ctx, config, addr, connectionIndex, bufferPool)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case connection.DialError:
|
||||
logger.Errorf("Unable to dial edge: %s connectionID: %d", err, connectionID)
|
||||
logger.Errorf("Connection %d unable to dial edge: %s", connectionIndex, err)
|
||||
case h2mux.MuxerHandshakeError:
|
||||
logger.Errorf("Handshake failed with edge server: %s connectionID: %d", err, connectionID)
|
||||
logger.Errorf("Connection %d handshake with edge server failed: %s", connectionIndex, err)
|
||||
default:
|
||||
logger.Errorf("Tunnel creation failure: %s connectionID: %d", err, connectionID)
|
||||
logger.Errorf("Connection %d failed: %s", connectionIndex, err)
|
||||
return err, false
|
||||
}
|
||||
return err, true
|
||||
@@ -293,20 +304,21 @@ func ServeTunnel(
|
||||
}
|
||||
}()
|
||||
|
||||
if config.NamedTunnel != nil {
|
||||
return RegisterConnection(ctx, handler.muxer, config, connectionIndex, originLocalAddr)
|
||||
}
|
||||
|
||||
if config.UseReconnectToken && connectedFuse.Value() {
|
||||
token, tokenErr := credentialManager.ReconnectToken()
|
||||
eventDigest, eventDigestErr := credentialManager.EventDigest()
|
||||
// if we have both credentials, we can reconnect
|
||||
if tokenErr == nil && eventDigestErr == nil {
|
||||
var connDigest []byte
|
||||
|
||||
// check if we can use Quick Reconnects
|
||||
if config.UseQuickReconnects {
|
||||
if digest, connDigestErr := credentialManager.ConnDigest(connectionID); connDigestErr == nil {
|
||||
connDigest = digest
|
||||
}
|
||||
if digest, connDigestErr := credentialManager.ConnDigest(connectionIndex); connDigestErr == nil {
|
||||
connDigest = digest
|
||||
}
|
||||
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID, credentialManager)
|
||||
|
||||
return ReconnectTunnel(serveCtx, token, eventDigest, connDigest, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID, credentialManager)
|
||||
}
|
||||
// log errors and proceed to RegisterTunnel
|
||||
if tokenErr != nil {
|
||||
@@ -316,7 +328,7 @@ func ServeTunnel(
|
||||
logger.Errorf("Couldn't get event digest: %s", eventDigestErr)
|
||||
}
|
||||
}
|
||||
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionID, originLocalIP, cloudflaredUUID)
|
||||
return RegisterTunnel(serveCtx, credentialManager, handler.muxer, config, logger, connectionIndex, originLocalAddr, cloudflaredUUID)
|
||||
})
|
||||
|
||||
errGroup.Go(func() error {
|
||||
@@ -325,12 +337,15 @@ func ServeTunnel(
|
||||
select {
|
||||
case <-serveCtx.Done():
|
||||
// UnregisterTunnel blocks until the RPC call returns
|
||||
var err error
|
||||
if connectedFuse.Value() {
|
||||
err = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
|
||||
if config.NamedTunnel != nil {
|
||||
_ = UnregisterConnection(ctx, handler.muxer, config)
|
||||
} else {
|
||||
_ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
|
||||
}
|
||||
}
|
||||
handler.muxer.Shutdown()
|
||||
return err
|
||||
return nil
|
||||
case <-updateMetricsTickC:
|
||||
handler.UpdateMetrics(connectionTag)
|
||||
}
|
||||
@@ -361,8 +376,6 @@ func ServeTunnel(
|
||||
|
||||
err = errGroup.Wait()
|
||||
if err != nil {
|
||||
_ = newClientRegisterTunnelError(err, config.Metrics.regFail, unknown)
|
||||
|
||||
switch castedErr := err.(type) {
|
||||
case dupConnRegisterTunnelError:
|
||||
logger.Info("Already connected to this server, selecting a different one")
|
||||
@@ -382,7 +395,7 @@ func ServeTunnel(
|
||||
logger.Info("Muxer shutdown")
|
||||
return err, true
|
||||
case *ReconnectSignal:
|
||||
logger.Infof("Restarting due to reconnect signal in %d seconds", castedErr.Delay)
|
||||
logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, castedErr.Delay)
|
||||
castedErr.DelayBeforeReconnect()
|
||||
return err, true
|
||||
default:
|
||||
@@ -393,6 +406,74 @@ func ServeTunnel(
|
||||
return nil, true
|
||||
}
|
||||
|
||||
func RegisterConnection(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
config *TunnelConfig,
|
||||
connectionIndex uint8,
|
||||
originLocalAddr string,
|
||||
) error {
|
||||
const registerConnection = "registerConnection"
|
||||
|
||||
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
|
||||
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection)
|
||||
}
|
||||
defer rpc.Close()
|
||||
|
||||
conn, err := rpc.RegisterConnection(
|
||||
ctx,
|
||||
config.NamedTunnel.Auth,
|
||||
config.NamedTunnel.ID,
|
||||
connectionIndex,
|
||||
config.ConnectionOptions(originLocalAddr),
|
||||
)
|
||||
if err != nil {
|
||||
if err.Error() == DuplicateConnectionError {
|
||||
config.Metrics.regFail.WithLabelValues("dup_edge_conn", registerConnection).Inc()
|
||||
return errDuplicationConnection
|
||||
}
|
||||
config.Metrics.regFail.WithLabelValues("server_error", registerConnection).Inc()
|
||||
return serverRegistrationErrorFromRPC(err)
|
||||
}
|
||||
|
||||
config.Metrics.regSuccess.WithLabelValues(registerConnection).Inc()
|
||||
config.Logger.Infof("Connection %d registered with %s using ID %s", connectionIndex, conn.Location, conn.UUID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func serverRegistrationErrorFromRPC(err error) *serverRegisterTunnelError {
|
||||
if retryable, ok := err.(*tunnelpogs.RetryableError); ok {
|
||||
return &serverRegisterTunnelError{
|
||||
cause: retryable.Unwrap(),
|
||||
permanent: false,
|
||||
}
|
||||
}
|
||||
return &serverRegisterTunnelError{
|
||||
cause: err,
|
||||
permanent: true,
|
||||
}
|
||||
}
|
||||
|
||||
func UnregisterConnection(
|
||||
ctx context.Context,
|
||||
muxer *h2mux.Muxer,
|
||||
config *TunnelConfig,
|
||||
) error {
|
||||
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
|
||||
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
|
||||
}
|
||||
defer rpc.Close()
|
||||
|
||||
return rpc.UnregisterConnection(ctx)
|
||||
}
|
||||
|
||||
func RegisterTunnel(
|
||||
ctx context.Context,
|
||||
credentialManager ReconnectTunnelCredentialManager,
|
||||
@@ -437,7 +518,7 @@ func ReconnectTunnel(
|
||||
config *TunnelConfig,
|
||||
logger logger.Service,
|
||||
connectionID uint8,
|
||||
originLocalIP string,
|
||||
originLocalAddr string,
|
||||
uuid uuid.UUID,
|
||||
credentialManager ReconnectTunnelCredentialManager,
|
||||
) error {
|
||||
@@ -459,7 +540,7 @@ func ReconnectTunnel(
|
||||
eventDigest,
|
||||
connDigest,
|
||||
config.Hostname,
|
||||
config.RegistrationOptions(connectionID, originLocalIP, uuid),
|
||||
config.RegistrationOptions(connectionID, originLocalAddr, uuid),
|
||||
)
|
||||
if registrationErr := registration.DeserializeError(); registrationErr != nil {
|
||||
// ReconnectTunnel RPC failure
|
||||
@@ -508,11 +589,11 @@ func processRegistrationSuccess(
|
||||
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error {
|
||||
if err.Error() == DuplicateConnectionError {
|
||||
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
|
||||
return dupConnRegisterTunnelError{}
|
||||
return errDuplicationConnection
|
||||
}
|
||||
metrics.regFail.WithLabelValues("server_error", string(name)).Inc()
|
||||
return serverRegisterTunnelError{
|
||||
cause: fmt.Errorf("Server error: %s", err.Error()),
|
||||
cause: err,
|
||||
permanent: err.IsPermanent(),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user