mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 20:59:58 +00:00
TUN-3863: Consolidate header handling logic in the connection package; move headers definitions from h2mux to packages that manage them; cleanup header conversions
All header transformation code from h2mux has been consolidated in the connection package since it's used by both h2mux and http2 logic. Exported headers used by proxying between edge and cloudflared so then can be shared by tunnel service on the edge. Moved access-related headers to corresponding packages that have the code that sets/uses these headers. Removed tunnel hostname tracking from h2mux since it wasn't used by anything. We will continue to set the tunnel hostname header from the edge for backward compatibilty, but it's no longer used by cloudflared. Move bastion-related logic into carrier package, untangled dependencies between carrier, origin, and websocket packages.
This commit is contained in:
@@ -234,7 +234,7 @@ func (h *h2muxConnection) newRequest(stream *h2mux.MuxedStream) (*http.Request,
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Unexpected error from http.NewRequest")
|
||||
}
|
||||
err = h2mux.H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
err = H2RequestHeadersToH1Request(stream.Headers, req)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "invalid request received")
|
||||
}
|
||||
@@ -246,15 +246,15 @@ type h2muxRespWriter struct {
|
||||
}
|
||||
|
||||
func (rp *h2muxRespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||
headers := h2mux.H1ResponseToH2ResponseHeaders(status, header)
|
||||
headers = append(headers, h2mux.Header{Name: ResponseMetaHeaderField, Value: responseMetaHeaderOrigin})
|
||||
headers := H1ResponseToH2ResponseHeaders(status, header)
|
||||
headers = append(headers, h2mux.Header{Name: ResponseMetaHeader, Value: responseMetaHeaderOrigin})
|
||||
return rp.WriteHeaders(headers)
|
||||
}
|
||||
|
||||
func (rp *h2muxRespWriter) WriteErrorResponse() {
|
||||
_ = rp.WriteHeaders([]h2mux.Header{
|
||||
{Name: ":status", Value: "502"},
|
||||
{Name: ResponseMetaHeaderField, Value: responseMetaHeaderCfd},
|
||||
{Name: ResponseMetaHeader, Value: responseMetaHeaderCfd},
|
||||
})
|
||||
_, _ = rp.Write([]byte("502 Bad Gateway"))
|
||||
}
|
||||
|
@@ -115,9 +115,9 @@ func TestServeStreamHTTP(t *testing.T) {
|
||||
require.True(t, hasHeader(stream, ":status", strconv.Itoa(test.expectedStatus)))
|
||||
|
||||
if test.isProxyError {
|
||||
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderCfd))
|
||||
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderCfd))
|
||||
} else {
|
||||
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
|
||||
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||
body := make([]byte, len(test.expectedBody))
|
||||
_, err = stream.Read(body)
|
||||
require.NoError(t, err)
|
||||
@@ -164,7 +164,7 @@ func TestServeStreamWS(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, hasHeader(stream, ":status", strconv.Itoa(http.StatusSwitchingProtocols)))
|
||||
assert.True(t, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
|
||||
assert.True(t, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||
|
||||
data := []byte("test websocket")
|
||||
err = wsutil.WriteClientText(writePipe, data)
|
||||
@@ -268,7 +268,7 @@ func benchmarkServeStreamHTTPSimple(b *testing.B, test testRequest) {
|
||||
b.StopTimer()
|
||||
|
||||
require.NoError(b, openstreamErr)
|
||||
assert.True(b, hasHeader(stream, ResponseMetaHeaderField, responseMetaHeaderOrigin))
|
||||
assert.True(b, hasHeader(stream, ResponseMetaHeader, responseMetaHeaderOrigin))
|
||||
require.True(b, hasHeader(stream, ":status", strconv.Itoa(http.StatusOK)))
|
||||
require.NoError(b, readBodyErr)
|
||||
require.Equal(b, test.expectedBody, body)
|
||||
|
@@ -1,21 +1,33 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
const (
|
||||
ResponseMetaHeaderField = "cf-cloudflared-response-meta"
|
||||
var (
|
||||
// h2mux-style special headers
|
||||
RequestUserHeaders = "cf-cloudflared-request-headers"
|
||||
ResponseUserHeaders = "cf-cloudflared-response-headers"
|
||||
ResponseMetaHeader = "cf-cloudflared-response-meta"
|
||||
|
||||
// h2mux-style special headers
|
||||
CanonicalResponseUserHeaders = http.CanonicalHeaderKey(ResponseUserHeaders)
|
||||
CanonicalResponseMetaHeader = http.CanonicalHeaderKey(ResponseMetaHeader)
|
||||
)
|
||||
|
||||
var (
|
||||
canonicalResponseUserHeadersField = http.CanonicalHeaderKey(h2mux.ResponseUserHeadersField)
|
||||
canonicalResponseMetaHeaderField = http.CanonicalHeaderKey(ResponseMetaHeaderField)
|
||||
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
||||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
||||
// pre-generate possible values for res
|
||||
responseMetaHeaderCfd = mustInitRespMetaHeader("cloudflared")
|
||||
responseMetaHeaderOrigin = mustInitRespMetaHeader("origin")
|
||||
)
|
||||
|
||||
type responseMetaHeader struct {
|
||||
@@ -29,3 +41,204 @@ func mustInitRespMetaHeader(src string) string {
|
||||
}
|
||||
return string(header)
|
||||
}
|
||||
|
||||
var headerEncoding = base64.RawStdEncoding
|
||||
|
||||
// note: all h2mux headers should be lower-case (http/2 style)
|
||||
const ()
|
||||
|
||||
// H2RequestHeadersToH1Request converts the HTTP/2 headers coming from origintunneld
|
||||
// to an HTTP/1 Request object destined for the local origin web service.
|
||||
// This operation includes conversion of the pseudo-headers into their closest
|
||||
// HTTP/1 equivalents. See https://tools.ietf.org/html/rfc7540#section-8.1.2.3
|
||||
func H2RequestHeadersToH1Request(h2 []h2mux.Header, h1 *http.Request) error {
|
||||
for _, header := range h2 {
|
||||
name := strings.ToLower(header.Name)
|
||||
if !IsControlHeader(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
switch name {
|
||||
case ":method":
|
||||
h1.Method = header.Value
|
||||
case ":scheme":
|
||||
// noop - use the preexisting scheme from h1.URL
|
||||
case ":authority":
|
||||
// Otherwise the host header will be based on the origin URL
|
||||
h1.Host = header.Value
|
||||
case ":path":
|
||||
// We don't want to be an "opinionated" proxy, so ideally we would use :path as-is.
|
||||
// However, this HTTP/1 Request object belongs to the Go standard library,
|
||||
// whose URL package makes some opinionated decisions about the encoding of
|
||||
// URL characters: see the docs of https://godoc.org/net/url#URL,
|
||||
// in particular the EscapedPath method https://godoc.org/net/url#URL.EscapedPath,
|
||||
// which is always used when computing url.URL.String(), whether we'd like it or not.
|
||||
//
|
||||
// Well, not *always*. We could circumvent this by using url.URL.Opaque. But
|
||||
// that would present unusual difficulties when using an HTTP proxy: url.URL.Opaque
|
||||
// is treated differently when HTTP_PROXY is set!
|
||||
// See https://github.com/golang/go/issues/5684#issuecomment-66080888
|
||||
//
|
||||
// This means we are subject to the behavior of net/url's function `shouldEscape`
|
||||
// (as invoked with mode=encodePath): https://github.com/golang/go/blob/go1.12.7/src/net/url/url.go#L101
|
||||
|
||||
if header.Value == "*" {
|
||||
h1.URL.Path = "*"
|
||||
continue
|
||||
}
|
||||
// Due to the behavior of validation.ValidateUrl, h1.URL may
|
||||
// already have a partial value, with or without a trailing slash.
|
||||
base := h1.URL.String()
|
||||
base = strings.TrimRight(base, "/")
|
||||
// But we know :path begins with '/', because we handled '*' above - see RFC7540
|
||||
requestURL, err := url.Parse(base + header.Value)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, fmt.Sprintf("invalid path '%v'", header.Value))
|
||||
}
|
||||
h1.URL = requestURL
|
||||
case "content-length":
|
||||
contentLength, err := strconv.ParseInt(header.Value, 10, 64)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unparseable content length")
|
||||
}
|
||||
h1.ContentLength = contentLength
|
||||
case RequestUserHeaders:
|
||||
// Do not forward the serialized headers to the origin -- deserialize them, and ditch the serialized version
|
||||
// Find and parse user headers serialized into a single one
|
||||
userHeaders, err := DeserializeHeaders(header.Value)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "Unable to parse user headers")
|
||||
}
|
||||
for _, userHeader := range userHeaders {
|
||||
h1.Header.Add(userHeader.Name, userHeader.Value)
|
||||
}
|
||||
default:
|
||||
// All other control headers shall just be proxied transparently
|
||||
h1.Header.Add(header.Name, header.Value)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsControlHeader(headerName string) bool {
|
||||
return headerName == "content-length" ||
|
||||
headerName == "connection" || headerName == "upgrade" || // Websocket headers
|
||||
strings.HasPrefix(headerName, ":") ||
|
||||
strings.HasPrefix(headerName, "cf-")
|
||||
}
|
||||
|
||||
// isWebsocketClientHeader returns true if the header name is required by the client to upgrade properly
|
||||
func IsWebsocketClientHeader(headerName string) bool {
|
||||
return headerName == "sec-websocket-accept" ||
|
||||
headerName == "connection" ||
|
||||
headerName == "upgrade"
|
||||
}
|
||||
|
||||
func H1ResponseToH2ResponseHeaders(status int, h1 http.Header) (h2 []h2mux.Header) {
|
||||
h2 = []h2mux.Header{
|
||||
{Name: ":status", Value: strconv.Itoa(status)},
|
||||
}
|
||||
userHeaders := make(http.Header, len(h1))
|
||||
for header, values := range h1 {
|
||||
h2name := strings.ToLower(header)
|
||||
if h2name == "content-length" {
|
||||
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||
// so it should be sent as an HTTP/2 response header.
|
||||
|
||||
// Since these are http2 headers, they're required to be lowercase
|
||||
h2 = append(h2, h2mux.Header{Name: "content-length", Value: values[0]})
|
||||
} else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) {
|
||||
// User headers, on the other hand, must all be serialized so that
|
||||
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
||||
userHeaders[header] = values
|
||||
}
|
||||
}
|
||||
|
||||
// Perform user header serialization and set them in the single header
|
||||
h2 = append(h2, h2mux.Header{Name: ResponseUserHeaders, Value: SerializeHeaders(userHeaders)})
|
||||
return h2
|
||||
}
|
||||
|
||||
// Serialize HTTP1.x headers by base64-encoding each header name and value,
|
||||
// and then joining them in the format of [key:value;]
|
||||
func SerializeHeaders(h1Headers http.Header) string {
|
||||
// compute size of the fully serialized value and largest temp buffer we will need
|
||||
serializedLen := 0
|
||||
maxTempLen := 0
|
||||
for headerName, headerValues := range h1Headers {
|
||||
for _, headerValue := range headerValues {
|
||||
nameLen := headerEncoding.EncodedLen(len(headerName))
|
||||
valueLen := headerEncoding.EncodedLen(len(headerValue))
|
||||
const delims = 2
|
||||
serializedLen += delims + nameLen + valueLen
|
||||
if nameLen > maxTempLen {
|
||||
maxTempLen = nameLen
|
||||
}
|
||||
if valueLen > maxTempLen {
|
||||
maxTempLen = valueLen
|
||||
}
|
||||
}
|
||||
}
|
||||
var buf strings.Builder
|
||||
buf.Grow(serializedLen)
|
||||
|
||||
temp := make([]byte, maxTempLen)
|
||||
writeB64 := func(s string) {
|
||||
n := headerEncoding.EncodedLen(len(s))
|
||||
if n > len(temp) {
|
||||
temp = make([]byte, n)
|
||||
}
|
||||
headerEncoding.Encode(temp[:n], []byte(s))
|
||||
buf.Write(temp[:n])
|
||||
}
|
||||
|
||||
for headerName, headerValues := range h1Headers {
|
||||
for _, headerValue := range headerValues {
|
||||
if buf.Len() > 0 {
|
||||
buf.WriteByte(';')
|
||||
}
|
||||
writeB64(headerName)
|
||||
buf.WriteByte(':')
|
||||
writeB64(headerValue)
|
||||
}
|
||||
}
|
||||
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// Deserialize headers serialized by `SerializeHeader`
|
||||
func DeserializeHeaders(serializedHeaders string) ([]h2mux.Header, error) {
|
||||
const unableToDeserializeErr = "Unable to deserialize headers"
|
||||
|
||||
var deserialized []h2mux.Header
|
||||
for _, serializedPair := range strings.Split(serializedHeaders, ";") {
|
||||
if len(serializedPair) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
serializedHeaderParts := strings.Split(serializedPair, ":")
|
||||
if len(serializedHeaderParts) != 2 {
|
||||
return nil, errors.New(unableToDeserializeErr)
|
||||
}
|
||||
|
||||
serializedName := serializedHeaderParts[0]
|
||||
serializedValue := serializedHeaderParts[1]
|
||||
deserializedName := make([]byte, headerEncoding.DecodedLen(len(serializedName)))
|
||||
deserializedValue := make([]byte, headerEncoding.DecodedLen(len(serializedValue)))
|
||||
|
||||
if _, err := headerEncoding.Decode(deserializedName, []byte(serializedName)); err != nil {
|
||||
return nil, errors.Wrap(err, unableToDeserializeErr)
|
||||
}
|
||||
if _, err := headerEncoding.Decode(deserializedValue, []byte(serializedValue)); err != nil {
|
||||
return nil, errors.Wrap(err, unableToDeserializeErr)
|
||||
}
|
||||
|
||||
deserialized = append(deserialized, h2mux.Header{
|
||||
Name: string(deserializedName),
|
||||
Value: string(deserializedValue),
|
||||
})
|
||||
}
|
||||
|
||||
return deserialized, nil
|
||||
}
|
||||
|
677
connection/header_test.go
Normal file
677
connection/header_test.go
Normal file
@@ -0,0 +1,677 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
)
|
||||
|
||||
type ByName []h2mux.Header
|
||||
|
||||
func (a ByName) Len() int { return len(a) }
|
||||
func (a ByName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByName) Less(i, j int) bool {
|
||||
if a[i].Name == a[j].Name {
|
||||
return a[i].Value < a[j].Value
|
||||
}
|
||||
|
||||
return a[i].Name < a[j].Name
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_RegularHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockHeaders := http.Header{
|
||||
"Mock header 1": {"Mock value 1"},
|
||||
"Mock header 2": {"Mock value 2"},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(createSerializedHeaders(RequestUserHeaders, mockHeaders), request)
|
||||
|
||||
assert.True(t, reflect.DeepEqual(mockHeaders, request.Header))
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func createSerializedHeaders(headersField string, headers http.Header) []h2mux.Header {
|
||||
return []h2mux.Header{{
|
||||
Name: headersField,
|
||||
Value: SerializeHeaders(headers),
|
||||
}}
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_NoHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
emptyHeaders := make(http.Header)
|
||||
headersConversionErr := H2RequestHeadersToH1Request(
|
||||
[]h2mux.Header{{
|
||||
Name: RequestUserHeaders,
|
||||
Value: SerializeHeaders(emptyHeaders),
|
||||
}},
|
||||
request,
|
||||
)
|
||||
|
||||
assert.True(t, reflect.DeepEqual(emptyHeaders, request.Header))
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_InvalidHostPath(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: "//bad_path/"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com//bad_path/", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_HostPathWithQuery(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: "/?query=mock%20value"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com/?query=mock%20value", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_HostPathWithURLEncoding(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: "/mock%20path"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
|
||||
assert.Equal(t, http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
}, request.Header)
|
||||
|
||||
assert.Equal(t, "http://example.com/mock%20path", request.URL.String())
|
||||
|
||||
assert.NoError(t, headersConversionErr)
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_WeirdURLs(t *testing.T) {
|
||||
type testCase struct {
|
||||
path string
|
||||
want string
|
||||
}
|
||||
testCases := []testCase{
|
||||
{
|
||||
path: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
path: "/",
|
||||
want: "/",
|
||||
},
|
||||
{
|
||||
path: "//",
|
||||
want: "//",
|
||||
},
|
||||
{
|
||||
path: "/test",
|
||||
want: "/test",
|
||||
},
|
||||
{
|
||||
path: "//test",
|
||||
want: "//test",
|
||||
},
|
||||
{
|
||||
// https://github.com/cloudflare/cloudflared/issues/81
|
||||
path: "//test/",
|
||||
want: "//test/",
|
||||
},
|
||||
{
|
||||
path: "/%2Ftest",
|
||||
want: "/%2Ftest",
|
||||
},
|
||||
{
|
||||
path: "//%20test",
|
||||
want: "//%20test",
|
||||
},
|
||||
{
|
||||
// https://github.com/cloudflare/cloudflared/issues/124
|
||||
path: "/test?get=somthing%20a",
|
||||
want: "/test?get=somthing%20a",
|
||||
},
|
||||
{
|
||||
path: "/%20",
|
||||
want: "/%20",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode ' '
|
||||
path: "/ ",
|
||||
want: "/%20",
|
||||
},
|
||||
{
|
||||
path: "/ a ",
|
||||
want: "/%20a%20",
|
||||
},
|
||||
{
|
||||
path: "/a%20b",
|
||||
want: "/a%20b",
|
||||
},
|
||||
{
|
||||
path: "/foo/bar;param?query#frag",
|
||||
want: "/foo/bar;param?query#frag",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode non-ASCII chars
|
||||
path: "/a␠b",
|
||||
want: "/a%E2%90%A0b",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-ä",
|
||||
want: "/a-umlaut-%C3%A4",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-%C3%A4",
|
||||
want: "/a-umlaut-%C3%A4",
|
||||
},
|
||||
{
|
||||
path: "/a-umlaut-%c3%a4",
|
||||
want: "/a-umlaut-%c3%a4",
|
||||
},
|
||||
{
|
||||
// here the second '#' is treated as part of the fragment
|
||||
path: "/a#b#c",
|
||||
want: "/a#b%23c",
|
||||
},
|
||||
{
|
||||
path: "/a#b␠c",
|
||||
want: "/a#b%E2%90%A0c",
|
||||
},
|
||||
{
|
||||
path: "/a#b%20c",
|
||||
want: "/a#b%20c",
|
||||
},
|
||||
{
|
||||
path: "/a#b c",
|
||||
want: "/a#b%20c",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode '\'
|
||||
path: "/\\",
|
||||
want: "/%5C",
|
||||
},
|
||||
{
|
||||
path: "/a\\",
|
||||
want: "/a%5C",
|
||||
},
|
||||
{
|
||||
path: "/a,b.c.",
|
||||
want: "/a,b.c.",
|
||||
},
|
||||
{
|
||||
path: "/.",
|
||||
want: "/.",
|
||||
},
|
||||
{
|
||||
// stdlib's EscapedPath() will always percent-encode '`'
|
||||
path: "/a`",
|
||||
want: "/a%60",
|
||||
},
|
||||
{
|
||||
path: "/a[0]",
|
||||
want: "/a[0]",
|
||||
},
|
||||
{
|
||||
path: "/?a[0]=5 &b[]=",
|
||||
want: "/?a[0]=5 &b[]=",
|
||||
},
|
||||
{
|
||||
path: "/?a=%22b%20%22",
|
||||
want: "/?a=%22b%20%22",
|
||||
},
|
||||
}
|
||||
|
||||
for index, testCase := range testCases {
|
||||
requestURL := "https://example.com"
|
||||
|
||||
request, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockRequestHeaders := []h2mux.Header{
|
||||
{Name: ":path", Value: testCase.path},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(http.Header{"Mock header": {"Mock value"}})},
|
||||
}
|
||||
|
||||
headersConversionErr := H2RequestHeadersToH1Request(mockRequestHeaders, request)
|
||||
assert.NoError(t, headersConversionErr)
|
||||
|
||||
assert.Equal(t,
|
||||
http.Header{
|
||||
"Mock header": []string{"Mock value"},
|
||||
},
|
||||
request.Header)
|
||||
|
||||
assert.Equal(t,
|
||||
"https://example.com"+testCase.want,
|
||||
request.URL.String(),
|
||||
"Failed URL index: %v %#v", index, testCase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestH2RequestHeadersToH1Request_QuickCheck(t *testing.T) {
|
||||
config := &quick.Config{
|
||||
Values: func(args []reflect.Value, rand *rand.Rand) {
|
||||
args[0] = reflect.ValueOf(randomHTTP2Path(t, rand))
|
||||
},
|
||||
}
|
||||
|
||||
type testOrigin struct {
|
||||
url string
|
||||
|
||||
expectedScheme string
|
||||
expectedBasePath string
|
||||
}
|
||||
testOrigins := []testOrigin{
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/api",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
{
|
||||
url: "http://origin.hostname.example.com:8080/api/",
|
||||
expectedScheme: "http",
|
||||
expectedBasePath: "http://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
{
|
||||
url: "https://origin.hostname.example.com:8080/api",
|
||||
expectedScheme: "https",
|
||||
expectedBasePath: "https://origin.hostname.example.com:8080/api",
|
||||
},
|
||||
}
|
||||
|
||||
// use multiple schemes to demonstrate that the URL is based on the
|
||||
// origin's scheme, not the :scheme header
|
||||
for _, testScheme := range []string{"http", "https"} {
|
||||
for _, testOrigin := range testOrigins {
|
||||
assertion := func(testPath string) bool {
|
||||
const expectedMethod = "POST"
|
||||
const expectedHostname = "request.hostname.example.com"
|
||||
|
||||
h2 := []h2mux.Header{
|
||||
{Name: ":method", Value: expectedMethod},
|
||||
{Name: ":scheme", Value: testScheme},
|
||||
{Name: ":authority", Value: expectedHostname},
|
||||
{Name: ":path", Value: testPath},
|
||||
{Name: RequestUserHeaders, Value: ""},
|
||||
}
|
||||
h1, err := http.NewRequest("GET", testOrigin.url, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = H2RequestHeadersToH1Request(h2, h1)
|
||||
return assert.NoError(t, err) &&
|
||||
assert.Equal(t, expectedMethod, h1.Method) &&
|
||||
assert.Equal(t, expectedHostname, h1.Host) &&
|
||||
assert.Equal(t, testOrigin.expectedScheme, h1.URL.Scheme) &&
|
||||
assert.Equal(t, testOrigin.expectedBasePath+testPath, h1.URL.String())
|
||||
}
|
||||
err := quick.Check(assertion, config)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func randomASCIIPrintableChar(rand *rand.Rand) int {
|
||||
// smallest printable ASCII char is 32, largest is 126
|
||||
const startPrintable = 32
|
||||
const endPrintable = 127
|
||||
return startPrintable + rand.Intn(endPrintable-startPrintable)
|
||||
}
|
||||
|
||||
// randomASCIIText generates an ASCII string, some of whose characters may be
|
||||
// percent-encoded. Its "logical length" (ignoring percent-encoding) is
|
||||
// between 1 and `maxLength`.
|
||||
func randomASCIIText(rand *rand.Rand, minLength int, maxLength int) string {
|
||||
length := minLength + rand.Intn(maxLength)
|
||||
var result strings.Builder
|
||||
for i := 0; i < length; i++ {
|
||||
c := randomASCIIPrintableChar(rand)
|
||||
|
||||
// 1/4 chance of using percent encoding when not necessary
|
||||
if c == '%' || rand.Intn(4) == 0 {
|
||||
result.WriteString(fmt.Sprintf("%%%02X", c))
|
||||
} else {
|
||||
result.WriteByte(byte(c))
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL path,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Path(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
re, err := regexp.Compile("[^/;,]*")
|
||||
require.NoError(t, err)
|
||||
return "/" + re.ReplaceAllStringFunc(text, url.PathEscape)
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL query,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Query(rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
return "?" + strings.ReplaceAll(text, "#", "%23")
|
||||
}
|
||||
|
||||
// Calls `randomASCIIText` and ensures the result is a valid URL fragment,
|
||||
// i.e. one that can pass unchanged through url.URL.String()
|
||||
func randomHTTP1Fragment(t *testing.T, rand *rand.Rand, minLength int, maxLength int) string {
|
||||
text := randomASCIIText(rand, minLength, maxLength)
|
||||
u, err := url.Parse("#" + text)
|
||||
require.NoError(t, err)
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// Assemble a random :path pseudoheader that is legal by Go stdlib standards
|
||||
// (i.e. all characters will satisfy "net/url".shouldEscape for their respective locations)
|
||||
func randomHTTP2Path(t *testing.T, rand *rand.Rand) string {
|
||||
result := randomHTTP1Path(t, rand, 1, 64)
|
||||
if rand.Intn(2) == 1 {
|
||||
result += randomHTTP1Query(rand, 1, 32)
|
||||
}
|
||||
if rand.Intn(2) == 1 {
|
||||
result += randomHTTP1Fragment(t, rand, 1, 16)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func stdlibHeaderToH2muxHeader(headers http.Header) (h2muxHeaders []h2mux.Header) {
|
||||
for name, values := range headers {
|
||||
for _, value := range values {
|
||||
h2muxHeaders = append(h2muxHeaders, h2mux.Header{Name: name, Value: value})
|
||||
}
|
||||
}
|
||||
|
||||
return h2muxHeaders
|
||||
}
|
||||
|
||||
func TestSerializeHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mockHeaders := http.Header{
|
||||
"Mock-Header-One": {"Mock header one value", "three"},
|
||||
"Mock-Header-Two-Long": {"Mock header two value\nlong"},
|
||||
":;": {":;", ";:"},
|
||||
":": {":"},
|
||||
";": {";"},
|
||||
";;": {";;"},
|
||||
"Empty values": {"", ""},
|
||||
"": {"Empty key"},
|
||||
"control\tcharacter\b\n": {"value\n\b\t"},
|
||||
";\v:": {":\v;"},
|
||||
}
|
||||
|
||||
for header, values := range mockHeaders {
|
||||
for _, value := range values {
|
||||
// Note that Golang's http library is opinionated;
|
||||
// at this point every header name will be title-cased in order to comply with the HTTP RFC
|
||||
// This means our proxy is not completely transparent when it comes to proxying headers
|
||||
request.Header.Add(header, value)
|
||||
}
|
||||
}
|
||||
|
||||
serializedHeaders := SerializeHeaders(request.Header)
|
||||
|
||||
// Sanity check: the headers serialized to something that's not an empty string
|
||||
assert.NotEqual(t, "", serializedHeaders)
|
||||
|
||||
// Deserialize back, and ensure we get the same set of headers
|
||||
deserializedHeaders, err := DeserializeHeaders(serializedHeaders)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 13, len(deserializedHeaders))
|
||||
h2muxExpectedHeaders := stdlibHeaderToH2muxHeader(mockHeaders)
|
||||
|
||||
sort.Sort(ByName(deserializedHeaders))
|
||||
sort.Sort(ByName(h2muxExpectedHeaders))
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
reflect.DeepEqual(h2muxExpectedHeaders, deserializedHeaders),
|
||||
fmt.Sprintf("got = %#v, want = %#v\n", deserializedHeaders, h2muxExpectedHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
func TestSerializeNoHeaders(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
|
||||
assert.NoError(t, err)
|
||||
|
||||
serializedHeaders := SerializeHeaders(request.Header)
|
||||
deserializedHeaders, err := DeserializeHeaders(serializedHeaders)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(deserializedHeaders))
|
||||
}
|
||||
|
||||
func TestDeserializeMalformed(t *testing.T) {
|
||||
var err error
|
||||
|
||||
malformedData := []string{
|
||||
"malformed data",
|
||||
"bW9jawo=", // "mock"
|
||||
"bW9jawo=:ZGF0YQo=:bW9jawo=", // "mock:data:mock"
|
||||
"::",
|
||||
}
|
||||
|
||||
for _, malformedValue := range malformedData {
|
||||
_, err = DeserializeHeaders(malformedValue)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHeaders(t *testing.T) {
|
||||
mockUserHeadersToSerialize := http.Header{
|
||||
"Mock-Header-One": {"1", "1.5"},
|
||||
"Mock-Header-Two": {"2"},
|
||||
"Mock-Header-Three": {"3"},
|
||||
}
|
||||
|
||||
mockHeaders := []h2mux.Header{
|
||||
{Name: "One", Value: "1"}, // will be dropped
|
||||
{Name: "Cf-Two", Value: "cf-value-1"},
|
||||
{Name: "Cf-Two", Value: "cf-value-2"},
|
||||
{Name: RequestUserHeaders, Value: SerializeHeaders(mockUserHeadersToSerialize)},
|
||||
}
|
||||
|
||||
expectedHeaders := []h2mux.Header{
|
||||
{Name: "Cf-Two", Value: "cf-value-1"},
|
||||
{Name: "Cf-Two", Value: "cf-value-2"},
|
||||
{Name: "Mock-Header-One", Value: "1"},
|
||||
{Name: "Mock-Header-One", Value: "1.5"},
|
||||
{Name: "Mock-Header-Two", Value: "2"},
|
||||
{Name: "Mock-Header-Three", Value: "3"},
|
||||
}
|
||||
h1 := &http.Request{
|
||||
Header: make(http.Header),
|
||||
}
|
||||
err := H2RequestHeadersToH1Request(mockHeaders, h1)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, expectedHeaders, stdlibHeaderToH2muxHeader(h1.Header))
|
||||
}
|
||||
|
||||
func TestIsControlHeader(t *testing.T) {
|
||||
controlHeaders := []string{
|
||||
// Anything that begins with cf-
|
||||
"cf-sample-header",
|
||||
|
||||
// Any http2 pseudoheader
|
||||
":sample-pseudo-header",
|
||||
|
||||
// content-length is a special case, it has to be there
|
||||
// for some requests to work (per the HTTP2 spec)
|
||||
"content-length",
|
||||
}
|
||||
|
||||
for _, header := range controlHeaders {
|
||||
assert.True(t, IsControlHeader(header))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNotControlHeader(t *testing.T) {
|
||||
notControlHeaders := []string{
|
||||
"mock-header",
|
||||
"another-sample-header",
|
||||
}
|
||||
|
||||
for _, header := range notControlHeaders {
|
||||
assert.False(t, IsControlHeader(header))
|
||||
}
|
||||
}
|
||||
|
||||
func TestH1ResponseToH2ResponseHeaders(t *testing.T) {
|
||||
mockHeaders := http.Header{
|
||||
"User-header-one": {""},
|
||||
"User-header-two": {"1", "2"},
|
||||
"cf-header": {"cf-value"},
|
||||
"Content-Length": {"123"},
|
||||
}
|
||||
mockResponse := http.Response{
|
||||
StatusCode: 200,
|
||||
Header: mockHeaders,
|
||||
}
|
||||
|
||||
headers := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
|
||||
|
||||
serializedHeadersIndex := -1
|
||||
for i, header := range headers {
|
||||
if header.Name == ResponseUserHeaders {
|
||||
serializedHeadersIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.NotEqual(t, -1, serializedHeadersIndex)
|
||||
actualControlHeaders := append(
|
||||
headers[:serializedHeadersIndex],
|
||||
headers[serializedHeadersIndex+1:]...,
|
||||
)
|
||||
expectedControlHeaders := []h2mux.Header{
|
||||
{Name: ":status", Value: "200"},
|
||||
{Name: "content-length", Value: "123"},
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, expectedControlHeaders, actualControlHeaders)
|
||||
|
||||
actualUserHeaders, err := DeserializeHeaders(headers[serializedHeadersIndex].Value)
|
||||
expectedUserHeaders := []h2mux.Header{
|
||||
{Name: "User-header-one", Value: ""},
|
||||
{Name: "User-header-two", Value: "1"},
|
||||
{Name: "User-header-two", Value: "2"},
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, expectedUserHeaders, actualUserHeaders)
|
||||
}
|
||||
|
||||
// The purpose of this test is to check that our code and the http.Header
|
||||
// implementation don't throw validation errors about header size
|
||||
func TestHeaderSize(t *testing.T) {
|
||||
largeValue := randSeq(5 * 1024 * 1024) // 5Mb
|
||||
largeHeaders := http.Header{
|
||||
"User-header": {largeValue},
|
||||
}
|
||||
mockResponse := http.Response{
|
||||
StatusCode: 200,
|
||||
Header: largeHeaders,
|
||||
}
|
||||
|
||||
serializedHeaders := H1ResponseToH2ResponseHeaders(mockResponse.StatusCode, mockResponse.Header)
|
||||
request, err := http.NewRequest(http.MethodGet, "https://example.com/", nil)
|
||||
assert.NoError(t, err)
|
||||
for _, header := range serializedHeaders {
|
||||
request.Header.Set(header.Name, header.Value)
|
||||
}
|
||||
|
||||
for _, header := range serializedHeaders {
|
||||
if header.Name != ResponseUserHeaders {
|
||||
continue
|
||||
}
|
||||
|
||||
deserializedHeaders, err := DeserializeHeaders(header.Value)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, largeValue, deserializedHeaders[0].Value)
|
||||
}
|
||||
}
|
||||
|
||||
func randSeq(n int) string {
|
||||
randomizer := rand.New(rand.NewSource(17))
|
||||
var letters = []rune(":;,+/=abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
|
||||
b := make([]rune, n)
|
||||
for i := range b {
|
||||
b[i] = letters[randomizer.Intn(len(letters))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func BenchmarkH1ResponseToH2ResponseHeaders(b *testing.B) {
|
||||
ser := "eC1mb3J3YXJkZWQtcHJvdG8:aHR0cHM;dXBncmFkZS1pbnNlY3VyZS1yZXF1ZXN0cw:MQ;YWNjZXB0LWxhbmd1YWdl:ZW4tVVMsZW47cT0wLjkscnU7cT0wLjg;YWNjZXB0LWVuY29kaW5n:Z3ppcA;eC1mb3J3YXJkZWQtZm9y:MTczLjI0NS42MC42;dXNlci1hZ2VudA:TW96aWxsYS81LjAgKE1hY2ludG9zaDsgSW50ZWwgTWFjIE9TIFggMTBfMTRfNikgQXBwbGVXZWJLaXQvNTM3LjM2IChLSFRNTCwgbGlrZSBHZWNrbykgQ2hyb21lLzg0LjAuNDE0Ny44OSBTYWZhcmkvNTM3LjM2;c2VjLWZldGNoLW1vZGU:bmF2aWdhdGU;Y2RuLWxvb3A:Y2xvdWRmbGFyZQ;c2VjLWZldGNoLWRlc3Q:ZG9jdW1lbnQ;c2VjLWZldGNoLXVzZXI:PzE;c2VjLWZldGNoLXNpdGU:bm9uZQ;Y29va2ll:X19jZmR1aWQ9ZGNkOWZjOGNjNWMxMzE0NTMyYTFkMjhlZDEyOWRhOTYwMTU2OTk1MTYzNDsgX19jZl9ibT1mYzY2MzMzYzAzZmM0MWFiZTZmOWEyYzI2ZDUwOTA0YzIxYzZhMTQ2LTE1OTU2MjIzNDEtMTgwMC1BZTVzS2pIU2NiWGVFM05mMUhrTlNQMG1tMHBLc2pQWkloVnM1Z2g1SkNHQkFhS1UxVDB2b003alBGN3FjMHVSR2NjZGcrWHdhL1EzbTJhQzdDVU4xZ2M9;YWNjZXB0:dGV4dC9odG1sLGFwcGxpY2F0aW9uL3hodG1sK3htbCxhcHBsaWNhdGlvbi94bWw7cT0wLjksaW1hZ2Uvd2VicCxpbWFnZS9hcG5nLCovKjtxPTAuOCxhcHBsaWNhdGlvbi9zaWduZWQtZXhjaGFuZ2U7dj1iMztxPTAuOQ"
|
||||
h2, _ := DeserializeHeaders(ser)
|
||||
h1 := make(http.Header)
|
||||
for _, header := range h2 {
|
||||
h1.Add(header.Name, header.Value)
|
||||
}
|
||||
h1.Add("Content-Length", "200")
|
||||
h1.Add("Cf-Something", "Else")
|
||||
h1.Add("Upgrade", "websocket")
|
||||
|
||||
h1resp := &http.Response{
|
||||
StatusCode: 200,
|
||||
Header: h1,
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = H1ResponseToH2ResponseHeaders(h1resp.StatusCode, h1resp.Header)
|
||||
}
|
||||
}
|
@@ -13,15 +13,15 @@ import (
|
||||
"github.com/rs/zerolog"
|
||||
"golang.org/x/net/http2"
|
||||
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
|
||||
)
|
||||
|
||||
// note: these constants are exported so we can reuse them in the edge-side code
|
||||
const (
|
||||
internalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
||||
tcpStreamHeader = "Cf-Cloudflared-Proxy-Src"
|
||||
websocketUpgrade = "websocket"
|
||||
controlStreamUpgrade = "control-stream"
|
||||
InternalUpgradeHeader = "Cf-Cloudflared-Proxy-Connection-Upgrade"
|
||||
InternalTCPProxySrcHeader = "Cf-Cloudflared-Proxy-Src"
|
||||
WebsocketUpgrade = "websocket"
|
||||
ControlStreamUpgrade = "control-stream"
|
||||
)
|
||||
|
||||
var errEdgeConnectionClosed = fmt.Errorf("connection with edge closed")
|
||||
@@ -178,25 +178,23 @@ func newHTTP2RespWriter(r *http.Request, w http.ResponseWriter, connType Type) (
|
||||
func (rp *http2RespWriter) WriteRespHeaders(status int, header http.Header) error {
|
||||
dest := rp.w.Header()
|
||||
userHeaders := make(http.Header, len(header))
|
||||
for header, values := range header {
|
||||
for name, values := range header {
|
||||
// Since these are http2 headers, they're required to be lowercase
|
||||
h2name := strings.ToLower(header)
|
||||
for _, v := range values {
|
||||
if h2name == "content-length" {
|
||||
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||
// so it should be sent as an HTTP/2 response header.
|
||||
dest.Add(h2name, v)
|
||||
// Since these are http2 headers, they're required to be lowercase
|
||||
} else if !h2mux.IsControlHeader(h2name) || h2mux.IsWebsocketClientHeader(h2name) {
|
||||
// User headers, on the other hand, must all be serialized so that
|
||||
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
||||
userHeaders.Add(h2name, v)
|
||||
}
|
||||
h2name := strings.ToLower(name)
|
||||
if h2name == "content-length" {
|
||||
// This header has meaning in HTTP/2 and will be used by the edge,
|
||||
// so it should be sent as an HTTP/2 response header.
|
||||
dest[name] = values
|
||||
// Since these are http2 headers, they're required to be lowercase
|
||||
} else if !IsControlHeader(h2name) || IsWebsocketClientHeader(h2name) {
|
||||
// User headers, on the other hand, must all be serialized so that
|
||||
// HTTP/2 header validation won't be applied to HTTP/1 header values
|
||||
userHeaders[name] = values
|
||||
}
|
||||
}
|
||||
|
||||
// Perform user header serialization and set them in the single header
|
||||
dest.Set(canonicalResponseUserHeadersField, h2mux.SerializeHeaders(userHeaders))
|
||||
dest.Set(CanonicalResponseUserHeaders, SerializeHeaders(userHeaders))
|
||||
rp.setResponseMetaHeader(responseMetaHeaderOrigin)
|
||||
// HTTP2 removes support for 101 Switching Protocols https://tools.ietf.org/html/rfc7540#section-8.1.1
|
||||
if status == http.StatusSwitchingProtocols {
|
||||
@@ -218,7 +216,7 @@ func (rp *http2RespWriter) WriteErrorResponse() {
|
||||
}
|
||||
|
||||
func (rp *http2RespWriter) setResponseMetaHeader(value string) {
|
||||
rp.w.Header().Set(canonicalResponseMetaHeaderField, value)
|
||||
rp.w.Header().Set(CanonicalResponseMetaHeader, value)
|
||||
}
|
||||
|
||||
func (rp *http2RespWriter) Read(p []byte) (n int, err error) {
|
||||
@@ -258,18 +256,18 @@ func determineHTTP2Type(r *http.Request) Type {
|
||||
}
|
||||
|
||||
func isControlStreamUpgrade(r *http.Request) bool {
|
||||
return r.Header.Get(internalUpgradeHeader) == controlStreamUpgrade
|
||||
return r.Header.Get(InternalUpgradeHeader) == ControlStreamUpgrade
|
||||
}
|
||||
|
||||
func isWebsocketUpgrade(r *http.Request) bool {
|
||||
return r.Header.Get(internalUpgradeHeader) == websocketUpgrade
|
||||
return r.Header.Get(InternalUpgradeHeader) == WebsocketUpgrade
|
||||
}
|
||||
|
||||
// IsTCPStream discerns if the connection request needs a tcp stream proxy.
|
||||
func IsTCPStream(r *http.Request) bool {
|
||||
return r.Header.Get(tcpStreamHeader) != ""
|
||||
return r.Header.Get(InternalTCPProxySrcHeader) != ""
|
||||
}
|
||||
|
||||
func stripWebsocketUpgradeHeader(r *http.Request) {
|
||||
r.Header.Del(internalUpgradeHeader)
|
||||
r.Header.Del(InternalUpgradeHeader)
|
||||
}
|
||||
|
@@ -103,9 +103,9 @@ func TestServeHTTP(t *testing.T) {
|
||||
require.Equal(t, test.expectedBody, respBody)
|
||||
}
|
||||
if test.isProxyError {
|
||||
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeaderField))
|
||||
require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
|
||||
} else {
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField))
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
@@ -191,7 +191,7 @@ func TestServeWS(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws", readPipe)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(internalUpgradeHeader, websocketUpgrade)
|
||||
req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -211,7 +211,7 @@ func TestServeWS(t *testing.T) {
|
||||
resp := respWriter.Result()
|
||||
// http2RespWriter should rewrite status 101 to 200
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeaderField))
|
||||
require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -235,7 +235,7 @@ func TestServeControlStream(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
||||
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(t, err)
|
||||
@@ -274,7 +274,7 @@ func TestFailRegistration(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
||||
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(t, err)
|
||||
@@ -310,7 +310,7 @@ func TestGracefulShutdownHTTP2(t *testing.T) {
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(internalUpgradeHeader, controlStreamUpgrade)
|
||||
req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
|
||||
|
||||
edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
|
||||
require.NoError(t, err)
|
||||
|
Reference in New Issue
Block a user