mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 00:59:58 +00:00
TUN-3427: Define a struct that only implements RegistrationServer in tunnelpogs
This commit is contained in:
@@ -7,7 +7,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/rpc"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
@@ -18,6 +20,156 @@ type RegistrationServer interface {
|
||||
UnregisterConnection(ctx context.Context)
|
||||
}
|
||||
|
||||
type RegistrationServer_PogsImpl struct {
|
||||
impl RegistrationServer
|
||||
}
|
||||
|
||||
func RegistrationServer_ServerToClient(s RegistrationServer) tunnelrpc.RegistrationServer {
|
||||
return tunnelrpc.RegistrationServer_ServerToClient(RegistrationServer_PogsImpl{s})
|
||||
}
|
||||
|
||||
func (i RegistrationServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
|
||||
server.Ack(p.Options)
|
||||
|
||||
auth, err := p.Params.Auth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pogsAuth TunnelAuth
|
||||
err = pogsAuth.UnmarshalCapnproto(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uuidBytes, err := p.Params.TunnelId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnelID, err := uuid.FromBytes(uuidBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
connIndex := p.Params.ConnIndex()
|
||||
options, err := p.Params.Options()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pogsOptions ConnectionOptions
|
||||
err = pogsOptions.UnmarshalCapnproto(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
|
||||
|
||||
resp, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if callError != nil {
|
||||
if connError, err := resp.Result().NewError(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return MarshalError(connError, callError)
|
||||
}
|
||||
}
|
||||
|
||||
if details, err := resp.Result().NewConnectionDetails(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return connDetails.MarshalCapnproto(details)
|
||||
}
|
||||
}
|
||||
|
||||
func (i RegistrationServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
|
||||
server.Ack(p.Options)
|
||||
|
||||
i.impl.UnregisterConnection(p.Ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
type RegistrationServer_PogsClient struct {
|
||||
Client capnp.Client
|
||||
Conn *rpc.Conn
|
||||
}
|
||||
|
||||
func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
|
||||
tunnelAuth, err := p.NewAuth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetAuth(tunnelAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetTunnelId(tunnelID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetConnIndex(connIndex)
|
||||
connectionOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = options.MarshalCapnproto(connectionOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
response, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
result := response.Result()
|
||||
switch result.Which() {
|
||||
case tunnelrpc.ConnectionResponse_result_Which_error:
|
||||
resultError, err := result.Error()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
cause, err := resultError.Cause()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
err = errors.New(cause)
|
||||
if resultError.ShouldRetry() {
|
||||
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
|
||||
}
|
||||
return nil, err
|
||||
|
||||
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
|
||||
connDetails, err := result.ConnectionDetails()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
details := new(ConnectionDetails)
|
||||
if err = details.UnmarshalCapnproto(connDetails); err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
return details, nil
|
||||
}
|
||||
|
||||
return nil, newRPCError("unknown result which %d", result.Which())
|
||||
}
|
||||
|
||||
func (c RegistrationServer_PogsClient) UnregisterConnection(ctx context.Context) error {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
|
||||
return nil
|
||||
})
|
||||
_, err := promise.Struct()
|
||||
if err != nil {
|
||||
return wrapRPCError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type ClientInfo struct {
|
||||
ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
|
||||
Features []string
|
||||
@@ -98,140 +250,3 @@ func MarshalError(s tunnelrpc.ConnectionError, err error) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
|
||||
server.Ack(p.Options)
|
||||
|
||||
auth, err := p.Params.Auth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pogsAuth TunnelAuth
|
||||
err = pogsAuth.UnmarshalCapnproto(auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uuidBytes, err := p.Params.TunnelId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnelID, err := uuid.FromBytes(uuidBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
connIndex := p.Params.ConnIndex()
|
||||
options, err := p.Params.Options()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var pogsOptions ConnectionOptions
|
||||
err = pogsOptions.UnmarshalCapnproto(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
|
||||
|
||||
resp, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if callError != nil {
|
||||
if connError, err := resp.Result().NewError(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return MarshalError(connError, callError)
|
||||
}
|
||||
}
|
||||
|
||||
if details, err := resp.Result().NewConnectionDetails(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
return connDetails.MarshalCapnproto(details)
|
||||
}
|
||||
}
|
||||
|
||||
func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
|
||||
server.Ack(p.Options)
|
||||
|
||||
i.impl.UnregisterConnection(p.Ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
|
||||
tunnelAuth, err := p.NewAuth()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetAuth(tunnelAuth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetTunnelId(tunnelID[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.SetConnIndex(connIndex)
|
||||
connectionOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = options.MarshalCapnproto(connectionOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
response, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
result := response.Result()
|
||||
switch result.Which() {
|
||||
case tunnelrpc.ConnectionResponse_result_Which_error:
|
||||
resultError, err := result.Error()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
cause, err := resultError.Cause()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
err = errors.New(cause)
|
||||
if resultError.ShouldRetry() {
|
||||
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
|
||||
}
|
||||
return nil, err
|
||||
|
||||
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
|
||||
connDetails, err := result.ConnectionDetails()
|
||||
if err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
details := new(ConnectionDetails)
|
||||
if err = details.UnmarshalCapnproto(connDetails); err != nil {
|
||||
return nil, wrapRPCError(err)
|
||||
}
|
||||
return details, nil
|
||||
}
|
||||
|
||||
return nil, newRPCError("unknown result which %d", result.Which())
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
|
||||
return nil
|
||||
})
|
||||
_, err := promise.Struct()
|
||||
if err != nil {
|
||||
return wrapRPCError(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -62,6 +62,10 @@ func TestConnectionRegistrationRPC(t *testing.T) {
|
||||
clientConn := rpc.NewConn(t2)
|
||||
defer clientConn.Close()
|
||||
client := TunnelServer_PogsClient{
|
||||
RegistrationServer_PogsClient: RegistrationServer_PogsClient{
|
||||
Client: clientConn.Bootstrap(ctx),
|
||||
Conn: clientConn,
|
||||
},
|
||||
Client: clientConn.Bootstrap(ctx),
|
||||
Conn: clientConn,
|
||||
}
|
||||
|
@@ -210,10 +210,11 @@ type TunnelServer interface {
|
||||
}
|
||||
|
||||
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
||||
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s})
|
||||
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{RegistrationServer_PogsImpl{s}, s})
|
||||
}
|
||||
|
||||
type TunnelServer_PogsImpl struct {
|
||||
RegistrationServer_PogsImpl
|
||||
impl TunnelServer
|
||||
}
|
||||
|
||||
@@ -268,6 +269,7 @@ func (i TunnelServer_PogsImpl) ObsoleteDeclarativeTunnelConnect(p tunnelrpc.Tunn
|
||||
}
|
||||
|
||||
type TunnelServer_PogsClient struct {
|
||||
RegistrationServer_PogsClient
|
||||
Client capnp.Client
|
||||
Conn *rpc.Conn
|
||||
}
|
||||
|
Reference in New Issue
Block a user