Fix db client deadlocks with non-blocking cleanup and RW locks (#4918)

This commit is contained in:
Puskar Basu
2025-12-16 21:19:27 +05:30
committed by GitHub
parent c2421b0849
commit 3f4eaae1a8
10 changed files with 211 additions and 36 deletions

View File

@@ -60,11 +60,15 @@ func doRunPluginManager(cmd *cobra.Command) error {
log.Printf("[INFO] starting connection watcher") log.Printf("[INFO] starting connection watcher")
connectionWatcher, err := connection.NewConnectionWatcher(pluginManager) connectionWatcher, err := connection.NewConnectionWatcher(pluginManager)
if err != nil { if err != nil {
log.Printf("[ERROR] failed to create connection watcher: %v", err)
return err return err
} }
log.Printf("[INFO] connection watcher created successfully")
// close the connection watcher // close the connection watcher
defer connectionWatcher.Close() defer connectionWatcher.Close()
} else {
log.Printf("[WARN] connection watcher is DISABLED")
} }
log.Printf("[INFO] about to serve") log.Printf("[INFO] about to serve")

View File

@@ -26,12 +26,19 @@ func NewConnectionWatcher(pluginManager pluginManager) (*ConnectionWatcher, erro
pluginManager: pluginManager, pluginManager: pluginManager,
} }
configDir := filepaths.EnsureConfigDir()
log.Printf("[INFO] ConnectionWatcher will watch directory: %s for %s files", configDir, constants.ConfigExtension)
watcherOptions := &filewatcher.WatcherOptions{ watcherOptions := &filewatcher.WatcherOptions{
Directories: []string{filepaths.EnsureConfigDir()}, Directories: []string{configDir},
Include: filehelpers.InclusionsFromExtensions([]string{constants.ConfigExtension}), Include: filehelpers.InclusionsFromExtensions([]string{constants.ConfigExtension}),
ListFlag: filehelpers.FilesRecursive, ListFlag: filehelpers.FilesRecursive,
EventMask: fsnotify.Create | fsnotify.Remove | fsnotify.Rename | fsnotify.Write | fsnotify.Chmod, EventMask: fsnotify.Create | fsnotify.Remove | fsnotify.Rename | fsnotify.Write | fsnotify.Chmod,
OnChange: func(events []fsnotify.Event) { OnChange: func(events []fsnotify.Event) {
log.Printf("[INFO] ConnectionWatcher detected %d file events", len(events))
for _, event := range events {
log.Printf("[INFO] ConnectionWatcher event: %s - %s", event.Op, event.Name)
}
w.handleFileWatcherEvent(events) w.handleFileWatcherEvent(events)
}, },
} }
@@ -80,13 +87,17 @@ func (w *ConnectionWatcher) handleFileWatcherEvent([]fsnotify.Event) {
// as these are both used by RefreshConnectionAndSearchPathsWithLocalClient // as these are both used by RefreshConnectionAndSearchPathsWithLocalClient
// set the global steampipe config // set the global steampipe config
log.Printf("[DEBUG] ConnectionWatcher: setting GlobalConfig")
steampipeconfig.GlobalConfig = config steampipeconfig.GlobalConfig = config
// call on changed callback - we must call this BEFORE calling refresh connections // call on changed callback - we must call this BEFORE calling refresh connections
// convert config to format expected by plugin manager // convert config to format expected by plugin manager
// (plugin manager cannot reference steampipe config to avoid circular deps) // (plugin manager cannot reference steampipe config to avoid circular deps)
log.Printf("[DEBUG] ConnectionWatcher: creating connection config map")
configMap := NewConnectionConfigMap(config.Connections) configMap := NewConnectionConfigMap(config.Connections)
log.Printf("[DEBUG] ConnectionWatcher: calling OnConnectionConfigChanged with %d connections", len(configMap))
w.pluginManager.OnConnectionConfigChanged(ctx, configMap, config.PluginsInstances) w.pluginManager.OnConnectionConfigChanged(ctx, configMap, config.PluginsInstances)
log.Printf("[DEBUG] ConnectionWatcher: OnConnectionConfigChanged complete")
// The only configurations from GlobalConfig which have // The only configurations from GlobalConfig which have
// impact during Refresh are Database options and the Connections // impact during Refresh are Database options and the Connections
@@ -99,7 +110,9 @@ func (w *ConnectionWatcher) handleFileWatcherEvent([]fsnotify.Event) {
// Workspace Profile does not have any setting which can alter // Workspace Profile does not have any setting which can alter
// behavior in service mode (namely search path). Therefore, it is safe // behavior in service mode (namely search path). Therefore, it is safe
// to use the GlobalConfig here and ignore Workspace Profile in general // to use the GlobalConfig here and ignore Workspace Profile in general
log.Printf("[DEBUG] ConnectionWatcher: calling SetDefaultsFromConfig")
cmdconfig.SetDefaultsFromConfig(steampipeconfig.GlobalConfig.ConfigMap()) cmdconfig.SetDefaultsFromConfig(steampipeconfig.GlobalConfig.ConfigMap())
log.Printf("[DEBUG] ConnectionWatcher: SetDefaultsFromConfig complete")
log.Printf("[INFO] calling RefreshConnections asyncronously") log.Printf("[INFO] calling RefreshConnections asyncronously")

View File

@@ -50,12 +50,13 @@ type DbClient struct {
// allows locked access to the 'sessions' map // allows locked access to the 'sessions' map
sessionsMutex *sync.Mutex sessionsMutex *sync.Mutex
sessionsLockFlag atomic.Bool
// if a custom search path or a prefix is used, store it here // if a custom search path or a prefix is used, store it here
customSearchPath []string customSearchPath []string
searchPathPrefix []string searchPathPrefix []string
// allows locked access to customSearchPath and searchPathPrefix // allows locked access to customSearchPath and searchPathPrefix
searchPathMutex *sync.Mutex searchPathMutex *sync.RWMutex
// the default user search path // the default user search path
userSearchPath []string userSearchPath []string
// disable timing - set whilst in process of querying the timing // disable timing - set whilst in process of querying the timing
@@ -73,7 +74,7 @@ func NewDbClient(ctx context.Context, connectionString string, opts ...ClientOpt
parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits), parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
sessions: make(map[uint32]*db_common.DatabaseSession), sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{}, sessionsMutex: &sync.Mutex{},
searchPathMutex: &sync.Mutex{}, searchPathMutex: &sync.RWMutex{},
connectionString: connectionString, connectionString: connectionString,
} }
@@ -152,6 +153,37 @@ func (c *DbClient) shouldFetchVerboseTiming() bool {
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON) (viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
} }
// lockSessions acquires the sessionsMutex and tracks ownership for tryLock compatibility.
func (c *DbClient) lockSessions() {
if c.sessionsMutex == nil {
return
}
c.sessionsLockFlag.Store(true)
c.sessionsMutex.Lock()
}
// sessionsTryLock attempts to acquire the sessionsMutex without blocking.
// Returns false if the lock is already held.
func (c *DbClient) sessionsTryLock() bool {
if c.sessionsMutex == nil {
return false
}
// best-effort: only one contender sets the flag and proceeds to lock
if !c.sessionsLockFlag.CompareAndSwap(false, true) {
return false
}
c.sessionsMutex.Lock()
return true
}
func (c *DbClient) sessionsUnlock() {
if c.sessionsMutex == nil {
return
}
c.sessionsMutex.Unlock()
c.sessionsLockFlag.Store(false)
}
// ServerSettings returns the settings of the steampipe service that this DbClient is connected to // ServerSettings returns the settings of the steampipe service that this DbClient is connected to
// //
// Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected. // Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected.
@@ -173,9 +205,9 @@ func (c *DbClient) Close(context.Context) error {
// nullify active sessions, since with the closing of the pools // nullify active sessions, since with the closing of the pools
// none of the sessions will be valid anymore // none of the sessions will be valid anymore
// Acquire mutex to prevent concurrent access to sessions map // Acquire mutex to prevent concurrent access to sessions map
c.sessionsMutex.Lock() c.lockSessions()
c.sessions = nil c.sessions = nil
c.sessionsMutex.Unlock() c.sessionsUnlock()
return nil return nil
} }

