mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 18:59:57 +00:00
TUN-6676: Add suport for trailers in http2 connections
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -29,6 +28,8 @@ const (
|
||||
LogFieldRule = "ingressRule"
|
||||
LogFieldOriginService = "originService"
|
||||
LogFieldFlowID = "flowID"
|
||||
|
||||
trailerHeaderName = "Trailer"
|
||||
)
|
||||
|
||||
// Proxy represents a means to Proxy between cloudflared and the origin services.
|
||||
@@ -207,15 +208,16 @@ func (p *Proxy) proxyHTTPRequest(
|
||||
tracing.EndWithStatusCode(ttfbSpan, resp.StatusCode)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// resp headers can be nil
|
||||
if resp.Header == nil {
|
||||
resp.Header = make(http.Header)
|
||||
headers := make(http.Header, len(resp.Header))
|
||||
// copy headers
|
||||
for k, v := range resp.Header {
|
||||
headers[k] = v
|
||||
}
|
||||
|
||||
// Add spans to response header (if available)
|
||||
tr.AddSpans(resp.Header)
|
||||
tr.AddSpans(headers)
|
||||
|
||||
err = w.WriteRespHeaders(resp.StatusCode, resp.Header)
|
||||
err = w.WriteRespHeaders(resp.StatusCode, headers)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Error writing response header")
|
||||
}
|
||||
@@ -236,12 +238,10 @@ func (p *Proxy) proxyHTTPRequest(
|
||||
return nil
|
||||
}
|
||||
|
||||
if connection.IsServerSentEvent(resp.Header) {
|
||||
p.log.Debug().Msg("Detected Server-Side Events from Origin")
|
||||
p.writeEventStream(w, resp.Body)
|
||||
} else {
|
||||
_, _ = cfio.Copy(w, resp.Body)
|
||||
}
|
||||
_, _ = cfio.Copy(w, resp.Body)
|
||||
|
||||
// copy trailers
|
||||
copyTrailers(w, resp)
|
||||
|
||||
p.logOriginResponse(resp, fields)
|
||||
return nil
|
||||
@@ -296,26 +296,6 @@ func (wr *bidirectionalStream) Write(p []byte) (n int, err error) {
|
||||
return wr.writer.Write(p)
|
||||
}
|
||||
|
||||
func (p *Proxy) writeEventStream(w connection.ResponseWriter, respBody io.ReadCloser) {
|
||||
reader := bufio.NewReader(respBody)
|
||||
for {
|
||||
line, readErr := reader.ReadBytes('\n')
|
||||
|
||||
// We first try to write whatever we read even if an error occurred
|
||||
// The reason for doing it is to guarantee we really push everything to the eyeball side
|
||||
// before returning
|
||||
if len(line) > 0 {
|
||||
if _, writeErr := w.Write(line); writeErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) appendTagHeaders(r *http.Request) {
|
||||
for _, tag := range p.tags {
|
||||
r.Header.Add(TagHeaderNamePrefix+tag.Name, tag.Value)
|
||||
@@ -329,6 +309,14 @@ type logFields struct {
|
||||
flowID string
|
||||
}
|
||||
|
||||
func copyTrailers(w connection.ResponseWriter, response *http.Response) {
|
||||
for trailerHeader, trailerValues := range response.Trailer {
|
||||
for _, trailerValue := range trailerValues {
|
||||
w.AddTrailer(trailerHeader, trailerValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) logRequest(r *http.Request, fields logFields) {
|
||||
if fields.cfRay != "" {
|
||||
p.log.Debug().Msgf("CF-RAY: %s %s %s %s", fields.cfRay, r.Method, r.URL, r.Proto)
|
||||
|
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/urfave/cli/v2"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/cloudflare/cloudflared/cfio"
|
||||
|
||||
"github.com/cloudflare/cloudflared/config"
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/hello"
|
||||
@@ -62,6 +64,10 @@ func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *mockHTTPRespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
|
||||
return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
|
||||
}
|
||||
@@ -117,7 +123,10 @@ func newMockSSERespWriter() *mockSSERespWriter {
|
||||
}
|
||||
|
||||
func (w *mockSSERespWriter) Write(data []byte) (int, error) {
|
||||
w.writeNotification <- data
|
||||
newData := make([]byte, len(data))
|
||||
copy(newData, data)
|
||||
|
||||
w.writeNotification <- newData
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
@@ -256,11 +265,8 @@ func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
|
||||
|
||||
for i := 0; i < pushCount; i++ {
|
||||
line := responseWriter.ReadBytes()
|
||||
expect := fmt.Sprintf("%d\n", i)
|
||||
expect := fmt.Sprintf("%d\n\n", i)
|
||||
require.Equal(t, []byte(expect), line, fmt.Sprintf("Expect to read %v, got %v", expect, line))
|
||||
|
||||
line = responseWriter.ReadBytes()
|
||||
require.Equal(t, []byte("\n"), line, fmt.Sprintf("Expect to read '\n', got %v", line))
|
||||
}
|
||||
|
||||
cancel()
|
||||
@@ -276,7 +282,7 @@ func testProxySSEAllData(proxy *Proxy) func(t *testing.T) {
|
||||
responseWriter := newMockSSERespWriter()
|
||||
|
||||
// responseWriter uses an unbuffered channel, so we call in a different go-routine
|
||||
go proxy.writeEventStream(responseWriter, eyeballReader)
|
||||
go cfio.Copy(responseWriter, eyeballReader)
|
||||
|
||||
result := string(<-responseWriter.writeNotification)
|
||||
require.Equal(t, "data\r\r", result)
|
||||
@@ -825,6 +831,10 @@ func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
// respHeaders is a test function to read respHeaders
|
||||
func (w *wsRespWriter) headers() http.Header {
|
||||
// Removing indeterminstic header because it cannot be asserted.
|
||||
@@ -852,6 +862,10 @@ func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
|
||||
return m.w.Write(p)
|
||||
}
|
||||
|
||||
func (w *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||
m.responseHeaders = header
|
||||
m.code = status
|
||||
|
Reference in New Issue
Block a user