TUN-1626: Create new supervisor to establish connection with origintunneld

This commit is contained in:
Chung-Ting Huang
2019-03-18 18:14:47 -05:00
parent 980bee22a5
commit c18702f297
8 changed files with 332 additions and 6 deletions

159
connection/connection.go Normal file
View File

@@ -0,0 +1,159 @@
package connection
import (
"context"
"crypto/tls"
"net"
"sync"
"time"
"github.com/cloudflare/cloudflared/h2mux"
"github.com/cloudflare/cloudflared/tunnelrpc"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
rpc "zombiezen.com/go/capnproto2/rpc"
)
const (
dialTimeout = 5 * time.Second
)
type dialError struct {
cause error
}
func (e dialError) Error() string {
return e.cause.Error()
}
type muxerShutdownError struct{}
func (e muxerShutdownError) Error() string {
return "muxer shutdown"
}
type ConnectionConfig struct {
TLSConfig *tls.Config
HeartbeatInterval time.Duration
MaxHeartbeats uint64
Logger *logrus.Entry
}
type connectionHandler interface {
serve(ctx context.Context) error
connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error)
shutdown()
}
type h2muxHandler struct {
muxer *h2mux.Muxer
logger *logrus.Entry
}
type muxedStreamHandler struct {
}
// Implements MuxedStreamHandler interface
func (h *muxedStreamHandler) ServeStream(stream *h2mux.MuxedStream) error {
return nil
}
func (h *h2muxHandler) serve(ctx context.Context) error {
// Serve doesn't return until h2mux is shutdown
if err := h.muxer.Serve(ctx); err != nil {
return err
}
return muxerShutdownError{}
}
// Connect is used to establish connections with cloudflare's edge network
func (h *h2muxHandler) connect(ctx context.Context, parameters *tunnelpogs.ConnectParameters) (*tunnelpogs.ConnectResult, error) {
conn, err := h.newRPConn()
if err != nil {
return nil, errors.Wrap(err, "Failed to create new RPC connection")
}
defer conn.Close()
tsClient := tunnelpogs.TunnelServer_PogsClient{Client: conn.Bootstrap(ctx)}
return tsClient.Connect(ctx, parameters)
}
func (h *h2muxHandler) shutdown() {
h.muxer.Shutdown()
}
func (h *h2muxHandler) newRPConn() (*rpc.Conn, error) {
stream, err := h.muxer.OpenStream([]h2mux.Header{
{Name: ":method", Value: "RPC"},
{Name: ":scheme", Value: "capnp"},
{Name: ":path", Value: "*"},
}, nil)
if err != nil {
return nil, err
}
return rpc.NewConn(
tunnelrpc.NewTransportLogger(h.logger.WithField("subsystem", "rpc-register"), rpc.StreamTransport(stream)),
tunnelrpc.ConnLog(h.logger.WithField("subsystem", "rpc-transport")),
), nil
}
// NewConnectionHandler returns a connectionHandler, wrapping h2mux to make RPC calls
func newH2MuxHandler(ctx context.Context,
config *ConnectionConfig,
edgeIP *net.TCPAddr,
) (connectionHandler, error) {
// Inherit from parent context so we can cancel (Ctrl-C) while dialing
dialCtx, dialCancel := context.WithTimeout(ctx, dialTimeout)
defer dialCancel()
dialer := net.Dialer{DualStack: true}
plaintextEdgeConn, err := dialer.DialContext(dialCtx, "tcp", edgeIP.String())
if err != nil {
return nil, dialError{cause: errors.Wrap(err, "DialContext error")}
}
edgeConn := tls.Client(plaintextEdgeConn, config.TLSConfig)
edgeConn.SetDeadline(time.Now().Add(dialTimeout))
err = edgeConn.Handshake()
if err != nil {
return nil, dialError{cause: errors.Wrap(err, "Handshake with edge error")}
}
// clear the deadline on the conn; h2mux has its own timeouts
edgeConn.SetDeadline(time.Time{})
// Establish a muxed connection with the edge
// Client mux handshake with agent server
muxer, err := h2mux.Handshake(edgeConn, edgeConn, h2mux.MuxerConfig{
Timeout: dialTimeout,
Handler: &muxedStreamHandler{},
IsClient: true,
HeartbeatInterval: config.HeartbeatInterval,
MaxHeartbeats: config.MaxHeartbeats,
Logger: config.Logger,
})
if err != nil {
return nil, err
}
return &h2muxHandler{
muxer: muxer,
logger: config.Logger,
}, nil
}
// connectionPool is a pool of connection handlers
type connectionPool struct {
sync.Mutex
connectionHandlers []connectionHandler
}
func (cp *connectionPool) put(h connectionHandler) {
cp.Lock()
defer cp.Unlock()
cp.connectionHandlers = append(cp.connectionHandlers, h)
}
func (cp *connectionPool) close() {
cp.Lock()
defer cp.Unlock()
for _, h := range cp.connectionHandlers {
h.shutdown()
}
}

