mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:19:57 +00:00
TUN-3427: Define a struct that only implements RegistrationServer in tunnelpogs
This commit is contained in:
@@ -7,7 +7,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/connection"
|
||||
"github.com/cloudflare/cloudflared/h2mux"
|
||||
"github.com/cloudflare/cloudflared/logger"
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
@@ -164,18 +163,17 @@ func ReconnectTunnel(
|
||||
}
|
||||
|
||||
config.TransportLogger.Debug("initiating RPC stream to reconnect")
|
||||
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, reconnect)
|
||||
return err
|
||||
}
|
||||
defer tunnelServer.Close()
|
||||
defer rpcClient.Close()
|
||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
|
||||
registration := tunnelServer.ReconnectTunnel(
|
||||
registration := rpcClient.ReconnectTunnel(
|
||||
ctx,
|
||||
token,
|
||||
eventDigest,
|
||||
|
@@ -323,16 +323,16 @@ func (s *Supervisor) authenticate(ctx context.Context, numPreviousAttempts int)
|
||||
<-muxer.Shutdown()
|
||||
}()
|
||||
|
||||
tunnelServer, err := connection.NewRPCClient(ctx, muxer, s.logger, openStreamTimeout)
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, s.config, authenticate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer tunnelServer.Close()
|
||||
defer rpcClient.Close()
|
||||
|
||||
const arbitraryConnectionID = uint8(0)
|
||||
registrationOptions := s.config.RegistrationOptions(arbitraryConnectionID, edgeConn.LocalAddr().String(), s.cloudflaredUUID)
|
||||
registrationOptions.NumPreviousAttempts = uint8(numPreviousAttempts)
|
||||
authResponse, err := tunnelServer.Authenticate(
|
||||
authResponse, err := rpcClient.Authenticate(
|
||||
ctx,
|
||||
s.config.OriginCert,
|
||||
s.config.Hostname,
|
||||
|
@@ -44,11 +44,13 @@ const (
|
||||
FeatureQuickReconnects = "quick_reconnects"
|
||||
)
|
||||
|
||||
type registerRPCName string
|
||||
type rpcName string
|
||||
|
||||
const (
|
||||
register registerRPCName = "register"
|
||||
reconnect registerRPCName = "reconnect"
|
||||
register rpcName = "register"
|
||||
reconnect rpcName = "reconnect"
|
||||
unregister rpcName = "unregister"
|
||||
authenticate rpcName = " authenticate"
|
||||
)
|
||||
|
||||
type TunnelConfig struct {
|
||||
@@ -121,7 +123,7 @@ type clientRegisterTunnelError struct {
|
||||
cause error
|
||||
}
|
||||
|
||||
func newClientRegisterTunnelError(cause error, counter *prometheus.CounterVec, name registerRPCName) clientRegisterTunnelError {
|
||||
func newRPCError(cause error, counter *prometheus.CounterVec, name rpcName) clientRegisterTunnelError {
|
||||
counter.WithLabelValues(cause.Error(), string(name)).Inc()
|
||||
return clientRegisterTunnelError{cause: cause}
|
||||
}
|
||||
@@ -337,7 +339,7 @@ func ServeTunnel(
|
||||
if config.NamedTunnel != nil {
|
||||
_ = UnregisterConnection(ctx, handler.muxer, config)
|
||||
} else {
|
||||
_ = UnregisterTunnel(handler.muxer, config.GracePeriod, config.TransportLogger)
|
||||
_ = UnregisterTunnel(handler.muxer, config)
|
||||
}
|
||||
}
|
||||
handler.muxer.Shutdown()
|
||||
@@ -417,14 +419,13 @@ func RegisterConnection(
|
||||
const registerConnection = "registerConnection"
|
||||
|
||||
config.TransportLogger.Debug("initiating RPC stream for RegisterConnection")
|
||||
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, registerConnection)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, registerConnection)
|
||||
return err
|
||||
}
|
||||
defer rpc.Close()
|
||||
defer rpcClient.Close()
|
||||
|
||||
conn, err := rpc.RegisterConnection(
|
||||
conn, err := rpcClient.RegisterConnection(
|
||||
ctx,
|
||||
config.NamedTunnel.Auth,
|
||||
config.NamedTunnel.ID,
|
||||
@@ -470,14 +471,14 @@ func UnregisterConnection(
|
||||
config *TunnelConfig,
|
||||
) error {
|
||||
config.TransportLogger.Debug("initiating RPC stream for UnregisterConnection")
|
||||
rpc, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
|
||||
return err
|
||||
}
|
||||
defer rpc.Close()
|
||||
defer rpcClient.Close()
|
||||
|
||||
return rpc.UnregisterConnection(ctx)
|
||||
return rpcClient.UnregisterConnection(ctx)
|
||||
}
|
||||
|
||||
func RegisterTunnel(
|
||||
@@ -494,18 +495,18 @@ func RegisterTunnel(
|
||||
if config.TunnelEventChan != nil {
|
||||
config.TunnelEventChan <- ui.TunnelEvent{EventType: ui.RegisteringTunnel}
|
||||
}
|
||||
tunnelServer, err := connection.NewRPCClient(ctx, muxer, config.TransportLogger, openStreamTimeout)
|
||||
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, register)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return newClientRegisterTunnelError(err, config.Metrics.rpcFail, register)
|
||||
return err
|
||||
}
|
||||
defer tunnelServer.Close()
|
||||
defer rpcClient.Close()
|
||||
// Request server info without blocking tunnel registration; must use capnp library directly.
|
||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: tunnelServer.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
|
||||
return nil
|
||||
})
|
||||
LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
|
||||
registration := tunnelServer.RegisterTunnel(
|
||||
registration := rpcClient.RegisterTunnel(
|
||||
ctx,
|
||||
config.OriginCert,
|
||||
config.Hostname,
|
||||
@@ -529,7 +530,7 @@ func processRegistrationSuccess(
|
||||
logger logger.Service,
|
||||
connectionID uint8,
|
||||
registration *tunnelpogs.TunnelRegistration,
|
||||
name registerRPCName,
|
||||
name rpcName,
|
||||
credentialManager *reconnectCredentialManager,
|
||||
) error {
|
||||
for _, logLine := range registration.LogLines {
|
||||
@@ -563,7 +564,7 @@ func processRegistrationSuccess(
|
||||
return nil
|
||||
}
|
||||
|
||||
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name registerRPCName) error {
|
||||
func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics *TunnelMetrics, name rpcName) error {
|
||||
if err.Error() == DuplicateConnectionError {
|
||||
metrics.regFail.WithLabelValues("dup_edge_conn", string(name)).Inc()
|
||||
return errDuplicationConnection
|
||||
@@ -575,18 +576,18 @@ func processRegisterTunnelError(err tunnelpogs.TunnelRegistrationError, metrics
|
||||
}
|
||||
}
|
||||
|
||||
func UnregisterTunnel(muxer *h2mux.Muxer, gracePeriod time.Duration, logger logger.Service) error {
|
||||
logger.Debug("initiating RPC stream to unregister")
|
||||
func UnregisterTunnel(muxer *h2mux.Muxer, config *TunnelConfig) error {
|
||||
config.TransportLogger.Debug("initiating RPC stream to unregister")
|
||||
ctx := context.Background()
|
||||
tunnelServer, err := connection.NewRPCClient(ctx, muxer, logger, openStreamTimeout)
|
||||
rpcClient, err := newTunnelRPCClient(ctx, muxer, config, unregister)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return err
|
||||
}
|
||||
defer tunnelServer.Close()
|
||||
defer rpcClient.Close()
|
||||
|
||||
// gracePeriod is encoded in int64 using capnproto
|
||||
return tunnelServer.UnregisterTunnel(ctx, gracePeriod.Nanoseconds())
|
||||
return rpcClient.UnregisterTunnel(ctx, config.GracePeriod.Nanoseconds())
|
||||
}
|
||||
|
||||
func LogServerInfo(
|
||||
@@ -909,3 +910,18 @@ func findCfRayHeader(h1 *http.Request) string {
|
||||
func isLBProbeRequest(req *http.Request) bool {
|
||||
return strings.HasPrefix(req.UserAgent(), lbProbeUserAgentPrefix)
|
||||
}
|
||||
|
||||
func newTunnelRPCClient(ctx context.Context, muxer *h2mux.Muxer, config *TunnelConfig, rpcName rpcName) (tunnelpogs.TunnelServer_PogsClient, error) {
|
||||
openStreamCtx, openStreamCancel := context.WithTimeout(ctx, openStreamTimeout)
|
||||
defer openStreamCancel()
|
||||
stream, err := muxer.OpenRPCStream(openStreamCtx)
|
||||
if err != nil {
|
||||
return tunnelpogs.TunnelServer_PogsClient{}, err
|
||||
}
|
||||
rpcClient, err := connection.NewTunnelRPCClient(ctx, stream, config.TransportLogger)
|
||||
if err != nil {
|
||||
// RPC stream open error
|
||||
return tunnelpogs.TunnelServer_PogsClient{}, newRPCError(err, config.Metrics.rpcFail, rpcName)
|
||||
}
|
||||
return rpcClient, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user