cloudflared/tunnelrpc/pogs/connectionrpc.go
Devin Carr eb2e4349e8 TUN-8415: Refactor capnp rpc into a single module
Combines the tunnelrpc and quic/schema capnp files into the same module.

To help reduce future issues with capnp id generation, capnpids are
provided in the capnp files from the existing capnp struct ids generated
in the go files.

Reduces the overall interface of the Capnp methods to the rest of
the code by providing an interface that will handle the quic protocol
selection.

Introduces a new `rpc-timeout` config that will allow all of the
SessionManager and ConfigurationManager RPC requests to have a timeout.
The timeout for these values is set to 5 seconds as non of these operations
for the managers should take a long time to complete.

Removed the RPC-specific logger as it never provided good debugging value
as the RPC method names were not visible in the logs.
2024-05-17 11:22:07 -07:00

291 lines
7.1 KiB
Go

package pogs
import (
"context"
"errors"
"net"
"time"
"github.com/google/uuid"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
type RegistrationServer interface {
RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error)
UnregisterConnection(ctx context.Context)
UpdateLocalConfiguration(ctx context.Context, config []byte) error
}
type RegistrationServer_PogsImpl struct {
impl RegistrationServer
}
func RegistrationServer_ServerToClient(s RegistrationServer) proto.RegistrationServer {
return proto.RegistrationServer_ServerToClient(RegistrationServer_PogsImpl{s})
}
func (i RegistrationServer_PogsImpl) RegisterConnection(p proto.RegistrationServer_registerConnection) error {
server.Ack(p.Options)
auth, err := p.Params.Auth()
if err != nil {
return err
}
var pogsAuth TunnelAuth
err = pogsAuth.UnmarshalCapnproto(auth)
if err != nil {
return err
}
uuidBytes, err := p.Params.TunnelId()
if err != nil {
return err
}
tunnelID, err := uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
connIndex := p.Params.ConnIndex()
options, err := p.Params.Options()
if err != nil {
return err
}
var pogsOptions ConnectionOptions
err = pogsOptions.UnmarshalCapnproto(options)
if err != nil {
return err
}
connDetails, callError := i.impl.RegisterConnection(p.Ctx, pogsAuth, tunnelID, connIndex, &pogsOptions)
resp, err := p.Results.NewResult()
if err != nil {
return err
}
if callError != nil {
if connError, err := resp.Result().NewError(); err != nil {
return err
} else {
return MarshalError(connError, callError)
}
}
if details, err := resp.Result().NewConnectionDetails(); err != nil {
return err
} else {
return connDetails.MarshalCapnproto(details)
}
}
func (i RegistrationServer_PogsImpl) UnregisterConnection(p proto.RegistrationServer_unregisterConnection) error {
server.Ack(p.Options)
i.impl.UnregisterConnection(p.Ctx)
return nil
}
func (i RegistrationServer_PogsImpl) UpdateLocalConfiguration(c proto.RegistrationServer_updateLocalConfiguration) error {
server.Ack(c.Options)
configBytes, err := c.Params.Config()
if err != nil {
return err
}
return i.impl.UpdateLocalConfiguration(c.Ctx, configBytes)
}
type RegistrationServer_PogsClient struct {
Client capnp.Client
Conn *rpc.Conn
}
func (c RegistrationServer_PogsClient) Close() error {
c.Client.Close()
return c.Conn.Close()
}
func (c RegistrationServer_PogsClient) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
client := proto.TunnelServer{Client: c.Client}
promise := client.RegisterConnection(ctx, func(p proto.RegistrationServer_registerConnection_Params) error {
tunnelAuth, err := p.NewAuth()
if err != nil {
return err
}
if err = auth.MarshalCapnproto(tunnelAuth); err != nil {
return err
}
err = p.SetAuth(tunnelAuth)
if err != nil {
return err
}
err = p.SetTunnelId(tunnelID[:])
if err != nil {
return err
}
p.SetConnIndex(connIndex)
connectionOptions, err := p.NewOptions()
if err != nil {
return err
}
err = options.MarshalCapnproto(connectionOptions)
if err != nil {
return err
}
return nil
})
response, err := promise.Result().Struct()
if err != nil {
return nil, wrapRPCError(err)
}
result := response.Result()
switch result.Which() {
case proto.ConnectionResponse_result_Which_error:
resultError, err := result.Error()
if err != nil {
return nil, wrapRPCError(err)
}
cause, err := resultError.Cause()
if err != nil {
return nil, wrapRPCError(err)
}
err = errors.New(cause)
if resultError.ShouldRetry() {
err = RetryErrorAfter(err, time.Duration(resultError.RetryAfter()))
}
return nil, err
case proto.ConnectionResponse_result_Which_connectionDetails:
connDetails, err := result.ConnectionDetails()
if err != nil {
return nil, wrapRPCError(err)
}
details := new(ConnectionDetails)
if err = details.UnmarshalCapnproto(connDetails); err != nil {
return nil, wrapRPCError(err)
}
return details, nil
}
return nil, newRPCError("unknown result which %d", result.Which())
}
func (c RegistrationServer_PogsClient) SendLocalConfiguration(ctx context.Context, config []byte) error {
client := proto.TunnelServer{Client: c.Client}
promise := client.UpdateLocalConfiguration(ctx, func(p proto.RegistrationServer_updateLocalConfiguration_Params) error {
if err := p.SetConfig(config); err != nil {
return err
}
return nil
})
_, err := promise.Struct()
if err != nil {
return wrapRPCError(err)
}
return nil
}
func (c RegistrationServer_PogsClient) UnregisterConnection(ctx context.Context) error {
client := proto.TunnelServer{Client: c.Client}
promise := client.UnregisterConnection(ctx, func(p proto.RegistrationServer_unregisterConnection_Params) error {
return nil
})
_, err := promise.Struct()
if err != nil {
return wrapRPCError(err)
}
return nil
}
type ClientInfo struct {
ClientID []byte `capnp:"clientId"` // must be a slice for capnp compatibility
Features []string
Version string
Arch string
}
type ConnectionOptions struct {
Client ClientInfo
OriginLocalIP net.IP `capnp:"originLocalIp"`
ReplaceExisting bool
CompressionQuality uint8
NumPreviousAttempts uint8
}
type TunnelAuth struct {
AccountTag string
TunnelSecret []byte
}
func (p *ConnectionOptions) MarshalCapnproto(s proto.ConnectionOptions) error {
return pogs.Insert(proto.ConnectionOptions_TypeID, s.Struct, p)
}
func (p *ConnectionOptions) UnmarshalCapnproto(s proto.ConnectionOptions) error {
return pogs.Extract(p, proto.ConnectionOptions_TypeID, s.Struct)
}
func (a *TunnelAuth) MarshalCapnproto(s proto.TunnelAuth) error {
return pogs.Insert(proto.TunnelAuth_TypeID, s.Struct, a)
}
func (a *TunnelAuth) UnmarshalCapnproto(s proto.TunnelAuth) error {
return pogs.Extract(a, proto.TunnelAuth_TypeID, s.Struct)
}
type ConnectionDetails struct {
UUID uuid.UUID
Location string
TunnelIsRemotelyManaged bool
}
func (details *ConnectionDetails) MarshalCapnproto(s proto.ConnectionDetails) error {
if err := s.SetUuid(details.UUID[:]); err != nil {
return err
}
if err := s.SetLocationName(details.Location); err != nil {
return err
}
s.SetTunnelIsRemotelyManaged(details.TunnelIsRemotelyManaged)
return nil
}
func (details *ConnectionDetails) UnmarshalCapnproto(s proto.ConnectionDetails) error {
uuidBytes, err := s.Uuid()
if err != nil {
return err
}
details.UUID, err = uuid.FromBytes(uuidBytes)
if err != nil {
return err
}
details.Location, err = s.LocationName()
if err != nil {
return err
}
details.TunnelIsRemotelyManaged = s.TunnelIsRemotelyManaged()
return err
}
func MarshalError(s proto.ConnectionError, err error) error {
if err := s.SetCause(err.Error()); err != nil {
return err
}
if retryableErr, ok := err.(*RetryableError); ok {
s.SetShouldRetry(true)
s.SetRetryAfter(int64(retryableErr.Delay))
}
return nil
}