View File

@@ -66,12 +66,14 @@ func (c *DbClient) establishConnectionPool(ctx context.Context, overrides client
config.BeforeClose = func(conn *pgx.Conn) { config.BeforeClose = func(conn *pgx.Conn) {
if conn != nil && conn.PgConn() != nil { if conn != nil && conn.PgConn() != nil {
backendPid := conn.PgConn().PID() backendPid := conn.PgConn().PID()
c.sessionsMutex.Lock() // Best-effort cleanup: do not block pool.Close() if sessions lock is busy.
if c.sessionsTryLock() {
// Check if sessions map has been nil'd by Close() // Check if sessions map has been nil'd by Close()
if c.sessions != nil { if c.sessions != nil {
delete(c.sessions, backendPid) delete(c.sessions, backendPid)
} }
c.sessionsMutex.Unlock() c.sessionsUnlock()
}
} }
} }
// set an app name so that we can track database connections from this Steampipe execution // set an app name so that we can track database connections from this Steampipe execution

View File

@@ -78,8 +78,8 @@ func (c *DbClient) loadUserSearchPath(ctx context.Context, connection *pgx.Conn)
// GetRequiredSessionSearchPath implements Client // GetRequiredSessionSearchPath implements Client
func (c *DbClient) GetRequiredSessionSearchPath() []string { func (c *DbClient) GetRequiredSessionSearchPath() []string {
c.searchPathMutex.Lock() c.searchPathMutex.RLock()
defer c.searchPathMutex.Unlock() defer c.searchPathMutex.RUnlock()
if c.customSearchPath != nil { if c.customSearchPath != nil {
return c.customSearchPath return c.customSearchPath
@@ -89,8 +89,8 @@ func (c *DbClient) GetRequiredSessionSearchPath() []string {
} }
func (c *DbClient) GetCustomSearchPath() []string { func (c *DbClient) GetCustomSearchPath() []string {
c.searchPathMutex.Lock() c.searchPathMutex.RLock()
defer c.searchPathMutex.Unlock() defer c.searchPathMutex.RUnlock()
return c.customSearchPath return c.customSearchPath
} }

View File

@@ -37,10 +37,10 @@ func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common
} }
backendPid := databaseConnection.Conn().PgConn().PID() backendPid := databaseConnection.Conn().PgConn().PID()
c.sessionsMutex.Lock() c.lockSessions()
// Check if client has been closed (sessions set to nil) // Check if client has been closed (sessions set to nil)
if c.sessions == nil { if c.sessions == nil {
c.sessionsMutex.Unlock() c.sessionsUnlock()
sessionResult.Error = fmt.Errorf("client has been closed") sessionResult.Error = fmt.Errorf("client has been closed")
return sessionResult return sessionResult
} }
@@ -52,7 +52,7 @@ func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common
// we get a new *sql.Conn everytime. USE IT! // we get a new *sql.Conn everytime. USE IT!
session.Connection = databaseConnection session.Connection = databaseConnection
sessionResult.Session = session sessionResult.Session = session
c.sessionsMutex.Unlock() c.sessionsUnlock()
// make sure that we close the acquired session, in case of error // make sure that we close the acquired session, in case of error
defer func() { defer func() {

View File

@@ -2,10 +2,13 @@ package db_client
import ( import (
"context" "context"
"os"
"strings"
"sync" "sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/turbot/steampipe/v2/pkg/db/db_common" "github.com/turbot/steampipe/v2/pkg/db/db_common"
) )
@@ -160,6 +163,28 @@ func TestDbClient_SearchPathUpdates(t *testing.T) {
assert.Len(t, client.customSearchPath, 2, "Should have 2 schemas in search path") assert.Len(t, client.customSearchPath, 2, "Should have 2 schemas in search path")
} }
// TestSearchPathAccessShouldUseReadLocks checks that search path access does not block other goroutines unnecessarily.
//
// Holding an exclusive mutex during search-path reads in concurrent query setup can deadlock when
// another goroutine is setting the path. The current code uses Lock/Unlock; this test documents
// the expectation to move to a read/non-blocking lock so concurrent reads are safe.
func TestSearchPathAccessShouldUseReadLocks(t *testing.T) {
content, err := os.ReadFile("db_client_search_path.go")
require.NoError(t, err, "should be able to read db_client_search_path.go")
source := string(content)
assert.Contains(t, source, "GetRequiredSessionSearchPath", "getter must exist")
assert.Contains(t, source, "searchPathMutex", "getter must guard access to searchPath state")
// Expect a read or non-blocking lock in getters; fail if only full Lock/Unlock is present.
hasRLock := strings.Contains(source, "RLock")
hasTry := strings.Contains(source, "TryLock") || strings.Contains(source, "tryLock")
if !hasRLock && !hasTry {
t.Fatalf("GetRequiredSessionSearchPath should avoid exclusive Lock/Unlock to prevent deadlocks under concurrent query setup")
}
}
// TestDbClient_SessionConnectionNilSafety verifies handling of nil connections // TestDbClient_SessionConnectionNilSafety verifies handling of nil connections
func TestDbClient_SessionConnectionNilSafety(t *testing.T) { func TestDbClient_SessionConnectionNilSafety(t *testing.T) {
session := db_common.NewDBSession(12345) session := db_common.NewDBSession(12345)
@@ -181,7 +206,7 @@ func TestDbClient_SessionSearchPathUpdatesThreadSafe(t *testing.T) {
client := &DbClient{ client := &DbClient{
customSearchPath: []string{"public", "internal"}, customSearchPath: []string{"public", "internal"},
userSearchPath: []string{"public"}, userSearchPath: []string{"public"},
searchPathMutex: &sync.Mutex{}, searchPathMutex: &sync.RWMutex{},
} }
// Number of concurrent operations to test // Number of concurrent operations to test

View File

@@ -52,6 +52,36 @@ func TestSessionMapCleanupImplemented(t *testing.T) {
"Comment should document automatic cleanup mechanism") "Comment should document automatic cleanup mechanism")
} }
// TestBeforeCloseCleanupShouldBeNonBlocking ensures the cleanup hook does not take a blocking lock.
//
// A blocking mutex in the BeforeClose hook can deadlock pool.Close() when another goroutine
// holds sessionsMutex (service stop/restart hangs). This test is intentionally strict and
// will fail until the hook uses a non-blocking strategy (e.g., TryLock or similar).
func TestBeforeCloseCleanupShouldBeNonBlocking(t *testing.T) {
content, err := os.ReadFile("db_client_connect.go")
require.NoError(t, err, "should be able to read db_client_connect.go")
source := string(content)
// Guardrail: the BeforeClose hook should avoid unconditionally blocking on sessionsMutex.
assert.Contains(t, source, "config.BeforeClose", "BeforeClose cleanup hook must exist")
assert.Contains(t, source, "sessionsTryLock", "BeforeClose cleanup should use non-blocking lock helper")
// Expect a non-blocking lock pattern; if we only find Lock()/Unlock, this fails.
nonBlockingPatterns := []string{"TryLock", "tryLock", "non-block", "select {"}
foundNonBlocking := false
for _, p := range nonBlockingPatterns {
if strings.Contains(source, p) {
foundNonBlocking = true
break
}
}
if !foundNonBlocking {
t.Fatalf("BeforeClose cleanup appears to take a blocking lock on sessionsMutex; add a non-blocking guard to prevent pool.Close deadlocks")
}
}
// TestDbClient_Close_Idempotent verifies that calling Close() multiple times does not cause issues // TestDbClient_Close_Idempotent verifies that calling Close() multiple times does not cause issues
// Reference: Similar to bug #4712 (Result.Close() idempotency) // Reference: Similar to bug #4712 (Result.Close() idempotency)
// //
@@ -284,13 +314,14 @@ func TestDbClient_SessionsMutexProtectsMap(t *testing.T) {
sourceCode := string(content) sourceCode := string(content)
// Count occurrences of mutex locks // Count occurrences of mutex lock helpers
mutexLocks := strings.Count(sourceCode, "c.sessionsMutex.Lock()") mutexLocks := strings.Count(sourceCode, "lockSessions()") +
strings.Count(sourceCode, "sessionsTryLock()")
// This is a heuristic check - in practice, we'd need more sophisticated analysis // This is a heuristic check - in practice, we'd need more sophisticated analysis
// But it serves as a reminder to use the mutex // But it serves as a reminder to use the mutex
assert.True(t, mutexLocks > 0, assert.True(t, mutexLocks > 0,
"sessionsMutex.Lock() should be used when accessing sessions map") "sessions lock helpers should be used when accessing sessions map")
} }
// TestDbClient_SessionMapDocumentation verifies that session lifecycle is documented // TestDbClient_SessionMapDocumentation verifies that session lifecycle is documented

View File

@@ -51,7 +51,29 @@ type PluginManager struct {
// map of max cache size, keyed by plugin instance // map of max cache size, keyed by plugin instance
pluginCacheSizeMap map[string]int64 pluginCacheSizeMap map[string]int64
// map lock // mut protects concurrent access to plugin manager state (runningPluginMap, connectionConfigMap, etc.)
//
// LOCKING PATTERN TO PREVENT DEADLOCKS:
// - Functions that acquire mut.Lock() and call other methods MUST only call *Internal versions
// - Public methods that need locking: acquire lock → call internal version → release lock
// - Internal methods: assume caller holds lock, never acquire lock themselves
//
// Example:
// func (m *PluginManager) SomeMethod() {
// m.mut.Lock()
// defer m.mut.Unlock()
// return m.someMethodInternal()
// }
// func (m *PluginManager) someMethodInternal() {
// // NOTE: caller must hold m.mut lock
// // ... implementation without locking ...
// }
//
// Functions with internal/external versions:
// - refreshRateLimiterTable / refreshRateLimiterTableInternal
// - updateRateLimiterStatus / updateRateLimiterStatusInternal
// - setRateLimiters / setRateLimitersInternal
// - getPluginsWithChangedLimiters / getPluginsWithChangedLimitersInternal
mut sync.RWMutex mut sync.RWMutex
// shutdown synchronization // shutdown synchronization
@@ -231,23 +253,32 @@ func (m *PluginManager) doRefresh() {
// OnConnectionConfigChanged is the callback function invoked by the connection watcher when the config changed // OnConnectionConfigChanged is the callback function invoked by the connection watcher when the config changed
func (m *PluginManager) OnConnectionConfigChanged(ctx context.Context, configMap connection.ConnectionConfigMap, plugins map[string]*plugin.Plugin) { func (m *PluginManager) OnConnectionConfigChanged(ctx context.Context, configMap connection.ConnectionConfigMap, plugins map[string]*plugin.Plugin) {
log.Printf("[DEBUG] OnConnectionConfigChanged: acquiring lock")
m.mut.Lock() m.mut.Lock()
defer m.mut.Unlock() defer m.mut.Unlock()
log.Printf("[DEBUG] OnConnectionConfigChanged: lock acquired")
log.Printf("[TRACE] OnConnectionConfigChanged: connections: %s plugin instances: %s", strings.Join(utils.SortedMapKeys(configMap), ","), strings.Join(utils.SortedMapKeys(plugins), ",")) log.Printf("[TRACE] OnConnectionConfigChanged: connections: %s plugin instances: %s", strings.Join(utils.SortedMapKeys(configMap), ","), strings.Join(utils.SortedMapKeys(plugins), ","))
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handleConnectionConfigChanges")
if err := m.handleConnectionConfigChanges(ctx, configMap); err != nil { if err := m.handleConnectionConfigChanges(ctx, configMap); err != nil {
log.Printf("[WARN] handleConnectionConfigChanges failed: %s", err.Error()) log.Printf("[WARN] handleConnectionConfigChanges failed: %s", err.Error())
} }
log.Printf("[DEBUG] OnConnectionConfigChanged: handleConnectionConfigChanges complete")
// update our plugin configs // update our plugin configs
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handlePluginInstanceChanges")
if err := m.handlePluginInstanceChanges(ctx, plugins); err != nil { if err := m.handlePluginInstanceChanges(ctx, plugins); err != nil {
log.Printf("[WARN] handlePluginInstanceChanges failed: %s", err.Error()) log.Printf("[WARN] handlePluginInstanceChanges failed: %s", err.Error())
} }
log.Printf("[DEBUG] OnConnectionConfigChanged: handlePluginInstanceChanges complete")
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handleUserLimiterChanges")
if err := m.handleUserLimiterChanges(ctx, plugins); err != nil { if err := m.handleUserLimiterChanges(ctx, plugins); err != nil {
log.Printf("[WARN] handleUserLimiterChanges failed: %s", err.Error()) log.Printf("[WARN] handleUserLimiterChanges failed: %s", err.Error())
} }
log.Printf("[DEBUG] OnConnectionConfigChanged: handleUserLimiterChanges complete")
log.Printf("[DEBUG] OnConnectionConfigChanged: about to release lock and return")
} }
func (m *PluginManager) GetConnectionConfig() connection.ConnectionConfigMap { func (m *PluginManager) GetConnectionConfig() connection.ConnectionConfigMap {
@@ -776,14 +807,19 @@ func (m *PluginManager) setCacheOptions(pluginClient *sdkgrpc.PluginClient) erro
} }
func (m *PluginManager) setRateLimiters(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error { func (m *PluginManager) setRateLimiters(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error {
m.mut.RLock()
defer m.mut.RUnlock()
return m.setRateLimitersInternal(pluginInstance, pluginClient)
}
func (m *PluginManager) setRateLimitersInternal(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error {
// NOTE: caller must hold m.mut lock (at least RLock)
log.Printf("[INFO] setRateLimiters for plugin '%s'", pluginInstance) log.Printf("[INFO] setRateLimiters for plugin '%s'", pluginInstance)
var defs []*sdkproto.RateLimiterDefinition var defs []*sdkproto.RateLimiterDefinition
m.mut.RLock()
for _, l := range m.userLimiters[pluginInstance] { for _, l := range m.userLimiters[pluginInstance] {
defs = append(defs, RateLimiterAsProto(l)) defs = append(defs, RateLimiterAsProto(l))
} }
m.mut.RUnlock()
req := &sdkproto.SetRateLimitersRequest{Definitions: defs} req := &sdkproto.SetRateLimitersRequest{Definitions: defs}

View File

@@ -29,6 +29,9 @@ func (m *PluginManager) ShouldFetchRateLimiterDefs() bool {
// update the stored limiters, refrresh the rate limiter table and call `setRateLimiters` // update the stored limiters, refrresh the rate limiter table and call `setRateLimiters`
// for all plugins with changed limiters // for all plugins with changed limiters
func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.PluginLimiterMap) error { func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.PluginLimiterMap) error {
m.mut.Lock()
defer m.mut.Unlock()
if m.pluginLimiters == nil { if m.pluginLimiters == nil {
// this must be the first time we have populated them // this must be the first time we have populated them
m.pluginLimiters = make(connection.PluginLimiterMap) m.pluginLimiters = make(connection.PluginLimiterMap)
@@ -38,13 +41,22 @@ func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.Plugin
} }
// update the steampipe_plugin_limiters table // update the steampipe_plugin_limiters table
if err := m.refreshRateLimiterTable(context.Background()); err != nil { // NOTE: we hold m.mut lock, so call internal version
if err := m.refreshRateLimiterTableInternal(context.Background()); err != nil {
log.Println("[WARN] could not refresh rate limiter table", err) log.Println("[WARN] could not refresh rate limiter table", err)
} }
return nil return nil
} }
func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error { func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
m.mut.Lock()
defer m.mut.Unlock()
return m.refreshRateLimiterTableInternal(ctx)
}
func (m *PluginManager) refreshRateLimiterTableInternal(ctx context.Context) error {
// NOTE: caller must hold m.mut lock
// if we have not yet populated the rate limiter table, do nothing // if we have not yet populated the rate limiter table, do nothing
if m.pluginLimiters == nil { if m.pluginLimiters == nil {
return nil return nil
@@ -56,7 +68,7 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
} }
// update the status of the plugin rate limiters (determine which are overriden and set state accordingly) // update the status of the plugin rate limiters (determine which are overriden and set state accordingly)
m.updateRateLimiterStatus() m.updateRateLimiterStatusInternal()
queries := []db_common.QueryWithArgs{ queries := []db_common.QueryWithArgs{
introspection.GetRateLimiterTableDropSql(), introspection.GetRateLimiterTableDropSql(),
@@ -70,13 +82,12 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
} }
} }
m.mut.RLock() // NOTE: no lock needed here, caller already holds m.mut
for _, limitersForPlugin := range m.userLimiters { for _, limitersForPlugin := range m.userLimiters {
for _, l := range limitersForPlugin { for _, l := range limitersForPlugin {
queries = append(queries, introspection.GetRateLimiterTablePopulateSql(l)) queries = append(queries, introspection.GetRateLimiterTablePopulateSql(l))
} }
} }
m.mut.RUnlock()
conn, err := m.pool.Acquire(ctx) conn, err := m.pool.Acquire(ctx)
if err != nil { if err != nil {
@@ -92,30 +103,42 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
// update the stored limiters, refresh the rate limiter table and call `setRateLimiters` // update the stored limiters, refresh the rate limiter table and call `setRateLimiters`
// for all plugins with changed limiters // for all plugins with changed limiters
func (m *PluginManager) handleUserLimiterChanges(_ context.Context, plugins connection.PluginMap) error { func (m *PluginManager) handleUserLimiterChanges(_ context.Context, plugins connection.PluginMap) error {
log.Printf("[DEBUG] handleUserLimiterChanges: start")
limiterPluginMap := plugins.ToPluginLimiterMap() limiterPluginMap := plugins.ToPluginLimiterMap()
pluginsWithChangedLimiters := m.getPluginsWithChangedLimiters(limiterPluginMap) log.Printf("[DEBUG] handleUserLimiterChanges: got limiter plugin map")
// NOTE: caller (OnConnectionConfigChanged) already holds m.mut lock, so use internal version
pluginsWithChangedLimiters := m.getPluginsWithChangedLimitersInternal(limiterPluginMap)
log.Printf("[DEBUG] handleUserLimiterChanges: found %d plugins with changed limiters", len(pluginsWithChangedLimiters))
if len(pluginsWithChangedLimiters) == 0 { if len(pluginsWithChangedLimiters) == 0 {
log.Printf("[DEBUG] handleUserLimiterChanges: no changes, returning")
return nil return nil
} }
// update stored limiters to the new map // update stored limiters to the new map
m.mut.Lock() // NOTE: caller (OnConnectionConfigChanged) already holds m.mut lock, so we don't lock here
log.Printf("[DEBUG] handleUserLimiterChanges: updating user limiters")
m.userLimiters = limiterPluginMap m.userLimiters = limiterPluginMap
m.mut.Unlock()
// update the steampipe_plugin_limiters table // update the steampipe_plugin_limiters table
if err := m.refreshRateLimiterTable(context.Background()); err != nil { // NOTE: caller already holds m.mut lock, so call internal version
log.Printf("[DEBUG] handleUserLimiterChanges: calling refreshRateLimiterTableInternal")
if err := m.refreshRateLimiterTableInternal(context.Background()); err != nil {
log.Println("[WARN] could not refresh rate limiter table", err) log.Println("[WARN] could not refresh rate limiter table", err)
} }
log.Printf("[DEBUG] handleUserLimiterChanges: refreshRateLimiterTableInternal complete")
// now update the plugins - call setRateLimiters for any plugin with updated user limiters // now update the plugins - call setRateLimiters for any plugin with updated user limiters
log.Printf("[DEBUG] handleUserLimiterChanges: setting rate limiters for plugins")
for p := range pluginsWithChangedLimiters { for p := range pluginsWithChangedLimiters {
log.Printf("[DEBUG] handleUserLimiterChanges: calling setRateLimitersForPlugin for %s", p)
if err := m.setRateLimitersForPlugin(p); err != nil { if err := m.setRateLimitersForPlugin(p); err != nil {
return err return err
} }
log.Printf("[DEBUG] handleUserLimiterChanges: setRateLimitersForPlugin complete for %s", p)
} }
log.Printf("[DEBUG] handleUserLimiterChanges: complete")
return nil return nil
} }
@@ -138,17 +161,22 @@ func (m *PluginManager) setRateLimitersForPlugin(pluginShortName string) error {
return sperr.WrapWithMessage(err, "failed to create a plugin client when updating the rate limiter for plugin '%s'", imageRef) return sperr.WrapWithMessage(err, "failed to create a plugin client when updating the rate limiter for plugin '%s'", imageRef)
} }
if err := m.setRateLimiters(pluginShortName, pluginClient); err != nil { // NOTE: caller (handleUserLimiterChanges via OnConnectionConfigChanged) already holds m.mut lock
if err := m.setRateLimitersInternal(pluginShortName, pluginClient); err != nil {
return sperr.WrapWithMessage(err, "failed to update rate limiters for plugin '%s'", imageRef) return sperr.WrapWithMessage(err, "failed to update rate limiters for plugin '%s'", imageRef)
} }
return nil return nil
} }
func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.PluginLimiterMap) map[string]struct{} { func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.PluginLimiterMap) map[string]struct{} {
var pluginsWithChangedLimiters = make(map[string]struct{})
m.mut.RLock() m.mut.RLock()
defer m.mut.RUnlock() defer m.mut.RUnlock()
return m.getPluginsWithChangedLimitersInternal(newLimiters)
}
func (m *PluginManager) getPluginsWithChangedLimitersInternal(newLimiters connection.PluginLimiterMap) map[string]struct{} {
// NOTE: caller must hold m.mut lock (at least RLock)
var pluginsWithChangedLimiters = make(map[string]struct{})
for plugin, limitersForPlugin := range m.userLimiters { for plugin, limitersForPlugin := range m.userLimiters {
newLimitersForPlugin := newLimiters[plugin] newLimitersForPlugin := newLimiters[plugin]
@@ -169,7 +197,11 @@ func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.Plu
func (m *PluginManager) updateRateLimiterStatus() { func (m *PluginManager) updateRateLimiterStatus() {
m.mut.Lock() m.mut.Lock()
defer m.mut.Unlock() defer m.mut.Unlock()
m.updateRateLimiterStatusInternal()
}
func (m *PluginManager) updateRateLimiterStatusInternal() {
// NOTE: caller must hold m.mut lock
// iterate through limiters for each plug // iterate through limiters for each plug
for p, pluginDefinedLimiters := range m.pluginLimiters { for p, pluginDefinedLimiters := range m.pluginLimiters {
// get user limiters for this plugin (already holding lock, so call internal version) // get user limiters for this plugin (already holding lock, so call internal version)