TUN-3489: Add unit tests to cover proxy logic in connection package of cloudflared

This commit is contained in:
cthuang
2020-10-27 22:27:15 +00:00
parent 5974fb4cfd
commit d5769519b2
9 changed files with 754 additions and 92 deletions

View File

@@ -11,6 +11,7 @@ import (
"sync"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/logger"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"golang.org/x/net/http2"
@@ -26,17 +27,19 @@ var (
errNotFlusher = errors.New("ResponseWriter doesn't implement http.Flusher")
)
type HTTP2Connection struct {
conn net.Conn
server *http2.Server
config *Config
namedTunnel *NamedTunnelConfig
connOptions *tunnelpogs.ConnectionOptions
observer *Observer
connIndexStr string
connIndex uint8
wg *sync.WaitGroup
connectedFuse ConnectedFuse
type http2Connection struct {
conn net.Conn
server *http2.Server
config *Config
namedTunnel *NamedTunnelConfig
connOptions *tunnelpogs.ConnectionOptions
observer *Observer
connIndexStr string
connIndex uint8
wg *sync.WaitGroup
// newRPCClientFunc allows us to mock RPCs during testing
newRPCClientFunc func(context.Context, io.ReadWriteCloser, logger.Service) NamedTunnelRPCClient
connectedFuse ConnectedFuse
}
func NewHTTP2Connection(
@@ -47,24 +50,25 @@ func NewHTTP2Connection(
observer *Observer,
connIndex uint8,
connectedFuse ConnectedFuse,
) *HTTP2Connection {
return &HTTP2Connection{
) *http2Connection {
return &http2Connection{
conn: conn,
server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32,
},
config: config,
namedTunnel: namedTunnelConfig,
connOptions: connOptions,
observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
wg: &sync.WaitGroup{},
connectedFuse: connectedFuse,
config: config,
namedTunnel: namedTunnelConfig,
connOptions: connOptions,
observer: observer,
connIndexStr: uint8ToString(connIndex),
connIndex: connIndex,
wg: &sync.WaitGroup{},
newRPCClientFunc: newRegistrationRPCClient,
connectedFuse: connectedFuse,
}
}
func (c *HTTP2Connection) Serve(ctx context.Context) {
func (c *http2Connection) Serve(ctx context.Context) {
go func() {
<-ctx.Done()
c.close()
@@ -75,7 +79,7 @@ func (c *HTTP2Connection) Serve(ctx context.Context) {
})
}
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (c *http2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
c.wg.Add(1)
defer c.wg.Done()
@@ -86,65 +90,42 @@ func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
flusher, isFlusher := w.(http.Flusher)
if !isFlusher {
c.observer.Errorf("%T doesn't implement http.Flusher", w)
respWriter.WriteErrorResponse(errNotFlusher)
respWriter.WriteErrorResponse()
return
}
respWriter.flusher = flusher
var err error
if isControlStreamUpgrade(r) {
respWriter.shouldFlush = true
err := c.serveControlStream(r.Context(), respWriter)
if err != nil {
respWriter.WriteErrorResponse(err)
}
err = c.serveControlStream(r.Context(), respWriter)
} else if isWebsocketUpgrade(r) {
respWriter.shouldFlush = true
stripWebsocketUpgradeHeader(r)
c.config.OriginClient.Proxy(respWriter, r, true)
err = c.config.OriginClient.Proxy(respWriter, r, true)
} else {
c.config.OriginClient.Proxy(respWriter, r, false)
err = c.config.OriginClient.Proxy(respWriter, r, false)
}
if err != nil {
respWriter.WriteErrorResponse()
}
}
func (c *HTTP2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
rpcClient := newRegistrationRPCClient(ctx, respWriter, c.observer)
defer rpcClient.close()
func (c *http2Connection) serveControlStream(ctx context.Context, respWriter *http2RespWriter) error {
rpcClient := c.newRPCClientFunc(ctx, respWriter, c.observer)
defer rpcClient.Close()
if err := registerConnection(ctx, rpcClient, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
if err := rpcClient.RegisterConnection(ctx, c.namedTunnel, c.connOptions, c.connIndex, c.observer); err != nil {
return err
}
c.connectedFuse.Connected()
<-ctx.Done()
c.gracefulShutdown(ctx, rpcClient)
rpcClient.GracefulShutdown(ctx, c.config.GracePeriod)
return nil
}
func (c *HTTP2Connection) registerConnection(
ctx context.Context,
rpcClient tunnelpogs.RegistrationServer_PogsClient,
) error {
connDetail, err := rpcClient.RegisterConnection(
ctx,
c.namedTunnel.Auth,
c.namedTunnel.ID,
c.connIndex,
c.connOptions,
)
if err != nil {
c.observer.Errorf("Cannot register connection, err: %v", err)
return err
}
c.observer.Infof("Connection %s registered with %s using ID %s", c.connIndexStr, connDetail.Location, connDetail.UUID)
return nil
}
func (c *HTTP2Connection) gracefulShutdown(ctx context.Context, rpcClient *registrationServerClient) {
ctx, cancel := context.WithTimeout(ctx, c.config.GracePeriod)
defer cancel()
rpcClient.client.UnregisterConnection(ctx)
}
func (c *HTTP2Connection) close() {
func (c *http2Connection) close() {
// Wait for all serve HTTP handlers to return
c.wg.Wait()
c.conn.Close()
@@ -195,7 +176,7 @@ func (rp *http2RespWriter) WriteRespHeaders(resp *http.Response) error {
return nil
}
func (rp *http2RespWriter) WriteErrorResponse(err error) {
func (rp *http2RespWriter) WriteErrorResponse() {
rp.setResponseMetaHeader(responseMetaHeaderCfd)
rp.w.WriteHeader(http.StatusBadGateway)
}