TUN-3506: OriginService needs to set request host and scheme for websocket requests

This commit is contained in:
cthuang
2020-11-05 13:52:46 +00:00
parent be9a558867
commit 61c814bd79
3 changed files with 39 additions and 17 deletions

View File

@@ -55,9 +55,14 @@ func (o *unixSocketPath) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *unixSocketPath) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
return d.Dial(url, headers)
func (o *unixSocketPath) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
NetDial: o.transport.Dial,
NetDialContext: o.transport.DialContext,
TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Scheme = websocket.ChangeRequestScheme(reqURL)
return d.Dial(reqURL.String(), headers)
}
// localService is an OriginService listening on a TCP/IP address the user's origin can route to.
@@ -71,9 +76,12 @@ type localService struct {
transport *http.Transport
}
func (o *localService) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
func (o *localService) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
return d.Dial(url, headers)
// Rewrite the request URL so that it goes to the origin service.
reqURL.Host = o.URL.Host
reqURL.Scheme = websocket.ChangeRequestScheme(o.URL)
return d.Dial(reqURL.String(), headers)
}
func (o *localService) address() string {
@@ -215,9 +223,13 @@ func (o *helloWorld) RoundTrip(req *http.Request) (*http.Response, error) {
return o.transport.RoundTrip(req)
}
func (o *helloWorld) Dial(url string, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{TLSClientConfig: o.transport.TLSClientConfig}
return d.Dial(url, headers)
func (o *helloWorld) Dial(reqURL *url.URL, headers http.Header) (*gws.Conn, *http.Response, error) {
d := &gws.Dialer{
TLSClientConfig: o.transport.TLSClientConfig,
}
reqURL.Host = o.server.Addr().String()
reqURL.Scheme = "wss"
return d.Dial(reqURL.String(), headers)
}
func originRequiresProxy(staticHost string, cfg OriginRequestConfig) bool {