TUN-3863: Consolidate header handling logic in the connection package; move headers definitions from h2mux to packages that manage them; cleanup header conversions

All header transformation code from h2mux has been consolidated in the connection package since it's used by both h2mux and http2 logic.
Exported headers used by proxying between edge and cloudflared so then can be shared by tunnel service on the edge.
Moved access-related headers to corresponding packages that have the code that sets/uses these headers.
Removed tunnel hostname tracking from h2mux since it wasn't used by anything. We will continue to set the tunnel hostname header from the edge for backward compatibilty, but it's no longer used by cloudflared.
Move bastion-related logic into carrier package, untangled dependencies between carrier, origin, and websocket packages.
This commit is contained in:
Igor Postelnik
2021-03-25 23:04:56 -05:00
parent ebf5292bf9
commit 8ca0d86c85
29 changed files with 541 additions and 713 deletions

View File

@@ -1,113 +0,0 @@
package origin
import (
"context"
"math/rand"
"time"
)
// Redeclare time functions so they can be overridden in tests.
var (
timeNow = time.Now
timeAfter = time.After
)
// BackoffHandler manages exponential backoff and limits the maximum number of retries.
// The base time period is 1 second, doubling with each retry.
// After initial success, a grace period can be set to reset the backoff timer if
// a connection is maintained successfully for a long enough period. The base grace period
// is 2 seconds, doubling with each retry.
type BackoffHandler struct {
// MaxRetries sets the maximum number of retries to perform. The default value
// of 0 disables retry completely.
MaxRetries uint
// RetryForever caps the exponential backoff period according to MaxRetries
// but allows you to retry indefinitely.
RetryForever bool
// BaseTime sets the initial backoff period.
BaseTime time.Duration
retries uint
resetDeadline time.Time
}
func (b BackoffHandler) GetMaxBackoffDuration(ctx context.Context) (time.Duration, bool) {
// Follows the same logic as Backoff, but without mutating the receiver.
// This select has to happen first to reflect the actual behaviour of the Backoff function.
select {
case <-ctx.Done():
return time.Duration(0), false
default:
}
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
// b.retries would be set to 0 at this point
return time.Second, true
}
if b.retries >= b.MaxRetries && !b.RetryForever {
return time.Duration(0), false
}
maxTimeToWait := b.GetBaseTime() * 1 << (b.retries + 1)
return maxTimeToWait, true
}
// BackoffTimer returns a channel that sends the current time when the exponential backoff timeout expires.
// Returns nil if the maximum number of retries have been used.
func (b *BackoffHandler) BackoffTimer() <-chan time.Time {
if !b.resetDeadline.IsZero() && timeNow().After(b.resetDeadline) {
b.retries = 0
b.resetDeadline = time.Time{}
}
if b.retries >= b.MaxRetries {
if !b.RetryForever {
return nil
}
} else {
b.retries++
}
maxTimeToWait := time.Duration(b.GetBaseTime() * 1 << (b.retries))
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
return timeAfter(timeToWait)
}
// Backoff is used to wait according to exponential backoff. Returns false if the
// maximum number of retries have been used or if the underlying context has been cancelled.
func (b *BackoffHandler) Backoff(ctx context.Context) bool {
c := b.BackoffTimer()
if c == nil {
return false
}
select {
case <-c:
return true
case <-ctx.Done():
return false
}
}
// Sets a grace period within which the the backoff timer is maintained. After the grace
// period expires, the number of retries & backoff duration is reset.
func (b *BackoffHandler) SetGracePeriod() {
maxTimeToWait := b.GetBaseTime() * 2 << (b.retries + 1)
timeToWait := time.Duration(rand.Int63n(maxTimeToWait.Nanoseconds()))
b.resetDeadline = timeNow().Add(timeToWait)
}
func (b BackoffHandler) GetBaseTime() time.Duration {
if b.BaseTime == 0 {
return time.Second
}
return b.BaseTime
}
// Retries returns the number of retries consumed so far.
func (b *BackoffHandler) Retries() int {
return int(b.retries)
}
func (b *BackoffHandler) ReachedMaxRetries() bool {
return b.retries == b.MaxRetries
}
func (b *BackoffHandler) resetNow() {
b.resetDeadline = time.Now()
}

View File

