mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 16:49:57 +00:00
TUN-8236: Add write timeout to quic and tcp connections
## Summary To prevent bad eyeballs and severs to be able to exhaust the quic control flows we are adding the possibility of having a timeout for a write operation to be acknowledged. This will prevent hanging connections from exhausting the quic control flows, creating a DDoS.
This commit is contained in:
7
ingress/constants_test.go
Normal file
7
ingress/constants_test.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ingress
|
||||
|
||||
import "github.com/cloudflare/cloudflared/logger"
|
||||
|
||||
var (
|
||||
TestLogger = logger.Create(nil)
|
||||
)
|
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
@@ -31,15 +32,32 @@ func DefaultStreamHandler(originConn io.ReadWriter, remoteConn net.Conn, log *ze
|
||||
|
||||
// tcpConnection is an OriginConnection that directly streams to raw TCP.
|
||||
type tcpConnection struct {
|
||||
conn net.Conn
|
||||
net.Conn
|
||||
writeTimeout time.Duration
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Stream(ctx context.Context, tunnelConn io.ReadWriter, log *zerolog.Logger) {
|
||||
stream.Pipe(tunnelConn, tc.conn, log)
|
||||
func (tc *tcpConnection) Stream(_ context.Context, tunnelConn io.ReadWriter, _ *zerolog.Logger) {
|
||||
stream.Pipe(tunnelConn, tc, tc.logger)
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Write(b []byte) (int, error) {
|
||||
if tc.writeTimeout > 0 {
|
||||
if err := tc.Conn.SetWriteDeadline(time.Now().Add(tc.writeTimeout)); err != nil {
|
||||
tc.logger.Err(err).Msg("Error setting write deadline for TCP connection")
|
||||
}
|
||||
}
|
||||
|
||||
nBytes, err := tc.Conn.Write(b)
|
||||
if err != nil {
|
||||
tc.logger.Err(err).Msg("Error writing to the TCP connection")
|
||||
}
|
||||
|
||||
return nBytes, err
|
||||
}
|
||||
|
||||
func (tc *tcpConnection) Close() {
|
||||
tc.conn.Close()
|
||||
tc.Conn.Close()
|
||||
}
|
||||
|
||||
// tcpOverWSConnection is an OriginConnection that streams to TCP over WS.
|
||||
|
@@ -19,7 +19,6 @@ import (
|
||||
"golang.org/x/net/proxy"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/socks"
|
||||
"github.com/cloudflare/cloudflared/stream"
|
||||
"github.com/cloudflare/cloudflared/websocket"
|
||||
@@ -31,7 +30,6 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
testLogger = logger.Create(nil)
|
||||
testMessage = []byte("TestStreamOriginConnection")
|
||||
testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
|
||||
)
|
||||
@@ -39,7 +37,8 @@ var (
|
||||
func TestStreamTCPConnection(t *testing.T) {
|
||||
cfdConn, originConn := net.Pipe()
|
||||
tcpConn := tcpConnection{
|
||||
conn: cfdConn,
|
||||
Conn: cfdConn,
|
||||
writeTimeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
eyeballConn, edgeConn := net.Pipe()
|
||||
@@ -66,7 +65,7 @@ func TestStreamTCPConnection(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
tcpConn.Stream(ctx, edgeConn, testLogger)
|
||||
tcpConn.Stream(ctx, edgeConn, TestLogger)
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
@@ -93,7 +92,7 @@ func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
|
||||
return nil
|
||||
})
|
||||
|
||||
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
||||
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
|
||||
require.NoError(t, errGroup.Wait())
|
||||
}
|
||||
|
||||
@@ -147,7 +146,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||
|
||||
errGroup, ctx := errgroup.WithContext(ctx)
|
||||
errGroup.Go(func() error {
|
||||
tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
|
||||
tcpOverWSConn.Stream(ctx, edgeConn, TestLogger)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -159,7 +158,7 @@ func TestSocksStreamWSOverTCPConnection(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
defer wsForwarderInConn.Close()
|
||||
|
||||
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
|
||||
stream.Pipe(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, TestLogger)
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -209,7 +208,7 @@ func TestWsConnReturnsBeforeStreamReturns(t *testing.T) {
|
||||
originConn.Close()
|
||||
}()
|
||||
ctx := context.WithValue(r.Context(), websocket.PingPeriodContextKey, time.Microsecond)
|
||||
tcpOverWSConn.Stream(ctx, eyeballConn, testLogger)
|
||||
tcpOverWSConn.Stream(ctx, eyeballConn, TestLogger)
|
||||
})
|
||||
server := httptest.NewServer(handler)
|
||||
defer server.Close()
|
||||
|
@@ -4,6 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
// HTTPOriginProxy can be implemented by origin services that want to proxy http requests.
|
||||
@@ -14,7 +16,7 @@ type HTTPOriginProxy interface {
|
||||
|
||||
// StreamBasedOriginProxy can be implemented by origin services that want to proxy ws/TCP.
|
||||
type StreamBasedOriginProxy interface {
|
||||
EstablishConnection(ctx context.Context, dest string) (OriginConnection, error)
|
||||
EstablishConnection(ctx context.Context, dest string, log *zerolog.Logger) (OriginConnection, error)
|
||||
}
|
||||
|
||||
// HTTPLocalProxy can be implemented by cloudflared services that want to handle incoming http requests.
|
||||
@@ -62,19 +64,21 @@ func (o *statusCode) RoundTrip(_ *http.Request) (*http.Response, error) {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
|
||||
func (o *rawTCPService) EstablishConnection(ctx context.Context, dest string, logger *zerolog.Logger) (OriginConnection, error) {
|
||||
conn, err := o.dialer.DialContext(ctx, "tcp", dest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
originConn := &tcpConnection{
|
||||
conn: conn,
|
||||
Conn: conn,
|
||||
writeTimeout: o.writeTimeout,
|
||||
logger: logger,
|
||||
}
|
||||
return originConn, nil
|
||||
}
|
||||
|
||||
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string) (OriginConnection, error) {
|
||||
func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string, _ *zerolog.Logger) (OriginConnection, error) {
|
||||
var err error
|
||||
if !o.isBastion {
|
||||
dest = o.dest
|
||||
@@ -92,6 +96,6 @@ func (o *tcpOverWSService) EstablishConnection(ctx context.Context, dest string)
|
||||
|
||||
}
|
||||
|
||||
func (o *socksProxyOverWSService) EstablishConnection(_ctx context.Context, _dest string) (OriginConnection, error) {
|
||||
func (o *socksProxyOverWSService) EstablishConnection(_ context.Context, _ string, _ *zerolog.Logger) (OriginConnection, error) {
|
||||
return o.conn, nil
|
||||
}
|
||||
|
@@ -36,7 +36,7 @@ func TestRawTCPServiceEstablishConnection(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Origin not listening for new connection, should return an error
|
||||
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String())
|
||||
_, err = rawTCPService.EstablishConnection(context.Background(), req.URL.String(), TestLogger)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
t.Run(test.testCase, func(t *testing.T) {
|
||||
if test.expectErr {
|
||||
bastionHost, _ := carrier.ResolveBastionDest(test.req)
|
||||
_, err := test.service.EstablishConnection(context.Background(), bastionHost)
|
||||
_, err := test.service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
@@ -99,7 +99,7 @@ func TestTCPOverWSServiceEstablishConnection(t *testing.T) {
|
||||
for _, service := range []*tcpOverWSService{newTCPOverWSService(originURL), newBastionService()} {
|
||||
// Origin not listening for new connection, should return an error
|
||||
bastionHost, _ := carrier.ResolveBastionDest(bastionReq)
|
||||
_, err := service.EstablishConnection(context.Background(), bastionHost)
|
||||
_, err := service.EstablishConnection(context.Background(), bastionHost, TestLogger)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func TestHTTPServiceHostHeaderOverride(t *testing.T) {
|
||||
url: originURL,
|
||||
}
|
||||
shutdownC := make(chan struct{})
|
||||
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
|
||||
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, originURL.String(), nil)
|
||||
require.NoError(t, err)
|
||||
@@ -167,7 +167,7 @@ func TestHTTPServiceUsesIngressRuleScheme(t *testing.T) {
|
||||
url: originURL,
|
||||
}
|
||||
shutdownC := make(chan struct{})
|
||||
require.NoError(t, httpService.start(testLogger, shutdownC, cfg))
|
||||
require.NoError(t, httpService.start(TestLogger, shutdownC, cfg))
|
||||
|
||||
// Tunnel uses scheme defined in the service field of the ingress rule, independent of the X-Forwarded-Proto header
|
||||
protos := []string{"https", "http", "dne"}
|
||||
|
@@ -94,15 +94,17 @@ func (o httpService) MarshalJSON() ([]byte, error) {
|
||||
// rawTCPService dials TCP to the destination specified by the client
|
||||
// It's used by warp routing
|
||||
type rawTCPService struct {
|
||||
name string
|
||||
dialer net.Dialer
|
||||
name string
|
||||
dialer net.Dialer
|
||||
writeTimeout time.Duration
|
||||
logger *zerolog.Logger
|
||||
}
|
||||
|
||||
func (o *rawTCPService) String() string {
|
||||
return o.name
|
||||
}
|
||||
|
||||
func (o *rawTCPService) start(log *zerolog.Logger, _ <-chan struct{}, cfg OriginRequestConfig) error {
|
||||
func (o *rawTCPService) start(_ *zerolog.Logger, _ <-chan struct{}, _ OriginRequestConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -285,13 +287,14 @@ type WarpRoutingService struct {
|
||||
Proxy StreamBasedOriginProxy
|
||||
}
|
||||
|
||||
func NewWarpRoutingService(config WarpRoutingConfig) *WarpRoutingService {
|
||||
func NewWarpRoutingService(config WarpRoutingConfig, writeTimeout time.Duration) *WarpRoutingService {
|
||||
svc := &rawTCPService{
|
||||
name: ServiceWarpRouting,
|
||||
dialer: net.Dialer{
|
||||
Timeout: config.ConnectTimeout.Duration,
|
||||
KeepAlive: config.TCPKeepAlive.Duration,
|
||||
},
|
||||
writeTimeout: writeTimeout,
|
||||
}
|
||||
|
||||
return &WarpRoutingService{Proxy: svc}
|
||||
|
Reference in New Issue
Block a user