cloudflared/ingress/origin_connection_test.go
Igor Postelnik 8ca0d86c85 TUN-3863: Consolidate header handling logic in the connection package; move headers definitions from h2mux to packages that manage them; cleanup header conversions
All header transformation code from h2mux has been consolidated in the connection package since it's used by both h2mux and http2 logic.
Exported headers used by proxying between edge and cloudflared so then can be shared by tunnel service on the edge.
Moved access-related headers to corresponding packages that have the code that sets/uses these headers.
Removed tunnel hostname tracking from h2mux since it wasn't used by anything. We will continue to set the tunnel hostname header from the edge for backward compatibilty, but it's no longer used by cloudflared.
Move bastion-related logic into carrier package, untangled dependencies between carrier, origin, and websocket packages.
2021-03-29 21:57:56 +00:00

297 lines
7.5 KiB
Go

package ingress
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/gobwas/ws/wsutil"
gorillaWS "github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/proxy"
"golang.org/x/sync/errgroup"
"github.com/cloudflare/cloudflared/logger"
"github.com/cloudflare/cloudflared/socks"
"github.com/cloudflare/cloudflared/websocket"
)
const (
testStreamTimeout = time.Second * 3
echoHeaderName = "Test-Cloudflared-Echo"
)
var (
testLogger = logger.Create(nil)
testMessage = []byte("TestStreamOriginConnection")
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
)
func TestStreamTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpConn := tcpConnection{
conn: cfdConn,
}
eyeballConn, edgeConn := net.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
_, err := eyeballConn.Write(testMessage)
readBuffer := make([]byte, len(testResponse))
_, err = eyeballConn.Read(readBuffer)
require.NoError(t, err)
require.Equal(t, testResponse, readBuffer)
return nil
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
return nil
})
tcpConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
cfdConn, originConn := net.Pipe()
tcpOverWSConn := tcpOverWSConnection{
conn: cfdConn,
streamHandler: DefaultStreamHandler,
}
eyeballConn, edgeConn := net.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
return nil
})
errGroup.Go(func() error {
echoTCPOrigin(t, originConn)
originConn.Close()
return nil
})
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
// TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
// Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
// wraps SOCKS5 traffic in websocket
// Origin side runs a tcpOverWSConnection with socks.StreamHandler
func TestSocksStreamWSOverTCPConnection(t *testing.T) {
var (
sendMessage = t.Name()
echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage)
echoMessage = fmt.Sprintf("echo-%s", sendMessage)
echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue)
)
statusCodes := []int{
http.StatusOK,
http.StatusTemporaryRedirect,
http.StatusBadRequest,
http.StatusInternalServerError,
}
for _, status := range statusCodes {
handler := func(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
require.NoError(t, err)
require.Equal(t, []byte(sendMessage), body)
require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
w.Header().Set(echoHeaderName, echoHeaderReturnValue)
w.WriteHeader(status)
w.Write([]byte(echoMessage))
}
origin := httptest.NewServer(http.HandlerFunc(handler))
defer origin.Close()
originURL, err := url.Parse(origin.URL)
require.NoError(t, err)
originConn, err := net.Dial("tcp", originURL.Host)
require.NoError(t, err)
tcpOverWSConn := tcpOverWSConnection{
conn: originConn,
streamHandler: socks.StreamHandler,
}
wsForwarderOutConn, edgeConn := net.Pipe()
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
return nil
})
wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
errGroup.Go(func() error {
wsForwarderInConn, err := wsForwarderListener.Accept()
require.NoError(t, err)
defer wsForwarderInConn.Close()
websocket.Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
return nil
})
eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct)
require.NoError(t, err)
transport := &http.Transport{
Dial: eyeballDialer.Dial,
}
// Request URL doesn't matter because the transport is using eyeballDialer to connectq
req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
assert.NoError(t, err)
req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
resp, err := transport.RoundTrip(req)
assert.NoError(t, err)
assert.Equal(t, status, resp.StatusCode)
require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
body, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, []byte(echoMessage), body)
wsForwarderOutConn.Close()
edgeConn.Close()
tcpOverWSConn.Close()
require.NoError(t, errGroup.Wait())
}
}
func TestStreamWSConnection(t *testing.T) {
eyeballConn, edgeConn := net.Pipe()
origin := echoWSOrigin(t)
defer origin.Close()
req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
require.NoError(t, err)
req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
clientTLSConfig := &tls.Config{
InsecureSkipVerify: true,
}
wsConn, resp, err := newWSConnection(clientTLSConfig, req)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
defer cancel()
errGroup, ctx := errgroup.WithContext(ctx)
errGroup.Go(func() error {
echoWSEyeball(t, eyeballConn)
return nil
})
wsConn.Stream(ctx, edgeConn, testLogger)
require.NoError(t, errGroup.Wait())
}
type wsEyeball struct {
conn net.Conn
}
func (wse *wsEyeball) Read(p []byte) (int, error) {
data, err := wsutil.ReadServerBinary(wse.conn)
if err != nil {
return 0, err
}
return copy(p, data), nil
}
func (wse *wsEyeball) Write(p []byte) (int, error) {
err := wsutil.WriteClientBinary(wse.conn, p)
return len(p), err
}
func echoWSEyeball(t *testing.T, conn net.Conn) {
require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
readMsg, err := wsutil.ReadServerBinary(conn)
require.NoError(t, err)
require.Equal(t, testResponse, readMsg)
require.NoError(t, conn.Close())
}
func echoWSOrigin(t *testing.T) *httptest.Server {
var upgrader = gorillaWS.Upgrader{
ReadBufferSize: 10,
WriteBufferSize: 10,
}
ws := func(w http.ResponseWriter, r *http.Request) {
header := make(http.Header)
for k, vs := range r.Header {
if k == "Test-Cloudflared-Echo" {
header[k] = vs
}
}
conn, err := upgrader.Upgrade(w, r, header)
require.NoError(t, err)
defer conn.Close()
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
return
}
require.Equal(t, testMessage, p)
if err := conn.WriteMessage(messageType, testResponse); err != nil {
return
}
}
}
// NewTLSServer starts the server in another thread
return httptest.NewTLSServer(http.HandlerFunc(ws))
}
func echoTCPOrigin(t *testing.T, conn net.Conn) {
readBuffer := make([]byte, len(testMessage))
_, err := conn.Read(readBuffer)
assert.NoError(t, err)
assert.Equal(t, testMessage, readBuffer)
_, err = conn.Write(testResponse)
assert.NoError(t, err)
}