TUN-7373: Streaming logs override for same actor

To help accommodate web browser interactions with websockets, when a
streaming logs session is requested for the same actor while already
serving a session for that user in a separate request, the original
request will be closed and the new request start streaming logs
instead. This should help with rogue sessions holding on for too long
with no client on the other side (before idle timeout or connection
close).
This commit is contained in:
Devin Carr
2023-04-21 11:54:37 -07:00
parent ee5e447d44
commit 38cd455e4d
109 changed files with 12691 additions and 1798 deletions

View File

@@ -11,16 +11,9 @@ import (
var json = jsoniter.ConfigFastest
const (
// Indicates how many log messages the listener will hold before dropping.
// Provides a throttling mechanism to drop latest messages if the sender
// can't keep up with the influx of log messages.
logWindow = 30
)
// Logger manages the number of management streaming log sessions
type Logger struct {
sessions []*Session
sessions []*session
mu sync.RWMutex
// Unique logger that isn't a io.Writer of the list of zerolog writers. This helps prevent management log
@@ -40,69 +33,47 @@ func NewLogger() *Logger {
}
type LoggerListener interface {
Listen(*StreamingFilters) *Session
Close(*Session)
// ActiveSession returns the first active session for the requested actor.
ActiveSession(actor) *session
// ActiveSession returns the count of active sessions.
ActiveSessions() int
// Listen appends the session to the list of sessions that receive log events.
Listen(*session)
// Remove a session from the available sessions that were receiving log events.
Remove(*session)
}
type Session struct {
// Buffered channel that holds the recent log events
listener chan *Log
// Types of log events that this session will provide through the listener
filters *StreamingFilters
}
func newSession(size int, filters *StreamingFilters) *Session {
s := &Session{
listener: make(chan *Log, size),
}
if filters != nil {
s.filters = filters
} else {
s.filters = &StreamingFilters{}
}
return s
}
// Insert attempts to insert the log to the session. If the log event matches the provided session filters, it
// will be applied to the listener.
func (s *Session) Insert(log *Log) {
// Level filters are optional
if s.filters.Level != nil {
if *s.filters.Level > log.Level {
return
func (l *Logger) ActiveSession(actor actor) *session {
l.mu.RLock()
defer l.mu.RUnlock()
for _, session := range l.sessions {
if session.actor.ID == actor.ID && session.active.Load() {
return session
}
}
// Event filters are optional
if len(s.filters.Events) != 0 && !contains(s.filters.Events, log.Event) {
return
}
select {
case s.listener <- log:
default:
// buffer is full, discard
}
return nil
}
func contains(array []LogEventType, t LogEventType) bool {
for _, v := range array {
if v == t {
return true
func (l *Logger) ActiveSessions() int {
l.mu.RLock()
defer l.mu.RUnlock()
count := 0
for _, session := range l.sessions {
if session.active.Load() {
count += 1
}
}
return false
return count
}
// Listen creates a new Session that will append filtered log events as they are created.
func (l *Logger) Listen(filters *StreamingFilters) *Session {
func (l *Logger) Listen(session *session) {
l.mu.Lock()
defer l.mu.Unlock()
listener := newSession(logWindow, filters)
l.sessions = append(l.sessions, listener)
return listener
session.active.Store(true)
l.sessions = append(l.sessions, session)
}
// Close will remove a Session from the available sessions that were receiving log events.
func (l *Logger) Close(session *Session) {
func (l *Logger) Remove(session *session) {
l.mu.Lock()
defer l.mu.Unlock()
index := -1

View File

@@ -1,8 +1,8 @@
package management
import (
"context"
"testing"
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
@@ -21,9 +21,14 @@ func TestLoggerWrite_NoSessions(t *testing.T) {
func TestLoggerWrite_OneSession(t *testing.T) {
logger := NewLogger()
zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel)
_, cancel := context.WithCancel(context.Background())
defer cancel()
session := logger.Listen(nil)
defer logger.Close(session)
session := newSession(logWindow, actor{ID: actorID}, cancel)
logger.Listen(session)
defer logger.Remove(session)
assert.Equal(t, 1, logger.ActiveSessions())
assert.Equal(t, session, logger.ActiveSession(actor{ID: actorID}))
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello")
select {
case event := <-session.listener:
@@ -40,12 +45,20 @@ func TestLoggerWrite_OneSession(t *testing.T) {
func TestLoggerWrite_MultipleSessions(t *testing.T) {
logger := NewLogger()
zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel)
_, cancel := context.WithCancel(context.Background())
defer cancel()
session1 := newSession(logWindow, actor{}, cancel)
logger.Listen(session1)
defer logger.Remove(session1)
assert.Equal(t, 1, logger.ActiveSessions())
session2 := newSession(logWindow, actor{}, cancel)
logger.Listen(session2)
assert.Equal(t, 2, logger.ActiveSessions())
session1 := logger.Listen(nil)
defer logger.Close(session1)
session2 := logger.Listen(nil)
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello")
for _, session := range []*Session{session1, session2} {
for _, session := range []*session{session1, session2} {
select {
case event := <-session.listener:
assert.NotEmpty(t, event.Time)
@@ -58,7 +71,7 @@ func TestLoggerWrite_MultipleSessions(t *testing.T) {
}
// Close session2 and make sure session1 still receives events
logger.Close(session2)
logger.Remove(session2)
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello2")
select {
case event := <-session1.listener:
@@ -79,104 +92,6 @@ func TestLoggerWrite_MultipleSessions(t *testing.T) {
}
}
// Validate that the session filters events
func TestSession_Insert(t *testing.T) {
infoLevel := new(LogLevel)
*infoLevel = Info
warnLevel := new(LogLevel)
*warnLevel = Warn
for _, test := range []struct {
name string
filters StreamingFilters
expectLog bool
}{
{
name: "none",
expectLog: true,
},
{
name: "level",
filters: StreamingFilters{
Level: infoLevel,
},
expectLog: true,
},
{
name: "filtered out level",
filters: StreamingFilters{
Level: warnLevel,
},
expectLog: false,
},
{
name: "events",
filters: StreamingFilters{
Events: []LogEventType{HTTP},
},
expectLog: true,
},
{
name: "filtered out event",
filters: StreamingFilters{
Events: []LogEventType{Cloudflared},
},
expectLog: false,
},
{
name: "filter and event",
filters: StreamingFilters{
Level: infoLevel,
Events: []LogEventType{HTTP},
},
expectLog: true,
},
} {
t.Run(test.name, func(t *testing.T) {
session := newSession(4, &test.filters)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
session.Insert(&log)
select {
case <-session.listener:
require.True(t, test.expectLog)
default:
require.False(t, test.expectLog)
}
})
}
}
// Validate that the session has a max amount of events to hold
func TestSession_InsertOverflow(t *testing.T) {
session := newSession(1, nil)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
// Insert 2 but only max channel size for 1
session.Insert(&log)
session.Insert(&log)
select {
case <-session.listener:
// pass
default:
require.Fail(t, "expected one log event")
}
// Second dequeue should fail
select {
case <-session.listener:
require.Fail(t, "expected no more remaining log events")
default:
// pass
}
}
type mockWriter struct {
event *Log
err error

76
management/middleware.go Normal file
View File

@@ -0,0 +1,76 @@
package management
import (
"context"
"fmt"
"net/http"
)
type ctxKey int
const (
accessClaimsCtxKey ctxKey = iota
)
const (
connectorIDQuery = "connector_id"
accessTokenQuery = "access_token"
)
var (
errMissingAccessToken = managementError{Code: 1001, Message: "missing access_token query parameter"}
)
// HTTP middleware setting the parsed access_token claims in the request context
func ValidateAccessTokenQueryMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Validate access token
accessToken := r.URL.Query().Get("access_token")
if accessToken == "" {
writeHTTPErrorResponse(w, errMissingAccessToken)
return
}
token, err := parseToken(accessToken)
if err != nil {
writeHTTPErrorResponse(w, errMissingAccessToken)
return
}
r = r.WithContext(context.WithValue(r.Context(), accessClaimsCtxKey, token))
next.ServeHTTP(w, r)
})
}
// Middleware validation error struct for returning to the eyeball
type managementError struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
func (m *managementError) Error() string {
return m.Message
}
// Middleware validation error HTTP response JSON for returning to the eyeball
type managementErrorResponse struct {
Success bool `json:"success,omitempty"`
Errors []managementError `json:"errors,omitempty"`
}
// writeErrorResponse will respond to the eyeball with basic HTTP JSON payloads with validation failure information
func writeHTTPErrorResponse(w http.ResponseWriter, errResp managementError) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
err := json.NewEncoder(w).Encode(managementErrorResponse{
Success: false,
Errors: []managementError{errResp},
})
// we have already written the header, so write a basic error response if unable to encode the error
if err != nil {
// fallback to text message
http.Error(w, fmt.Sprintf(
"%d %s",
http.StatusBadRequest,
http.StatusText(http.StatusBadRequest),
), http.StatusBadRequest)
}
}

View File

@@ -0,0 +1,71 @@
package management
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
)
func TestValidateAccessTokenQueryMiddleware(t *testing.T) {
r := chi.NewRouter()
r.Use(ValidateAccessTokenQueryMiddleware)
r.Get("/valid", func(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
assert.True(t, ok)
assert.True(t, claims.verify())
w.WriteHeader(http.StatusOK)
})
r.Get("/invalid", func(w http.ResponseWriter, r *http.Request) {
_, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
assert.False(t, ok)
w.WriteHeader(http.StatusOK)
})
ts := httptest.NewServer(r)
defer ts.Close()
// valid: with access_token query param
path := "/valid?access_token=" + validToken
resp, _ := testRequest(t, ts, "GET", path, nil)
assert.Equal(t, http.StatusOK, resp.StatusCode)
// invalid: unset token
path = "/invalid"
resp, err := testRequest(t, ts, "GET", path, nil)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.NotNil(t, err)
assert.Equal(t, errMissingAccessToken, err.Errors[0])
// invalid: invalid token
path = "/invalid?access_token=eyJ"
resp, err = testRequest(t, ts, "GET", path, nil)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
assert.NotNil(t, err)
assert.Equal(t, errMissingAccessToken, err.Errors[0])
}
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, *managementErrorResponse) {
req, err := http.NewRequest(method, ts.URL+path, body)
if err != nil {
t.Fatal(err)
return nil, nil
}
resp, err := ts.Client().Do(req)
if err != nil {
t.Fatal(err)
return nil, nil
}
var claims managementErrorResponse
err = json.NewDecoder(resp.Body).Decode(&claims)
if err != nil {
return resp, nil
}
defer resp.Body.Close()
return resp, &claims
}

View File

@@ -7,7 +7,6 @@ import (
"net/http"
"os"
"sync"
"sync/atomic"
"time"
"github.com/go-chi/chi/v5"
@@ -24,7 +23,7 @@ const (
// value will return this error to incoming requests.
StatusSessionLimitExceeded websocket.StatusCode = 4002
reasonSessionLimitExceeded = "limit exceeded for streaming sessions"
// There is a limited idle time while not actively serving a session for a request before dropping the connection.
StatusIdleLimitExceeded websocket.StatusCode = 4003
reasonIdleLimitExceeded = "session was idle for too long"
)
@@ -41,9 +40,6 @@ type ManagementService struct {
log *zerolog.Logger
router chi.Router
// streaming signifies if the service is already streaming logs. Helps limit the number of active users streaming logs
// from this cloudflared instance.
streaming atomic.Bool
// streamingMut is a lock to prevent concurrent requests to start streaming. Utilizing the atomic.Bool is not
// sufficient to complete this operation since many other checks during an incoming new request are needed
// to validate this before setting streaming to true.
@@ -67,6 +63,7 @@ func New(managementHostname string,
label: label,
}
r := chi.NewRouter()
r.Use(ValidateAccessTokenQueryMiddleware)
r.Get("/ping", ping)
r.Head("/ping", ping)
r.Get("/logs", s.logs)
@@ -92,7 +89,6 @@ type getHostDetailsResponse struct {
}
func (m *ManagementService) getHostDetails(w http.ResponseWriter, r *http.Request) {
var getHostDetailsResponse = getHostDetailsResponse{
ClientID: m.clientID.String(),
}
@@ -157,12 +153,11 @@ func (m *ManagementService) readEvents(c *websocket.Conn, ctx context.Context, e
}
// streamLogs will begin the process of reading from the Session listener and write the log events to the client.
func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, session *Session) {
defer m.logger.Close(session)
for m.streaming.Load() {
func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, session *session) {
for session.Active() {
select {
case <-ctx.Done():
m.streaming.Store(false)
session.Stop()
return
case event := <-session.listener:
err := WriteEvent(c, ctx, &EventLog{
@@ -176,7 +171,7 @@ func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, s
m.log.Err(c.Close(websocket.StatusInternalError, err.Error())).Send()
}
// Any errors when writing the messages to the client will stop streaming and close the connection
m.streaming.Store(false)
session.Stop()
return
}
default:
@@ -185,28 +180,38 @@ func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, s
}
}
// startStreaming will check the conditions of the request and begin streaming or close the connection for invalid
// requests.
func (m *ManagementService) startStreaming(c *websocket.Conn, ctx context.Context, event *ClientEvent) {
// canStartStream will check the conditions of the request and return if the session can begin streaming.
func (m *ManagementService) canStartStream(session *session) bool {
m.streamingMut.Lock()
defer m.streamingMut.Unlock()
// Limits to one user for streaming logs
if m.streaming.Load() {
m.log.Warn().
Msgf("Another management session request was attempted but one session already being served; there is a limit of streaming log sessions to reduce overall performance impact.")
m.log.Err(c.Close(StatusSessionLimitExceeded, reasonSessionLimitExceeded)).Send()
return
// Limits to one actor for streaming logs
if m.logger.ActiveSessions() > 0 {
// Allow the same user to preempt their existing session to disconnect their old session and start streaming
// with this new session instead.
if existingSession := m.logger.ActiveSession(session.actor); existingSession != nil {
m.log.Info().
Msgf("Another management session request for the same actor was requested; the other session will be disconnected to handle the new request.")
existingSession.Stop()
m.logger.Remove(existingSession)
existingSession.cancel()
} else {
m.log.Warn().
Msgf("Another management session request was attempted but one session already being served; there is a limit of streaming log sessions to reduce overall performance impact.")
return false
}
}
return true
}
// parseFilters will check the ClientEvent for start_streaming and assign filters if provided to the session
func (m *ManagementService) parseFilters(c *websocket.Conn, event *ClientEvent, session *session) bool {
// Expect the first incoming request
startEvent, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming)
if !ok {
m.log.Warn().Err(c.Close(StatusInvalidCommand, reasonInvalidCommand)).Msgf("expected start_streaming as first recieved event")
return
return false
}
m.streaming.Store(true)
listener := m.logger.Listen(startEvent.Filters)
m.log.Debug().Msgf("Streaming logs")
go m.streamLogs(c, ctx, listener)
session.Filters(startEvent.Filters)
return true
}
// Management Streaming Logs accept handler
@@ -227,11 +232,23 @@ func (m *ManagementService) logs(w http.ResponseWriter, r *http.Request) {
ping := time.NewTicker(15 * time.Second)
defer ping.Stop()
// Close the connection if no operation has occurred after the idle timeout.
// Close the connection if no operation has occurred after the idle timeout. The timeout is halted
// when streaming logs is active.
idleTimeout := 5 * time.Minute
idle := time.NewTimer(idleTimeout)
defer idle.Stop()
// Fetch the claims from the request context to acquire the actor
claims, ok := ctx.Value(accessClaimsCtxKey).(*managementTokenClaims)
if !ok || claims == nil {
// Typically should never happen as it is provided in the context from the middleware
m.log.Err(c.Close(websocket.StatusInternalError, "missing access_token")).Send()
return
}
session := newSession(logWindow, claims.Actor, cancel)
defer m.logger.Remove(session)
for {
select {
case <-ctx.Done():
@@ -242,12 +259,28 @@ func (m *ManagementService) logs(w http.ResponseWriter, r *http.Request) {
switch event.Type {
case StartStreaming:
idle.Stop()
m.startStreaming(c, ctx, event)
// Expect the first incoming request
startEvent, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming)
if !ok {
m.log.Warn().Msgf("expected start_streaming as first recieved event")
m.log.Err(c.Close(StatusInvalidCommand, reasonInvalidCommand)).Send()
return
}
// Make sure the session can start
if !m.canStartStream(session) {
m.log.Err(c.Close(StatusSessionLimitExceeded, reasonSessionLimitExceeded)).Send()
return
}
session.Filters(startEvent.Filters)
m.logger.Listen(session)
m.log.Debug().Msgf("Streaming logs")
go m.streamLogs(c, ctx, session)
continue
case StopStreaming:
idle.Reset(idleTimeout)
// TODO: limit StopStreaming to only halt streaming for clients that are already streaming
m.streaming.Store(false)
// Stop the current session for the current actor who requested it
session.Stop()
m.logger.Remove(session)
case UnknownClientEventType:
fallthrough
default:

View File

@@ -7,6 +7,7 @@ import (
"time"
"github.com/rs/zerolog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"nhooyr.io/websocket"
@@ -59,3 +60,67 @@ func TestReadEventsLoop_ContextCancelled(t *testing.T) {
m.readEvents(server, ctx, events)
server.Close(websocket.StatusInternalError, "")
}
func TestCanStartStream_NoSessions(t *testing.T) {
m := ManagementService{
log: &noopLogger,
logger: &Logger{
Log: &noopLogger,
},
}
_, cancel := context.WithCancel(context.Background())
session := newSession(0, actor{}, cancel)
assert.True(t, m.canStartStream(session))
}
func TestCanStartStream_ExistingSessionDifferentActor(t *testing.T) {
m := ManagementService{
log: &noopLogger,
logger: &Logger{
Log: &noopLogger,
},
}
_, cancel := context.WithCancel(context.Background())
session1 := newSession(0, actor{ID: "test"}, cancel)
assert.True(t, m.canStartStream(session1))
m.logger.Listen(session1)
assert.True(t, session1.Active())
// Try another session
session2 := newSession(0, actor{ID: "test2"}, cancel)
assert.Equal(t, 1, m.logger.ActiveSessions())
assert.False(t, m.canStartStream(session2))
// Close session1
m.logger.Remove(session1)
assert.True(t, session1.Active()) // Remove doesn't stop a session
session1.Stop()
assert.False(t, session1.Active())
assert.Equal(t, 0, m.logger.ActiveSessions())
// Try session2 again
assert.True(t, m.canStartStream(session2))
}
func TestCanStartStream_ExistingSessionSameActor(t *testing.T) {
m := ManagementService{
log: &noopLogger,
logger: &Logger{
Log: &noopLogger,
},
}
actor := actor{ID: "test"}
_, cancel := context.WithCancel(context.Background())
session1 := newSession(0, actor, cancel)
assert.True(t, m.canStartStream(session1))
m.logger.Listen(session1)
assert.True(t, session1.Active())
// Try another session
session2 := newSession(0, actor, cancel)
assert.Equal(t, 1, m.logger.ActiveSessions())
assert.True(t, m.canStartStream(session2))
// session1 is removed and stopped
assert.Equal(t, 0, m.logger.ActiveSessions())
assert.False(t, session1.Active())
}

88
management/session.go Normal file
View File

@@ -0,0 +1,88 @@
package management
import (
"context"
"sync/atomic"
)
const (
// Indicates how many log messages the listener will hold before dropping.
// Provides a throttling mechanism to drop latest messages if the sender
// can't keep up with the influx of log messages.
logWindow = 30
)
// session captures a streaming logs session for a connection of an actor.
type session struct {
// Indicates if the session is streaming or not. Modifying this will affect the active session.
active atomic.Bool
// Allows the session to control the context of the underlying connection to close it out when done. Mostly
// used by the LoggerListener to close out and cleanup a session.
cancel context.CancelFunc
// Actor who started the session
actor actor
// Buffered channel that holds the recent log events
listener chan *Log
// Types of log events that this session will provide through the listener
filters *StreamingFilters
}
// NewSession creates a new session.
func newSession(size int, actor actor, cancel context.CancelFunc) *session {
s := &session{
active: atomic.Bool{},
cancel: cancel,
actor: actor,
listener: make(chan *Log, size),
filters: &StreamingFilters{},
}
return s
}
// Filters assigns the StreamingFilters to the session
func (s *session) Filters(filters *StreamingFilters) {
if filters != nil {
s.filters = filters
} else {
s.filters = &StreamingFilters{}
}
}
// Insert attempts to insert the log to the session. If the log event matches the provided session filters, it
// will be applied to the listener.
func (s *session) Insert(log *Log) {
// Level filters are optional
if s.filters.Level != nil {
if *s.filters.Level > log.Level {
return
}
}
// Event filters are optional
if len(s.filters.Events) != 0 && !contains(s.filters.Events, log.Event) {
return
}
select {
case s.listener <- log:
default:
// buffer is full, discard
}
}
// Active returns if the session is active
func (s *session) Active() bool {
return s.active.Load()
}
// Stop will halt the session
func (s *session) Stop() {
s.active.Store(false)
}
func contains(array []LogEventType, t LogEventType) bool {
for _, v := range array {
if v == t {
return true
}
}
return false
}

126
management/session_test.go Normal file
View File

@@ -0,0 +1,126 @@
package management
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Validate the active states of the session
func TestSession_ActiveControl(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
session := newSession(4, actor{}, cancel)
// session starts out not active
assert.False(t, session.Active())
session.active.Store(true)
assert.True(t, session.Active())
session.Stop()
assert.False(t, session.Active())
}
// Validate that the session filters events
func TestSession_Insert(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
infoLevel := new(LogLevel)
*infoLevel = Info
warnLevel := new(LogLevel)
*warnLevel = Warn
for _, test := range []struct {
name string
filters StreamingFilters
expectLog bool
}{
{
name: "none",
expectLog: true,
},
{
name: "level",
filters: StreamingFilters{
Level: infoLevel,
},
expectLog: true,
},
{
name: "filtered out level",
filters: StreamingFilters{
Level: warnLevel,
},
expectLog: false,
},
{
name: "events",
filters: StreamingFilters{
Events: []LogEventType{HTTP},
},
expectLog: true,
},
{
name: "filtered out event",
filters: StreamingFilters{
Events: []LogEventType{Cloudflared},
},
expectLog: false,
},
{
name: "filter and event",
filters: StreamingFilters{
Level: infoLevel,
Events: []LogEventType{HTTP},
},
expectLog: true,
},
} {
t.Run(test.name, func(t *testing.T) {
session := newSession(4, actor{}, cancel)
session.Filters(&test.filters)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
session.Insert(&log)
select {
case <-session.listener:
require.True(t, test.expectLog)
default:
require.False(t, test.expectLog)
}
})
}
}
// Validate that the session has a max amount of events to hold
func TestSession_InsertOverflow(t *testing.T) {
_, cancel := context.WithCancel(context.Background())
defer cancel()
session := newSession(1, actor{}, cancel)
log := Log{
Time: time.Now().UTC().Format(time.RFC3339),
Event: HTTP,
Level: Info,
Message: "test",
}
// Insert 2 but only max channel size for 1
session.Insert(&log)
session.Insert(&log)
select {
case <-session.listener:
// pass
default:
require.Fail(t, "expected one log event")
}
// Second dequeue should fail
select {
case <-session.listener:
require.Fail(t, "expected no more remaining log events")
default:
// pass
}
}

55
management/token.go Normal file
View File

@@ -0,0 +1,55 @@
package management
import (
"fmt"
"github.com/go-jose/go-jose/v3/jwt"
)
type managementTokenClaims struct {
Tunnel tunnel `json:"tun"`
Actor actor `json:"actor"`
}
// VerifyTunnel compares the tun claim isn't empty
func (c *managementTokenClaims) verify() bool {
return c.Tunnel.verify() && c.Actor.verify()
}
type tunnel struct {
ID string `json:"id"`
AccountTag string `json:"account_tag"`
}
// verify compares the tun claim isn't empty
func (t *tunnel) verify() bool {
return t.AccountTag != "" && t.ID != ""
}
type actor struct {
ID string `json:"id"`
Support bool `json:"support"`
}
// verify checks the ID claim isn't empty
func (t *actor) verify() bool {
return t.ID != ""
}
func parseToken(token string) (*managementTokenClaims, error) {
jwt, err := jwt.ParseSigned(token)
if err != nil {
return nil, fmt.Errorf("malformed jwt: %v", err)
}
var claims managementTokenClaims
// This is actually safe because we verify the token in the edge before it reaches cloudflared
err = jwt.UnsafeClaimsWithoutVerification(&claims)
if err != nil {
return nil, fmt.Errorf("malformed jwt: %v", err)
}
if !claims.verify() {
return nil, fmt.Errorf("invalid management token format provided")
}
return &claims, nil
}

130
management/token_test.go Normal file
View File

@@ -0,0 +1,130 @@
package management
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"testing"
"github.com/stretchr/testify/require"
"gopkg.in/square/go-jose.v2"
)
const (
validToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6IjEifQ.eyJ0dW4iOnsiaWQiOiI3YjA5ODE0OS01MWZlLTRlZTUtYTY4Ny0zZTM3NDQ2NmVmYzciLCJhY2NvdW50X3RhZyI6ImNkMzkxZTljMDYyNmE4Zjc2Y2IxZjY3MGY2NTkxYjA1In0sImFjdG9yIjp7ImlkIjoiZGNhcnJAY2xvdWRmbGFyZS5jb20iLCJzdXBwb3J0IjpmYWxzZX0sInJlcyI6WyJsb2dzIl0sImV4cCI6MTY3NzExNzY5NiwiaWF0IjoxNjc3MTE0MDk2LCJpc3MiOiJ0dW5uZWxzdG9yZSJ9.mKenOdOy3Xi4O-grldFnAAemdlE9WajEpTDC_FwezXQTstWiRTLwU65P5jt4vNsIiZA4OJRq7bH-QYID9wf9NA"
accountTag = "cd391e9c0626a8f76cb1f670f6591b05"
tunnelID = "7b098149-51fe-4ee5-a687-3e374466efc7"
actorID = "45d2751e-6b59-45a9-814d-f630786bd0cd"
)
type invalidManagementTokenClaims struct {
Invalid string `json:"invalid"`
}
func TestParseToken(t *testing.T) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
for _, test := range []struct {
name string
claims any
err error
}{
{
name: "Valid",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
AccountTag: accountTag,
},
Actor: actor{
ID: actorID,
},
},
},
{
name: "Invalid claims",
claims: invalidManagementTokenClaims{Invalid: "invalid"},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Tunnel",
claims: managementTokenClaims{
Actor: actor{
ID: actorID,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Tunnel ID",
claims: managementTokenClaims{
Tunnel: tunnel{
AccountTag: accountTag,
},
Actor: actor{
ID: actorID,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Account Tag",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
},
Actor: actor{
ID: actorID,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Actor",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
AccountTag: accountTag,
},
},
err: errors.New("invalid management token format provided"),
},
{
name: "Missing Actor ID",
claims: managementTokenClaims{
Tunnel: tunnel{
ID: tunnelID,
},
Actor: actor{},
},
err: errors.New("invalid management token format provided"),
},
} {
t.Run(test.name, func(t *testing.T) {
jwt := signToken(t, test.claims, key)
claims, err := parseToken(jwt)
if test.err != nil {
require.EqualError(t, err, test.err.Error())
return
}
require.Nil(t, err)
require.Equal(t, test.claims, *claims)
})
}
}
func signToken(t *testing.T, token any, key *ecdsa.PrivateKey) string {
opts := (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "1")
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, opts)
require.NoError(t, err)
payload, err := json.Marshal(token)
require.NoError(t, err)
jws, err := signer.Sign(payload)
require.NoError(t, err)
jwt, err := jws.CompactSerialize()
require.NoError(t, err)
return jwt
}