mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 08:09:58 +00:00
TUN-2547: TunnelRPC definitions for Authenticate flow
This commit is contained in:
83
tunnelrpc/pogs/auth_outcome.go
Normal file
83
tunnelrpc/pogs/auth_outcome.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package pogs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AuthenticateResponse is the serialized response from the Authenticate RPC.
|
||||
// It's a 1:1 representation of the capnp message, so it's not very useful for programmers.
|
||||
// Instead, you should call the `Outcome()` method to get a programmer-friendly sum type, with one
|
||||
// case for each possible outcome.
|
||||
type AuthenticateResponse struct {
|
||||
PermanentErr string
|
||||
RetryableErr string
|
||||
Jwt []byte
|
||||
HoursUntilRefresh uint8
|
||||
}
|
||||
|
||||
// Outcome turns the deserialized response of Authenticate into a programmer-friendly sum type.
|
||||
func (ar AuthenticateResponse) Outcome() AuthOutcome {
|
||||
// If there was a network error, then cloudflared should retry later,
|
||||
// because origintunneld couldn't prove whether auth was correct or not.
|
||||
if ar.RetryableErr != "" {
|
||||
return &AuthUnknown{Err: fmt.Errorf(ar.RetryableErr), HoursUntilRefresh: ar.HoursUntilRefresh}
|
||||
}
|
||||
|
||||
// If the user's authentication was unsuccessful, the server will return an error explaining why.
|
||||
// cloudflared should fatal with this error.
|
||||
if ar.PermanentErr != "" {
|
||||
return &AuthFail{Err: fmt.Errorf(ar.PermanentErr)}
|
||||
}
|
||||
|
||||
// If auth succeeded, return the token and refresh it when instructed.
|
||||
if ar.PermanentErr == "" && len(ar.Jwt) > 0 {
|
||||
return &AuthSuccess{Jwt: ar.Jwt, HoursUntilRefresh: ar.HoursUntilRefresh}
|
||||
}
|
||||
|
||||
// Otherwise the state got messed up.
|
||||
return nil
|
||||
}
|
||||
|
||||
// AuthOutcome is a programmer-friendly sum type denoting the possible outcomes of Authenticate.
|
||||
//go-sumtype:decl AuthOutcome
|
||||
type AuthOutcome interface {
|
||||
isAuthOutcome()
|
||||
}
|
||||
|
||||
// AuthSuccess means the backend successfully authenticated this cloudflared.
|
||||
type AuthSuccess struct {
|
||||
Jwt []byte
|
||||
HoursUntilRefresh uint8
|
||||
}
|
||||
|
||||
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||
func (ao *AuthSuccess) RefreshAfter() time.Duration {
|
||||
return hoursToTime(ao.HoursUntilRefresh)
|
||||
}
|
||||
|
||||
func (ao *AuthSuccess) isAuthOutcome() {}
|
||||
|
||||
// AuthFail means this cloudflared has the wrong auth and should exit.
|
||||
type AuthFail struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (ao *AuthFail) isAuthOutcome() {}
|
||||
|
||||
// AuthUnknown means the backend couldn't finish checking authentication. Try again later.
|
||||
type AuthUnknown struct {
|
||||
Err error
|
||||
HoursUntilRefresh uint8
|
||||
}
|
||||
|
||||
// RefreshAfter is how long cloudflared should wait before rerunning Authenticate.
|
||||
func (ao *AuthUnknown) RefreshAfter() time.Duration {
|
||||
return hoursToTime(ao.HoursUntilRefresh)
|
||||
}
|
||||
|
||||
func (ao *AuthUnknown) isAuthOutcome() {}
|
||||
|
||||
func hoursToTime(hours uint8) time.Duration {
|
||||
return time.Duration(hours) * time.Hour
|
||||
}
|
78
tunnelrpc/pogs/auth_serialize.go
Normal file
78
tunnelrpc/pogs/auth_serialize.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package pogs
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
|
||||
"zombiezen.com/go/capnproto2/pogs"
|
||||
"zombiezen.com/go/capnproto2/server"
|
||||
)
|
||||
|
||||
func (i TunnelServer_PogsImpl) Authenticate(p tunnelrpc.TunnelServer_authenticate) error {
|
||||
originCert, err := p.Params.OriginCert()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
hostname, err := p.Params.Hostname()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
options, err := p.Params.Options()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pogsOptions, err := UnmarshalRegistrationOptions(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
server.Ack(p.Options)
|
||||
resp, err := i.impl.Authenticate(p.Ctx, originCert, hostname, pogsOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := p.Results.NewResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return MarshalAuthenticateResponse(result, resp)
|
||||
}
|
||||
|
||||
func MarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse, p *AuthenticateResponse) error {
|
||||
return pogs.Insert(tunnelrpc.AuthenticateResponse_TypeID, s.Struct, p)
|
||||
}
|
||||
|
||||
func (c TunnelServer_PogsClient) Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error) {
|
||||
client := tunnelrpc.TunnelServer{Client: c.Client}
|
||||
promise := client.Authenticate(ctx, func(p tunnelrpc.TunnelServer_authenticate_Params) error {
|
||||
err := p.SetOriginCert(originCert)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.SetHostname(hostname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
registrationOptions, err := p.NewOptions()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = MarshalRegistrationOptions(registrationOptions, options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
retval, err := promise.Result().Struct()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return UnmarshalAuthenticateResponse(retval)
|
||||
}
|
||||
|
||||
func UnmarshalAuthenticateResponse(s tunnelrpc.AuthenticateResponse) (*AuthenticateResponse, error) {
|
||||
p := new(AuthenticateResponse)
|
||||
err := pogs.Extract(p, tunnelrpc.AuthenticateResponse_TypeID, s.Struct)
|
||||
return p, err
|
||||
}
|
109
tunnelrpc/pogs/auth_test.go
Normal file
109
tunnelrpc/pogs/auth_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package pogs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudflare/cloudflared/tunnelrpc"
|
||||
"github.com/stretchr/testify/assert"
|
||||
capnp "zombiezen.com/go/capnproto2"
|
||||
)
|
||||
|
||||
// Ensure the AuthOutcome sum is correct
|
||||
var _ AuthOutcome = &AuthSuccess{}
|
||||
var _ AuthOutcome = &AuthFail{}
|
||||
var _ AuthOutcome = &AuthUnknown{}
|
||||
|
||||
// Unit tests for AuthenticateResponse.Outcome()
|
||||
func TestAuthenticateResponseOutcome(t *testing.T) {
|
||||
type fields struct {
|
||||
PermanentErr string
|
||||
RetryableErr string
|
||||
Jwt []byte
|
||||
HoursUntilRefresh uint8
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want AuthOutcome
|
||||
}{
|
||||
{"success",
|
||||
fields{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
|
||||
&AuthSuccess{Jwt: []byte("asdf"), HoursUntilRefresh: 6},
|
||||
},
|
||||
{"fail",
|
||||
fields{PermanentErr: "bad creds"},
|
||||
&AuthFail{Err: fmt.Errorf("bad creds")},
|
||||
},
|
||||
{"error",
|
||||
fields{RetryableErr: "bad conn", HoursUntilRefresh: 6},
|
||||
&AuthUnknown{Err: fmt.Errorf("bad conn"), HoursUntilRefresh: 6},
|
||||
},
|
||||
{"nil (no fields are set)",
|
||||
fields{},
|
||||
nil,
|
||||
},
|
||||
{"nil (too few fields are set)",
|
||||
fields{HoursUntilRefresh: 6},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ar := AuthenticateResponse{
|
||||
PermanentErr: tt.fields.PermanentErr,
|
||||
RetryableErr: tt.fields.RetryableErr,
|
||||
Jwt: tt.fields.Jwt,
|
||||
HoursUntilRefresh: tt.fields.HoursUntilRefresh,
|
||||
}
|
||||
if got := ar.Outcome(); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AuthenticateResponse.Outcome() = %T, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhenToRefresh(t *testing.T) {
|
||||
expected := 4 * time.Hour
|
||||
actual := hoursToTime(4)
|
||||
if expected != actual {
|
||||
t.Fatalf("expected %v hours, got %v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that serializing and deserializing AuthenticationResponse undo each other.
|
||||
func TestSerializeAuthenticationResponse(t *testing.T) {
|
||||
|
||||
tests := []*AuthenticateResponse{
|
||||
&AuthenticateResponse{
|
||||
Jwt: []byte("\xbd\xb2\x3d\xbc\x20\xe2\x8c\x98"),
|
||||
HoursUntilRefresh: 24,
|
||||
},
|
||||
&AuthenticateResponse{
|
||||
PermanentErr: "bad auth",
|
||||
},
|
||||
&AuthenticateResponse{
|
||||
RetryableErr: "bad connection",
|
||||
HoursUntilRefresh: 24,
|
||||
},
|
||||
}
|
||||
|
||||
for i, testCase := range tests {
|
||||
_, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
|
||||
capnpEntity, err := tunnelrpc.NewAuthenticateResponse(seg)
|
||||
if !assert.NoError(t, err) {
|
||||
t.Fatal("Couldn't initialize a new message")
|
||||
}
|
||||
err = MarshalAuthenticateResponse(capnpEntity, testCase)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to marshal", i) {
|
||||
continue
|
||||
}
|
||||
result, err := UnmarshalAuthenticateResponse(capnpEntity)
|
||||
if !assert.NoError(t, err, "testCase index %v failed to unmarshal", i) {
|
||||
continue
|
||||
}
|
||||
assert.Equal(t, testCase, result, "testCase index %v didn't preserve struct through marshalling and unmarshalling", i)
|
||||
}
|
||||
}
|
@@ -327,6 +327,7 @@ type TunnelServer interface {
|
||||
GetServerInfo(ctx context.Context) (*ServerInfo, error)
|
||||
UnregisterTunnel(ctx context.Context, gracePeriodNanoSec int64) error
|
||||
Connect(ctx context.Context, parameters *ConnectParameters) (ConnectResult, error)
|
||||
Authenticate(ctx context.Context, originCert []byte, hostname string, options *RegistrationOptions) (*AuthenticateResponse, error)
|
||||
}
|
||||
|
||||
func TunnelServer_ServerToClient(s TunnelServer) tunnelrpc.TunnelServer {
|
||||
|
Reference in New Issue
Block a user