@@ -1,147 +0,0 @@
package origin
import (
"context"
"testing"
"time"
)
func immediateTimeAfter(time.Duration) <-chan time.Time {
c := make(chan time.Time, 1)
c <- time.Now()
return c
}
func TestBackoffRetries(t *testing.T) {
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed immediately")
}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed after 1 retry")
}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed after 2 retry")
}
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after 3 (max) retries")
}
}
func TestBackoffCancel(t *testing.T) {
// prevent backoff from returning normally
timeAfter = func(time.Duration) <-chan time.Time { return make(chan time.Time) }
ctx, cancelFunc := context.WithCancel(context.Background())
backoff := BackoffHandler{MaxRetries: 3}
cancelFunc()
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after cancel")
}
if _, ok := backoff.GetMaxBackoffDuration(ctx); ok {
t.Fatalf("backoff allowed after cancel")
}
}
func TestBackoffGracePeriod(t *testing.T) {
currentTime := time.Now()
// make timeNow return whatever we like
timeNow = func() time.Time { return currentTime }
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 1}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed immediately")
}
// the next call to Backoff would fail unless it's after the grace period
backoff.SetGracePeriod()
// advance time to after the grace period (~4 seconds) and see what happens
currentTime = currentTime.Add(time.Second * 5)
if !backoff.Backoff(ctx) {
t.Fatalf("backoff failed after the grace period expired")
}
// confirm we ignore grace period after backoff
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after 1 (max) retry")
}
}
func TestGetMaxBackoffDurationRetries(t *testing.T) {
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
if _, ok := backoff.GetMaxBackoffDuration(ctx); !ok {
t.Fatalf("backoff failed immediately")
}
backoff.Backoff(ctx) // noop
if _, ok := backoff.GetMaxBackoffDuration(ctx); !ok {
t.Fatalf("backoff failed after 1 retry")
}
backoff.Backoff(ctx) // noop
if _, ok := backoff.GetMaxBackoffDuration(ctx); !ok {
t.Fatalf("backoff failed after 2 retry")
}
backoff.Backoff(ctx) // noop
if _, ok := backoff.GetMaxBackoffDuration(ctx); ok {
t.Fatalf("backoff allowed after 3 (max) retries")
}
if backoff.Backoff(ctx) {
t.Fatalf("backoff allowed after 3 (max) retries")
}
}
func TestGetMaxBackoffDuration(t *testing.T) {
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3}
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
t.Fatalf("backoff (%s) didn't return < 2 seconds on first retry", duration)
}
backoff.Backoff(ctx) // noop
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*4 {
t.Fatalf("backoff (%s) didn't return < 4 seconds on second retry", duration)
}
backoff.Backoff(ctx) // noop
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*8 {
t.Fatalf("backoff (%s) didn't return < 8 seconds on third retry", duration)
}
backoff.Backoff(ctx) // noop
if duration, ok := backoff.GetMaxBackoffDuration(ctx); ok || duration != 0 {
t.Fatalf("backoff (%s) didn't return 0 seconds on fourth retry (exceeding limit)", duration)
}
}
func TestBackoffRetryForever(t *testing.T) {
// make backoff return immediately
timeAfter = immediateTimeAfter
ctx := context.Background()
backoff := BackoffHandler{MaxRetries: 3, RetryForever: true}
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*2 {
t.Fatalf("backoff (%s) didn't return < 2 seconds on first retry", duration)
}
backoff.Backoff(ctx) // noop
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*4 {
t.Fatalf("backoff (%s) didn't return < 4 seconds on second retry", duration)
}
backoff.Backoff(ctx) // noop
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*8 {
t.Fatalf("backoff (%s) didn't return < 8 seconds on third retry", duration)
}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff refused on fourth retry despire RetryForever")
}
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*16 {
t.Fatalf("backoff returned %v instead of 8 seconds on fourth retry", duration)
}
if !backoff.Backoff(ctx) {
t.Fatalf("backoff refused on fifth retry despire RetryForever")
}
if duration, ok := backoff.GetMaxBackoffDuration(ctx); !ok || duration > time.Second*16 {
t.Fatalf("backoff returned %v instead of 8 seconds on fifth retry", duration)
}
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/prometheus/client_golang/prometheus"
"github.com/cloudflare/cloudflared/retry"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@@ -103,7 +104,7 @@ func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte)
func (cm *reconnectCredentialManager) RefreshAuth(
ctx context.Context,
backoff *BackoffHandler,
backoff *retry.BackoffHandler,
authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
) (retryTimer <-chan time.Time, err error) {
authOutcome, err := authenticate(ctx, backoff.Retries())
@@ -121,11 +122,11 @@ func (cm *reconnectCredentialManager) RefreshAuth(
case tunnelpogs.AuthSuccess:
cm.SetReconnectToken(outcome.JWT())
cm.authSuccess.Inc()
return timeAfter(outcome.RefreshAfter()), nil
return retry.Clock.After(outcome.RefreshAfter()), nil
case tunnelpogs.AuthUnknown:
duration := outcome.RefreshAfter()
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return timeAfter(duration), nil
return retry.Clock.After(duration), nil
case tunnelpogs.AuthFail:
cm.authFail.WithLabelValues(outcome.Error()).Inc()
return nil, outcome

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/cloudflare/cloudflared/retry"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@@ -17,11 +18,11 @@ func TestRefreshAuthBackoff(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
retry.Clock.After = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
backoff := &BackoffHandler{MaxRetries: 3}
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return nil, fmt.Errorf("authentication failure")
}
@@ -45,7 +46,7 @@ func TestRefreshAuthBackoff(t *testing.T) {
// The backoff timer should have been reset. To confirm this, make timeNow
// return a value after the backoff timer's grace period
timeNow = func() time.Time {
retry.Clock.Now = func() time.Time {
expectedGracePeriod := time.Duration(time.Second * 2 << backoff.MaxRetries)
return time.Now().Add(expectedGracePeriod * 2)
}
@@ -57,12 +58,12 @@ func TestRefreshAuthSuccess(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
retry.Clock.After = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
backoff := &BackoffHandler{MaxRetries: 3}
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthSuccess([]byte("jwt"), 19), nil
}
@@ -81,12 +82,12 @@ func TestRefreshAuthUnknown(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
var wait time.Duration
timeAfter = func(d time.Duration) <-chan time.Time {
retry.Clock.After = func(d time.Duration) <-chan time.Time {
wait = d
return time.After(d)
}
backoff := &BackoffHandler{MaxRetries: 3}
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthUnknown(errors.New("auth unknown"), 19), nil
}
@@ -104,7 +105,7 @@ func TestRefreshAuthUnknown(t *testing.T) {
func TestRefreshAuthFail(t *testing.T) {
rcm := newReconnectCredentialManager(t.Name(), t.Name(), 4)
backoff := &BackoffHandler{MaxRetries: 3}
backoff := &retry.BackoffHandler{MaxRetries: 3}
auth := func(ctx context.Context, n int) (tunnelpogs.AuthOutcome, error) {
return tunnelpogs.NewAuthFail(errors.New("auth fail")), nil
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
)
@@ -112,10 +113,10 @@ func (s *Supervisor) Run(
var tunnelsWaiting []int
tunnelsActive := s.config.HAConnections
backoff := BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
backoff := retry.BackoffHandler{MaxRetries: s.config.Retries, BaseTime: tunnelRetryDuration, RetryForever: true}
var backoffTimer <-chan time.Time
refreshAuthBackoff := &BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
refreshAuthBackoff := &retry.BackoffHandler{MaxRetries: refreshAuthMaxBackoff, BaseTime: refreshAuthRetryDuration, RetryForever: true}
var refreshAuthBackoffTimer <-chan time.Time
if s.useReconnectToken {

View File

@@ -18,6 +18,7 @@ import (
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/edgediscovery"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/retry"
"github.com/cloudflare/cloudflared/signal"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
@@ -138,7 +139,7 @@ func ServeTunnelLoop(
connLog := config.Log.With().Uint8(connection.LogFieldConnIndex, connIndex).Logger()
protocolFallback := &protocolFallback{
BackoffHandler{MaxRetries: config.Retries},
retry.BackoffHandler{MaxRetries: config.Retries},
config.ProtocolSelector.Current(),
false,
}
@@ -195,18 +196,18 @@ func ServeTunnelLoop(
// protocolFallback is a wrapper around backoffHandler that will try fallback option when backoff reaches
// max retries
type protocolFallback struct {
BackoffHandler
retry.BackoffHandler
protocol connection.Protocol
inFallback bool
}
func (pf *protocolFallback) reset() {
pf.resetNow()
pf.ResetNow()
pf.inFallback = false
}
func (pf *protocolFallback) fallback(fallback connection.Protocol) {
pf.resetNow()
pf.ResetNow()
pf.protocol = fallback
pf.inFallback = true
}
@@ -281,7 +282,7 @@ func ServeTunnel(
}
if protocol == connection.HTTP2 {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.retries))
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(backoff.Retries()))
err = ServeHTTP2(
ctx,
connLog,
@@ -382,7 +383,7 @@ func ServeH2mux(
errGroup.Go(func() error {
if config.NamedTunnel != nil {
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.retries))
connOptions := config.ConnectionOptions(edgeConn.LocalAddr().String(), uint8(connectedFuse.backoff.Retries()))
return handler.ServeNamedTunnel(serveCtx, config.NamedTunnel, connOptions, connectedFuse)
}
registrationOptions := config.RegistrationOptions(connIndex, edgeConn.LocalAddr().String(), cloudflaredUUID)

View File

@@ -8,6 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/cloudflare/cloudflared/connection"
"github.com/cloudflare/cloudflared/retry"
)
type dynamicMockFetcher struct {
@@ -26,7 +27,7 @@ func (dmf *dynamicMockFetcher) fetch() connection.PercentageFetcher {
func TestWaitForBackoffFallback(t *testing.T) {
maxRetries := uint(3)
backoff := BackoffHandler{
backoff := retry.BackoffHandler{
MaxRetries: maxRetries,
BaseTime: time.Millisecond * 10,
}