mirror of
https://github.com/turbot/steampipe.git
synced 2025-12-19 09:58:53 -05:00
Fix db client deadlocks with non-blocking cleanup and RW locks (#4918)
This commit is contained in:
@@ -60,11 +60,15 @@ func doRunPluginManager(cmd *cobra.Command) error {
|
||||
log.Printf("[INFO] starting connection watcher")
|
||||
connectionWatcher, err := connection.NewConnectionWatcher(pluginManager)
|
||||
if err != nil {
|
||||
log.Printf("[ERROR] failed to create connection watcher: %v", err)
|
||||
return err
|
||||
}
|
||||
log.Printf("[INFO] connection watcher created successfully")
|
||||
|
||||
// close the connection watcher
|
||||
defer connectionWatcher.Close()
|
||||
} else {
|
||||
log.Printf("[WARN] connection watcher is DISABLED")
|
||||
}
|
||||
|
||||
log.Printf("[INFO] about to serve")
|
||||
|
||||
@@ -26,12 +26,19 @@ func NewConnectionWatcher(pluginManager pluginManager) (*ConnectionWatcher, erro
|
||||
pluginManager: pluginManager,
|
||||
}
|
||||
|
||||
configDir := filepaths.EnsureConfigDir()
|
||||
log.Printf("[INFO] ConnectionWatcher will watch directory: %s for %s files", configDir, constants.ConfigExtension)
|
||||
|
||||
watcherOptions := &filewatcher.WatcherOptions{
|
||||
Directories: []string{filepaths.EnsureConfigDir()},
|
||||
Directories: []string{configDir},
|
||||
Include: filehelpers.InclusionsFromExtensions([]string{constants.ConfigExtension}),
|
||||
ListFlag: filehelpers.FilesRecursive,
|
||||
EventMask: fsnotify.Create | fsnotify.Remove | fsnotify.Rename | fsnotify.Write | fsnotify.Chmod,
|
||||
OnChange: func(events []fsnotify.Event) {
|
||||
log.Printf("[INFO] ConnectionWatcher detected %d file events", len(events))
|
||||
for _, event := range events {
|
||||
log.Printf("[INFO] ConnectionWatcher event: %s - %s", event.Op, event.Name)
|
||||
}
|
||||
w.handleFileWatcherEvent(events)
|
||||
},
|
||||
}
|
||||
@@ -80,13 +87,17 @@ func (w *ConnectionWatcher) handleFileWatcherEvent([]fsnotify.Event) {
|
||||
// as these are both used by RefreshConnectionAndSearchPathsWithLocalClient
|
||||
|
||||
// set the global steampipe config
|
||||
log.Printf("[DEBUG] ConnectionWatcher: setting GlobalConfig")
|
||||
steampipeconfig.GlobalConfig = config
|
||||
|
||||
// call on changed callback - we must call this BEFORE calling refresh connections
|
||||
// convert config to format expected by plugin manager
|
||||
// (plugin manager cannot reference steampipe config to avoid circular deps)
|
||||
log.Printf("[DEBUG] ConnectionWatcher: creating connection config map")
|
||||
configMap := NewConnectionConfigMap(config.Connections)
|
||||
log.Printf("[DEBUG] ConnectionWatcher: calling OnConnectionConfigChanged with %d connections", len(configMap))
|
||||
w.pluginManager.OnConnectionConfigChanged(ctx, configMap, config.PluginsInstances)
|
||||
log.Printf("[DEBUG] ConnectionWatcher: OnConnectionConfigChanged complete")
|
||||
|
||||
// The only configurations from GlobalConfig which have
|
||||
// impact during Refresh are Database options and the Connections
|
||||
@@ -99,7 +110,9 @@ func (w *ConnectionWatcher) handleFileWatcherEvent([]fsnotify.Event) {
|
||||
// Workspace Profile does not have any setting which can alter
|
||||
// behavior in service mode (namely search path). Therefore, it is safe
|
||||
// to use the GlobalConfig here and ignore Workspace Profile in general
|
||||
log.Printf("[DEBUG] ConnectionWatcher: calling SetDefaultsFromConfig")
|
||||
cmdconfig.SetDefaultsFromConfig(steampipeconfig.GlobalConfig.ConfigMap())
|
||||
log.Printf("[DEBUG] ConnectionWatcher: SetDefaultsFromConfig complete")
|
||||
|
||||
log.Printf("[INFO] calling RefreshConnections asyncronously")
|
||||
|
||||
|
||||
@@ -49,13 +49,14 @@ type DbClient struct {
|
||||
sessions map[uint32]*db_common.DatabaseSession
|
||||
|
||||
// allows locked access to the 'sessions' map
|
||||
sessionsMutex *sync.Mutex
|
||||
sessionsMutex *sync.Mutex
|
||||
sessionsLockFlag atomic.Bool
|
||||
|
||||
// if a custom search path or a prefix is used, store it here
|
||||
customSearchPath []string
|
||||
searchPathPrefix []string
|
||||
// allows locked access to customSearchPath and searchPathPrefix
|
||||
searchPathMutex *sync.Mutex
|
||||
searchPathMutex *sync.RWMutex
|
||||
// the default user search path
|
||||
userSearchPath []string
|
||||
// disable timing - set whilst in process of querying the timing
|
||||
@@ -73,7 +74,7 @@ func NewDbClient(ctx context.Context, connectionString string, opts ...ClientOpt
|
||||
parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
searchPathMutex: &sync.Mutex{},
|
||||
searchPathMutex: &sync.RWMutex{},
|
||||
connectionString: connectionString,
|
||||
}
|
||||
|
||||
@@ -152,6 +153,37 @@ func (c *DbClient) shouldFetchVerboseTiming() bool {
|
||||
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
|
||||
}
|
||||
|
||||
// lockSessions acquires the sessionsMutex and tracks ownership for tryLock compatibility.
|
||||
func (c *DbClient) lockSessions() {
|
||||
if c.sessionsMutex == nil {
|
||||
return
|
||||
}
|
||||
c.sessionsLockFlag.Store(true)
|
||||
c.sessionsMutex.Lock()
|
||||
}
|
||||
|
||||
// sessionsTryLock attempts to acquire the sessionsMutex without blocking.
|
||||
// Returns false if the lock is already held.
|
||||
func (c *DbClient) sessionsTryLock() bool {
|
||||
if c.sessionsMutex == nil {
|
||||
return false
|
||||
}
|
||||
// best-effort: only one contender sets the flag and proceeds to lock
|
||||
if !c.sessionsLockFlag.CompareAndSwap(false, true) {
|
||||
return false
|
||||
}
|
||||
c.sessionsMutex.Lock()
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *DbClient) sessionsUnlock() {
|
||||
if c.sessionsMutex == nil {
|
||||
return
|
||||
}
|
||||
c.sessionsMutex.Unlock()
|
||||
c.sessionsLockFlag.Store(false)
|
||||
}
|
||||
|
||||
// ServerSettings returns the settings of the steampipe service that this DbClient is connected to
|
||||
//
|
||||
// Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected.
|
||||
@@ -173,9 +205,9 @@ func (c *DbClient) Close(context.Context) error {
|
||||
// nullify active sessions, since with the closing of the pools
|
||||
// none of the sessions will be valid anymore
|
||||
// Acquire mutex to prevent concurrent access to sessions map
|
||||
c.sessionsMutex.Lock()
|
||||
c.lockSessions()
|
||||
c.sessions = nil
|
||||
c.sessionsMutex.Unlock()
|
||||
c.sessionsUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -66,12 +66,14 @@ func (c *DbClient) establishConnectionPool(ctx context.Context, overrides client
|
||||
config.BeforeClose = func(conn *pgx.Conn) {
|
||||
if conn != nil && conn.PgConn() != nil {
|
||||
backendPid := conn.PgConn().PID()
|
||||
c.sessionsMutex.Lock()
|
||||
// Check if sessions map has been nil'd by Close()
|
||||
if c.sessions != nil {
|
||||
delete(c.sessions, backendPid)
|
||||
// Best-effort cleanup: do not block pool.Close() if sessions lock is busy.
|
||||
if c.sessionsTryLock() {
|
||||
// Check if sessions map has been nil'd by Close()
|
||||
if c.sessions != nil {
|
||||
delete(c.sessions, backendPid)
|
||||
}
|
||||
c.sessionsUnlock()
|
||||
}
|
||||
c.sessionsMutex.Unlock()
|
||||
}
|
||||
}
|
||||
// set an app name so that we can track database connections from this Steampipe execution
|
||||
|
||||
@@ -78,8 +78,8 @@ func (c *DbClient) loadUserSearchPath(ctx context.Context, connection *pgx.Conn)
|
||||
|
||||
// GetRequiredSessionSearchPath implements Client
|
||||
func (c *DbClient) GetRequiredSessionSearchPath() []string {
|
||||
c.searchPathMutex.Lock()
|
||||
defer c.searchPathMutex.Unlock()
|
||||
c.searchPathMutex.RLock()
|
||||
defer c.searchPathMutex.RUnlock()
|
||||
|
||||
if c.customSearchPath != nil {
|
||||
return c.customSearchPath
|
||||
@@ -89,8 +89,8 @@ func (c *DbClient) GetRequiredSessionSearchPath() []string {
|
||||
}
|
||||
|
||||
func (c *DbClient) GetCustomSearchPath() []string {
|
||||
c.searchPathMutex.Lock()
|
||||
defer c.searchPathMutex.Unlock()
|
||||
c.searchPathMutex.RLock()
|
||||
defer c.searchPathMutex.RUnlock()
|
||||
|
||||
return c.customSearchPath
|
||||
}
|
||||
|
||||
@@ -37,10 +37,10 @@ func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common
|
||||
}
|
||||
backendPid := databaseConnection.Conn().PgConn().PID()
|
||||
|
||||
c.sessionsMutex.Lock()
|
||||
c.lockSessions()
|
||||
// Check if client has been closed (sessions set to nil)
|
||||
if c.sessions == nil {
|
||||
c.sessionsMutex.Unlock()
|
||||
c.sessionsUnlock()
|
||||
sessionResult.Error = fmt.Errorf("client has been closed")
|
||||
return sessionResult
|
||||
}
|
||||
@@ -52,7 +52,7 @@ func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common
|
||||
// we get a new *sql.Conn everytime. USE IT!
|
||||
session.Connection = databaseConnection
|
||||
sessionResult.Session = session
|
||||
c.sessionsMutex.Unlock()
|
||||
c.sessionsUnlock()
|
||||
|
||||
// make sure that we close the acquired session, in case of error
|
||||
defer func() {
|
||||
|
||||
@@ -2,10 +2,13 @@ package db_client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/turbot/steampipe/v2/pkg/db/db_common"
|
||||
)
|
||||
|
||||
@@ -160,6 +163,28 @@ func TestDbClient_SearchPathUpdates(t *testing.T) {
|
||||
assert.Len(t, client.customSearchPath, 2, "Should have 2 schemas in search path")
|
||||
}
|
||||
|
||||
// TestSearchPathAccessShouldUseReadLocks checks that search path access does not block other goroutines unnecessarily.
|
||||
//
|
||||
// Holding an exclusive mutex during search-path reads in concurrent query setup can deadlock when
|
||||
// another goroutine is setting the path. The current code uses Lock/Unlock; this test documents
|
||||
// the expectation to move to a read/non-blocking lock so concurrent reads are safe.
|
||||
func TestSearchPathAccessShouldUseReadLocks(t *testing.T) {
|
||||
content, err := os.ReadFile("db_client_search_path.go")
|
||||
require.NoError(t, err, "should be able to read db_client_search_path.go")
|
||||
|
||||
source := string(content)
|
||||
|
||||
assert.Contains(t, source, "GetRequiredSessionSearchPath", "getter must exist")
|
||||
assert.Contains(t, source, "searchPathMutex", "getter must guard access to searchPath state")
|
||||
|
||||
// Expect a read or non-blocking lock in getters; fail if only full Lock/Unlock is present.
|
||||
hasRLock := strings.Contains(source, "RLock")
|
||||
hasTry := strings.Contains(source, "TryLock") || strings.Contains(source, "tryLock")
|
||||
if !hasRLock && !hasTry {
|
||||
t.Fatalf("GetRequiredSessionSearchPath should avoid exclusive Lock/Unlock to prevent deadlocks under concurrent query setup")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDbClient_SessionConnectionNilSafety verifies handling of nil connections
|
||||
func TestDbClient_SessionConnectionNilSafety(t *testing.T) {
|
||||
session := db_common.NewDBSession(12345)
|
||||
@@ -181,7 +206,7 @@ func TestDbClient_SessionSearchPathUpdatesThreadSafe(t *testing.T) {
|
||||
client := &DbClient{
|
||||
customSearchPath: []string{"public", "internal"},
|
||||
userSearchPath: []string{"public"},
|
||||
searchPathMutex: &sync.Mutex{},
|
||||
searchPathMutex: &sync.RWMutex{},
|
||||
}
|
||||
|
||||
// Number of concurrent operations to test
|
||||
|
||||
@@ -52,6 +52,36 @@ func TestSessionMapCleanupImplemented(t *testing.T) {
|
||||
"Comment should document automatic cleanup mechanism")
|
||||
}
|
||||
|
||||
// TestBeforeCloseCleanupShouldBeNonBlocking ensures the cleanup hook does not take a blocking lock.
|
||||
//
|
||||
// A blocking mutex in the BeforeClose hook can deadlock pool.Close() when another goroutine
|
||||
// holds sessionsMutex (service stop/restart hangs). This test is intentionally strict and
|
||||
// will fail until the hook uses a non-blocking strategy (e.g., TryLock or similar).
|
||||
func TestBeforeCloseCleanupShouldBeNonBlocking(t *testing.T) {
|
||||
content, err := os.ReadFile("db_client_connect.go")
|
||||
require.NoError(t, err, "should be able to read db_client_connect.go")
|
||||
|
||||
source := string(content)
|
||||
|
||||
// Guardrail: the BeforeClose hook should avoid unconditionally blocking on sessionsMutex.
|
||||
assert.Contains(t, source, "config.BeforeClose", "BeforeClose cleanup hook must exist")
|
||||
assert.Contains(t, source, "sessionsTryLock", "BeforeClose cleanup should use non-blocking lock helper")
|
||||
|
||||
// Expect a non-blocking lock pattern; if we only find Lock()/Unlock, this fails.
|
||||
nonBlockingPatterns := []string{"TryLock", "tryLock", "non-block", "select {"}
|
||||
foundNonBlocking := false
|
||||
for _, p := range nonBlockingPatterns {
|
||||
if strings.Contains(source, p) {
|
||||
foundNonBlocking = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundNonBlocking {
|
||||
t.Fatalf("BeforeClose cleanup appears to take a blocking lock on sessionsMutex; add a non-blocking guard to prevent pool.Close deadlocks")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDbClient_Close_Idempotent verifies that calling Close() multiple times does not cause issues
|
||||
// Reference: Similar to bug #4712 (Result.Close() idempotency)
|
||||
//
|
||||
@@ -284,13 +314,14 @@ func TestDbClient_SessionsMutexProtectsMap(t *testing.T) {
|
||||
|
||||
sourceCode := string(content)
|
||||
|
||||
// Count occurrences of mutex locks
|
||||
mutexLocks := strings.Count(sourceCode, "c.sessionsMutex.Lock()")
|
||||
// Count occurrences of mutex lock helpers
|
||||
mutexLocks := strings.Count(sourceCode, "lockSessions()") +
|
||||
strings.Count(sourceCode, "sessionsTryLock()")
|
||||
|
||||
// This is a heuristic check - in practice, we'd need more sophisticated analysis
|
||||
// But it serves as a reminder to use the mutex
|
||||
assert.True(t, mutexLocks > 0,
|
||||
"sessionsMutex.Lock() should be used when accessing sessions map")
|
||||
"sessions lock helpers should be used when accessing sessions map")
|
||||
}
|
||||
|
||||
// TestDbClient_SessionMapDocumentation verifies that session lifecycle is documented
|
||||
|
||||
@@ -51,7 +51,29 @@ type PluginManager struct {
|
||||
// map of max cache size, keyed by plugin instance
|
||||
pluginCacheSizeMap map[string]int64
|
||||
|
||||
// map lock
|
||||
// mut protects concurrent access to plugin manager state (runningPluginMap, connectionConfigMap, etc.)
|
||||
//
|
||||
// LOCKING PATTERN TO PREVENT DEADLOCKS:
|
||||
// - Functions that acquire mut.Lock() and call other methods MUST only call *Internal versions
|
||||
// - Public methods that need locking: acquire lock → call internal version → release lock
|
||||
// - Internal methods: assume caller holds lock, never acquire lock themselves
|
||||
//
|
||||
// Example:
|
||||
// func (m *PluginManager) SomeMethod() {
|
||||
// m.mut.Lock()
|
||||
// defer m.mut.Unlock()
|
||||
// return m.someMethodInternal()
|
||||
// }
|
||||
// func (m *PluginManager) someMethodInternal() {
|
||||
// // NOTE: caller must hold m.mut lock
|
||||
// // ... implementation without locking ...
|
||||
// }
|
||||
//
|
||||
// Functions with internal/external versions:
|
||||
// - refreshRateLimiterTable / refreshRateLimiterTableInternal
|
||||
// - updateRateLimiterStatus / updateRateLimiterStatusInternal
|
||||
// - setRateLimiters / setRateLimitersInternal
|
||||
// - getPluginsWithChangedLimiters / getPluginsWithChangedLimitersInternal
|
||||
mut sync.RWMutex
|
||||
|
||||
// shutdown synchronization
|
||||
@@ -231,23 +253,32 @@ func (m *PluginManager) doRefresh() {
|
||||
|
||||
// OnConnectionConfigChanged is the callback function invoked by the connection watcher when the config changed
|
||||
func (m *PluginManager) OnConnectionConfigChanged(ctx context.Context, configMap connection.ConnectionConfigMap, plugins map[string]*plugin.Plugin) {
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: acquiring lock")
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: lock acquired")
|
||||
|
||||
log.Printf("[TRACE] OnConnectionConfigChanged: connections: %s plugin instances: %s", strings.Join(utils.SortedMapKeys(configMap), ","), strings.Join(utils.SortedMapKeys(plugins), ","))
|
||||
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handleConnectionConfigChanges")
|
||||
if err := m.handleConnectionConfigChanges(ctx, configMap); err != nil {
|
||||
log.Printf("[WARN] handleConnectionConfigChanges failed: %s", err.Error())
|
||||
}
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: handleConnectionConfigChanges complete")
|
||||
|
||||
// update our plugin configs
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handlePluginInstanceChanges")
|
||||
if err := m.handlePluginInstanceChanges(ctx, plugins); err != nil {
|
||||
log.Printf("[WARN] handlePluginInstanceChanges failed: %s", err.Error())
|
||||
}
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: handlePluginInstanceChanges complete")
|
||||
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handleUserLimiterChanges")
|
||||
if err := m.handleUserLimiterChanges(ctx, plugins); err != nil {
|
||||
log.Printf("[WARN] handleUserLimiterChanges failed: %s", err.Error())
|
||||
}
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: handleUserLimiterChanges complete")
|
||||
log.Printf("[DEBUG] OnConnectionConfigChanged: about to release lock and return")
|
||||
}
|
||||
|
||||
func (m *PluginManager) GetConnectionConfig() connection.ConnectionConfigMap {
|
||||
@@ -776,14 +807,19 @@ func (m *PluginManager) setCacheOptions(pluginClient *sdkgrpc.PluginClient) erro
|
||||
}
|
||||
|
||||
func (m *PluginManager) setRateLimiters(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error {
|
||||
m.mut.RLock()
|
||||
defer m.mut.RUnlock()
|
||||
return m.setRateLimitersInternal(pluginInstance, pluginClient)
|
||||
}
|
||||
|
||||
func (m *PluginManager) setRateLimitersInternal(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error {
|
||||
// NOTE: caller must hold m.mut lock (at least RLock)
|
||||
log.Printf("[INFO] setRateLimiters for plugin '%s'", pluginInstance)
|
||||
var defs []*sdkproto.RateLimiterDefinition
|
||||
|
||||
m.mut.RLock()
|
||||
for _, l := range m.userLimiters[pluginInstance] {
|
||||
defs = append(defs, RateLimiterAsProto(l))
|
||||
}
|
||||
m.mut.RUnlock()
|
||||
|
||||
req := &sdkproto.SetRateLimitersRequest{Definitions: defs}
|
||||
|
||||
|
||||
@@ -29,6 +29,9 @@ func (m *PluginManager) ShouldFetchRateLimiterDefs() bool {
|
||||
// update the stored limiters, refrresh the rate limiter table and call `setRateLimiters`
|
||||
// for all plugins with changed limiters
|
||||
func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.PluginLimiterMap) error {
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
|
||||
if m.pluginLimiters == nil {
|
||||
// this must be the first time we have populated them
|
||||
m.pluginLimiters = make(connection.PluginLimiterMap)
|
||||
@@ -38,13 +41,22 @@ func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.Plugin
|
||||
}
|
||||
|
||||
// update the steampipe_plugin_limiters table
|
||||
if err := m.refreshRateLimiterTable(context.Background()); err != nil {
|
||||
// NOTE: we hold m.mut lock, so call internal version
|
||||
if err := m.refreshRateLimiterTableInternal(context.Background()); err != nil {
|
||||
log.Println("[WARN] could not refresh rate limiter table", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
return m.refreshRateLimiterTableInternal(ctx)
|
||||
}
|
||||
|
||||
func (m *PluginManager) refreshRateLimiterTableInternal(ctx context.Context) error {
|
||||
// NOTE: caller must hold m.mut lock
|
||||
|
||||
// if we have not yet populated the rate limiter table, do nothing
|
||||
if m.pluginLimiters == nil {
|
||||
return nil
|
||||
@@ -56,7 +68,7 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// update the status of the plugin rate limiters (determine which are overriden and set state accordingly)
|
||||
m.updateRateLimiterStatus()
|
||||
m.updateRateLimiterStatusInternal()
|
||||
|
||||
queries := []db_common.QueryWithArgs{
|
||||
introspection.GetRateLimiterTableDropSql(),
|
||||
@@ -70,13 +82,12 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
m.mut.RLock()
|
||||
// NOTE: no lock needed here, caller already holds m.mut
|
||||
for _, limitersForPlugin := range m.userLimiters {
|
||||
for _, l := range limitersForPlugin {
|
||||
queries = append(queries, introspection.GetRateLimiterTablePopulateSql(l))
|
||||
}
|
||||
}
|
||||
m.mut.RUnlock()
|
||||
|
||||
conn, err := m.pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
@@ -92,30 +103,42 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
||||
// update the stored limiters, refresh the rate limiter table and call `setRateLimiters`
|
||||
// for all plugins with changed limiters
|
||||
func (m *PluginManager) handleUserLimiterChanges(_ context.Context, plugins connection.PluginMap) error {
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: start")
|
||||
limiterPluginMap := plugins.ToPluginLimiterMap()
|
||||
pluginsWithChangedLimiters := m.getPluginsWithChangedLimiters(limiterPluginMap)
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: got limiter plugin map")
|
||||
// NOTE: caller (OnConnectionConfigChanged) already holds m.mut lock, so use internal version
|
||||
pluginsWithChangedLimiters := m.getPluginsWithChangedLimitersInternal(limiterPluginMap)
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: found %d plugins with changed limiters", len(pluginsWithChangedLimiters))
|
||||
|
||||
if len(pluginsWithChangedLimiters) == 0 {
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: no changes, returning")
|
||||
return nil
|
||||
}
|
||||
|
||||
// update stored limiters to the new map
|
||||
m.mut.Lock()
|
||||
// NOTE: caller (OnConnectionConfigChanged) already holds m.mut lock, so we don't lock here
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: updating user limiters")
|
||||
m.userLimiters = limiterPluginMap
|
||||
m.mut.Unlock()
|
||||
|
||||
// update the steampipe_plugin_limiters table
|
||||
if err := m.refreshRateLimiterTable(context.Background()); err != nil {
|
||||
// NOTE: caller already holds m.mut lock, so call internal version
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: calling refreshRateLimiterTableInternal")
|
||||
if err := m.refreshRateLimiterTableInternal(context.Background()); err != nil {
|
||||
log.Println("[WARN] could not refresh rate limiter table", err)
|
||||
}
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: refreshRateLimiterTableInternal complete")
|
||||
|
||||
// now update the plugins - call setRateLimiters for any plugin with updated user limiters
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: setting rate limiters for plugins")
|
||||
for p := range pluginsWithChangedLimiters {
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: calling setRateLimitersForPlugin for %s", p)
|
||||
if err := m.setRateLimitersForPlugin(p); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: setRateLimitersForPlugin complete for %s", p)
|
||||
}
|
||||
|
||||
log.Printf("[DEBUG] handleUserLimiterChanges: complete")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -138,17 +161,22 @@ func (m *PluginManager) setRateLimitersForPlugin(pluginShortName string) error {
|
||||
return sperr.WrapWithMessage(err, "failed to create a plugin client when updating the rate limiter for plugin '%s'", imageRef)
|
||||
}
|
||||
|
||||
if err := m.setRateLimiters(pluginShortName, pluginClient); err != nil {
|
||||
// NOTE: caller (handleUserLimiterChanges via OnConnectionConfigChanged) already holds m.mut lock
|
||||
if err := m.setRateLimitersInternal(pluginShortName, pluginClient); err != nil {
|
||||
return sperr.WrapWithMessage(err, "failed to update rate limiters for plugin '%s'", imageRef)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.PluginLimiterMap) map[string]struct{} {
|
||||
var pluginsWithChangedLimiters = make(map[string]struct{})
|
||||
|
||||
m.mut.RLock()
|
||||
defer m.mut.RUnlock()
|
||||
return m.getPluginsWithChangedLimitersInternal(newLimiters)
|
||||
}
|
||||
|
||||
func (m *PluginManager) getPluginsWithChangedLimitersInternal(newLimiters connection.PluginLimiterMap) map[string]struct{} {
|
||||
// NOTE: caller must hold m.mut lock (at least RLock)
|
||||
var pluginsWithChangedLimiters = make(map[string]struct{})
|
||||
|
||||
for plugin, limitersForPlugin := range m.userLimiters {
|
||||
newLimitersForPlugin := newLimiters[plugin]
|
||||
@@ -169,7 +197,11 @@ func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.Plu
|
||||
func (m *PluginManager) updateRateLimiterStatus() {
|
||||
m.mut.Lock()
|
||||
defer m.mut.Unlock()
|
||||
m.updateRateLimiterStatusInternal()
|
||||
}
|
||||
|
||||
func (m *PluginManager) updateRateLimiterStatusInternal() {
|
||||
// NOTE: caller must hold m.mut lock
|
||||
// iterate through limiters for each plug
|
||||
for p, pluginDefinedLimiters := range m.pluginLimiters {
|
||||
// get user limiters for this plugin (already holding lock, so call internal version)
|
||||
|
||||
Reference in New Issue
Block a user