TUN-5621: Correctly manage QUIC stream closing

Until this PR, we were naively closing the quic.Stream whenever
the callstack for handling the request (HTTP or TCP) finished.
However, our proxy handler may still be reading or writing from
the quic.Stream at that point, because we return the callstack if
either side finishes, but not necessarily both.

This is a problem for quic-go library because quic.Stream#Close
cannot be called concurrently with quic.Stream#Write

Furthermore, we also noticed that quic.Stream#Close does nothing
to do receiving stream (since, underneath, quic.Stream has 2 streams,
1 for each direction), thus leaking memory, as explained in:
https://github.com/lucas-clemente/quic-go/issues/3322

This PR addresses both problems by wrapping the quic.Stream that
is passed down to the proxying logic and handle all these concerns.
This commit is contained in:
Nuno Diegues
2022-01-27 22:37:45 +00:00
parent e09dcf6d60
commit ed2bac026d
7 changed files with 244 additions and 45 deletions

View File

@@ -122,7 +122,7 @@ func (q *QUICConnection) serveControlStream(ctx context.Context, controlStream q
func (q *QUICConnection) acceptStream(ctx context.Context) error {
defer q.Close()
for {
stream, err := q.session.AcceptStream(ctx)
quicStream, err := q.session.AcceptStream(ctx)
if err != nil {
// context.Canceled is usually a user ctrl+c. We don't want to log an error here as it's intentional.
if errors.Is(err, context.Canceled) || q.controlStreamHandler.IsStopped() {
@@ -131,7 +131,9 @@ func (q *QUICConnection) acceptStream(ctx context.Context) error {
return fmt.Errorf("failed to accept QUIC stream: %w", err)
}
go func() {
stream := quicpogs.NewSafeStreamCloser(quicStream)
defer stream.Close()
if err = q.handleStream(stream); err != nil {
q.logger.Err(err).Msg("Failed to handle QUIC stream")
}
@@ -144,7 +146,7 @@ func (q *QUICConnection) Close() {
q.session.CloseWithError(0, "")
}
func (q *QUICConnection) handleStream(stream quic.Stream) error {
func (q *QUICConnection) handleStream(stream io.ReadWriteCloser) error {
signature, err := quicpogs.DetermineProtocol(stream)
if err != nil {
return err

View File

@@ -3,14 +3,9 @@ package connection
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/url"
@@ -33,7 +28,7 @@ import (
)
var (
testTLSServerConfig = generateTLSConfig()
testTLSServerConfig = quicpogs.GenerateTLSConfig()
testQUICConfig = &quic.Config{
KeepAlive: true,
EnableDatagrams: true,
@@ -84,7 +79,7 @@ func TestQUICServer(t *testing.T) {
},
{
desc: "test http body request streaming",
dest: "/echo_body",
dest: "/slow_echo_body",
connectionType: quicpogs.ConnectionTypeHTTP,
metadata: []quicpogs.Metadata{
{
@@ -195,8 +190,9 @@ func quicServer(
session, err := earlyListener.Accept(ctx)
require.NoError(t, err)
stream, err := session.OpenStreamSync(context.Background())
quicStream, err := session.OpenStreamSync(context.Background())
require.NoError(t, err)
stream := quicpogs.NewSafeStreamCloser(quicStream)
reqClientStream := quicpogs.RequestClientStream{ReadWriteCloser: stream}
err = reqClientStream.WriteConnectRequestData(dest, connectionType, metadata...)
@@ -207,42 +203,20 @@ func quicServer(
if message != nil {
// ALPN successful. Write data.
_, err := stream.Write([]byte(message))
_, err := stream.Write(message)
require.NoError(t, err)
}
response := make([]byte, len(expectedResponse))
stream.Read(response)
require.NoError(t, err)
_, err = stream.Read(response)
if err != io.EOF {
require.NoError(t, err)
}
// For now it is an echo server. Verify if the same data is returned.
assert.Equal(t, expectedResponse, response)
}
// Setup a bare-bones TLS config for the server
func generateTLSConfig() *tls.Config {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
panic(err)
}
template := x509.Certificate{SerialNumber: big.NewInt(1)}
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key)
if err != nil {
panic(err)
}
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
panic(err)
}
return &tls.Config{
Certificates: []tls.Certificate{tlsCert},
NextProtos: []string{"argotunnel"},
}
}
type mockOriginProxyWithRequest struct{}
func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Request, isWebsocket bool) error {
@@ -264,6 +238,9 @@ func (moc *mockOriginProxyWithRequest) ProxyHTTP(w ResponseWriter, r *http.Reque
switch r.URL.Path {
case "/ok":
originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
case "/slow_echo_body":
time.Sleep(5)
fallthrough
case "/echo_body":
resp := &http.Response{
StatusCode: http.StatusOK,