TUN-2309: Split ConnectResult into ConnectError and ConnectSuccess, each implementing its own capnp serialization logic

This commit is contained in:
Chung-Ting Huang
2019-09-17 16:58:49 -05:00
parent 4f23da2a6d
commit 5bcb2da0fe
6 changed files with 564 additions and 315 deletions

View File

@@ -2,10 +2,12 @@ package pogs
import (
"context"
"fmt"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/google/uuid"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
capnp "zombiezen.com/go/capnproto2"
@@ -73,28 +75,129 @@ func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*Registratio
return p, err
}
type ConnectResult struct {
Err *ConnectError
ServerInfo ServerInfo
ClientConfig ClientConfig
// ConnectResult models the result of Connect RPC, implemented by ConnectError and ConnectSuccess.
type ConnectResult interface {
ConnectError() *ConnectError
ConnectedTo() string
ClientConfig() *ClientConfig
Marshal(s tunnelrpc.ConnectResult) error
}
func MarshalConnectResult(s tunnelrpc.ConnectResult, p *ConnectResult) error {
return pogs.Insert(tunnelrpc.ConnectResult_TypeID, s.Struct, p)
func MarshalConnectResult(s tunnelrpc.ConnectResult, p ConnectResult) error {
return p.Marshal(s)
}
func UnmarshalConnectResult(s tunnelrpc.ConnectResult) (*ConnectResult, error) {
p := new(ConnectResult)
err := pogs.Extract(p, tunnelrpc.ConnectResult_TypeID, s.Struct)
return p, err
func UnmarshalConnectResult(s tunnelrpc.ConnectResult) (ConnectResult, error) {
switch s.Result().Which() {
case tunnelrpc.ConnectResult_result_Which_err:
capnpConnectError, err := s.Result().Err()
if err != nil {
return nil, err
}
return UnmarshalConnectError(capnpConnectError)
case tunnelrpc.ConnectResult_result_Which_success:
capnpConnectSuccess, err := s.Result().Success()
if err != nil {
return nil, err
}
return UnmarshalConnectSuccess(capnpConnectSuccess)
default:
return nil, fmt.Errorf("Unmarshal %v not implemented yet", s.Result().Which().String())
}
}
// ConnectSuccess is the concrete returned type when Connect RPC succeed
type ConnectSuccess struct {
ServerLocationName string
Config *ClientConfig
}
func (*ConnectSuccess) ConnectError() *ConnectError {
return nil
}
func (cs *ConnectSuccess) ConnectedTo() string {
return cs.ServerLocationName
}
func (cs *ConnectSuccess) ClientConfig() *ClientConfig {
return cs.Config
}
func (cs *ConnectSuccess) Marshal(s tunnelrpc.ConnectResult) error {
capnpConnectSuccess, err := s.Result().NewSuccess()
if err != nil {
return err
}
err = capnpConnectSuccess.SetServerLocationName(cs.ServerLocationName)
if err != nil {
return errors.Wrap(err, "failed to set ConnectSuccess.ServerLocationName")
}
if cs.Config != nil {
capnpClientConfig, err := capnpConnectSuccess.NewClientConfig()
if err != nil {
return errors.Wrap(err, "failed to initialize ConnectSuccess.ClientConfig")
}
if err := MarshalClientConfig(capnpClientConfig, cs.Config); err != nil {
return errors.Wrap(err, "failed to marshal ClientConfig")
}
}
return nil
}
func UnmarshalConnectSuccess(s tunnelrpc.ConnectSuccess) (*ConnectSuccess, error) {
p := new(ConnectSuccess)
serverLocationName, err := s.ServerLocationName()
if err != nil {
return nil, errors.Wrap(err, "failed to get tunnelrpc.ConnectSuccess.ServerLocationName")
}
p.ServerLocationName = serverLocationName
if s.HasClientConfig() {
capnpClientConfig, err := s.ClientConfig()
if err != nil {
return nil, errors.Wrap(err, "failed to get tunnelrpc.ConnectSuccess.ClientConfig")
}
p.Config, err = UnmarshalClientConfig(capnpClientConfig)
if err != nil {
return nil, errors.Wrap(err, "failed to get unmarshal ClientConfig")
}
}
return p, nil
}
// ConnectError is the concrete returned type when Connect RPC encounters some error
type ConnectError struct {
Cause string
RetryAfter time.Duration
ShouldRetry bool
}
func (ce *ConnectError) ConnectError() *ConnectError {
return ce
}
func (*ConnectError) ConnectedTo() string {
return ""
}
func (*ConnectError) ClientConfig() *ClientConfig {
return nil
}
func (ce *ConnectError) Marshal(s tunnelrpc.ConnectResult) error {
capnpConnectError, err := s.Result().NewErr()
if err != nil {
return err
}
return MarshalConnectError(capnpConnectError, ce)
}
func MarshalConnectError(s tunnelrpc.ConnectError, p *ConnectError) error {
return pogs.Insert(tunnelrpc.ConnectError_TypeID, s.Struct, p)
}
@@ -223,7 +326,7 @@ type TunnelServer interface {
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
GetServerInfo(ctx context.Context) (*ServerInfo, error)
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
Connect(ctx context.Context, paramaters *ConnectParameters) (*ConnectResult, error)
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
}
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
@@ -284,11 +387,11 @@ func (i TunnelServer_PogsImpl) UnregisterTunnel(p tunnelrpc.TunnelServer_unregis
}
func (i TunnelServer_PogsImpl) Connect(p tunnelrpc.TunnelServer_connect) error {
paramaters, err := p.Params.Parameters()
parameters, err := p.Params.Parameters()
if err != nil {
return err
}
pogsParameters, err := UnmarshalConnectParameters(paramaters)
pogsParameters, err := UnmarshalConnectParameters(parameters)
if err != nil {
return err
}
@@ -301,7 +404,7 @@ func (i TunnelServer_PogsImpl) Connect(p tunnelrpc.TunnelServer_connect) error {
if err != nil {
return err
}
return MarshalConnectResult(result, connectResult)
return connectResult.Marshal(result)
}
type TunnelServer_PogsClient struct {
@@ -365,7 +468,7 @@ func (c TunnelServer_PogsClient) UnregisterTunnel(ctx context.Context, gracePeri
func (c TunnelServer_PogsClient) Connect(ctx context.Context,
parameters *ConnectParameters,
) (*ConnectResult, error) {
) (ConnectResult, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.Connect(ctx, func(p tunnelrpc.TunnelServer_connect_Params) error {
connectParameters, err := p.NewParameters()