TUN-2547: TunnelRPC definitions for Authenticate flow

This commit is contained in:
Adam Chalmers
2019-11-18 10:28:18 -06:00
parent 6ea9b5c3ff
commit ca7fbf43da
7 changed files with 876 additions and 224 deletions

View File

@@ -0,0 +1,83 @@
package pogs
import (
"fmt"
"time"
)
// AuthenticateResponse is the serialized response from the Authenticate RPC.
// It's a 1:1 representation of the capnp message, so it's not very useful for programmers.
// Instead, you should call the `Outcome()` method to get a programmer-friendly sum type, with one
// case for each possible outcome.
type AuthenticateResponse struct {
PermanentErr string
RetryableErr string
Jwt []byte
HoursUntilRefresh uint8
}
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
func (ar AuthenticateResponse) Outcome() AuthOutcome {
// If there was a network error, then cloudflared should retry later,
// because origintunneld couldn't prove whether auth was correct or not.
if ar.RetryableErr != "" {
return &AuthUnknown{Err: fmt.Errorf(ar.RetryableErr), HoursUntilRefresh: ar.HoursUntilRefresh}
}
// If the user's authentication was unsuccessful, the server will return an error explaining why.
// cloudflared should fatal with this error.
if ar.PermanentErr != "" {
return &AuthFail{Err: fmt.Errorf(ar.PermanentErr)}
}
// If auth succeeded, return the token and refresh it when instructed.
if ar.PermanentErr == "" && len(ar.Jwt) > 0 {
return &AuthSuccess{Jwt: ar.Jwt, HoursUntilRefresh: ar.HoursUntilRefresh}
}
// Otherwise the state got messed up.
return nil
}
// AuthOutcome is a programmer-friendly sum type denoting the possible outcomes of Authenticate.
//go-sumtype:decl AuthOutcome
type AuthOutcome interface {
isAuthOutcome()
}
// AuthSuccess means the backend successfully authenticated this cloudflared.
type AuthSuccess struct {
Jwt []byte
HoursUntilRefresh uint8
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao *AuthSuccess) RefreshAfter() time.Duration {
return hoursToTime(ao.HoursUntilRefresh)
}
func (ao *AuthSuccess) isAuthOutcome() {}
// AuthFail means this cloudflared has the wrong auth and should exit.
type AuthFail struct {
Err error
}
func (ao *AuthFail) isAuthOutcome() {}
// AuthUnknown means the backend couldn't finish checking authentication. Try again later.
type AuthUnknown struct {
Err error
HoursUntilRefresh uint8
}
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
func (ao *AuthUnknown) RefreshAfter() time.Duration {
return hoursToTime(ao.HoursUntilRefresh)
}
func (ao *AuthUnknown) isAuthOutcome() {}
func hoursToTime(hours uint8) time.Duration {
return time.Duration(hours) * time.Hour
}

View File

@@ -0,0 +1,78 @@
package pogs
import (
"context"
"github.com/cloudflare/cloudflared/tunnelrpc"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/server"
)
func (i TunnelServer_PogsImpl) Authenticate(p tunnelrpc.TunnelServer_authenticate) error {
originCert, err := p.Params.OriginCert()
if err != nil {
return err
}
hostname, err := p.Params.Hostname()
if err != nil {
return err
}
options, err := p.Params.Options()
if err != nil {
return err
}
pogsOptions, err := UnmarshalRegistrationOptions(options)
if err != nil {
return err
}
server.Ack(p.Options)
resp, err := i.impl.Authenticate(p.Ctx, originCert, hostname, pogsOptions)
if err != nil {
return err
}
result, err := p.Results.NewResult()
if err != nil {
return err
}
return MarshalAuthenticateResponse(result, resp)
}
func MarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse, p *AuthenticateResponse) error {
return pogs.Insert(tunnelrpc.AuthenticateResponse_TypeID, s.Struct, p)
}
func (c TunnelServer_PogsClient) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
client := tunnelrpc.TunnelServer{Client: c.Client}
promise := client.Authenticate(ctx, func(p tunnelrpc.TunnelServer_authenticate_Params) error {
err := p.SetOriginCert(originCert)
if err != nil {
return err
}
err = p.SetHostname(hostname)
if err != nil {
return err
}
registrationOptions, err := p.NewOptions()
if err != nil {
return err
}
err = MarshalRegistrationOptions(registrationOptions, options)
if err != nil {
return err
}
return nil
})
retval, err := promise.Result().Struct()
if err != nil {
return nil, err
}
return UnmarshalAuthenticateResponse(retval)
}
func UnmarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse) (*AuthenticateResponse, error) {
p := new(AuthenticateResponse)
err := pogs.Extract(p, tunnelrpc.AuthenticateResponse_TypeID, s.Struct)
return p, err
}

109
tunnelrpc/pogs/auth_test.go Normal file
View File

@@ -0,0 +1,109 @@
package pogs
import (
"fmt"
"reflect"
"testing"
"time"
"github.com/cloudflare/cloudflared/tunnelrpc"
"github.com/stretchr/testify/assert"
capnp "zombiezen.com/go/capnproto2"
)
// Ensure the AuthOutcome sum is correct
var _ AuthOutcome = &AuthSuccess{}
var _ AuthOutcome = &AuthFail{}
var _ AuthOutcome = &AuthUnknown{}
// Unit tests for AuthenticateResponse.Outcome()
func TestAuthenticateResponseOutcome(t *testing.T) {
type fields struct {
PermanentErr string
RetryableErr string
Jwt []byte
HoursUntilRefresh uint8
}
tests := []struct {
name string
fields fields
want AuthOutcome
}{
{"success",
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
&AuthSuccess{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
},
{"fail",
fields{PermanentErr: "bad creds"},
&AuthFail{Err: fmt.Errorf("bad creds")},
},
{"error",
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
&AuthUnknown{Err: fmt.Errorf("bad conn"), HoursUntilRefresh: 6},
},
{"nil (no fields are set)",
fields{},
nil,
},
{"nil (too few fields are set)",
fields{HoursUntilRefresh: 6},
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ar := AuthenticateResponse{
PermanentErr: tt.fields.PermanentErr,
RetryableErr: tt.fields.RetryableErr,
Jwt: tt.fields.Jwt,
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
}
if got := ar.Outcome(); !reflect.DeepEqual(got, tt.want) {
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
}
})
}
}
func TestWhenToRefresh(t *testing.T) {
expected := 4 * time.Hour
actual := hoursToTime(4)
if expected != actual {
t.Fatalf("expected %v hours, got %v", expected, actual)
}
}
// Test that serializing and deserializing AuthenticationResponse undo each other.
func TestSerializeAuthenticationResponse(t *testing.T) {
tests := []*AuthenticateResponse{
&AuthenticateResponse{
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
HoursUntilRefresh: 24,
},
&AuthenticateResponse{
PermanentErr: "bad auth",
},
&AuthenticateResponse{
RetryableErr: "bad connection",
HoursUntilRefresh: 24,
},
}
for i, testCase := range tests {
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
capnpEntity, err := tunnelrpc.NewAuthenticateResponse(seg)
if !assert.NoError(t, err) {
t.Fatal("Couldn't initialize a new message")
}
err = MarshalAuthenticateResponse(capnpEntity, testCase)
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
continue
}
result, err := UnmarshalAuthenticateResponse(capnpEntity)
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
continue
}
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
}
}

View File

@@ -327,6 +327,7 @@ type TunnelServer interface {
GetServerInfo(ctx context.Context) (*ServerInfo, error)
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error)
}
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {