mirror of
https://github.com/turbot/steampipe.git
synced 2025-12-19 18:12:43 -05:00
* Add tests demonstrating bug #4793: Close() sets sessions=nil without mutex These tests demonstrate the race condition where Close() sets c.sessions to nil without holding the mutex, while AcquireSession() tries to access the map with the mutex held. Running with -race detects the data race and the test panics with "assignment to entry in nil map". 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix #4793: Protect sessions map access with mutex in Close() Acquire sessionsMutex before setting sessions to nil in Close() to prevent data race with AcquireSession(). Also add nil check in AcquireSession() to handle the case where Close() has been called. This prevents the panic "assignment to entry in nil map" when Close() and AcquireSession() are called concurrently. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
322 lines
10 KiB
Go
322 lines
10 KiB
Go
package db_client
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/spf13/viper"
|
|
pconstants "github.com/turbot/pipe-fittings/v2/constants"
|
|
"github.com/turbot/pipe-fittings/v2/utils"
|
|
"github.com/turbot/steampipe/v2/pkg/constants"
|
|
"github.com/turbot/steampipe/v2/pkg/db/db_common"
|
|
"github.com/turbot/steampipe/v2/pkg/serversettings"
|
|
"github.com/turbot/steampipe/v2/pkg/steampipeconfig"
|
|
"golang.org/x/exp/maps"
|
|
"golang.org/x/sync/semaphore"
|
|
)
|
|
|
|
// DbClient wraps over `sql.DB` and gives an interface to the database
|
|
type DbClient struct {
|
|
connectionString string
|
|
|
|
// connection userPool for user initiated queries
|
|
userPool *pgxpool.Pool
|
|
|
|
// connection used to run system/plumbing queries (connection state, server settings)
|
|
managementPool *pgxpool.Pool
|
|
|
|
// the settings of the server that this client is connected to
|
|
serverSettings *db_common.ServerSettings
|
|
|
|
// this flag is set if the service that this client
|
|
// is connected to is running in the same physical system
|
|
isLocalService bool
|
|
|
|
// concurrency management for db session access
|
|
parallelSessionInitLock *semaphore.Weighted
|
|
|
|
// map of database sessions, keyed to the backend_pid in postgres
|
|
// used to update session search path where necessary
|
|
// Session lifecycle: entries are added when connections are established and automatically
|
|
// removed via a pgxpool BeforeClose callback when connections are closed by the pool.
|
|
// This prevents memory accumulation from stale connection entries (see issue #3737)
|
|
sessions map[uint32]*db_common.DatabaseSession
|
|
|
|
// allows locked access to the 'sessions' map
|
|
sessionsMutex *sync.Mutex
|
|
|
|
// if a custom search path or a prefix is used, store it here
|
|
customSearchPath []string
|
|
searchPathPrefix []string
|
|
// the default user search path
|
|
userSearchPath []string
|
|
// disable timing - set whilst in process of querying the timing
|
|
disableTiming atomic.Bool
|
|
onConnectionCallback DbConnectionCallback
|
|
}
|
|
|
|
func NewDbClient(ctx context.Context, connectionString string, opts ...ClientOption) (_ *DbClient, err error) {
|
|
utils.LogTime("db_client.NewDbClient start")
|
|
defer utils.LogTime("db_client.NewDbClient end")
|
|
|
|
client := &DbClient{
|
|
// a weighted semaphore to control the maximum number parallel
|
|
// initializations under way
|
|
parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
|
|
sessions: make(map[uint32]*db_common.DatabaseSession),
|
|
sessionsMutex: &sync.Mutex{},
|
|
connectionString: connectionString,
|
|
}
|
|
|
|
defer func() {
|
|
if err != nil {
|
|
// try closing the client
|
|
client.Close(ctx)
|
|
}
|
|
}()
|
|
|
|
config := clientConfig{}
|
|
for _, o := range opts {
|
|
o(&config)
|
|
}
|
|
|
|
if err := client.establishConnectionPool(ctx, config); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// load up the server settings
|
|
if err := client.loadServerSettings(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// set user search path
|
|
if err := client.LoadUserSearchPath(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// populate customSearchPath
|
|
if err := client.SetRequiredSessionSearchPath(ctx); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (c *DbClient) closePools() {
|
|
if c.userPool != nil {
|
|
c.userPool.Close()
|
|
}
|
|
if c.managementPool != nil {
|
|
c.managementPool.Close()
|
|
}
|
|
}
|
|
|
|
func (c *DbClient) loadServerSettings(ctx context.Context) error {
|
|
serverSettings, err := serversettings.Load(ctx, c.managementPool)
|
|
if err != nil {
|
|
if notFound := db_common.IsRelationNotFoundError(err); notFound {
|
|
// when connecting to pre-0.21.0 services, the steampipe_server_settings table will not be available.
|
|
// this is expected and not an error
|
|
// code which uses steampipe_server_settings should handle this
|
|
log.Printf("[TRACE] could not find %s.%s table. skipping\n", constants.InternalSchema, constants.ServerSettingsTable)
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
c.serverSettings = serverSettings
|
|
log.Println("[TRACE] loaded server settings:", serverSettings)
|
|
return nil
|
|
}
|
|
|
|
func (c *DbClient) shouldFetchTiming() bool {
|
|
// check for override flag (this is to prevent timing being fetched when we read the timing metadata table)
|
|
if c.disableTiming.Load() {
|
|
return false
|
|
}
|
|
// only fetch timing if timing flag is set, or output is JSON
|
|
return (viper.GetString(pconstants.ArgTiming) != pconstants.ArgOff) ||
|
|
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
|
|
|
|
}
|
|
func (c *DbClient) shouldFetchVerboseTiming() bool {
|
|
return (viper.GetString(pconstants.ArgTiming) == pconstants.ArgVerbose) ||
|
|
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
|
|
}
|
|
|
|
// ServerSettings returns the settings of the steampipe service that this DbClient is connected to
|
|
//
|
|
// Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected.
|
|
// Code which read server_settings should take this into account.
|
|
func (c *DbClient) ServerSettings() *db_common.ServerSettings {
|
|
return c.serverSettings
|
|
}
|
|
|
|
// RegisterNotificationListener has an empty implementation
|
|
// NOTE: we do not (currently) support notifications from remote connections
|
|
func (c *DbClient) RegisterNotificationListener(func(notification *pgconn.Notification)) {}
|
|
|
|
// Close implements Client
|
|
|
|
// closes the connection to the database and shuts down the backend
|
|
func (c *DbClient) Close(context.Context) error {
|
|
log.Printf("[TRACE] DbClient.Close %v", c.userPool)
|
|
c.closePools()
|
|
// nullify active sessions, since with the closing of the pools
|
|
// none of the sessions will be valid anymore
|
|
// Acquire mutex to prevent concurrent access to sessions map
|
|
c.sessionsMutex.Lock()
|
|
c.sessions = nil
|
|
c.sessionsMutex.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetSchemaFromDB retrieves schemas for all steampipe connections (EXCEPT DISABLED CONNECTIONS)
|
|
// NOTE: it optimises the schema extraction by extracting schema information for
|
|
// connections backed by distinct plugins and then fanning back out.
|
|
func (c *DbClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetadata, error) {
|
|
log.Printf("[INFO] DbClient GetSchemaFromDB")
|
|
mgmtConn, err := c.managementPool.Acquire(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer mgmtConn.Release()
|
|
|
|
// for optimisation purposes, try to load connection state and build a map of schemas to load
|
|
// (if we are connected to a remote server running an older CLI,
|
|
// this load may fail, in which case bypass the optimisation)
|
|
connectionStateMap, err := steampipeconfig.LoadConnectionState(ctx, mgmtConn.Conn(), steampipeconfig.WithWaitUntilLoading())
|
|
// NOTE: if we failed to load connection state, this may be because we are connected to an older version of the CLI
|
|
// use legacy (v0.19.x) schema loading code
|
|
if err != nil {
|
|
return c.GetSchemaFromDBLegacy(ctx, mgmtConn)
|
|
}
|
|
|
|
// build a ConnectionSchemaMap object to identify the schemas to load
|
|
connectionSchemaMap := steampipeconfig.NewConnectionSchemaMap(ctx, connectionStateMap, c.GetRequiredSessionSearchPath())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// get the unique schema - we use this to limit the schemas we load from the database
|
|
schemas := maps.Keys(connectionSchemaMap)
|
|
|
|
// build a query to retrieve these schemas
|
|
query := c.buildSchemasQuery(schemas...)
|
|
|
|
// build schema metadata from query result
|
|
metadata, err := db_common.LoadSchemaMetadata(ctx, mgmtConn.Conn(), query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// we now need to add in all other schemas which have the same schemas as those we have loaded
|
|
for loadedSchema, otherSchemas := range connectionSchemaMap {
|
|
// all 'otherSchema's have the same schema as loadedSchema
|
|
exemplarSchema, ok := metadata.Schemas[loadedSchema]
|
|
if !ok {
|
|
// should can happen in the case of a dynamic plugin with no tables - use empty schema
|
|
exemplarSchema = make(map[string]db_common.TableSchema)
|
|
}
|
|
|
|
for _, s := range otherSchemas {
|
|
metadata.Schemas[s] = exemplarSchema
|
|
}
|
|
}
|
|
|
|
return metadata, nil
|
|
}
|
|
|
|
func (c *DbClient) GetSchemaFromDBLegacy(ctx context.Context, conn *pgxpool.Conn) (*db_common.SchemaMetadata, error) {
|
|
// build a query to retrieve these schemas
|
|
query := c.buildSchemasQueryLegacy()
|
|
|
|
// build schema metadata from query result
|
|
return db_common.LoadSchemaMetadata(ctx, conn.Conn(), query)
|
|
}
|
|
|
|
// refreshDbClient terminates the current connection and opens up a new connection to the service.
|
|
func (c *DbClient) ResetPools(ctx context.Context) {
|
|
log.Println("[TRACE] db_client.ResetPools start")
|
|
defer log.Println("[TRACE] db_client.ResetPools end")
|
|
|
|
c.userPool.Reset()
|
|
c.managementPool.Reset()
|
|
}
|
|
|
|
func (c *DbClient) buildSchemasQuery(schemas ...string) string {
|
|
for idx, s := range schemas {
|
|
schemas[idx] = fmt.Sprintf("'%s'", s)
|
|
}
|
|
|
|
// build the schemas filter clause
|
|
schemaClause := ""
|
|
if len(schemas) > 0 {
|
|
schemaClause = fmt.Sprintf(`
|
|
cols.table_schema in (%s)
|
|
OR`, strings.Join(schemas, ","))
|
|
}
|
|
|
|
query := fmt.Sprintf(`
|
|
SELECT
|
|
table_name,
|
|
column_name,
|
|
column_default,
|
|
is_nullable,
|
|
data_type,
|
|
udt_name,
|
|
table_schema,
|
|
(COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment,
|
|
(COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment
|
|
FROM
|
|
information_schema.columns cols
|
|
LEFT JOIN
|
|
pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema
|
|
LEFT JOIN
|
|
pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid
|
|
WHERE %s
|
|
LEFT(cols.table_schema,8) = 'pg_temp_'
|
|
`, schemaClause)
|
|
return query
|
|
}
|
|
func (c *DbClient) buildSchemasQueryLegacy() string {
|
|
|
|
query := `
|
|
WITH distinct_schema AS (
|
|
SELECT DISTINCT(foreign_table_schema)
|
|
FROM
|
|
information_schema.foreign_tables
|
|
WHERE
|
|
foreign_table_schema <> 'steampipe_command'
|
|
)
|
|
SELECT
|
|
table_name,
|
|
column_name,
|
|
column_default,
|
|
is_nullable,
|
|
data_type,
|
|
udt_name,
|
|
table_schema,
|
|
(COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment,
|
|
(COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment
|
|
FROM
|
|
information_schema.columns cols
|
|
LEFT JOIN
|
|
pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema
|
|
LEFT JOIN
|
|
pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid
|
|
WHERE
|
|
cols.table_schema in (select * from distinct_schema)
|
|
OR
|
|
LEFT(cols.table_schema,8) = 'pg_temp_'
|
|
|
|
`
|
|
return query
|
|
}
|