mirror of
https://github.com/turbot/steampipe.git
synced 2025-12-19 18:12:43 -05:00
Fix db client deadlocks with non-blocking cleanup and RW locks (#4918)
This commit is contained in:
@@ -60,11 +60,15 @@ func doRunPluginManager(cmd *cobra.Command) error {
|
|||||||
log.Printf("[INFO] starting connection watcher")
|
log.Printf("[INFO] starting connection watcher")
|
||||||
connectionWatcher, err := connection.NewConnectionWatcher(pluginManager)
|
connectionWatcher, err := connection.NewConnectionWatcher(pluginManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Printf("[ERROR] failed to create connection watcher: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
log.Printf("[INFO] connection watcher created successfully")
|
||||||
|
|
||||||
// close the connection watcher
|
// close the connection watcher
|
||||||
defer connectionWatcher.Close()
|
defer connectionWatcher.Close()
|
||||||
|
} else {
|
||||||
|
log.Printf("[WARN] connection watcher is DISABLED")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("[INFO] about to serve")
|
log.Printf("[INFO] about to serve")
|
||||||
|
|||||||
@@ -26,12 +26,19 @@ func NewConnectionWatcher(pluginManager pluginManager) (*ConnectionWatcher, erro
|
|||||||
pluginManager: pluginManager,
|
pluginManager: pluginManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
configDir := filepaths.EnsureConfigDir()
|
||||||
|
log.Printf("[INFO] ConnectionWatcher will watch directory: %s for %s files", configDir, constants.ConfigExtension)
|
||||||
|
|
||||||
watcherOptions := &filewatcher.WatcherOptions{
|
watcherOptions := &filewatcher.WatcherOptions{
|
||||||
Directories: []string{filepaths.EnsureConfigDir()},
|
Directories: []string{configDir},
|
||||||
Include: filehelpers.InclusionsFromExtensions([]string{constants.ConfigExtension}),
|
Include: filehelpers.InclusionsFromExtensions([]string{constants.ConfigExtension}),
|
||||||
ListFlag: filehelpers.FilesRecursive,
|
ListFlag: filehelpers.FilesRecursive,
|
||||||
EventMask: fsnotify.Create | fsnotify.Remove | fsnotify.Rename | fsnotify.Write | fsnotify.Chmod,
|
EventMask: fsnotify.Create | fsnotify.Remove | fsnotify.Rename | fsnotify.Write | fsnotify.Chmod,
|
||||||
OnChange: func(events []fsnotify.Event) {
|
OnChange: func(events []fsnotify.Event) {
|
||||||
|
log.Printf("[INFO] ConnectionWatcher detected %d file events", len(events))
|
||||||
|
for _, event := range events {
|
||||||
|
log.Printf("[INFO] ConnectionWatcher event: %s - %s", event.Op, event.Name)
|
||||||
|
}
|
||||||
w.handleFileWatcherEvent(events)
|
w.handleFileWatcherEvent(events)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -80,13 +87,17 @@ func (w *ConnectionWatcher) handleFileWatcherEvent([]fsnotify.Event) {
|
|||||||
// as these are both used by RefreshConnectionAndSearchPathsWithLocalClient
|
// as these are both used by RefreshConnectionAndSearchPathsWithLocalClient
|
||||||
|
|
||||||
// set the global steampipe config
|
// set the global steampipe config
|
||||||
|
log.Printf("[DEBUG] ConnectionWatcher: setting GlobalConfig")
|
||||||
steampipeconfig.GlobalConfig = config
|
steampipeconfig.GlobalConfig = config
|
||||||
|
|
||||||
// call on changed callback - we must call this BEFORE calling refresh connections
|
// call on changed callback - we must call this BEFORE calling refresh connections
|
||||||
// convert config to format expected by plugin manager
|
// convert config to format expected by plugin manager
|
||||||
// (plugin manager cannot reference steampipe config to avoid circular deps)
|
// (plugin manager cannot reference steampipe config to avoid circular deps)
|
||||||
|
log.Printf("[DEBUG] ConnectionWatcher: creating connection config map")
|
||||||
configMap := NewConnectionConfigMap(config.Connections)
|
configMap := NewConnectionConfigMap(config.Connections)
|
||||||
|
log.Printf("[DEBUG] ConnectionWatcher: calling OnConnectionConfigChanged with %d connections", len(configMap))
|
||||||
w.pluginManager.OnConnectionConfigChanged(ctx, configMap, config.PluginsInstances)
|
w.pluginManager.OnConnectionConfigChanged(ctx, configMap, config.PluginsInstances)
|
||||||
|
log.Printf("[DEBUG] ConnectionWatcher: OnConnectionConfigChanged complete")
|
||||||
|
|
||||||
// The only configurations from GlobalConfig which have
|
// The only configurations from GlobalConfig which have
|
||||||
// impact during Refresh are Database options and the Connections
|
// impact during Refresh are Database options and the Connections
|
||||||
@@ -99,7 +110,9 @@ func (w *ConnectionWatcher) handleFileWatcherEvent([]fsnotify.Event) {
|
|||||||
// Workspace Profile does not have any setting which can alter
|
// Workspace Profile does not have any setting which can alter
|
||||||
// behavior in service mode (namely search path). Therefore, it is safe
|
// behavior in service mode (namely search path). Therefore, it is safe
|
||||||
// to use the GlobalConfig here and ignore Workspace Profile in general
|
// to use the GlobalConfig here and ignore Workspace Profile in general
|
||||||
|
log.Printf("[DEBUG] ConnectionWatcher: calling SetDefaultsFromConfig")
|
||||||
cmdconfig.SetDefaultsFromConfig(steampipeconfig.GlobalConfig.ConfigMap())
|
cmdconfig.SetDefaultsFromConfig(steampipeconfig.GlobalConfig.ConfigMap())
|
||||||
|
log.Printf("[DEBUG] ConnectionWatcher: SetDefaultsFromConfig complete")
|
||||||
|
|
||||||
log.Printf("[INFO] calling RefreshConnections asyncronously")
|
log.Printf("[INFO] calling RefreshConnections asyncronously")
|
||||||
|
|
||||||
|
|||||||
@@ -50,12 +50,13 @@ type DbClient struct {
|
|||||||
|
|
||||||
// allows locked access to the 'sessions' map
|
// allows locked access to the 'sessions' map
|
||||||
sessionsMutex *sync.Mutex
|
sessionsMutex *sync.Mutex
|
||||||
|
sessionsLockFlag atomic.Bool
|
||||||
|
|
||||||
// if a custom search path or a prefix is used, store it here
|
// if a custom search path or a prefix is used, store it here
|
||||||
customSearchPath []string
|
customSearchPath []string
|
||||||
searchPathPrefix []string
|
searchPathPrefix []string
|
||||||
// allows locked access to customSearchPath and searchPathPrefix
|
// allows locked access to customSearchPath and searchPathPrefix
|
||||||
searchPathMutex *sync.Mutex
|
searchPathMutex *sync.RWMutex
|
||||||
// the default user search path
|
// the default user search path
|
||||||
userSearchPath []string
|
userSearchPath []string
|
||||||
// disable timing - set whilst in process of querying the timing
|
// disable timing - set whilst in process of querying the timing
|
||||||
@@ -73,7 +74,7 @@ func NewDbClient(ctx context.Context, connectionString string, opts ...ClientOpt
|
|||||||
parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
|
parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
|
||||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||||
sessionsMutex: &sync.Mutex{},
|
sessionsMutex: &sync.Mutex{},
|
||||||
searchPathMutex: &sync.Mutex{},
|
searchPathMutex: &sync.RWMutex{},
|
||||||
connectionString: connectionString,
|
connectionString: connectionString,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -152,6 +153,37 @@ func (c *DbClient) shouldFetchVerboseTiming() bool {
|
|||||||
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
|
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// lockSessions acquires the sessionsMutex and tracks ownership for tryLock compatibility.
|
||||||
|
func (c *DbClient) lockSessions() {
|
||||||
|
if c.sessionsMutex == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.sessionsLockFlag.Store(true)
|
||||||
|
c.sessionsMutex.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionsTryLock attempts to acquire the sessionsMutex without blocking.
|
||||||
|
// Returns false if the lock is already held.
|
||||||
|
func (c *DbClient) sessionsTryLock() bool {
|
||||||
|
if c.sessionsMutex == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// best-effort: only one contender sets the flag and proceeds to lock
|
||||||
|
if !c.sessionsLockFlag.CompareAndSwap(false, true) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.sessionsMutex.Lock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *DbClient) sessionsUnlock() {
|
||||||
|
if c.sessionsMutex == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.sessionsMutex.Unlock()
|
||||||
|
c.sessionsLockFlag.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
// ServerSettings returns the settings of the steampipe service that this DbClient is connected to
|
// ServerSettings returns the settings of the steampipe service that this DbClient is connected to
|
||||||
//
|
//
|
||||||
// Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected.
|
// Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected.
|
||||||
@@ -173,9 +205,9 @@ func (c *DbClient) Close(context.Context) error {
|
|||||||
// nullify active sessions, since with the closing of the pools
|
// nullify active sessions, since with the closing of the pools
|
||||||
// none of the sessions will be valid anymore
|
// none of the sessions will be valid anymore
|
||||||
// Acquire mutex to prevent concurrent access to sessions map
|
// Acquire mutex to prevent concurrent access to sessions map
|
||||||
c.sessionsMutex.Lock()
|
c.lockSessions()
|
||||||
c.sessions = nil
|
c.sessions = nil
|
||||||
c.sessionsMutex.Unlock()
|
c.sessionsUnlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -66,12 +66,14 @@ func (c *DbClient) establishConnectionPool(ctx context.Context, overrides client
|
|||||||
config.BeforeClose = func(conn *pgx.Conn) {
|
config.BeforeClose = func(conn *pgx.Conn) {
|
||||||
if conn != nil && conn.PgConn() != nil {
|
if conn != nil && conn.PgConn() != nil {
|
||||||
backendPid := conn.PgConn().PID()
|
backendPid := conn.PgConn().PID()
|
||||||
c.sessionsMutex.Lock()
|
// Best-effort cleanup: do not block pool.Close() if sessions lock is busy.
|
||||||
|
if c.sessionsTryLock() {
|
||||||
// Check if sessions map has been nil'd by Close()
|
// Check if sessions map has been nil'd by Close()
|
||||||
if c.sessions != nil {
|
if c.sessions != nil {
|
||||||
delete(c.sessions, backendPid)
|
delete(c.sessions, backendPid)
|
||||||
}
|
}
|
||||||
c.sessionsMutex.Unlock()
|
c.sessionsUnlock()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// set an app name so that we can track database connections from this Steampipe execution
|
// set an app name so that we can track database connections from this Steampipe execution
|
||||||
|
|||||||
@@ -78,8 +78,8 @@ func (c *DbClient) loadUserSearchPath(ctx context.Context, connection *pgx.Conn)
|
|||||||
|
|
||||||
// GetRequiredSessionSearchPath implements Client
|
// GetRequiredSessionSearchPath implements Client
|
||||||
func (c *DbClient) GetRequiredSessionSearchPath() []string {
|
func (c *DbClient) GetRequiredSessionSearchPath() []string {
|
||||||
c.searchPathMutex.Lock()
|
c.searchPathMutex.RLock()
|
||||||
defer c.searchPathMutex.Unlock()
|
defer c.searchPathMutex.RUnlock()
|
||||||
|
|
||||||
if c.customSearchPath != nil {
|
if c.customSearchPath != nil {
|
||||||
return c.customSearchPath
|
return c.customSearchPath
|
||||||
@@ -89,8 +89,8 @@ func (c *DbClient) GetRequiredSessionSearchPath() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *DbClient) GetCustomSearchPath() []string {
|
func (c *DbClient) GetCustomSearchPath() []string {
|
||||||
c.searchPathMutex.Lock()
|
c.searchPathMutex.RLock()
|
||||||
defer c.searchPathMutex.Unlock()
|
defer c.searchPathMutex.RUnlock()
|
||||||
|
|
||||||
return c.customSearchPath
|
return c.customSearchPath
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,10 +37,10 @@ func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common
|
|||||||
}
|
}
|
||||||
backendPid := databaseConnection.Conn().PgConn().PID()
|
backendPid := databaseConnection.Conn().PgConn().PID()
|
||||||
|
|
||||||
c.sessionsMutex.Lock()
|
c.lockSessions()
|
||||||
// Check if client has been closed (sessions set to nil)
|
// Check if client has been closed (sessions set to nil)
|
||||||
if c.sessions == nil {
|
if c.sessions == nil {
|
||||||
c.sessionsMutex.Unlock()
|
c.sessionsUnlock()
|
||||||
sessionResult.Error = fmt.Errorf("client has been closed")
|
sessionResult.Error = fmt.Errorf("client has been closed")
|
||||||
return sessionResult
|
return sessionResult
|
||||||
}
|
}
|
||||||
@@ -52,7 +52,7 @@ func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common
|
|||||||
// we get a new *sql.Conn everytime. USE IT!
|
// we get a new *sql.Conn everytime. USE IT!
|
||||||
session.Connection = databaseConnection
|
session.Connection = databaseConnection
|
||||||
sessionResult.Session = session
|
sessionResult.Session = session
|
||||||
c.sessionsMutex.Unlock()
|
c.sessionsUnlock()
|
||||||
|
|
||||||
// make sure that we close the acquired session, in case of error
|
// make sure that we close the acquired session, in case of error
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -2,10 +2,13 @@ package db_client
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/turbot/steampipe/v2/pkg/db/db_common"
|
"github.com/turbot/steampipe/v2/pkg/db/db_common"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -160,6 +163,28 @@ func TestDbClient_SearchPathUpdates(t *testing.T) {
|
|||||||
assert.Len(t, client.customSearchPath, 2, "Should have 2 schemas in search path")
|
assert.Len(t, client.customSearchPath, 2, "Should have 2 schemas in search path")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSearchPathAccessShouldUseReadLocks checks that search path access does not block other goroutines unnecessarily.
|
||||||
|
//
|
||||||
|
// Holding an exclusive mutex during search-path reads in concurrent query setup can deadlock when
|
||||||
|
// another goroutine is setting the path. The current code uses Lock/Unlock; this test documents
|
||||||
|
// the expectation to move to a read/non-blocking lock so concurrent reads are safe.
|
||||||
|
func TestSearchPathAccessShouldUseReadLocks(t *testing.T) {
|
||||||
|
content, err := os.ReadFile("db_client_search_path.go")
|
||||||
|
require.NoError(t, err, "should be able to read db_client_search_path.go")
|
||||||
|
|
||||||
|
source := string(content)
|
||||||
|
|
||||||
|
assert.Contains(t, source, "GetRequiredSessionSearchPath", "getter must exist")
|
||||||
|
assert.Contains(t, source, "searchPathMutex", "getter must guard access to searchPath state")
|
||||||
|
|
||||||
|
// Expect a read or non-blocking lock in getters; fail if only full Lock/Unlock is present.
|
||||||
|
hasRLock := strings.Contains(source, "RLock")
|
||||||
|
hasTry := strings.Contains(source, "TryLock") || strings.Contains(source, "tryLock")
|
||||||
|
if !hasRLock && !hasTry {
|
||||||
|
t.Fatalf("GetRequiredSessionSearchPath should avoid exclusive Lock/Unlock to prevent deadlocks under concurrent query setup")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestDbClient_SessionConnectionNilSafety verifies handling of nil connections
|
// TestDbClient_SessionConnectionNilSafety verifies handling of nil connections
|
||||||
func TestDbClient_SessionConnectionNilSafety(t *testing.T) {
|
func TestDbClient_SessionConnectionNilSafety(t *testing.T) {
|
||||||
session := db_common.NewDBSession(12345)
|
session := db_common.NewDBSession(12345)
|
||||||
@@ -181,7 +206,7 @@ func TestDbClient_SessionSearchPathUpdatesThreadSafe(t *testing.T) {
|
|||||||
client := &DbClient{
|
client := &DbClient{
|
||||||
customSearchPath: []string{"public", "internal"},
|
customSearchPath: []string{"public", "internal"},
|
||||||
userSearchPath: []string{"public"},
|
userSearchPath: []string{"public"},
|
||||||
searchPathMutex: &sync.Mutex{},
|
searchPathMutex: &sync.RWMutex{},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Number of concurrent operations to test
|
// Number of concurrent operations to test
|
||||||
|
|||||||
@@ -52,6 +52,36 @@ func TestSessionMapCleanupImplemented(t *testing.T) {
|
|||||||
"Comment should document automatic cleanup mechanism")
|
"Comment should document automatic cleanup mechanism")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestBeforeCloseCleanupShouldBeNonBlocking ensures the cleanup hook does not take a blocking lock.
|
||||||
|
//
|
||||||
|
// A blocking mutex in the BeforeClose hook can deadlock pool.Close() when another goroutine
|
||||||
|
// holds sessionsMutex (service stop/restart hangs). This test is intentionally strict and
|
||||||
|
// will fail until the hook uses a non-blocking strategy (e.g., TryLock or similar).
|
||||||
|
func TestBeforeCloseCleanupShouldBeNonBlocking(t *testing.T) {
|
||||||
|
content, err := os.ReadFile("db_client_connect.go")
|
||||||
|
require.NoError(t, err, "should be able to read db_client_connect.go")
|
||||||
|
|
||||||
|
source := string(content)
|
||||||
|
|
||||||
|
// Guardrail: the BeforeClose hook should avoid unconditionally blocking on sessionsMutex.
|
||||||
|
assert.Contains(t, source, "config.BeforeClose", "BeforeClose cleanup hook must exist")
|
||||||
|
assert.Contains(t, source, "sessionsTryLock", "BeforeClose cleanup should use non-blocking lock helper")
|
||||||
|
|
||||||
|
// Expect a non-blocking lock pattern; if we only find Lock()/Unlock, this fails.
|
||||||
|
nonBlockingPatterns := []string{"TryLock", "tryLock", "non-block", "select {"}
|
||||||
|
foundNonBlocking := false
|
||||||
|
for _, p := range nonBlockingPatterns {
|
||||||
|
if strings.Contains(source, p) {
|
||||||
|
foundNonBlocking = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !foundNonBlocking {
|
||||||
|
t.Fatalf("BeforeClose cleanup appears to take a blocking lock on sessionsMutex; add a non-blocking guard to prevent pool.Close deadlocks")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// TestDbClient_Close_Idempotent verifies that calling Close() multiple times does not cause issues
|
// TestDbClient_Close_Idempotent verifies that calling Close() multiple times does not cause issues
|
||||||
// Reference: Similar to bug #4712 (Result.Close() idempotency)
|
// Reference: Similar to bug #4712 (Result.Close() idempotency)
|
||||||
//
|
//
|
||||||
@@ -284,13 +314,14 @@ func TestDbClient_SessionsMutexProtectsMap(t *testing.T) {
|
|||||||
|
|
||||||
sourceCode := string(content)
|
sourceCode := string(content)
|
||||||
|
|
||||||
// Count occurrences of mutex locks
|
// Count occurrences of mutex lock helpers
|
||||||
mutexLocks := strings.Count(sourceCode, "c.sessionsMutex.Lock()")
|
mutexLocks := strings.Count(sourceCode, "lockSessions()") +
|
||||||
|
strings.Count(sourceCode, "sessionsTryLock()")
|
||||||
|
|
||||||
// This is a heuristic check - in practice, we'd need more sophisticated analysis
|
// This is a heuristic check - in practice, we'd need more sophisticated analysis
|
||||||
// But it serves as a reminder to use the mutex
|
// But it serves as a reminder to use the mutex
|
||||||
assert.True(t, mutexLocks > 0,
|
assert.True(t, mutexLocks > 0,
|
||||||
"sessionsMutex.Lock() should be used when accessing sessions map")
|
"sessions lock helpers should be used when accessing sessions map")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestDbClient_SessionMapDocumentation verifies that session lifecycle is documented
|
// TestDbClient_SessionMapDocumentation verifies that session lifecycle is documented
|
||||||
|
|||||||
@@ -51,7 +51,29 @@ type PluginManager struct {
|
|||||||
// map of max cache size, keyed by plugin instance
|
// map of max cache size, keyed by plugin instance
|
||||||
pluginCacheSizeMap map[string]int64
|
pluginCacheSizeMap map[string]int64
|
||||||
|
|
||||||
// map lock
|
// mut protects concurrent access to plugin manager state (runningPluginMap, connectionConfigMap, etc.)
|
||||||
|
//
|
||||||
|
// LOCKING PATTERN TO PREVENT DEADLOCKS:
|
||||||
|
// - Functions that acquire mut.Lock() and call other methods MUST only call *Internal versions
|
||||||
|
// - Public methods that need locking: acquire lock → call internal version → release lock
|
||||||
|
// - Internal methods: assume caller holds lock, never acquire lock themselves
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
// func (m *PluginManager) SomeMethod() {
|
||||||
|
// m.mut.Lock()
|
||||||
|
// defer m.mut.Unlock()
|
||||||
|
// return m.someMethodInternal()
|
||||||
|
// }
|
||||||
|
// func (m *PluginManager) someMethodInternal() {
|
||||||
|
// // NOTE: caller must hold m.mut lock
|
||||||
|
// // ... implementation without locking ...
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Functions with internal/external versions:
|
||||||
|
// - refreshRateLimiterTable / refreshRateLimiterTableInternal
|
||||||
|
// - updateRateLimiterStatus / updateRateLimiterStatusInternal
|
||||||
|
// - setRateLimiters / setRateLimitersInternal
|
||||||
|
// - getPluginsWithChangedLimiters / getPluginsWithChangedLimitersInternal
|
||||||
mut sync.RWMutex
|
mut sync.RWMutex
|
||||||
|
|
||||||
// shutdown synchronization
|
// shutdown synchronization
|
||||||
@@ -231,23 +253,32 @@ func (m *PluginManager) doRefresh() {
|
|||||||
|
|
||||||
// OnConnectionConfigChanged is the callback function invoked by the connection watcher when the config changed
|
// OnConnectionConfigChanged is the callback function invoked by the connection watcher when the config changed
|
||||||
func (m *PluginManager) OnConnectionConfigChanged(ctx context.Context, configMap connection.ConnectionConfigMap, plugins map[string]*plugin.Plugin) {
|
func (m *PluginManager) OnConnectionConfigChanged(ctx context.Context, configMap connection.ConnectionConfigMap, plugins map[string]*plugin.Plugin) {
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: acquiring lock")
|
||||||
m.mut.Lock()
|
m.mut.Lock()
|
||||||
defer m.mut.Unlock()
|
defer m.mut.Unlock()
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: lock acquired")
|
||||||
|
|
||||||
log.Printf("[TRACE] OnConnectionConfigChanged: connections: %s plugin instances: %s", strings.Join(utils.SortedMapKeys(configMap), ","), strings.Join(utils.SortedMapKeys(plugins), ","))
|
log.Printf("[TRACE] OnConnectionConfigChanged: connections: %s plugin instances: %s", strings.Join(utils.SortedMapKeys(configMap), ","), strings.Join(utils.SortedMapKeys(plugins), ","))
|
||||||
|
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handleConnectionConfigChanges")
|
||||||
if err := m.handleConnectionConfigChanges(ctx, configMap); err != nil {
|
if err := m.handleConnectionConfigChanges(ctx, configMap); err != nil {
|
||||||
log.Printf("[WARN] handleConnectionConfigChanges failed: %s", err.Error())
|
log.Printf("[WARN] handleConnectionConfigChanges failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: handleConnectionConfigChanges complete")
|
||||||
|
|
||||||
// update our plugin configs
|
// update our plugin configs
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handlePluginInstanceChanges")
|
||||||
if err := m.handlePluginInstanceChanges(ctx, plugins); err != nil {
|
if err := m.handlePluginInstanceChanges(ctx, plugins); err != nil {
|
||||||
log.Printf("[WARN] handlePluginInstanceChanges failed: %s", err.Error())
|
log.Printf("[WARN] handlePluginInstanceChanges failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: handlePluginInstanceChanges complete")
|
||||||
|
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: calling handleUserLimiterChanges")
|
||||||
if err := m.handleUserLimiterChanges(ctx, plugins); err != nil {
|
if err := m.handleUserLimiterChanges(ctx, plugins); err != nil {
|
||||||
log.Printf("[WARN] handleUserLimiterChanges failed: %s", err.Error())
|
log.Printf("[WARN] handleUserLimiterChanges failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: handleUserLimiterChanges complete")
|
||||||
|
log.Printf("[DEBUG] OnConnectionConfigChanged: about to release lock and return")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *PluginManager) GetConnectionConfig() connection.ConnectionConfigMap {
|
func (m *PluginManager) GetConnectionConfig() connection.ConnectionConfigMap {
|
||||||
@@ -776,14 +807,19 @@ func (m *PluginManager) setCacheOptions(pluginClient *sdkgrpc.PluginClient) erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *PluginManager) setRateLimiters(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error {
|
func (m *PluginManager) setRateLimiters(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error {
|
||||||
|
m.mut.RLock()
|
||||||
|
defer m.mut.RUnlock()
|
||||||
|
return m.setRateLimitersInternal(pluginInstance, pluginClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PluginManager) setRateLimitersInternal(pluginInstance string, pluginClient *sdkgrpc.PluginClient) error {
|
||||||
|
// NOTE: caller must hold m.mut lock (at least RLock)
|
||||||
log.Printf("[INFO] setRateLimiters for plugin '%s'", pluginInstance)
|
log.Printf("[INFO] setRateLimiters for plugin '%s'", pluginInstance)
|
||||||
var defs []*sdkproto.RateLimiterDefinition
|
var defs []*sdkproto.RateLimiterDefinition
|
||||||
|
|
||||||
m.mut.RLock()
|
|
||||||
for _, l := range m.userLimiters[pluginInstance] {
|
for _, l := range m.userLimiters[pluginInstance] {
|
||||||
defs = append(defs, RateLimiterAsProto(l))
|
defs = append(defs, RateLimiterAsProto(l))
|
||||||
}
|
}
|
||||||
m.mut.RUnlock()
|
|
||||||
|
|
||||||
req := &sdkproto.SetRateLimitersRequest{Definitions: defs}
|
req := &sdkproto.SetRateLimitersRequest{Definitions: defs}
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ func (m *PluginManager) ShouldFetchRateLimiterDefs() bool {
|
|||||||
// update the stored limiters, refrresh the rate limiter table and call `setRateLimiters`
|
// update the stored limiters, refrresh the rate limiter table and call `setRateLimiters`
|
||||||
// for all plugins with changed limiters
|
// for all plugins with changed limiters
|
||||||
func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.PluginLimiterMap) error {
|
func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.PluginLimiterMap) error {
|
||||||
|
m.mut.Lock()
|
||||||
|
defer m.mut.Unlock()
|
||||||
|
|
||||||
if m.pluginLimiters == nil {
|
if m.pluginLimiters == nil {
|
||||||
// this must be the first time we have populated them
|
// this must be the first time we have populated them
|
||||||
m.pluginLimiters = make(connection.PluginLimiterMap)
|
m.pluginLimiters = make(connection.PluginLimiterMap)
|
||||||
@@ -38,13 +41,22 @@ func (m *PluginManager) HandlePluginLimiterChanges(newLimiters connection.Plugin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// update the steampipe_plugin_limiters table
|
// update the steampipe_plugin_limiters table
|
||||||
if err := m.refreshRateLimiterTable(context.Background()); err != nil {
|
// NOTE: we hold m.mut lock, so call internal version
|
||||||
|
if err := m.refreshRateLimiterTableInternal(context.Background()); err != nil {
|
||||||
log.Println("[WARN] could not refresh rate limiter table", err)
|
log.Println("[WARN] could not refresh rate limiter table", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
||||||
|
m.mut.Lock()
|
||||||
|
defer m.mut.Unlock()
|
||||||
|
return m.refreshRateLimiterTableInternal(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PluginManager) refreshRateLimiterTableInternal(ctx context.Context) error {
|
||||||
|
// NOTE: caller must hold m.mut lock
|
||||||
|
|
||||||
// if we have not yet populated the rate limiter table, do nothing
|
// if we have not yet populated the rate limiter table, do nothing
|
||||||
if m.pluginLimiters == nil {
|
if m.pluginLimiters == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -56,7 +68,7 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// update the status of the plugin rate limiters (determine which are overriden and set state accordingly)
|
// update the status of the plugin rate limiters (determine which are overriden and set state accordingly)
|
||||||
m.updateRateLimiterStatus()
|
m.updateRateLimiterStatusInternal()
|
||||||
|
|
||||||
queries := []db_common.QueryWithArgs{
|
queries := []db_common.QueryWithArgs{
|
||||||
introspection.GetRateLimiterTableDropSql(),
|
introspection.GetRateLimiterTableDropSql(),
|
||||||
@@ -70,13 +82,12 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mut.RLock()
|
// NOTE: no lock needed here, caller already holds m.mut
|
||||||
for _, limitersForPlugin := range m.userLimiters {
|
for _, limitersForPlugin := range m.userLimiters {
|
||||||
for _, l := range limitersForPlugin {
|
for _, l := range limitersForPlugin {
|
||||||
queries = append(queries, introspection.GetRateLimiterTablePopulateSql(l))
|
queries = append(queries, introspection.GetRateLimiterTablePopulateSql(l))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
m.mut.RUnlock()
|
|
||||||
|
|
||||||
conn, err := m.pool.Acquire(ctx)
|
conn, err := m.pool.Acquire(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -92,30 +103,42 @@ func (m *PluginManager) refreshRateLimiterTable(ctx context.Context) error {
|
|||||||
// update the stored limiters, refresh the rate limiter table and call `setRateLimiters`
|
// update the stored limiters, refresh the rate limiter table and call `setRateLimiters`
|
||||||
// for all plugins with changed limiters
|
// for all plugins with changed limiters
|
||||||
func (m *PluginManager) handleUserLimiterChanges(_ context.Context, plugins connection.PluginMap) error {
|
func (m *PluginManager) handleUserLimiterChanges(_ context.Context, plugins connection.PluginMap) error {
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: start")
|
||||||
limiterPluginMap := plugins.ToPluginLimiterMap()
|
limiterPluginMap := plugins.ToPluginLimiterMap()
|
||||||
pluginsWithChangedLimiters := m.getPluginsWithChangedLimiters(limiterPluginMap)
|
log.Printf("[DEBUG] handleUserLimiterChanges: got limiter plugin map")
|
||||||
|
// NOTE: caller (OnConnectionConfigChanged) already holds m.mut lock, so use internal version
|
||||||
|
pluginsWithChangedLimiters := m.getPluginsWithChangedLimitersInternal(limiterPluginMap)
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: found %d plugins with changed limiters", len(pluginsWithChangedLimiters))
|
||||||
|
|
||||||
if len(pluginsWithChangedLimiters) == 0 {
|
if len(pluginsWithChangedLimiters) == 0 {
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: no changes, returning")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// update stored limiters to the new map
|
// update stored limiters to the new map
|
||||||
m.mut.Lock()
|
// NOTE: caller (OnConnectionConfigChanged) already holds m.mut lock, so we don't lock here
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: updating user limiters")
|
||||||
m.userLimiters = limiterPluginMap
|
m.userLimiters = limiterPluginMap
|
||||||
m.mut.Unlock()
|
|
||||||
|
|
||||||
// update the steampipe_plugin_limiters table
|
// update the steampipe_plugin_limiters table
|
||||||
if err := m.refreshRateLimiterTable(context.Background()); err != nil {
|
// NOTE: caller already holds m.mut lock, so call internal version
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: calling refreshRateLimiterTableInternal")
|
||||||
|
if err := m.refreshRateLimiterTableInternal(context.Background()); err != nil {
|
||||||
log.Println("[WARN] could not refresh rate limiter table", err)
|
log.Println("[WARN] could not refresh rate limiter table", err)
|
||||||
}
|
}
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: refreshRateLimiterTableInternal complete")
|
||||||
|
|
||||||
// now update the plugins - call setRateLimiters for any plugin with updated user limiters
|
// now update the plugins - call setRateLimiters for any plugin with updated user limiters
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: setting rate limiters for plugins")
|
||||||
for p := range pluginsWithChangedLimiters {
|
for p := range pluginsWithChangedLimiters {
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: calling setRateLimitersForPlugin for %s", p)
|
||||||
if err := m.setRateLimitersForPlugin(p); err != nil {
|
if err := m.setRateLimitersForPlugin(p); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: setRateLimitersForPlugin complete for %s", p)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Printf("[DEBUG] handleUserLimiterChanges: complete")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -138,17 +161,22 @@ func (m *PluginManager) setRateLimitersForPlugin(pluginShortName string) error {
|
|||||||
return sperr.WrapWithMessage(err, "failed to create a plugin client when updating the rate limiter for plugin '%s'", imageRef)
|
return sperr.WrapWithMessage(err, "failed to create a plugin client when updating the rate limiter for plugin '%s'", imageRef)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.setRateLimiters(pluginShortName, pluginClient); err != nil {
|
// NOTE: caller (handleUserLimiterChanges via OnConnectionConfigChanged) already holds m.mut lock
|
||||||
|
if err := m.setRateLimitersInternal(pluginShortName, pluginClient); err != nil {
|
||||||
return sperr.WrapWithMessage(err, "failed to update rate limiters for plugin '%s'", imageRef)
|
return sperr.WrapWithMessage(err, "failed to update rate limiters for plugin '%s'", imageRef)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.PluginLimiterMap) map[string]struct{} {
|
func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.PluginLimiterMap) map[string]struct{} {
|
||||||
var pluginsWithChangedLimiters = make(map[string]struct{})
|
|
||||||
|
|
||||||
m.mut.RLock()
|
m.mut.RLock()
|
||||||
defer m.mut.RUnlock()
|
defer m.mut.RUnlock()
|
||||||
|
return m.getPluginsWithChangedLimitersInternal(newLimiters)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PluginManager) getPluginsWithChangedLimitersInternal(newLimiters connection.PluginLimiterMap) map[string]struct{} {
|
||||||
|
// NOTE: caller must hold m.mut lock (at least RLock)
|
||||||
|
var pluginsWithChangedLimiters = make(map[string]struct{})
|
||||||
|
|
||||||
for plugin, limitersForPlugin := range m.userLimiters {
|
for plugin, limitersForPlugin := range m.userLimiters {
|
||||||
newLimitersForPlugin := newLimiters[plugin]
|
newLimitersForPlugin := newLimiters[plugin]
|
||||||
@@ -169,7 +197,11 @@ func (m *PluginManager) getPluginsWithChangedLimiters(newLimiters connection.Plu
|
|||||||
func (m *PluginManager) updateRateLimiterStatus() {
|
func (m *PluginManager) updateRateLimiterStatus() {
|
||||||
m.mut.Lock()
|
m.mut.Lock()
|
||||||
defer m.mut.Unlock()
|
defer m.mut.Unlock()
|
||||||
|
m.updateRateLimiterStatusInternal()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *PluginManager) updateRateLimiterStatusInternal() {
|
||||||
|
// NOTE: caller must hold m.mut lock
|
||||||
// iterate through limiters for each plug
|
// iterate through limiters for each plug
|
||||||
for p, pluginDefinedLimiters := range m.pluginLimiters {
|
for p, pluginDefinedLimiters := range m.pluginLimiters {
|
||||||
// get user limiters for this plugin (already holding lock, so call internal version)
|
// get user limiters for this plugin (already holding lock, so call internal version)
|
||||||
|
|||||||
Reference in New Issue
Block a user