TUN-8423: Deprecate older legacy tunnel capnp interfaces

Since legacy tunnels have been removed for a while now, we can remove
many of the capnp rpc interfaces that are no longer leveraged by the
legacy tunnel registration and authentication mechanisms.
This commit is contained in:
Devin Carr
2024-05-20 16:09:25 -07:00
parent e9f010111d
commit 43446bc692
25 changed files with 1468 additions and 2368 deletions

View File

@@ -1,131 +0,0 @@
package pogs
import (
"errors"
"time"
)
// AuthenticateResponse is the serialized response from the Authenticate RPC.
// It's a 1:1 representation of the capnp message, so it's not very useful for programmers.
// Instead, you should call the `Outcome()` method to get a programmer-friendly sum type, with one
// case for each possible outcome.
type AuthenticateResponse struct {
PermanentErr string
RetryableErr string
Jwt []byte
HoursUntilRefresh uint8
}
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
func (ar AuthenticateResponse) Outcome() AuthOutcome {
// If the user's authentication was unsuccessful, the server will return an error explaining why.
// cloudflared should fatal with this error.
if ar.PermanentErr != "" {
return NewAuthFail(errors.New(ar.PermanentErr))
}
// If there was a network error, then cloudflared should retry later,
// because origintunneld couldn't prove whether auth was correct or not.
if ar.RetryableErr != "" {
return NewAuthUnknown(errors.New(ar.RetryableErr), ar.HoursUntilRefresh)
}
// If auth succeeded, return the token and refresh it when instructed.
if len(ar.Jwt) > 0 {
return NewAuthSuccess(ar.Jwt, ar.HoursUntilRefresh)
}
// Otherwise the state got messed up.
return nil
}
// AuthOutcome is a programmer-friendly sum type denoting the possible outcomes of Authenticate.
type AuthOutcome interface {
isAuthOutcome()
// Serialize into an AuthenticateResponse which can be sent via Capnp
Serialize() AuthenticateResponse
}
// AuthSuccess means the backend successfully authenticated this cloudflared.
type AuthSuccess struct {
jwt []byte
hoursUntilRefresh uint8
}
func NewAuthSuccess(jwt []byte, hoursUntilRefresh uint8) AuthSuccess {
return AuthSuccess{jwt: jwt, hoursUntilRefresh: hoursUntilRefresh}
}
func (ao AuthSuccess) JWT() []byte {
return ao.jwt
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao AuthSuccess) RefreshAfter() time.Duration {
return hoursToTime(ao.hoursUntilRefresh)
}
// Serialize into an AuthenticateResponse which can be sent via Capnp
func (ao AuthSuccess) Serialize() AuthenticateResponse {
return AuthenticateResponse{
Jwt: ao.jwt,
HoursUntilRefresh: ao.hoursUntilRefresh,
}
}
func (ao AuthSuccess) isAuthOutcome() {}
// AuthFail means this cloudflared has the wrong auth and should exit.
type AuthFail struct {
err error
}
func NewAuthFail(err error) AuthFail {
return AuthFail{err: err}
}
func (ao AuthFail) Error() string {
return ao.err.Error()
}
// Serialize into an AuthenticateResponse which can be sent via Capnp
func (ao AuthFail) Serialize() AuthenticateResponse {
return AuthenticateResponse{
PermanentErr: ao.err.Error(),
}
}
func (ao AuthFail) isAuthOutcome() {}
// AuthUnknown means the backend couldn't finish checking authentication. Try again later.
type AuthUnknown struct {
err error
hoursUntilRefresh uint8
}
func NewAuthUnknown(err error, hoursUntilRefresh uint8) AuthUnknown {
return AuthUnknown{err: err, hoursUntilRefresh: hoursUntilRefresh}
}
func (ao AuthUnknown) Error() string {
return ao.err.Error()
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao AuthUnknown) RefreshAfter() time.Duration {
return hoursToTime(ao.hoursUntilRefresh)
}
// Serialize into an AuthenticateResponse which can be sent via Capnp
func (ao AuthUnknown) Serialize() AuthenticateResponse {
return AuthenticateResponse{
RetryableErr: ao.err.Error(),
HoursUntilRefresh: ao.hoursUntilRefresh,
}
}
func (ao AuthUnknown) isAuthOutcome() {}
func hoursToTime(hours uint8) time.Duration {
return time.Duration(hours) * time.Hour
}

