TUN-3458: Upgrade to http2 when available, fallback to h2mux when we reach max retries

This commit is contained in:
cthuang
2020-10-14 14:42:00 +01:00
parent b5cdf3b2c7
commit a490443630
13 changed files with 632 additions and 159 deletions

View File

@@ -62,13 +62,14 @@ type TunnelConfig struct {
ReportedVersion string
Retries uint
RunFromTerminal bool
TLSConfig *tls.Config
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress
NamedTunnel *connection.NamedTunnelConfig
ClassicTunnel *connection.ClassicTunnelConfig
MuxerConfig *connection.MuxerConfig
TunnelEventChan chan ui.TunnelEvent
IngressRules ingress.Ingress
ProtocolSelector connection.ProtocolSelector
EdgeTLSConfigs map[connection.Protocol]*tls.Config
}
type muxerShutdownError struct{}
@@ -157,7 +158,7 @@ func ServeTunnelLoop(ctx context.Context,
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *net.TCPAddr,
connectionIndex uint8,
connIndex uint8,
connectedSignal *signal.Signal,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
@@ -165,7 +166,11 @@ func ServeTunnelLoop(ctx context.Context,
haConnections.Inc()
defer haConnections.Dec()
backoff := BackoffHandler{MaxRetries: config.Retries}
protocallFallback := &protocallFallback{
BackoffHandler{MaxRetries: config.Retries},
config.ProtocolSelector.Current(),
false,
}
connectedFuse := h2mux.NewBooleanFuse()
go func() {
if connectedFuse.Await() {
@@ -174,29 +179,90 @@ func ServeTunnelLoop(ctx context.Context,
}()
// Ensure the above goroutine will terminate if we return without connecting
defer connectedFuse.Fuse(false)
// Each connection to keep its own copy of protocol, because individual connections might fallback
// to another protocol when a particular metal doesn't support new protocol
for {
err, recoverable := ServeTunnel(
ctx,
credentialManager,
config,
addr, connectionIndex,
addr,
connIndex,
connectedFuse,
&backoff,
protocallFallback,
cloudflaredUUID,
reconnectCh,
protocallFallback.protocol,
)
if recoverable {
if duration, ok := backoff.GetBackoffDuration(ctx); ok {
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connectionIndex, duration, err)
backoff.Backoff(ctx)
continue
}
if !recoverable {
return err
}
err = waitForBackoff(ctx, protocallFallback, config, connIndex, err)
if err != nil {
return err
}
}
}
// protocallFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// max retries
type protocallFallback struct {
BackoffHandler
protocol connection.Protocol
inFallback bool
}
func (pf *protocallFallback) reset() {
pf.resetNow()
pf.inFallback = false
}
func (pf *protocallFallback) fallback(fallback connection.Protocol) {
pf.resetNow()
pf.protocol = fallback
pf.inFallback = true
}
// Expect err to always be non nil
func waitForBackoff(
ctx context.Context,
protobackoff *protocallFallback,
config *TunnelConfig,
connIndex uint8,
err error,
) error {
duration, ok := protobackoff.GetBackoffDuration(ctx)
if !ok {
return err
}
if config.TunnelEventChan != nil {
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Reconnecting}
}
config.Logger.Infof("Retrying connection %d in %s seconds, error %v", connIndex, duration, err)
protobackoff.Backoff(ctx)
if protobackoff.ReachedMaxRetries() {
fallback, hasFallback := config.ProtocolSelector.Fallback()
if !hasFallback {
return err
}
// Already using fallback protocol, no point to retry
if protobackoff.protocol == fallback {
return err
}
config.Logger.Infof("Fallback to use %s", fallback)
protobackoff.fallback(fallback)
} else if !protobackoff.inFallback {
current := config.ProtocolSelector.Current()
if protobackoff.protocol != current {
protobackoff.protocol = current
config.Logger.Infof("Change protocol to %s", current)
}
}
return nil
}
func ServeTunnel(
@@ -204,11 +270,12 @@ func ServeTunnel(
credentialManager *reconnectCredentialManager,
config *TunnelConfig,
addr *net.TCPAddr,
connectionIndex uint8,
connIndex uint8,
fuse *h2mux.BooleanFuse,
backoff *BackoffHandler,
backoff *protocallFallback,
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
protocol connection.Protocol,
) (err error, recoverable bool) {
// Treat panics as recoverable errors
defer func() {
@@ -226,11 +293,11 @@ func ServeTunnel(
// If launch-ui flag is set, send disconnect msg
if config.TunnelEventChan != nil {
defer func() {
config.TunnelEventChan <- ui.TunnelEvent{Index: connectionIndex, EventType: ui.Disconnected}
config.TunnelEventChan <- ui.TunnelEvent{Index: connIndex, EventType: ui.Disconnected}
}()
}
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.TLSConfig, addr)
edgeConn, err := edgediscovery.DialEdge(ctx, dialTimeout, config.EdgeTLSConfigs[protocol], addr)
if err != nil {
return err, true
}
@@ -238,11 +305,11 @@ func ServeTunnel(
fuse: fuse,
backoff: backoff,
}
if config.Protocol == connection.HTTP2 {
if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
return ServeHTTP2(ctx, config, edgeConn, connOptions, connectionIndex, connectedFuse, reconnectCh)
return ServeHTTP2(ctx, config, edgeConn, connOptions, connIndex, connectedFuse, reconnectCh)
}
return ServeH2mux(ctx, credentialManager, config, edgeConn, connectionIndex, connectedFuse, cloudflaredUUID, reconnectCh)
return ServeH2mux(ctx, credentialManager, config, edgeConn, connIndex, connectedFuse, cloudflaredUUID, reconnectCh)
}
func ServeH2mux(
@@ -255,6 +322,7 @@ func ServeH2mux(
cloudflaredUUID uuid.UUID,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
config.Logger.Debugf("Connecting via h2mux")
// Returns error from parsing the origin URL or handshake errors
handler, err, recoverable := connection.NewH2muxConnection(ctx, config.ConnectionConfig, config.MuxerConfig, config.ProxyConfig.URL.String(), edgeConn, connectionIndex, config.Observer)
if err != nil {
@@ -266,10 +334,10 @@ func ServeH2mux(
errGroup.Go(func() (err error) {
if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
return handler.ServeNamedTunnel(ctx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, credentialManager, connOptions, connectedFuse)
}
registrationOptions := config.RegistrationOptions(connectionIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)
return handler.ServeClassicTunnel(ctx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
return handler.ServeClassicTunnel(serveCtx, config.ClassicTunnel, credentialManager, registrationOptions, connectedFuse)
})
errGroup.Go(listenReconnect(serveCtx, reconnectCh))
@@ -295,7 +363,7 @@ func ServeH2mux(
config.Logger.Info("Muxer shutdown")
return err, true
case *ReconnectSignal:
config.Logger.Infof("Restarting connection %d due to reconnect signal in %d seconds", connectionIndex, err.Delay)
config.Logger.Infof("Restarting connection %d due to reconnect signal in %s", connectionIndex, err.Delay)
err.DelayBeforeReconnect()
return err, true
default:
@@ -319,10 +387,8 @@ func ServeHTTP2(
connectedFuse connection.ConnectedFuse,
reconnectCh chan ReconnectSignal,
) (err error, recoverable bool) {
server, err := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
if err != nil {
return err, false
}
config.Logger.Debugf("Connecting via http2")
server := connection.NewHTTP2Connection(tlsServerConn, config.ConnectionConfig, config.ProxyConfig.URL, config.NamedTunnel, connOptions, config.Observer, connIndex, connectedFuse)
errGroup, serveCtx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
@@ -352,12 +418,12 @@ func listenReconnect(ctx context.Context, reconnectCh <-chan ReconnectSignal) fu
type connectedFuse struct {
fuse *h2mux.BooleanFuse
backoff *BackoffHandler
backoff *protocallFallback
}
func (cf *connectedFuse) Connected() {
cf.fuse.Fuse(true)
cf.backoff.SetGracePeriod()
cf.backoff.reset()
}
func (cf *connectedFuse) IsConnected() bool {