143
connection/discovery.go Normal file
View File

@@ -0,0 +1,143 @@
package connection
import (
"context"
"crypto/tls"
"fmt"
"net"
"time"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
)
const (
// Used to discover HA Warp servers
srvService = "warp"
srvProto = "tcp"
srvName = "cloudflarewarp.com"
// Used to fallback to DoT when we can't use the default resolver to
// discover HA Warp servers (GitHub issue #75).
dotServerName = "cloudflare-dns.com"
dotServerAddr = "1.1.1.1:853"
dotTimeout = time.Duration(15 * time.Second)
)
var friendlyDNSErrorLines = []string{
`Please try the following things to diagnose this issue:`,
` 1. ensure that cloudflarewarp.com is returning "warp" service records.`,
` Run your system's equivalent of: dig srv _warp._tcp.cloudflarewarp.com`,
` 2. ensure that your DNS resolver is not returning compressed SRV records.`,
` See GitHub issue https://github.com/golang/go/issues/27546`,
` For example, you could use Cloudflare's 1.1.1.1 as your resolver:`,
` https://developers.cloudflare.com/1.1.1.1/setting-up-1.1.1.1/`,
}
func ResolveEdgeIPs(logger *log.Logger, addresses []string) ([]*net.TCPAddr, error) {
if len(addresses) > 0 {
var tcpAddrs []*net.TCPAddr
for _, address := range addresses {
// Addresses specified (for testing, usually)
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
if err != nil {
return nil, err
}
tcpAddrs = append(tcpAddrs, tcpAddr)
}
return tcpAddrs, nil
}
// HA service discovery lookup
_, addrs, err := net.LookupSRV(srvService, srvProto, srvName)
if err != nil {
// Try to fall back to DoT from Cloudflare directly.
//
// Note: Instead of DoT, we could also have used DoH. Either of these:
// - directly via the JSON API (https://1.1.1.1/dns-query?ct=application/dns-json&name=_warp._tcp.cloudflarewarp.com&type=srv)
// - indirectly via `tunneldns.NewUpstreamHTTPS()`
// But both of these cases miss out on a key feature from the stdlib:
// "The returned records are sorted by priority and randomized by weight within a priority."
// (https://golang.org/pkg/net/#Resolver.LookupSRV)
// Does this matter? I don't know. It may someday. Let's use DoT so we don't need to worry about it.
// See also: Go feature request for stdlib-supported DoH: https://github.com/golang/go/issues/27552
r := fallbackResolver(dotServerName, dotServerAddr)
ctx, cancel := context.WithTimeout(context.Background(), dotTimeout)
defer cancel()
_, fallbackAddrs, fallbackErr := r.LookupSRV(ctx, srvService, srvProto, srvName)
if fallbackErr != nil || len(fallbackAddrs) == 0 {
// use the original DNS error `err` in messages, not `fallbackErr`
logger.Errorln("Error looking up Cloudflare edge IPs: the DNS query failed:", err)
for _, s := range friendlyDNSErrorLines {
logger.Errorln(s)
}
return nil, errors.Wrap(err, "Could not lookup srv records on _warp._tcp.cloudflarewarp.com")
}
// Accept the fallback results and keep going
addrs = fallbackAddrs
}
var resolvedIPsPerCNAME [][]*net.TCPAddr
var lookupErr error
for _, addr := range addrs {
ips, err := ResolveSRVToTCP(addr)
if err != nil || len(ips) == 0 {
// don't return early, we might be able to resolve other addresses
lookupErr = err
continue
}
resolvedIPsPerCNAME = append(resolvedIPsPerCNAME, ips)
}
ips := FlattenServiceIPs(resolvedIPsPerCNAME)
if lookupErr == nil && len(ips) == 0 {
return nil, fmt.Errorf("Unknown service discovery error")
}
return ips, lookupErr
}
func ResolveSRVToTCP(srv *net.SRV) ([]*net.TCPAddr, error) {
ips, err := net.LookupIP(srv.Target)
if err != nil {
return nil, err
}
addrs := make([]*net.TCPAddr, len(ips))
for i, ip := range ips {
addrs[i] = &net.TCPAddr{IP: ip, Port: int(srv.Port)}
}
return addrs, nil
}
// FlattenServiceIPs transposes and flattens the input slices such that the
// first element of the n inner slices are the first n elements of the result.
func FlattenServiceIPs(ipsByService [][]*net.TCPAddr) []*net.TCPAddr {
var result []*net.TCPAddr
for len(ipsByService) > 0 {
filtered := ipsByService[:0]
for _, ips := range ipsByService {
if len(ips) == 0 {
// sanity check
continue
}
result = append(result, ips[0])
if len(ips) > 1 {
filtered = append(filtered, ips[1:])
}
}
ipsByService = filtered
}
return result
}
// Inspiration: https://github.com/artyom/dot/blob/master/dot.go
func fallbackResolver(serverName, serverAddress string) *net.Resolver {
return &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, _ string, _ string) (net.Conn, error) {
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", serverAddress)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{ServerName: serverName}
return tls.Client(conn, tlsConfig), nil
},
}
}

