TUN-3427: Define a struct that only implements RegistrationServer in tunnelpogs

This commit is contained in:
cthuang
2020-09-28 10:10:30 +01:00
parent 8e8513e325
commit 2c9b7361b7
9 changed files with 242 additions and 201 deletions

View File

@@ -7,7 +7,9 @@ import (
"time"
"github.com/google/uuid"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc"
@@ -18,6 +20,156 @@ type RegistrationServer interface {
UnregisterConnection(ctx context.Context)
}
type RegistrationServer_PogsImpl struct {
impl RegistrationServer
}
func RegistrationServer_ServerToClient(s RegistrationServer) tunnelrpc.RegistrationServer {
return tunnelrpc.RegistrationServer_ServerToClient(RegistrationServer_PogsImpl{s})
}
func (i RegistrationServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
server.Ack(p.Options)
auth, err := p.Params.Auth()
if err != nil {
return err
}
var pogsAuth TunnelAuth
err = pogsAuth.UnmarshalCapnproto(auth)
if err != nil {
return err
}
uuidBytes, err := p.Params.TunnelId()
if err != nil {
return err
}
tunnelID, err := uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
connIndex := p.Params.ConnIndex()
options, err := p.Params.Options()
if err != nil {
return err
}
var pogsOptions ConnectionOptions
err = pogsOptions.UnmarshalCapnproto(options)
if err != nil {
return err
}
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
resp, err := p.Results.NewResult()
if err != nil {
return err
}
if callError != nil {
if connError, err := resp.Result().NewError(); err != nil {
return err
} else {
return MarshalError(connError, callError)
}
}
if details, err := resp.Result().NewConnectionDetails(); err != nil {
return err
} else {
return connDetails.MarshalCapnproto(details)
}
}
func (i RegistrationServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
server.Ack(p.Options)
i.impl.UnregisterConnection(p.Ctx)
return nil
}
type RegistrationServer_PogsClient struct {
Client capnp.Client
Conn *rpc.Conn
}
func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
tunnelAuth, err := p.NewAuth()
if err != nil {
return err
}
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
return err
}
err = p.SetAuth(tunnelAuth)
if err != nil {
return err
}
err = p.SetTunnelId(tunnelID[:])
if err != nil {
return err
}
p.SetConnIndex(connIndex)
connectionOptions, err := p.NewOptions()
if err != nil {
return err
}
err = options.MarshalCapnproto(connectionOptions)
if err != nil {
return err
}
return nil
})
response, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
result := response.Result()
switch result.Which() {
case tunnelrpc.ConnectionResponse_result_Which_error:
resultError, err := result.Error()
if err != nil {
return nil, wrapRPCError(err)
}
cause, err := resultError.Cause()
if err != nil {
return nil, wrapRPCError(err)
}
err = errors.New(cause)
if resultError.ShouldRetry() {
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
}
return nil, err
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
connDetails, err := result.ConnectionDetails()
if err != nil {
return nil, wrapRPCError(err)
}
details := new(ConnectionDetails)
if err = details.UnmarshalCapnproto(connDetails); err != nil {
return nil, wrapRPCError(err)
}
return details, nil
}
return nil, newRPCError("unknown result which %d", result.Which())
}
func (c RegistrationServer_PogsClient) UnregisterConnection(ctx context.Context) error {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
return nil
})
_, err := promise.Struct()
if err != nil {
return wrapRPCError(err)
}
return nil
}
type ClientInfo struct {
ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
Features []string
@@ -98,140 +250,3 @@ func MarshalError(s tunnelrpc.ConnectionError, err error) error {
return nil
}
func (i TunnelServer_PogsImpl) RegisterConnection(p tunnelrpc.RegistrationServer_registerConnection) error {
server.Ack(p.Options)
auth, err := p.Params.Auth()
if err != nil {
return err
}
var pogsAuth TunnelAuth
err = pogsAuth.UnmarshalCapnproto(auth)
if err != nil {
return err
}
uuidBytes, err := p.Params.TunnelId()
if err != nil {
return err
}
tunnelID, err := uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
connIndex := p.Params.ConnIndex()
options, err := p.Params.Options()
if err != nil {
return err
}
var pogsOptions ConnectionOptions
err = pogsOptions.UnmarshalCapnproto(options)
if err != nil {
return err
}
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
resp, err := p.Results.NewResult()
if err != nil {
return err
}
if callError != nil {
if connError, err := resp.Result().NewError(); err != nil {
return err
} else {
return MarshalError(connError, callError)
}
}
if details, err := resp.Result().NewConnectionDetails(); err != nil {
return err
} else {
return connDetails.MarshalCapnproto(details)
}
}
func (i TunnelServer_PogsImpl) UnregisterConnection(p tunnelrpc.RegistrationServer_unregisterConnection) error {
server.Ack(p.Options)
i.impl.UnregisterConnection(p.Ctx)
return nil
}
func (c TunnelServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
tunnelAuth, err := p.NewAuth()
if err != nil {
return err
}
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
return err
}
err = p.SetAuth(tunnelAuth)
if err != nil {
return err
}
err = p.SetTunnelId(tunnelID[:])
if err != nil {
return err
}
p.SetConnIndex(connIndex)
connectionOptions, err := p.NewOptions()
if err != nil {
return err
}
err = options.MarshalCapnproto(connectionOptions)
if err != nil {
return err
}
return nil
})
response, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
result := response.Result()
switch result.Which() {
case tunnelrpc.ConnectionResponse_result_Which_error:
resultError, err := result.Error()
if err != nil {
return nil, wrapRPCError(err)
}
cause, err := resultError.Cause()
if err != nil {
return nil, wrapRPCError(err)
}
err = errors.New(cause)
if resultError.ShouldRetry() {
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
}
return nil, err
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
connDetails, err := result.ConnectionDetails()
if err != nil {
return nil, wrapRPCError(err)
}
details := new(ConnectionDetails)
if err = details.UnmarshalCapnproto(connDetails); err != nil {
return nil, wrapRPCError(err)
}
return details, nil
}
return nil, newRPCError("unknown result which %d", result.Which())
}
func (c TunnelServer_PogsClient) UnregisterConnection(ctx context.Context) error {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p tunnelrpc.RegistrationServer_unregisterConnection_Params) error {
return nil
})
_, err := promise.Struct()
if err != nil {
return wrapRPCError(err)
}
return nil
}

View File

@@ -62,6 +62,10 @@ func TestConnectionRegistrationRPC(t *testing.T) {
clientConn := rpc.NewConn(t2)
defer clientConn.Close()
client := TunnelServer_PogsClient{
RegistrationServer_PogsClient: RegistrationServer_PogsClient{
Client: clientConn.Bootstrap(ctx),
Conn: clientConn,
},
Client: clientConn.Bootstrap(ctx),
Conn: clientConn,
}

View File

@@ -210,10 +210,11 @@ type TunnelServer interface {
}
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{s})
return tunnelrpc.TunnelServer_ServerToClient(TunnelServer_PogsImpl{RegistrationServer_PogsImpl{s}, s})
}
type TunnelServer_PogsImpl struct {
RegistrationServer_PogsImpl
impl TunnelServer
}
@@ -268,6 +269,7 @@ func (i TunnelServer_PogsImpl) ObsoleteDeclarativeTunnelConnect(p tunnelrpc.Tunn
}
type TunnelServer_PogsClient struct {
RegistrationServer_PogsClient
Client capnp.Client
Conn *rpc.Conn
}