View File

@@ -1,78 +0,0 @@
package pogs
import (
"context"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
func (i TunnelServer_PogsImpl) Authenticate(p proto.TunnelServer_authenticate) error {
originCert, err := p.Params.OriginCert()
if err != nil {
return err
}
hostname, err := p.Params.Hostname()
if err != nil {
return err
}
options, err := p.Params.Options()
if err != nil {
return err
}
pogsOptions, err := UnmarshalRegistrationOptions(options)
if err != nil {
return err
}
server.Ack(p.Options)
resp, err := i.impl.Authenticate(p.Ctx, originCert, hostname, pogsOptions)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalAuthenticateResponse(result, resp)
}
func MarshalAuthenticateResponse(s proto.AuthenticateResponse, p *AuthenticateResponse) error {
return pogs.Insert(proto.AuthenticateResponse_TypeID, s.Struct, p)
}
func (c TunnelServer_PogsClient) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
client := proto.TunnelServer{Client: c.Client}
promise := client.Authenticate(ctx, func(p proto.TunnelServer_authenticate_Params) error {
err := p.SetOriginCert(originCert)
if err != nil {
return err
}
err = p.SetHostname(hostname)
if err != nil {
return err
}
registrationOptions, err := p.NewOptions()
if err != nil {
return err
}
err = MarshalRegistrationOptions(registrationOptions, options)
if err != nil {
return err
}
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
}
return UnmarshalAuthenticateResponse(retval)
}
func UnmarshalAuthenticateResponse(s proto.AuthenticateResponse) (*AuthenticateResponse, error) {
p := new(AuthenticateResponse)
err := pogs.Extract(p, proto.AuthenticateResponse_TypeID, s.Struct)
return p, err
}

View File

@@ -1,136 +0,0 @@
package pogs
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/stretchr/testify/assert"
capnp "zombiezen.com/go/capnproto2"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
// Ensure the AuthOutcome sum is correct
var _ AuthOutcome = &AuthSuccess{}
var _ AuthOutcome = &AuthFail{}
var _ AuthOutcome = &AuthUnknown{}
// Unit tests for AuthenticateResponse.Outcome()
func TestAuthenticateResponseOutcome(t *testing.T) {
type fields struct {
PermanentErr string
RetryableErr string
Jwt []byte
HoursUntilRefresh uint8
}
tests := []struct {
name string
fields fields
want AuthOutcome
}{
{"success",
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
AuthSuccess{jwt: []byte("asdf"), hoursUntilRefresh: 6},
},
{"fail",
fields{PermanentErr: "bad creds"},
AuthFail{err: fmt.Errorf("bad creds")},
},
{"error",
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
AuthUnknown{err: fmt.Errorf("bad conn"), hoursUntilRefresh: 6},
},
{"nil (no fields are set)",
fields{},
nil,
},
{"nil (too few fields are set)",
fields{HoursUntilRefresh: 6},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ar := AuthenticateResponse{
PermanentErr: tt.fields.PermanentErr,
RetryableErr: tt.fields.RetryableErr,
Jwt: tt.fields.Jwt,
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
}
got := ar.Outcome()
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
}
if got != nil && !reflect.DeepEqual(got.Serialize(), ar) {
t.Errorf(".Outcome() and .Serialize() should be inverses but weren't. Expected %v, got %v", ar, got.Serialize())
}
})
}
}
func TestAuthSuccess(t *testing.T) {
input := NewAuthSuccess([]byte("asdf"), 6)
output, ok := input.Serialize().Outcome().(AuthSuccess)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestAuthUnknown(t *testing.T) {
input := NewAuthUnknown(fmt.Errorf("pdx unreachable"), 6)
output, ok := input.Serialize().Outcome().(AuthUnknown)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestAuthFail(t *testing.T) {
input := NewAuthFail(fmt.Errorf("wrong creds"))
output, ok := input.Serialize().Outcome().(AuthFail)
assert.True(t, ok)
assert.Equal(t, input, output)
}
func TestWhenToRefresh(t *testing.T) {
expected := 4 * time.Hour
actual := hoursToTime(4)
if expected != actual {
t.Fatalf("expected %v hours, got %v", expected, actual)
}
}
// Test that serializing and deserializing AuthenticationResponse undo each other.
func TestSerializeAuthenticationResponse(t *testing.T) {
tests := []*AuthenticateResponse{
{
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
HoursUntilRefresh: 24,
},
{
PermanentErr: "bad auth",
},
{
RetryableErr: "bad connection",
HoursUntilRefresh: 24,
},
}
for i, testCase := range tests {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
assert.NoError(t, err)
capnpEntity, err := proto.NewAuthenticateResponse(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalAuthenticateResponse(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalAuthenticateResponse(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}

View File

@@ -1,93 +0,0 @@
package pogs
import (
"context"
"zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
func (i TunnelServer_PogsImpl) ReconnectTunnel(p proto.TunnelServer_reconnectTunnel) error {
jwt, err := p.Params.Jwt()
if err != nil {
return err
}
eventDigest, err := p.Params.EventDigest()
if err != nil {
return err
}
connDigest, err := p.Params.ConnDigest()
if err != nil {
return err
}
hostname, err := p.Params.Hostname()
if err != nil {
return err
}
options, err := p.Params.Options()
if err != nil {
return err
}
pogsOptions, err := UnmarshalRegistrationOptions(options)
if err != nil {
return err
}
server.Ack(p.Options)
registration, err := i.impl.ReconnectTunnel(p.Ctx, jwt, eventDigest, connDigest, hostname, pogsOptions)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalTunnelRegistration(result, registration)
}
func (c TunnelServer_PogsClient) ReconnectTunnel(
ctx context.Context,
jwt,
eventDigest []byte,
connDigest []byte,
hostname string,
options *RegistrationOptions,
) *TunnelRegistration {
client := proto.TunnelServer{Client: c.Client}
promise := client.ReconnectTunnel(ctx, func(p proto.TunnelServer_reconnectTunnel_Params) error {
err := p.SetJwt(jwt)
if err != nil {
return err
}
err = p.SetEventDigest(eventDigest)
if err != nil {
return err
}
err = p.SetConnDigest(connDigest)
if err != nil {
return err
}
err = p.SetHostname(hostname)
if err != nil {
return err
}
registrationOptions, err := p.NewOptions()
if err != nil {
return err
}
err = MarshalRegistrationOptions(registrationOptions, options)
if err != nil {
return err
}
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
registration, err := UnmarshalTunnelRegistration(retval)
if err != nil {
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
return registration
}

View File

@@ -54,18 +54,14 @@ func TestConnectionRegistrationRPC(t *testing.T) {
// Server-side
testImpl := testConnectionRegistrationServer{}
srv := TunnelServer_ServerToClient(&testImpl)
srv := RegistrationServer_ServerToClient(&testImpl)
serverConn := rpc.NewConn(t1, rpc.MainInterface(srv.Client))
defer serverConn.Wait()
ctx := context.Background()
clientConn := rpc.NewConn(t2)
defer clientConn.Close()
client := TunnelServer_PogsClient{
RegistrationServer_PogsClient: RegistrationServer_PogsClient{
Client: clientConn.Bootstrap(ctx),
Conn: clientConn,
},
client := RegistrationServer_PogsClient{
Client: clientConn.Bootstrap(ctx),
Conn: clientConn,
}
@@ -123,8 +119,6 @@ func TestConnectionRegistrationRPC(t *testing.T) {
}
type testConnectionRegistrationServer struct {
mockTunnelServerBase
details *ConnectionDetails
err error
}
@@ -147,3 +141,7 @@ func (t *testConnectionRegistrationServer) RegisterConnection(ctx context.Contex
panic("either details or err mush be set")
}
func (t *testConnectionRegistrationServer) UnregisterConnection(ctx context.Context) {
panic("unimplemented: UnregisterConnection")
}

View File

@@ -1,40 +0,0 @@
package pogs
import (
"context"
"github.com/google/uuid"
)
// mockTunnelServerBase provides a placeholder implementation
// for TunnelServer interface that can be used to build
// mocks for specific unit tests without having to implement every method
type mockTunnelServerBase struct{}
func (mockTunnelServerBase) RegisterConnection(ctx context.Context, auth TunnelAuth, tunnelID uuid.UUID, connIndex byte, options *ConnectionOptions) (*ConnectionDetails, error) {
panic("unexpected call to RegisterConnection")
}
func (mockTunnelServerBase) UnregisterConnection(ctx context.Context) {
panic("unexpected call to UnregisterConnection")
}
func (mockTunnelServerBase) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
panic("unexpected call to RegisterTunnel")
}
func (mockTunnelServerBase) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
panic("unexpected call to GetServerInfo")
}
func (mockTunnelServerBase) UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error {
panic("unexpected call to UnregisterTunnel")
}
func (mockTunnelServerBase) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
panic("unexpected call to Authenticate")
}
func (mockTunnelServerBase) ReconnectTunnel(ctx context.Context, jwt, eventDigest, connDigest []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error) {
panic("unexpected call to ReconnectTunnel")
}

8
tunnelrpc/pogs/tag.go Normal file
View File

@@ -0,0 +1,8 @@
package pogs
// Tag previously was a legacy tunnel capnp struct but was deprecated. To help reduce the amount of changes imposed
// by removing this simple struct, it was copied out of the capnp and provided here instead.
type Tag struct {
Name string
Value string
}

View File

@@ -1,334 +0,0 @@
package pogs
import (
"context"
"fmt"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
const (
defaultRetryAfterSeconds = 15
)
type Authentication struct {
Key string
Email string
OriginCAKey string
}
func MarshalAuthentication(s proto.Authentication, p *Authentication) error {
return pogs.Insert(proto.Authentication_TypeID, s.Struct, p)
}
func UnmarshalAuthentication(s proto.Authentication) (*Authentication, error) {
p := new(Authentication)
err := pogs.Extract(p, proto.Authentication_TypeID, s.Struct)
return p, err
}
type TunnelRegistration struct {
SuccessfulTunnelRegistration
Err string
PermanentFailure bool
RetryAfterSeconds uint16
}
type SuccessfulTunnelRegistration struct {
Url string
LogLines []string
TunnelID string `capnp:"tunnelID"`
EventDigest []byte
ConnDigest []byte
}
func NewSuccessfulTunnelRegistration(
url string,
logLines []string,
tunnelID string,
eventDigest []byte,
connDigest []byte,
) *TunnelRegistration {
// Marshal nil will result in an error
if logLines == nil {
logLines = []string{}
}
return &TunnelRegistration{
SuccessfulTunnelRegistration: SuccessfulTunnelRegistration{
Url: url,
LogLines: logLines,
TunnelID: tunnelID,
EventDigest: eventDigest,
ConnDigest: connDigest,
},
}
}
// Not calling this function Error() to avoid confusion with implementing error interface
func (tr TunnelRegistration) DeserializeError() TunnelRegistrationError {
if tr.Err != "" {
err := fmt.Errorf(tr.Err)
if tr.PermanentFailure {
return NewPermanentRegistrationError(err)
}
retryAfterSeconds := tr.RetryAfterSeconds
if retryAfterSeconds < defaultRetryAfterSeconds {
retryAfterSeconds = defaultRetryAfterSeconds
}
return NewRetryableRegistrationError(err, retryAfterSeconds)
}
return nil
}
type TunnelRegistrationError interface {
error
Serialize() *TunnelRegistration
IsPermanent() bool
}
type PermanentRegistrationError struct {
err string
}
func NewPermanentRegistrationError(err error) TunnelRegistrationError {
return &PermanentRegistrationError{
err: err.Error(),
}
}
func (pre *PermanentRegistrationError) Error() string {
return pre.err
}
func (pre *PermanentRegistrationError) Serialize() *TunnelRegistration {
return &TunnelRegistration{
Err: pre.err,
PermanentFailure: true,
}
}
func (*PermanentRegistrationError) IsPermanent() bool {
return true
}
type RetryableRegistrationError struct {
err string
retryAfterSeconds uint16
}
func NewRetryableRegistrationError(err error, retryAfterSeconds uint16) TunnelRegistrationError {
return &RetryableRegistrationError{
err: err.Error(),
retryAfterSeconds: retryAfterSeconds,
}
}
func (rre *RetryableRegistrationError) Error() string {
return rre.err
}
func (rre *RetryableRegistrationError) Serialize() *TunnelRegistration {
return &TunnelRegistration{
Err: rre.err,
PermanentFailure: false,
RetryAfterSeconds: rre.retryAfterSeconds,
}
}
func (*RetryableRegistrationError) IsPermanent() bool {
return false
}
func MarshalTunnelRegistration(s proto.TunnelRegistration, p *TunnelRegistration) error {
return pogs.Insert(proto.TunnelRegistration_TypeID, s.Struct, p)
}
func UnmarshalTunnelRegistration(s proto.TunnelRegistration) (*TunnelRegistration, error) {
p := new(TunnelRegistration)
err := pogs.Extract(p, proto.TunnelRegistration_TypeID, s.Struct)
return p, err
}
type RegistrationOptions struct {
ClientID string `capnp:"clientId"`
Version string
OS string `capnp:"os"`
ExistingTunnelPolicy proto.ExistingTunnelPolicy
PoolName string `capnp:"poolName"`
Tags []Tag
ConnectionID uint8 `capnp:"connectionId"`
OriginLocalIP string `capnp:"originLocalIp"`
IsAutoupdated bool `capnp:"isAutoupdated"`
RunFromTerminal bool `capnp:"runFromTerminal"`
CompressionQuality uint64 `capnp:"compressionQuality"`
UUID string `capnp:"uuid"`
NumPreviousAttempts uint8
Features []string
}
func MarshalRegistrationOptions(s proto.RegistrationOptions, p *RegistrationOptions) error {
return pogs.Insert(proto.RegistrationOptions_TypeID, s.Struct, p)
}
func UnmarshalRegistrationOptions(s proto.RegistrationOptions) (*RegistrationOptions, error) {
p := new(RegistrationOptions)
err := pogs.Extract(p, proto.RegistrationOptions_TypeID, s.Struct)
return p, err
}
type Tag struct {
Name string `json:"name"`
Value string `json:"value"`
}
type ServerInfo struct {
LocationName string
}
func MarshalServerInfo(s proto.ServerInfo, p *ServerInfo) error {
return pogs.Insert(proto.ServerInfo_TypeID, s.Struct, p)
}
func UnmarshalServerInfo(s proto.ServerInfo) (*ServerInfo, error) {
p := new(ServerInfo)
err := pogs.Extract(p, proto.ServerInfo_TypeID, s.Struct)
return p, err
}
type TunnelServer interface {
RegistrationServer
RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration
GetServerInfo(ctx context.Context) (*ServerInfo, error)
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error)
ReconnectTunnel(ctx context.Context, jwt, eventDigest, connDigest []byte, hostname string, options *RegistrationOptions) (*TunnelRegistration, error)
}
func TunnelServer_ServerToClient(s TunnelServer) proto.TunnelServer {
return proto.TunnelServer_ServerToClient(TunnelServer_PogsImpl{RegistrationServer_PogsImpl{s}, s})
}
type TunnelServer_PogsImpl struct {
RegistrationServer_PogsImpl
impl TunnelServer
}
func (i TunnelServer_PogsImpl) RegisterTunnel(p proto.TunnelServer_registerTunnel) error {
originCert, err := p.Params.OriginCert()
if err != nil {
return err
}
hostname, err := p.Params.Hostname()
if err != nil {
return err
}
options, err := p.Params.Options()
if err != nil {
return err
}
pogsOptions, err := UnmarshalRegistrationOptions(options)
if err != nil {
return err
}
server.Ack(p.Options)
registration := i.impl.RegisterTunnel(p.Ctx, originCert, hostname, pogsOptions)
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalTunnelRegistration(result, registration)
}
func (i TunnelServer_PogsImpl) GetServerInfo(p proto.TunnelServer_getServerInfo) error {
server.Ack(p.Options)
serverInfo, err := i.impl.GetServerInfo(p.Ctx)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalServerInfo(result, serverInfo)
}
func (i TunnelServer_PogsImpl) UnregisterTunnel(p proto.TunnelServer_unregisterTunnel) error {
gracePeriodNanoSec := p.Params.GracePeriodNanoSec()
server.Ack(p.Options)
return i.impl.UnregisterTunnel(p.Ctx, gracePeriodNanoSec)
}
func (i TunnelServer_PogsImpl) ObsoleteDeclarativeTunnelConnect(p proto.TunnelServer_obsoleteDeclarativeTunnelConnect) error {
return fmt.Errorf("RPC to create declarative tunnel connection has been deprecated")
}
type TunnelServer_PogsClient struct {
RegistrationServer_PogsClient
Client capnp.Client
Conn *rpc.Conn
}
func (c TunnelServer_PogsClient) Close() error {
c.Client.Close()
return c.Conn.Close()
}
func (c TunnelServer_PogsClient) RegisterTunnel(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) *TunnelRegistration {
client := proto.TunnelServer{Client: c.Client}
promise := client.RegisterTunnel(ctx, func(p proto.TunnelServer_registerTunnel_Params) error {
err := p.SetOriginCert(originCert)
if err != nil {
return err
}
err = p.SetHostname(hostname)
if err != nil {
return err
}
registrationOptions, err := p.NewOptions()
if err != nil {
return err
}
err = MarshalRegistrationOptions(registrationOptions, options)
if err != nil {
return err
}
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
registration, err := UnmarshalTunnelRegistration(retval)
if err != nil {
return NewRetryableRegistrationError(err, defaultRetryAfterSeconds).Serialize()
}
return registration
}
func (c TunnelServer_PogsClient) GetServerInfo(ctx context.Context) (*ServerInfo, error) {
client := proto.TunnelServer{Client: c.Client}
promise := client.GetServerInfo(ctx, func(p proto.TunnelServer_getServerInfo_Params) error {
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
}
return UnmarshalServerInfo(retval)
}
func (c TunnelServer_PogsClient) UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error {
client := proto.TunnelServer{Client: c.Client}
promise := client.UnregisterTunnel(ctx, func(p proto.TunnelServer_unregisterTunnel_Params) error {
p.SetGracePeriodNanoSec(gracePeriodNanoSec)
return nil
})
_, err := promise.Struct()
return err
}

View File

@@ -1,57 +0,0 @@
package pogs
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
capnp "zombiezen.com/go/capnproto2"
"github.com/cloudflare/cloudflared/tunnelrpc/proto"
)
const (
testURL = "tunnel.example.com"
testTunnelID = "asdfghjkl;"
testRetryAfterSeconds = 19
)
var (
testErr = fmt.Errorf("Invalid credential")
testLogLines = []string{"all", "working"}
testEventDigest = []byte("asdf")
testConnDigest = []byte("lkjh")
)
// *PermanentRegistrationError implements TunnelRegistrationError
var _ TunnelRegistrationError = (*PermanentRegistrationError)(nil)
// *RetryableRegistrationError implements TunnelRegistrationError
var _ TunnelRegistrationError = (*RetryableRegistrationError)(nil)
func TestTunnelRegistration(t *testing.T) {
testCases := []*TunnelRegistration{
NewSuccessfulTunnelRegistration(testURL, testLogLines, testTunnelID, testEventDigest, testConnDigest),
NewSuccessfulTunnelRegistration(testURL, nil, testTunnelID, testEventDigest, testConnDigest),
NewPermanentRegistrationError(testErr).Serialize(),
NewRetryableRegistrationError(testErr, testRetryAfterSeconds).Serialize(),
}
for i, testCase := range testCases {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
assert.NoError(t, err)
capnpEntity, err := proto.NewTunnelRegistration(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalTunnelRegistration(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase #%v failed to marshal", i) {
continue
}
result, err := UnmarshalTunnelRegistration(capnpEntity)
if !assert.NoError(t, err, "testCase #%v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}