View File

@@ -0,0 +1,45 @@
package connection
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFlattenServiceIPs(t *testing.T) {
result := FlattenServiceIPs([][]*net.TCPAddr{
[]*net.TCPAddr{
&net.TCPAddr{Port: 1},
&net.TCPAddr{Port: 2},
&net.TCPAddr{Port: 3},
&net.TCPAddr{Port: 4},
},
[]*net.TCPAddr{
&net.TCPAddr{Port: 10},
&net.TCPAddr{Port: 12},
&net.TCPAddr{Port: 13},
},
[]*net.TCPAddr{
&net.TCPAddr{Port: 21},
&net.TCPAddr{Port: 22},
&net.TCPAddr{Port: 23},
&net.TCPAddr{Port: 24},
&net.TCPAddr{Port: 25},
},
})
assert.EqualValues(t, []*net.TCPAddr{
&net.TCPAddr{Port: 1},
&net.TCPAddr{Port: 10},
&net.TCPAddr{Port: 21},
&net.TCPAddr{Port: 2},
&net.TCPAddr{Port: 12},
&net.TCPAddr{Port: 22},
&net.TCPAddr{Port: 3},
&net.TCPAddr{Port: 13},
&net.TCPAddr{Port: 23},
&net.TCPAddr{Port: 4},
&net.TCPAddr{Port: 24},
&net.TCPAddr{Port: 25},
}, result)
}

145
connection/supervisor.go Normal file
View File

