TUN-1977: Validate OriginConfig has valid URL, and use scheme to determine if a HTTPOriginService is expecting HTTP or Unix

This commit is contained in:
Chung-Ting Huang
2019-06-20 11:18:59 -05:00
parent 4090049fff
commit 4858ce79d0
10 changed files with 375 additions and 752 deletions

View File

@@ -134,49 +134,18 @@ type OriginConfig interface {
}
type HTTPOriginConfig struct {
URL OriginAddr `capnp:"url"`
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
DialDualStack bool
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
TLSVerify bool `capnp:"tlsVerify"`
OriginCAPool string
OriginServerName string
MaxIdleConnections uint64
IdleConnectionTimeout time.Duration
ProxyConnectTimeout time.Duration
ExpectContinueTimeout time.Duration
ChunkedEncoding bool
}
type OriginAddr interface {
Addr() string
}
type HTTPURL struct {
URL *url.URL
}
func (ha *HTTPURL) Addr() string {
return ha.URL.String()
}
func (ha *HTTPURL) capnpHTTPURL() *CapnpHTTPURL {
return &CapnpHTTPURL{
URL: ha.URL.String(),
}
}
// URL for a HTTP origin, capnp doesn't have native support for URL, so represent it as string
type CapnpHTTPURL struct {
URL string `capnp:"url"`
}
type UnixPath struct {
Path string
}
func (up *UnixPath) Addr() string {
return up.Path
URLString string `capnp:"urlString"`
TCPKeepAlive time.Duration `capnp:"tcpKeepAlive"`
DialDualStack bool
TLSHandshakeTimeout time.Duration `capnp:"tlsHandshakeTimeout"`
TLSVerify bool `capnp:"tlsVerify"`
OriginCAPool string
OriginServerName string
MaxIdleConnections uint64
IdleConnectionTimeout time.Duration
ProxyConnectionTimeout time.Duration
ExpectContinueTimeout time.Duration
ChunkedEncoding bool
}
func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
@@ -184,8 +153,9 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
if err != nil {
return nil, err
}
dialContext := (&net.Dialer{
Timeout: hc.ProxyConnectTimeout,
Timeout: hc.ProxyConnectionTimeout,
KeepAlive: hc.TCPKeepAlive,
DualStack: hc.DialDualStack,
}).DialContext
@@ -202,18 +172,22 @@ func (hc *HTTPOriginConfig) Service() (originservice.OriginService, error) {
IdleConnTimeout: hc.IdleConnectionTimeout,
ExpectContinueTimeout: hc.ExpectContinueTimeout,
}
if unixPath, ok := hc.URL.(*UnixPath); ok {
url, err := url.Parse(hc.URLString)
if err != nil {
return nil, errors.Wrapf(err, "%s is not a valid URL", hc.URLString)
}
if url.Scheme == "unix" {
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialContext(ctx, "unix", unixPath.Addr())
return dialContext(ctx, "unix", url.Host)
}
}
return originservice.NewHTTPService(transport, hc.URL.Addr(), hc.ChunkedEncoding), nil
return originservice.NewHTTPService(transport, url, hc.ChunkedEncoding), nil
}
func (_ *HTTPOriginConfig) isOriginConfig() {}
type WebSocketOriginConfig struct {
URL string `capnp:"url"`
URLString string `capnp:"urlString"`
TLSVerify bool `capnp:"tlsVerify"`
OriginCAPool string
OriginServerName string
@@ -229,7 +203,12 @@ func (wsc *WebSocketOriginConfig) Service() (originservice.OriginService, error)
ServerName: wsc.OriginServerName,
InsecureSkipVerify: wsc.TLSVerify,
}
return originservice.NewWebSocketService(tlsConfig, wsc.URL)
url, err := url.Parse(wsc.URLString)
if err != nil {
return nil, errors.Wrapf(err, "%s is not a valid URL", wsc.URLString)
}
return originservice.NewWebSocketService(tlsConfig, url)
}
func (_ *WebSocketOriginConfig) isOriginConfig() {}
@@ -550,115 +529,12 @@ func UnmarshalReverseProxyConfig(s tunnelrpc.ReverseProxyConfig) (*ReverseProxyC
}
func MarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig, p *HTTPOriginConfig) error {
switch originAddr := p.URL.(type) {
case *HTTPURL:
ss, err := s.OriginAddr().NewHttp()
if err != nil {
return err
}
if err := MarshalHTTPURL(ss, originAddr); err != nil {
return err
}
case *UnixPath:
ss, err := s.OriginAddr().NewUnix()
if err != nil {
return err
}
if err := MarshalUnixPath(ss, originAddr); err != nil {
return err
}
default:
return fmt.Errorf("Unknown type for OriginAddr: %T", originAddr)
}
s.SetTcpKeepAlive(p.TCPKeepAlive.Nanoseconds())
s.SetDialDualStack(p.DialDualStack)
s.SetTlsHandshakeTimeout(p.TLSHandshakeTimeout.Nanoseconds())
s.SetTlsVerify(p.TLSVerify)
s.SetOriginCAPool(p.OriginCAPool)
s.SetOriginServerName(p.OriginServerName)
s.SetMaxIdleConnections(p.MaxIdleConnections)
s.SetIdleConnectionTimeout(p.IdleConnectionTimeout.Nanoseconds())
s.SetProxyConnectionTimeout(p.ProxyConnectTimeout.Nanoseconds())
s.SetExpectContinueTimeout(p.ExpectContinueTimeout.Nanoseconds())
s.SetChunkedEncoding(p.ChunkedEncoding)
return nil
return pogs.Insert(tunnelrpc.HTTPOriginConfig_TypeID, s.Struct, p)
}
func UnmarshalHTTPOriginConfig(s tunnelrpc.HTTPOriginConfig) (*HTTPOriginConfig, error) {
p := new(HTTPOriginConfig)
switch s.OriginAddr().Which() {
case tunnelrpc.HTTPOriginConfig_originAddr_Which_http:
ss, err := s.OriginAddr().Http()
if err != nil {
return nil, err
}
originAddr, err := UnmarshalCapnpHTTPURL(ss)
if err != nil {
return nil, err
}
p.URL = originAddr
case tunnelrpc.HTTPOriginConfig_originAddr_Which_unix:
ss, err := s.OriginAddr().Unix()
if err != nil {
return nil, err
}
originAddr, err := UnmarshalUnixPath(ss)
if err != nil {
return nil, err
}
p.URL = originAddr
default:
return nil, fmt.Errorf("Unknown type for OriginAddr: %T", s.OriginAddr().Which())
}
p.TCPKeepAlive = time.Duration(s.TcpKeepAlive())
p.DialDualStack = s.DialDualStack()
p.TLSHandshakeTimeout = time.Duration(s.TlsHandshakeTimeout())
p.TLSVerify = s.TlsVerify()
originCAPool, err := s.OriginCAPool()
if err != nil {
return nil, err
}
p.OriginCAPool = originCAPool
originServerName, err := s.OriginServerName()
if err != nil {
return nil, err
}
p.OriginServerName = originServerName
p.MaxIdleConnections = s.MaxIdleConnections()
p.IdleConnectionTimeout = time.Duration(s.IdleConnectionTimeout())
p.ProxyConnectTimeout = time.Duration(s.ProxyConnectionTimeout())
p.ExpectContinueTimeout = time.Duration(s.ExpectContinueTimeout())
p.ChunkedEncoding = s.ChunkedEncoding()
return p, nil
}
func MarshalHTTPURL(s tunnelrpc.CapnpHTTPURL, p *HTTPURL) error {
return pogs.Insert(tunnelrpc.CapnpHTTPURL_TypeID, s.Struct, p.capnpHTTPURL())
}
func UnmarshalCapnpHTTPURL(s tunnelrpc.CapnpHTTPURL) (*HTTPURL, error) {
p := new(CapnpHTTPURL)
err := pogs.Extract(p, tunnelrpc.CapnpHTTPURL_TypeID, s.Struct)
if err != nil {
return nil, err
}
url, err := url.Parse(p.URL)
if err != nil {
return nil, err
}
return &HTTPURL{
URL: url,
}, nil
}
func MarshalUnixPath(s tunnelrpc.UnixPath, p *UnixPath) error {
err := pogs.Insert(tunnelrpc.UnixPath_TypeID, s.Struct, p)
return err
}
func UnmarshalUnixPath(s tunnelrpc.UnixPath) (*UnixPath, error) {
p := new(UnixPath)
err := pogs.Extract(p, tunnelrpc.UnixPath_TypeID, s.Struct)
err := pogs.Extract(p, tunnelrpc.HTTPOriginConfig_TypeID, s.Struct)
return p, err
}