mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-2309: Split ConnectResult into ConnectError and ConnectSuccess, each implementing its own capnp serialization logic
This commit is contained in:
@@ -2,10 +2,12 @@ package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
@@ -73,28 +75,129 @@ func UnmarshalRegistrationOptions(s tunnelrpc.RegistrationOptions) (*Registratio
|
||||
return p, err
|
||||
}
|
||||
|
||||
type ConnectResult struct {
|
||||
Err *ConnectError
|
||||
ServerInfo ServerInfo
|
||||
ClientConfig ClientConfig
|
||||
// ConnectResult models the result of Connect RPC, implemented by ConnectError and ConnectSuccess.
|
||||
type ConnectResult interface {
|
||||
ConnectError() *ConnectError
|
||||
ConnectedTo() string
|
||||
ClientConfig() *ClientConfig
|
||||
Marshal(s tunnelrpc.ConnectResult) error
|
||||
}
|
||||
|
||||
func MarshalConnectResult(s tunnelrpc.ConnectResult, p *ConnectResult) error {
|
||||
return pogs.Insert(tunnelrpc.ConnectResult_TypeID, s.Struct, p)
|
||||
func MarshalConnectResult(s tunnelrpc.ConnectResult, p ConnectResult) error {
|
||||
return p.Marshal(s)
|
||||
}
|
||||
|
||||
func UnmarshalConnectResult(s tunnelrpc.ConnectResult) (*ConnectResult, error) {
|
||||
p := new(ConnectResult)
|
||||
err := pogs.Extract(p, tunnelrpc.ConnectResult_TypeID, s.Struct)
|
||||
return p, err
|
||||
func UnmarshalConnectResult(s tunnelrpc.ConnectResult) (ConnectResult, error) {
|
||||
switch s.Result().Which() {
|
||||
case tunnelrpc.ConnectResult_result_Which_err:
|
||||
capnpConnectError, err := s.Result().Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return UnmarshalConnectError(capnpConnectError)
|
||||
case tunnelrpc.ConnectResult_result_Which_success:
|
||||
capnpConnectSuccess, err := s.Result().Success()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return UnmarshalConnectSuccess(capnpConnectSuccess)
|
||||
default:
|
||||
return nil, fmt.Errorf("Unmarshal %v not implemented yet", s.Result().Which().String())
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectSuccess is the concrete returned type when Connect RPC succeed
|
||||
type ConnectSuccess struct {
|
||||
ServerLocationName string
|
||||
Config *ClientConfig
|
||||
}
|
||||
|
||||
func (*ConnectSuccess) ConnectError() *ConnectError {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cs *ConnectSuccess) ConnectedTo() string {
|
||||
return cs.ServerLocationName
|
||||
}
|
||||
|
||||
func (cs *ConnectSuccess) ClientConfig() *ClientConfig {
|
||||
return cs.Config
|
||||
}
|
||||
|
||||
func (cs *ConnectSuccess) Marshal(s tunnelrpc.ConnectResult) error {
|
||||
capnpConnectSuccess, err := s.Result().NewSuccess()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = capnpConnectSuccess.SetServerLocationName(cs.ServerLocationName)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set ConnectSuccess.ServerLocationName")
|
||||
}
|
||||
|
||||
if cs.Config != nil {
|
||||
capnpClientConfig, err := capnpConnectSuccess.NewClientConfig()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to initialize ConnectSuccess.ClientConfig")
|
||||
}
|
||||
if err := MarshalClientConfig(capnpClientConfig, cs.Config); err != nil {
|
||||
return errors.Wrap(err, "failed to marshal ClientConfig")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UnmarshalConnectSuccess(s tunnelrpc.ConnectSuccess) (*ConnectSuccess, error) {
|
||||
p := new(ConnectSuccess)
|
||||
|
||||
serverLocationName, err := s.ServerLocationName()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get tunnelrpc.ConnectSuccess.ServerLocationName")
|
||||
}
|
||||
p.ServerLocationName = serverLocationName
|
||||
|
||||
if s.HasClientConfig() {
|
||||
capnpClientConfig, err := s.ClientConfig()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get tunnelrpc.ConnectSuccess.ClientConfig")
|
||||
}
|
||||
p.Config, err = UnmarshalClientConfig(capnpClientConfig)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get unmarshal ClientConfig")
|
||||
}
|
||||
}
|
||||
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// ConnectError is the concrete returned type when Connect RPC encounters some error
|
||||
type ConnectError struct {
|
||||
Cause string
|
||||
RetryAfter time.Duration
|
||||
ShouldRetry bool
|
||||
}
|
||||
|
||||
func (ce *ConnectError) ConnectError() *ConnectError {
|
||||
return ce
|
||||
}
|
||||
|
||||
func (*ConnectError) ConnectedTo() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*ConnectError) ClientConfig() *ClientConfig {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ce *ConnectError) Marshal(s tunnelrpc.ConnectResult) error {
|
||||
capnpConnectError, err := s.Result().NewErr()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalConnectError(capnpConnectError, ce)
|
||||
}
|
||||
|
||||
func MarshalConnectError(s tunnelrpc.ConnectError, p *ConnectError) error {
|
||||
return pogs.Insert(tunnelrpc.ConnectError_TypeID, s.Struct, p)
|
||||
}
|
||||
@@ -223,7 +326,7 @@ type TunnelServer interface {
|
||||
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
|
||||
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
||||
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
|
||||
Connect(ctx context.Context, paramaters *ConnectParameters) (*ConnectResult, error)
|
||||
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
|
||||
}
|
||||
|
||||
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
||||
@@ -284,11 +387,11 @@ func (i TunnelServer_PogsImpl) UnregisterTunnel(p tunnelrpc.TunnelServer_unregis
|
||||
}
|
||||
|
||||
func (i TunnelServer_PogsImpl) Connect(p tunnelrpc.TunnelServer_connect) error {
|
||||
paramaters, err := p.Params.Parameters()
|
||||
parameters, err := p.Params.Parameters()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pogsParameters, err := UnmarshalConnectParameters(paramaters)
|
||||
pogsParameters, err := UnmarshalConnectParameters(parameters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -301,7 +404,7 @@ func (i TunnelServer_PogsImpl) Connect(p tunnelrpc.TunnelServer_connect) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalConnectResult(result, connectResult)
|
||||
return connectResult.Marshal(result)
|
||||
}
|
||||
|
||||
type TunnelServer_PogsClient struct {
|
||||
@@ -365,7 +468,7 @@ func (c TunnelServer_PogsClient) UnregisterTunnel(ctx context.Context, gracePeri
|
||||
|
||||
func (c TunnelServer_PogsClient) Connect(ctx context.Context,
|
||||
parameters *ConnectParameters,
|
||||
) (*ConnectResult, error) {
|
||||
) (ConnectResult, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.Connect(ctx, func(p tunnelrpc.TunnelServer_connect_Params) error {
|
||||
connectParameters, err := p.NewParameters()
|
||||
|
@@ -11,21 +11,21 @@ import (
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
func sampleTestConnectResult() *ConnectResult {
|
||||
return &ConnectResult{
|
||||
Err: &ConnectError{
|
||||
func TestConnectResult(t *testing.T) {
|
||||
testCases := []ConnectResult{
|
||||
&ConnectError{
|
||||
Cause: "it broke",
|
||||
ShouldRetry: false,
|
||||
RetryAfter: 2 * time.Second,
|
||||
},
|
||||
ServerInfo: ServerInfo{LocationName: "computer"},
|
||||
ClientConfig: *sampleClientConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectResult(t *testing.T) {
|
||||
testCases := []*ConnectResult{
|
||||
sampleTestConnectResult(),
|
||||
&ConnectSuccess{
|
||||
ServerLocationName: "SFO",
|
||||
Config: sampleClientConfig(),
|
||||
},
|
||||
&ConnectSuccess{
|
||||
ServerLocationName: "",
|
||||
Config: nil,
|
||||
},
|
||||
}
|
||||
for i, testCase := range testCases {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
|
Reference in New Issue
Block a user