@@ -0,0 +1,145 @@
package connection
import (
"context"
"net"
"time"
tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
"github.com/google/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
const (
// Waiting time before retrying a failed tunnel connection
reconnectDuration = time.Second * 10
// SRV record resolution TTL
resolveTTL = time.Hour
// Interval between establishing new connection
connectionInterval = time.Second
)
type CloudflaredConfig struct {
ConnectionConfig *ConnectionConfig
OriginCert []byte
Tags []tunnelpogs.Tag
EdgeAddrs []string
HAConnections uint
Logger *logrus.Logger
}
// Supervisor is a stateful object that manages connections with the edge
type Supervisor struct {
config *CloudflaredConfig
state *supervisorState
connErrors chan error
}
type supervisorState struct {
// IPs to connect to cloudflare's edge network
edgeIPs []*net.TCPAddr
// index of the next element to use in edgeIPs
nextEdgeIPIndex int
// last time edgeIPs were refreshed
lastResolveTime time.Time
// ID of this cloudflared instance
cloudflaredID uuid.UUID
// connectionPool is a pool of connectionHandlers that can be used to make RPCs
connectionPool *connectionPool
}
func (s *supervisorState) getNextEdgeIP() *net.TCPAddr {
ip := s.edgeIPs[s.nextEdgeIPIndex%len(s.edgeIPs)]
s.nextEdgeIPIndex++
return ip
}
func NewSupervisor(config *CloudflaredConfig) *Supervisor {
return &Supervisor{
config: config,
state: &supervisorState{
connectionPool: &connectionPool{},
},
connErrors: make(chan error),
}
}
func (s *Supervisor) Run(ctx context.Context) error {
logger := s.config.Logger
if err := s.initialize(); err != nil {
logger.WithError(err).Error("Failed to get edge IPs")
return err
}
defer s.state.connectionPool.close()
var currentConnectionCount uint
expectedConnectionCount := s.config.HAConnections
if uint(len(s.state.edgeIPs)) < s.config.HAConnections {
logger.Warnf("You requested %d HA connections but I can give you at most %d.", s.config.HAConnections, len(s.state.edgeIPs))
expectedConnectionCount = uint(len(s.state.edgeIPs))
}
for {
select {
case <-ctx.Done():
return nil
case connErr := <-s.connErrors:
logger.WithError(connErr).Warnf("Connection dropped unexpectedly")
currentConnectionCount--
default:
time.Sleep(5 * time.Second)
}
if currentConnectionCount < expectedConnectionCount {
h, err := newH2MuxHandler(ctx, s.config.ConnectionConfig, s.state.getNextEdgeIP())
if err != nil {
logger.WithError(err).Error("Failed to create new connection handler")
continue
}
go func() {
s.connErrors <- h.serve(ctx)
}()
connResult, err := s.connect(ctx, s.config, s.state.cloudflaredID, h)
if err != nil {
logger.WithError(err).Errorf("Failed to connect to cloudflared's edge network")
h.shutdown()
continue
}
if connErr := connResult.Err; connErr != nil && !connErr.ShouldRetry {
logger.WithError(connErr).Errorf("Server respond with don't retry to connect")
h.shutdown()
return err
}
logger.Infof("Connected to %s", connResult.ServerInfo.LocationName)
s.state.connectionPool.put(h)
currentConnectionCount++
}
}
}
func (s *Supervisor) initialize() error {
edgeIPs, err := ResolveEdgeIPs(s.config.Logger, s.config.EdgeAddrs)
if err != nil {
return errors.Wrapf(err, "Failed to resolve cloudflare edge network address")
}
s.state.edgeIPs = edgeIPs
s.state.lastResolveTime = time.Now()
cloudflaredID, err := uuid.NewRandom()
if err != nil {
return errors.Wrap(err, "Failed to generate cloudflared ID")
}
s.state.cloudflaredID = cloudflaredID
return nil
}
func (s *Supervisor) connect(ctx context.Context,
config *CloudflaredConfig,
cloudflaredID uuid.UUID,
h connectionHandler,
) (*tunnelpogs.ConnectResult, error) {
connectParameters := &tunnelpogs.ConnectParameters{
OriginCert: config.OriginCert,
CloudflaredID: cloudflaredID,
NumPreviousAttempts: 0,
}
return h.connect(ctx, connectParameters)
}