mirror of
https://github.com/cloudflare/cloudflared.git
synced 2025-07-27 16:59:57 +00:00
TUN-7373: Streaming logs override for same actor
To help accommodate web browser interactions with websockets, when a streaming logs session is requested for the same actor while already serving a session for that user in a separate request, the original request will be closed and the new request start streaming logs instead. This should help with rogue sessions holding on for too long with no client on the other side (before idle timeout or connection close).
This commit is contained in:
@@ -11,16 +11,9 @@ import (
|
||||
|
||||
var json = jsoniter.ConfigFastest
|
||||
|
||||
const (
|
||||
// Indicates how many log messages the listener will hold before dropping.
|
||||
// Provides a throttling mechanism to drop latest messages if the sender
|
||||
// can't keep up with the influx of log messages.
|
||||
logWindow = 30
|
||||
)
|
||||
|
||||
// Logger manages the number of management streaming log sessions
|
||||
type Logger struct {
|
||||
sessions []*Session
|
||||
sessions []*session
|
||||
mu sync.RWMutex
|
||||
|
||||
// Unique logger that isn't a io.Writer of the list of zerolog writers. This helps prevent management log
|
||||
@@ -40,69 +33,47 @@ func NewLogger() *Logger {
|
||||
}
|
||||
|
||||
type LoggerListener interface {
|
||||
Listen(*StreamingFilters) *Session
|
||||
Close(*Session)
|
||||
// ActiveSession returns the first active session for the requested actor.
|
||||
ActiveSession(actor) *session
|
||||
// ActiveSession returns the count of active sessions.
|
||||
ActiveSessions() int
|
||||
// Listen appends the session to the list of sessions that receive log events.
|
||||
Listen(*session)
|
||||
// Remove a session from the available sessions that were receiving log events.
|
||||
Remove(*session)
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
// Buffered channel that holds the recent log events
|
||||
listener chan *Log
|
||||
// Types of log events that this session will provide through the listener
|
||||
filters *StreamingFilters
|
||||
}
|
||||
|
||||
func newSession(size int, filters *StreamingFilters) *Session {
|
||||
s := &Session{
|
||||
listener: make(chan *Log, size),
|
||||
}
|
||||
if filters != nil {
|
||||
s.filters = filters
|
||||
} else {
|
||||
s.filters = &StreamingFilters{}
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Insert attempts to insert the log to the session. If the log event matches the provided session filters, it
|
||||
// will be applied to the listener.
|
||||
func (s *Session) Insert(log *Log) {
|
||||
// Level filters are optional
|
||||
if s.filters.Level != nil {
|
||||
if *s.filters.Level > log.Level {
|
||||
return
|
||||
func (l *Logger) ActiveSession(actor actor) *session {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
for _, session := range l.sessions {
|
||||
if session.actor.ID == actor.ID && session.active.Load() {
|
||||
return session
|
||||
}
|
||||
}
|
||||
// Event filters are optional
|
||||
if len(s.filters.Events) != 0 && !contains(s.filters.Events, log.Event) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case s.listener <- log:
|
||||
default:
|
||||
// buffer is full, discard
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func contains(array []LogEventType, t LogEventType) bool {
|
||||
for _, v := range array {
|
||||
if v == t {
|
||||
return true
|
||||
func (l *Logger) ActiveSessions() int {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
count := 0
|
||||
for _, session := range l.sessions {
|
||||
if session.active.Load() {
|
||||
count += 1
|
||||
}
|
||||
}
|
||||
return false
|
||||
return count
|
||||
}
|
||||
|
||||
// Listen creates a new Session that will append filtered log events as they are created.
|
||||
func (l *Logger) Listen(filters *StreamingFilters) *Session {
|
||||
func (l *Logger) Listen(session *session) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
listener := newSession(logWindow, filters)
|
||||
l.sessions = append(l.sessions, listener)
|
||||
return listener
|
||||
session.active.Store(true)
|
||||
l.sessions = append(l.sessions, session)
|
||||
}
|
||||
|
||||
// Close will remove a Session from the available sessions that were receiving log events.
|
||||
func (l *Logger) Close(session *Session) {
|
||||
func (l *Logger) Remove(session *session) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
index := -1
|
||||
|
@@ -1,8 +1,8 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -21,9 +21,14 @@ func TestLoggerWrite_NoSessions(t *testing.T) {
|
||||
func TestLoggerWrite_OneSession(t *testing.T) {
|
||||
logger := NewLogger()
|
||||
zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel)
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
session := logger.Listen(nil)
|
||||
defer logger.Close(session)
|
||||
session := newSession(logWindow, actor{ID: actorID}, cancel)
|
||||
logger.Listen(session)
|
||||
defer logger.Remove(session)
|
||||
assert.Equal(t, 1, logger.ActiveSessions())
|
||||
assert.Equal(t, session, logger.ActiveSession(actor{ID: actorID}))
|
||||
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello")
|
||||
select {
|
||||
case event := <-session.listener:
|
||||
@@ -40,12 +45,20 @@ func TestLoggerWrite_OneSession(t *testing.T) {
|
||||
func TestLoggerWrite_MultipleSessions(t *testing.T) {
|
||||
logger := NewLogger()
|
||||
zlog := zerolog.New(logger).With().Timestamp().Logger().Level(zerolog.InfoLevel)
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
session1 := newSession(logWindow, actor{}, cancel)
|
||||
logger.Listen(session1)
|
||||
defer logger.Remove(session1)
|
||||
assert.Equal(t, 1, logger.ActiveSessions())
|
||||
|
||||
session2 := newSession(logWindow, actor{}, cancel)
|
||||
logger.Listen(session2)
|
||||
assert.Equal(t, 2, logger.ActiveSessions())
|
||||
|
||||
session1 := logger.Listen(nil)
|
||||
defer logger.Close(session1)
|
||||
session2 := logger.Listen(nil)
|
||||
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello")
|
||||
for _, session := range []*Session{session1, session2} {
|
||||
for _, session := range []*session{session1, session2} {
|
||||
select {
|
||||
case event := <-session.listener:
|
||||
assert.NotEmpty(t, event.Time)
|
||||
@@ -58,7 +71,7 @@ func TestLoggerWrite_MultipleSessions(t *testing.T) {
|
||||
}
|
||||
|
||||
// Close session2 and make sure session1 still receives events
|
||||
logger.Close(session2)
|
||||
logger.Remove(session2)
|
||||
zlog.Info().Int(EventTypeKey, int(HTTP)).Msg("hello2")
|
||||
select {
|
||||
case event := <-session1.listener:
|
||||
@@ -79,104 +92,6 @@ func TestLoggerWrite_MultipleSessions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that the session filters events
|
||||
func TestSession_Insert(t *testing.T) {
|
||||
infoLevel := new(LogLevel)
|
||||
*infoLevel = Info
|
||||
warnLevel := new(LogLevel)
|
||||
*warnLevel = Warn
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
filters StreamingFilters
|
||||
expectLog bool
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
expectLog: true,
|
||||
},
|
||||
{
|
||||
name: "level",
|
||||
filters: StreamingFilters{
|
||||
Level: infoLevel,
|
||||
},
|
||||
expectLog: true,
|
||||
},
|
||||
{
|
||||
name: "filtered out level",
|
||||
filters: StreamingFilters{
|
||||
Level: warnLevel,
|
||||
},
|
||||
expectLog: false,
|
||||
},
|
||||
{
|
||||
name: "events",
|
||||
filters: StreamingFilters{
|
||||
Events: []LogEventType{HTTP},
|
||||
},
|
||||
expectLog: true,
|
||||
},
|
||||
{
|
||||
name: "filtered out event",
|
||||
filters: StreamingFilters{
|
||||
Events: []LogEventType{Cloudflared},
|
||||
},
|
||||
expectLog: false,
|
||||
},
|
||||
{
|
||||
name: "filter and event",
|
||||
filters: StreamingFilters{
|
||||
Level: infoLevel,
|
||||
Events: []LogEventType{HTTP},
|
||||
},
|
||||
expectLog: true,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
session := newSession(4, &test.filters)
|
||||
log := Log{
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
Event: HTTP,
|
||||
Level: Info,
|
||||
Message: "test",
|
||||
}
|
||||
session.Insert(&log)
|
||||
select {
|
||||
case <-session.listener:
|
||||
require.True(t, test.expectLog)
|
||||
default:
|
||||
require.False(t, test.expectLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that the session has a max amount of events to hold
|
||||
func TestSession_InsertOverflow(t *testing.T) {
|
||||
session := newSession(1, nil)
|
||||
log := Log{
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
Event: HTTP,
|
||||
Level: Info,
|
||||
Message: "test",
|
||||
}
|
||||
// Insert 2 but only max channel size for 1
|
||||
session.Insert(&log)
|
||||
session.Insert(&log)
|
||||
select {
|
||||
case <-session.listener:
|
||||
// pass
|
||||
default:
|
||||
require.Fail(t, "expected one log event")
|
||||
}
|
||||
// Second dequeue should fail
|
||||
select {
|
||||
case <-session.listener:
|
||||
require.Fail(t, "expected no more remaining log events")
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
}
|
||||
|
||||
type mockWriter struct {
|
||||
event *Log
|
||||
err error
|
||||
|
76
management/middleware.go
Normal file
76
management/middleware.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ctxKey int
|
||||
|
||||
const (
|
||||
accessClaimsCtxKey ctxKey = iota
|
||||
)
|
||||
|
||||
const (
|
||||
connectorIDQuery = "connector_id"
|
||||
accessTokenQuery = "access_token"
|
||||
)
|
||||
|
||||
var (
|
||||
errMissingAccessToken = managementError{Code: 1001, Message: "missing access_token query parameter"}
|
||||
)
|
||||
|
||||
// HTTP middleware setting the parsed access_token claims in the request context
|
||||
func ValidateAccessTokenQueryMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Validate access token
|
||||
accessToken := r.URL.Query().Get("access_token")
|
||||
if accessToken == "" {
|
||||
writeHTTPErrorResponse(w, errMissingAccessToken)
|
||||
return
|
||||
}
|
||||
token, err := parseToken(accessToken)
|
||||
if err != nil {
|
||||
writeHTTPErrorResponse(w, errMissingAccessToken)
|
||||
return
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), accessClaimsCtxKey, token))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Middleware validation error struct for returning to the eyeball
|
||||
type managementError struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func (m *managementError) Error() string {
|
||||
return m.Message
|
||||
}
|
||||
|
||||
// Middleware validation error HTTP response JSON for returning to the eyeball
|
||||
type managementErrorResponse struct {
|
||||
Success bool `json:"success,omitempty"`
|
||||
Errors []managementError `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
// writeErrorResponse will respond to the eyeball with basic HTTP JSON payloads with validation failure information
|
||||
func writeHTTPErrorResponse(w http.ResponseWriter, errResp managementError) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
err := json.NewEncoder(w).Encode(managementErrorResponse{
|
||||
Success: false,
|
||||
Errors: []managementError{errResp},
|
||||
})
|
||||
// we have already written the header, so write a basic error response if unable to encode the error
|
||||
if err != nil {
|
||||
// fallback to text message
|
||||
http.Error(w, fmt.Sprintf(
|
||||
"%d %s",
|
||||
http.StatusBadRequest,
|
||||
http.StatusText(http.StatusBadRequest),
|
||||
), http.StatusBadRequest)
|
||||
}
|
||||
}
|
71
management/middleware_test.go
Normal file
71
management/middleware_test.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateAccessTokenQueryMiddleware(t *testing.T) {
|
||||
r := chi.NewRouter()
|
||||
r.Use(ValidateAccessTokenQueryMiddleware)
|
||||
r.Get("/valid", func(w http.ResponseWriter, r *http.Request) {
|
||||
claims, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, claims.verify())
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
r.Get("/invalid", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, ok := r.Context().Value(accessClaimsCtxKey).(*managementTokenClaims)
|
||||
assert.False(t, ok)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(r)
|
||||
defer ts.Close()
|
||||
|
||||
// valid: with access_token query param
|
||||
path := "/valid?access_token=" + validToken
|
||||
resp, _ := testRequest(t, ts, "GET", path, nil)
|
||||
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// invalid: unset token
|
||||
path = "/invalid"
|
||||
resp, err := testRequest(t, ts, "GET", path, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, errMissingAccessToken, err.Errors[0])
|
||||
|
||||
// invalid: invalid token
|
||||
path = "/invalid?access_token=eyJ"
|
||||
resp, err = testRequest(t, ts, "GET", path, nil)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||
assert.NotNil(t, err)
|
||||
assert.Equal(t, errMissingAccessToken, err.Errors[0])
|
||||
}
|
||||
|
||||
func testRequest(t *testing.T, ts *httptest.Server, method, path string, body io.Reader) (*http.Response, *managementErrorResponse) {
|
||||
req, err := http.NewRequest(method, ts.URL+path, body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
resp, err := ts.Client().Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil, nil
|
||||
}
|
||||
var claims managementErrorResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&claims)
|
||||
if err != nil {
|
||||
return resp, nil
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp, &claims
|
||||
}
|
@@ -7,7 +7,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@@ -24,7 +23,7 @@ const (
|
||||
// value will return this error to incoming requests.
|
||||
StatusSessionLimitExceeded websocket.StatusCode = 4002
|
||||
reasonSessionLimitExceeded = "limit exceeded for streaming sessions"
|
||||
|
||||
// There is a limited idle time while not actively serving a session for a request before dropping the connection.
|
||||
StatusIdleLimitExceeded websocket.StatusCode = 4003
|
||||
reasonIdleLimitExceeded = "session was idle for too long"
|
||||
)
|
||||
@@ -41,9 +40,6 @@ type ManagementService struct {
|
||||
log *zerolog.Logger
|
||||
router chi.Router
|
||||
|
||||
// streaming signifies if the service is already streaming logs. Helps limit the number of active users streaming logs
|
||||
// from this cloudflared instance.
|
||||
streaming atomic.Bool
|
||||
// streamingMut is a lock to prevent concurrent requests to start streaming. Utilizing the atomic.Bool is not
|
||||
// sufficient to complete this operation since many other checks during an incoming new request are needed
|
||||
// to validate this before setting streaming to true.
|
||||
@@ -67,6 +63,7 @@ func New(managementHostname string,
|
||||
label: label,
|
||||
}
|
||||
r := chi.NewRouter()
|
||||
r.Use(ValidateAccessTokenQueryMiddleware)
|
||||
r.Get("/ping", ping)
|
||||
r.Head("/ping", ping)
|
||||
r.Get("/logs", s.logs)
|
||||
@@ -92,7 +89,6 @@ type getHostDetailsResponse struct {
|
||||
}
|
||||
|
||||
func (m *ManagementService) getHostDetails(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var getHostDetailsResponse = getHostDetailsResponse{
|
||||
ClientID: m.clientID.String(),
|
||||
}
|
||||
@@ -157,12 +153,11 @@ func (m *ManagementService) readEvents(c *websocket.Conn, ctx context.Context, e
|
||||
}
|
||||
|
||||
// streamLogs will begin the process of reading from the Session listener and write the log events to the client.
|
||||
func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, session *Session) {
|
||||
defer m.logger.Close(session)
|
||||
for m.streaming.Load() {
|
||||
func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, session *session) {
|
||||
for session.Active() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
m.streaming.Store(false)
|
||||
session.Stop()
|
||||
return
|
||||
case event := <-session.listener:
|
||||
err := WriteEvent(c, ctx, &EventLog{
|
||||
@@ -176,7 +171,7 @@ func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, s
|
||||
m.log.Err(c.Close(websocket.StatusInternalError, err.Error())).Send()
|
||||
}
|
||||
// Any errors when writing the messages to the client will stop streaming and close the connection
|
||||
m.streaming.Store(false)
|
||||
session.Stop()
|
||||
return
|
||||
}
|
||||
default:
|
||||
@@ -185,28 +180,38 @@ func (m *ManagementService) streamLogs(c *websocket.Conn, ctx context.Context, s
|
||||
}
|
||||
}
|
||||
|
||||
// startStreaming will check the conditions of the request and begin streaming or close the connection for invalid
|
||||
// requests.
|
||||
func (m *ManagementService) startStreaming(c *websocket.Conn, ctx context.Context, event *ClientEvent) {
|
||||
// canStartStream will check the conditions of the request and return if the session can begin streaming.
|
||||
func (m *ManagementService) canStartStream(session *session) bool {
|
||||
m.streamingMut.Lock()
|
||||
defer m.streamingMut.Unlock()
|
||||
// Limits to one user for streaming logs
|
||||
if m.streaming.Load() {
|
||||
m.log.Warn().
|
||||
Msgf("Another management session request was attempted but one session already being served; there is a limit of streaming log sessions to reduce overall performance impact.")
|
||||
m.log.Err(c.Close(StatusSessionLimitExceeded, reasonSessionLimitExceeded)).Send()
|
||||
return
|
||||
// Limits to one actor for streaming logs
|
||||
if m.logger.ActiveSessions() > 0 {
|
||||
// Allow the same user to preempt their existing session to disconnect their old session and start streaming
|
||||
// with this new session instead.
|
||||
if existingSession := m.logger.ActiveSession(session.actor); existingSession != nil {
|
||||
m.log.Info().
|
||||
Msgf("Another management session request for the same actor was requested; the other session will be disconnected to handle the new request.")
|
||||
existingSession.Stop()
|
||||
m.logger.Remove(existingSession)
|
||||
existingSession.cancel()
|
||||
} else {
|
||||
m.log.Warn().
|
||||
Msgf("Another management session request was attempted but one session already being served; there is a limit of streaming log sessions to reduce overall performance impact.")
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// parseFilters will check the ClientEvent for start_streaming and assign filters if provided to the session
|
||||
func (m *ManagementService) parseFilters(c *websocket.Conn, event *ClientEvent, session *session) bool {
|
||||
// Expect the first incoming request
|
||||
startEvent, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming)
|
||||
if !ok {
|
||||
m.log.Warn().Err(c.Close(StatusInvalidCommand, reasonInvalidCommand)).Msgf("expected start_streaming as first recieved event")
|
||||
return
|
||||
return false
|
||||
}
|
||||
m.streaming.Store(true)
|
||||
listener := m.logger.Listen(startEvent.Filters)
|
||||
m.log.Debug().Msgf("Streaming logs")
|
||||
go m.streamLogs(c, ctx, listener)
|
||||
session.Filters(startEvent.Filters)
|
||||
return true
|
||||
}
|
||||
|
||||
// Management Streaming Logs accept handler
|
||||
@@ -227,11 +232,23 @@ func (m *ManagementService) logs(w http.ResponseWriter, r *http.Request) {
|
||||
ping := time.NewTicker(15 * time.Second)
|
||||
defer ping.Stop()
|
||||
|
||||
// Close the connection if no operation has occurred after the idle timeout.
|
||||
// Close the connection if no operation has occurred after the idle timeout. The timeout is halted
|
||||
// when streaming logs is active.
|
||||
idleTimeout := 5 * time.Minute
|
||||
idle := time.NewTimer(idleTimeout)
|
||||
defer idle.Stop()
|
||||
|
||||
// Fetch the claims from the request context to acquire the actor
|
||||
claims, ok := ctx.Value(accessClaimsCtxKey).(*managementTokenClaims)
|
||||
if !ok || claims == nil {
|
||||
// Typically should never happen as it is provided in the context from the middleware
|
||||
m.log.Err(c.Close(websocket.StatusInternalError, "missing access_token")).Send()
|
||||
return
|
||||
}
|
||||
|
||||
session := newSession(logWindow, claims.Actor, cancel)
|
||||
defer m.logger.Remove(session)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -242,12 +259,28 @@ func (m *ManagementService) logs(w http.ResponseWriter, r *http.Request) {
|
||||
switch event.Type {
|
||||
case StartStreaming:
|
||||
idle.Stop()
|
||||
m.startStreaming(c, ctx, event)
|
||||
// Expect the first incoming request
|
||||
startEvent, ok := IntoClientEvent[EventStartStreaming](event, StartStreaming)
|
||||
if !ok {
|
||||
m.log.Warn().Msgf("expected start_streaming as first recieved event")
|
||||
m.log.Err(c.Close(StatusInvalidCommand, reasonInvalidCommand)).Send()
|
||||
return
|
||||
}
|
||||
// Make sure the session can start
|
||||
if !m.canStartStream(session) {
|
||||
m.log.Err(c.Close(StatusSessionLimitExceeded, reasonSessionLimitExceeded)).Send()
|
||||
return
|
||||
}
|
||||
session.Filters(startEvent.Filters)
|
||||
m.logger.Listen(session)
|
||||
m.log.Debug().Msgf("Streaming logs")
|
||||
go m.streamLogs(c, ctx, session)
|
||||
continue
|
||||
case StopStreaming:
|
||||
idle.Reset(idleTimeout)
|
||||
// TODO: limit StopStreaming to only halt streaming for clients that are already streaming
|
||||
m.streaming.Store(false)
|
||||
// Stop the current session for the current actor who requested it
|
||||
session.Stop()
|
||||
m.logger.Remove(session)
|
||||
case UnknownClientEventType:
|
||||
fallthrough
|
||||
default:
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"nhooyr.io/websocket"
|
||||
|
||||
@@ -59,3 +60,67 @@ func TestReadEventsLoop_ContextCancelled(t *testing.T) {
|
||||
m.readEvents(server, ctx, events)
|
||||
server.Close(websocket.StatusInternalError, "")
|
||||
}
|
||||
|
||||
func TestCanStartStream_NoSessions(t *testing.T) {
|
||||
m := ManagementService{
|
||||
log: &noopLogger,
|
||||
logger: &Logger{
|
||||
Log: &noopLogger,
|
||||
},
|
||||
}
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
session := newSession(0, actor{}, cancel)
|
||||
assert.True(t, m.canStartStream(session))
|
||||
}
|
||||
|
||||
func TestCanStartStream_ExistingSessionDifferentActor(t *testing.T) {
|
||||
m := ManagementService{
|
||||
log: &noopLogger,
|
||||
logger: &Logger{
|
||||
Log: &noopLogger,
|
||||
},
|
||||
}
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
session1 := newSession(0, actor{ID: "test"}, cancel)
|
||||
assert.True(t, m.canStartStream(session1))
|
||||
m.logger.Listen(session1)
|
||||
assert.True(t, session1.Active())
|
||||
|
||||
// Try another session
|
||||
session2 := newSession(0, actor{ID: "test2"}, cancel)
|
||||
assert.Equal(t, 1, m.logger.ActiveSessions())
|
||||
assert.False(t, m.canStartStream(session2))
|
||||
|
||||
// Close session1
|
||||
m.logger.Remove(session1)
|
||||
assert.True(t, session1.Active()) // Remove doesn't stop a session
|
||||
session1.Stop()
|
||||
assert.False(t, session1.Active())
|
||||
assert.Equal(t, 0, m.logger.ActiveSessions())
|
||||
|
||||
// Try session2 again
|
||||
assert.True(t, m.canStartStream(session2))
|
||||
}
|
||||
|
||||
func TestCanStartStream_ExistingSessionSameActor(t *testing.T) {
|
||||
m := ManagementService{
|
||||
log: &noopLogger,
|
||||
logger: &Logger{
|
||||
Log: &noopLogger,
|
||||
},
|
||||
}
|
||||
actor := actor{ID: "test"}
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
session1 := newSession(0, actor, cancel)
|
||||
assert.True(t, m.canStartStream(session1))
|
||||
m.logger.Listen(session1)
|
||||
assert.True(t, session1.Active())
|
||||
|
||||
// Try another session
|
||||
session2 := newSession(0, actor, cancel)
|
||||
assert.Equal(t, 1, m.logger.ActiveSessions())
|
||||
assert.True(t, m.canStartStream(session2))
|
||||
// session1 is removed and stopped
|
||||
assert.Equal(t, 0, m.logger.ActiveSessions())
|
||||
assert.False(t, session1.Active())
|
||||
}
|
||||
|
88
management/session.go
Normal file
88
management/session.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
// Indicates how many log messages the listener will hold before dropping.
|
||||
// Provides a throttling mechanism to drop latest messages if the sender
|
||||
// can't keep up with the influx of log messages.
|
||||
logWindow = 30
|
||||
)
|
||||
|
||||
// session captures a streaming logs session for a connection of an actor.
|
||||
type session struct {
|
||||
// Indicates if the session is streaming or not. Modifying this will affect the active session.
|
||||
active atomic.Bool
|
||||
// Allows the session to control the context of the underlying connection to close it out when done. Mostly
|
||||
// used by the LoggerListener to close out and cleanup a session.
|
||||
cancel context.CancelFunc
|
||||
// Actor who started the session
|
||||
actor actor
|
||||
// Buffered channel that holds the recent log events
|
||||
listener chan *Log
|
||||
// Types of log events that this session will provide through the listener
|
||||
filters *StreamingFilters
|
||||
}
|
||||
|
||||
// NewSession creates a new session.
|
||||
func newSession(size int, actor actor, cancel context.CancelFunc) *session {
|
||||
s := &session{
|
||||
active: atomic.Bool{},
|
||||
cancel: cancel,
|
||||
actor: actor,
|
||||
listener: make(chan *Log, size),
|
||||
filters: &StreamingFilters{},
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// Filters assigns the StreamingFilters to the session
|
||||
func (s *session) Filters(filters *StreamingFilters) {
|
||||
if filters != nil {
|
||||
s.filters = filters
|
||||
} else {
|
||||
s.filters = &StreamingFilters{}
|
||||
}
|
||||
}
|
||||
|
||||
// Insert attempts to insert the log to the session. If the log event matches the provided session filters, it
|
||||
// will be applied to the listener.
|
||||
func (s *session) Insert(log *Log) {
|
||||
// Level filters are optional
|
||||
if s.filters.Level != nil {
|
||||
if *s.filters.Level > log.Level {
|
||||
return
|
||||
}
|
||||
}
|
||||
// Event filters are optional
|
||||
if len(s.filters.Events) != 0 && !contains(s.filters.Events, log.Event) {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case s.listener <- log:
|
||||
default:
|
||||
// buffer is full, discard
|
||||
}
|
||||
}
|
||||
|
||||
// Active returns if the session is active
|
||||
func (s *session) Active() bool {
|
||||
return s.active.Load()
|
||||
}
|
||||
|
||||
// Stop will halt the session
|
||||
func (s *session) Stop() {
|
||||
s.active.Store(false)
|
||||
}
|
||||
|
||||
func contains(array []LogEventType, t LogEventType) bool {
|
||||
for _, v := range array {
|
||||
if v == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
126
management/session_test.go
Normal file
126
management/session_test.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Validate the active states of the session
|
||||
func TestSession_ActiveControl(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
session := newSession(4, actor{}, cancel)
|
||||
// session starts out not active
|
||||
assert.False(t, session.Active())
|
||||
session.active.Store(true)
|
||||
assert.True(t, session.Active())
|
||||
session.Stop()
|
||||
assert.False(t, session.Active())
|
||||
}
|
||||
|
||||
// Validate that the session filters events
|
||||
func TestSession_Insert(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
infoLevel := new(LogLevel)
|
||||
*infoLevel = Info
|
||||
warnLevel := new(LogLevel)
|
||||
*warnLevel = Warn
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
filters StreamingFilters
|
||||
expectLog bool
|
||||
}{
|
||||
{
|
||||
name: "none",
|
||||
expectLog: true,
|
||||
},
|
||||
{
|
||||
name: "level",
|
||||
filters: StreamingFilters{
|
||||
Level: infoLevel,
|
||||
},
|
||||
expectLog: true,
|
||||
},
|
||||
{
|
||||
name: "filtered out level",
|
||||
filters: StreamingFilters{
|
||||
Level: warnLevel,
|
||||
},
|
||||
expectLog: false,
|
||||
},
|
||||
{
|
||||
name: "events",
|
||||
filters: StreamingFilters{
|
||||
Events: []LogEventType{HTTP},
|
||||
},
|
||||
expectLog: true,
|
||||
},
|
||||
{
|
||||
name: "filtered out event",
|
||||
filters: StreamingFilters{
|
||||
Events: []LogEventType{Cloudflared},
|
||||
},
|
||||
expectLog: false,
|
||||
},
|
||||
{
|
||||
name: "filter and event",
|
||||
filters: StreamingFilters{
|
||||
Level: infoLevel,
|
||||
Events: []LogEventType{HTTP},
|
||||
},
|
||||
expectLog: true,
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
session := newSession(4, actor{}, cancel)
|
||||
session.Filters(&test.filters)
|
||||
log := Log{
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
Event: HTTP,
|
||||
Level: Info,
|
||||
Message: "test",
|
||||
}
|
||||
session.Insert(&log)
|
||||
select {
|
||||
case <-session.listener:
|
||||
require.True(t, test.expectLog)
|
||||
default:
|
||||
require.False(t, test.expectLog)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Validate that the session has a max amount of events to hold
|
||||
func TestSession_InsertOverflow(t *testing.T) {
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
session := newSession(1, actor{}, cancel)
|
||||
log := Log{
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
Event: HTTP,
|
||||
Level: Info,
|
||||
Message: "test",
|
||||
}
|
||||
// Insert 2 but only max channel size for 1
|
||||
session.Insert(&log)
|
||||
session.Insert(&log)
|
||||
select {
|
||||
case <-session.listener:
|
||||
// pass
|
||||
default:
|
||||
require.Fail(t, "expected one log event")
|
||||
}
|
||||
// Second dequeue should fail
|
||||
select {
|
||||
case <-session.listener:
|
||||
require.Fail(t, "expected no more remaining log events")
|
||||
default:
|
||||
// pass
|
||||
}
|
||||
}
|
55
management/token.go
Normal file
55
management/token.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/go-jose/go-jose/v3/jwt"
|
||||
)
|
||||
|
||||
type managementTokenClaims struct {
|
||||
Tunnel tunnel `json:"tun"`
|
||||
Actor actor `json:"actor"`
|
||||
}
|
||||
|
||||
// VerifyTunnel compares the tun claim isn't empty
|
||||
func (c *managementTokenClaims) verify() bool {
|
||||
return c.Tunnel.verify() && c.Actor.verify()
|
||||
}
|
||||
|
||||
type tunnel struct {
|
||||
ID string `json:"id"`
|
||||
AccountTag string `json:"account_tag"`
|
||||
}
|
||||
|
||||
// verify compares the tun claim isn't empty
|
||||
func (t *tunnel) verify() bool {
|
||||
return t.AccountTag != "" && t.ID != ""
|
||||
}
|
||||
|
||||
type actor struct {
|
||||
ID string `json:"id"`
|
||||
Support bool `json:"support"`
|
||||
}
|
||||
|
||||
// verify checks the ID claim isn't empty
|
||||
func (t *actor) verify() bool {
|
||||
return t.ID != ""
|
||||
}
|
||||
|
||||
func parseToken(token string) (*managementTokenClaims, error) {
|
||||
jwt, err := jwt.ParseSigned(token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("malformed jwt: %v", err)
|
||||
}
|
||||
|
||||
var claims managementTokenClaims
|
||||
// This is actually safe because we verify the token in the edge before it reaches cloudflared
|
||||
err = jwt.UnsafeClaimsWithoutVerification(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("malformed jwt: %v", err)
|
||||
}
|
||||
if !claims.verify() {
|
||||
return nil, fmt.Errorf("invalid management token format provided")
|
||||
}
|
||||
return &claims, nil
|
||||
}
|
130
management/token_test.go
Normal file
130
management/token_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/square/go-jose.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
validToken = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiIsImtpZCI6IjEifQ.eyJ0dW4iOnsiaWQiOiI3YjA5ODE0OS01MWZlLTRlZTUtYTY4Ny0zZTM3NDQ2NmVmYzciLCJhY2NvdW50X3RhZyI6ImNkMzkxZTljMDYyNmE4Zjc2Y2IxZjY3MGY2NTkxYjA1In0sImFjdG9yIjp7ImlkIjoiZGNhcnJAY2xvdWRmbGFyZS5jb20iLCJzdXBwb3J0IjpmYWxzZX0sInJlcyI6WyJsb2dzIl0sImV4cCI6MTY3NzExNzY5NiwiaWF0IjoxNjc3MTE0MDk2LCJpc3MiOiJ0dW5uZWxzdG9yZSJ9.mKenOdOy3Xi4O-grldFnAAemdlE9WajEpTDC_FwezXQTstWiRTLwU65P5jt4vNsIiZA4OJRq7bH-QYID9wf9NA"
|
||||
|
||||
accountTag = "cd391e9c0626a8f76cb1f670f6591b05"
|
||||
tunnelID = "7b098149-51fe-4ee5-a687-3e374466efc7"
|
||||
actorID = "45d2751e-6b59-45a9-814d-f630786bd0cd"
|
||||
)
|
||||
|
||||
type invalidManagementTokenClaims struct {
|
||||
Invalid string `json:"invalid"`
|
||||
}
|
||||
|
||||
func TestParseToken(t *testing.T) {
|
||||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range []struct {
|
||||
name string
|
||||
claims any
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "Valid",
|
||||
claims: managementTokenClaims{
|
||||
Tunnel: tunnel{
|
||||
ID: tunnelID,
|
||||
AccountTag: accountTag,
|
||||
},
|
||||
Actor: actor{
|
||||
ID: actorID,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid claims",
|
||||
claims: invalidManagementTokenClaims{Invalid: "invalid"},
|
||||
err: errors.New("invalid management token format provided"),
|
||||
},
|
||||
{
|
||||
name: "Missing Tunnel",
|
||||
claims: managementTokenClaims{
|
||||
Actor: actor{
|
||||
ID: actorID,
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid management token format provided"),
|
||||
},
|
||||
{
|
||||
name: "Missing Tunnel ID",
|
||||
claims: managementTokenClaims{
|
||||
Tunnel: tunnel{
|
||||
AccountTag: accountTag,
|
||||
},
|
||||
Actor: actor{
|
||||
ID: actorID,
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid management token format provided"),
|
||||
},
|
||||
{
|
||||
name: "Missing Account Tag",
|
||||
claims: managementTokenClaims{
|
||||
Tunnel: tunnel{
|
||||
ID: tunnelID,
|
||||
},
|
||||
Actor: actor{
|
||||
ID: actorID,
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid management token format provided"),
|
||||
},
|
||||
{
|
||||
name: "Missing Actor",
|
||||
claims: managementTokenClaims{
|
||||
Tunnel: tunnel{
|
||||
ID: tunnelID,
|
||||
AccountTag: accountTag,
|
||||
},
|
||||
},
|
||||
err: errors.New("invalid management token format provided"),
|
||||
},
|
||||
{
|
||||
name: "Missing Actor ID",
|
||||
claims: managementTokenClaims{
|
||||
Tunnel: tunnel{
|
||||
ID: tunnelID,
|
||||
},
|
||||
Actor: actor{},
|
||||
},
|
||||
err: errors.New("invalid management token format provided"),
|
||||
},
|
||||
} {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
jwt := signToken(t, test.claims, key)
|
||||
claims, err := parseToken(jwt)
|
||||
if test.err != nil {
|
||||
require.EqualError(t, err, test.err.Error())
|
||||
return
|
||||
}
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, test.claims, *claims)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func signToken(t *testing.T, token any, key *ecdsa.PrivateKey) string {
|
||||
opts := (&jose.SignerOptions{}).WithType("JWT").WithHeader("kid", "1")
|
||||
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.ES256, Key: key}, opts)
|
||||
require.NoError(t, err)
|
||||
payload, err := json.Marshal(token)
|
||||
require.NoError(t, err)
|
||||
jws, err := signer.Sign(payload)
|
||||
require.NoError(t, err)
|
||||
jwt, err := jws.CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
return jwt
|
||||
}
|
Reference in New Issue
Block a user