diff --git a/connection/quic.go b/connection/quic.go index a7f15e69..5de16dc2 100644 --- a/connection/quic.go +++ b/connection/quic.go @@ -9,6 +9,7 @@ import ( "net/http" "strconv" "strings" + "sync/atomic" "time" "github.com/google/uuid" @@ -156,9 +157,10 @@ func (q *QUICConnection) runStream(quicStream quic.Stream) { defer stream.Close() // we are going to fuse readers/writers from stream <- cloudflared -> origin, and we want to guarantee that - // code executed in the code path of handleStream don't trigger an earlier close to the downstream stream. - // So, we wrap the stream with a no-op closer and only this method can actually close the stream. - noCloseStream := &nopCloserReadWriter{stream} + // code executed in the code path of handleStream don't trigger an earlier close to the downstream write stream. + // So, we wrap the stream with a no-op write closer and only this method can actually close write side of the stream. + // A call to close will simulate a close to the read-side, which will fail subsequent reads. + noCloseStream := &nopCloserReadWriter{ReadWriteCloser: stream} if err := q.handleStream(ctx, noCloseStream); err != nil { q.logger.Err(err).Msg("Failed to handle QUIC stream") } @@ -408,10 +410,39 @@ func isTransferEncodingChunked(req *http.Request) bool { return strings.Contains(strings.ToLower(transferEncodingVal), "chunked") } +// A helper struct that guarantees a call to close only affects read side, but not write side. type nopCloserReadWriter struct { io.ReadWriteCloser + + // for use by Read only + // we don't need a memory barrier here because there is an implicit assumption that + // Read calls can't happen concurrently by different go-routines. + sawEOF bool + // should be updated and read using atomic primitives. + // value is read in Read method and written in Close method, which could be done by different + // go-routines. + closed uint32 } -func (n *nopCloserReadWriter) Close() error { +func (np *nopCloserReadWriter) Read(p []byte) (n int, err error) { + if np.sawEOF { + return 0, io.EOF + } + + if atomic.LoadUint32(&np.closed) > 0 { + return 0, fmt.Errorf("closed by handler") + } + + n, err = np.ReadWriteCloser.Read(p) + if err == io.EOF { + np.sawEOF = true + } + + return +} + +func (np *nopCloserReadWriter) Close() error { + atomic.StoreUint32(&np.closed, 1) + return nil } diff --git a/connection/quic_test.go b/connection/quic_test.go index d82947c2..0afb3953 100644 --- a/connection/quic_test.go +++ b/connection/quic_test.go @@ -10,6 +10,7 @@ import ( "net/http" "net/url" "os" + "strings" "sync" "testing" "time" @@ -527,6 +528,44 @@ func TestServeUDPSession(t *testing.T) { cancel() } +func TestNopCloserReadWriterCloseBeforeEOF(t *testing.T) { + readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}} + buffer := make([]byte, 5) + + n, err := readerWriter.Read(buffer) + require.NoError(t, err) + require.Equal(t, n, 5) + + // close + require.NoError(t, readerWriter.Close()) + + // read should get error + n, err = readerWriter.Read(buffer) + require.Equal(t, n, 0) + require.Equal(t, err, fmt.Errorf("closed by handler")) +} + +func TestNopCloserReadWriterCloseAfterEOF(t *testing.T) { + readerWriter := nopCloserReadWriter{ReadWriteCloser: &mockReaderNoopWriter{Reader: strings.NewReader("123456789")}} + buffer := make([]byte, 20) + + n, err := readerWriter.Read(buffer) + require.NoError(t, err) + require.Equal(t, n, 9) + + // force another read to read eof + n, err = readerWriter.Read(buffer) + require.Equal(t, err, io.EOF) + + // close + require.NoError(t, readerWriter.Close()) + + // read should get EOF still + n, err = readerWriter.Read(buffer) + require.Equal(t, n, 0) + require.Equal(t, err, io.EOF) +} + func serveSession(ctx context.Context, qc *QUICConnection, edgeQUICSession quic.Connection, closeType closeReason, expectedReason string, t *testing.T) { var ( payload = []byte(t.Name()) @@ -647,3 +686,15 @@ func testQUICConnection(udpListenerAddr net.Addr, t *testing.T) *QUICConnection require.NoError(t, err) return qc } + +type mockReaderNoopWriter struct { + io.Reader +} + +func (m *mockReaderNoopWriter) Write(p []byte) (n int, err error) { + return len(p), nil +} + +func (m *mockReaderNoopWriter) Close() error { + return nil +}