Files
steampipe/pkg/db/db_client/db_client.go
Nathan Wallace b6aa7fc7d5 Fix #4698: Add nil checks to ResetPools method (#4732)
* Add test for #4698: ResetPools should handle nil pools gracefully

* Fix #4698: Add nil checks to ResetPools method
2025-11-15 11:55:33 -05:00

326 lines
10 KiB
Go

package db_client
import (
"context"
"fmt"
"log"
"strings"
"sync"
"sync/atomic"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/spf13/viper"
pconstants "github.com/turbot/pipe-fittings/v2/constants"
"github.com/turbot/pipe-fittings/v2/utils"
"github.com/turbot/steampipe/v2/pkg/constants"
"github.com/turbot/steampipe/v2/pkg/db/db_common"
"github.com/turbot/steampipe/v2/pkg/serversettings"
"github.com/turbot/steampipe/v2/pkg/steampipeconfig"
"golang.org/x/exp/maps"
"golang.org/x/sync/semaphore"
)
// DbClient wraps over `sql.DB` and gives an interface to the database
type DbClient struct {
connectionString string
// connection userPool for user initiated queries
userPool *pgxpool.Pool
// connection used to run system/plumbing queries (connection state, server settings)
managementPool *pgxpool.Pool
// the settings of the server that this client is connected to
serverSettings *db_common.ServerSettings
// this flag is set if the service that this client
// is connected to is running in the same physical system
isLocalService bool
// concurrency management for db session access
parallelSessionInitLock *semaphore.Weighted
// map of database sessions, keyed to the backend_pid in postgres
// used to update session search path where necessary
// Session lifecycle: entries are added when connections are established and automatically
// removed via a pgxpool BeforeClose callback when connections are closed by the pool.
// This prevents memory accumulation from stale connection entries (see issue #3737)
sessions map[uint32]*db_common.DatabaseSession
// allows locked access to the 'sessions' map
sessionsMutex *sync.Mutex
// if a custom search path or a prefix is used, store it here
customSearchPath []string
searchPathPrefix []string
// the default user search path
userSearchPath []string
// disable timing - set whilst in process of querying the timing
disableTiming atomic.Bool
onConnectionCallback DbConnectionCallback
}
func NewDbClient(ctx context.Context, connectionString string, opts ...ClientOption) (_ *DbClient, err error) {
utils.LogTime("db_client.NewDbClient start")
defer utils.LogTime("db_client.NewDbClient end")
client := &DbClient{
// a weighted semaphore to control the maximum number parallel
// initializations under way
parallelSessionInitLock: semaphore.NewWeighted(constants.MaxParallelClientInits),
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
connectionString: connectionString,
}
defer func() {
if err != nil {
// try closing the client
client.Close(ctx)
}
}()
config := clientConfig{}
for _, o := range opts {
o(&config)
}
if err := client.establishConnectionPool(ctx, config); err != nil {
return nil, err
}
// load up the server settings
if err := client.loadServerSettings(ctx); err != nil {
return nil, err
}
// set user search path
if err := client.LoadUserSearchPath(ctx); err != nil {
return nil, err
}
// populate customSearchPath
if err := client.SetRequiredSessionSearchPath(ctx); err != nil {
return nil, err
}
return client, nil
}
func (c *DbClient) closePools() {
if c.userPool != nil {
c.userPool.Close()
}
if c.managementPool != nil {
c.managementPool.Close()
}
}
func (c *DbClient) loadServerSettings(ctx context.Context) error {
serverSettings, err := serversettings.Load(ctx, c.managementPool)
if err != nil {
if notFound := db_common.IsRelationNotFoundError(err); notFound {
// when connecting to pre-0.21.0 services, the steampipe_server_settings table will not be available.
// this is expected and not an error
// code which uses steampipe_server_settings should handle this
log.Printf("[TRACE] could not find %s.%s table. skipping\n", constants.InternalSchema, constants.ServerSettingsTable)
return nil
}
return err
}
c.serverSettings = serverSettings
log.Println("[TRACE] loaded server settings:", serverSettings)
return nil
}
func (c *DbClient) shouldFetchTiming() bool {
// check for override flag (this is to prevent timing being fetched when we read the timing metadata table)
if c.disableTiming.Load() {
return false
}
// only fetch timing if timing flag is set, or output is JSON
return (viper.GetString(pconstants.ArgTiming) != pconstants.ArgOff) ||
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
}
func (c *DbClient) shouldFetchVerboseTiming() bool {
return (viper.GetString(pconstants.ArgTiming) == pconstants.ArgVerbose) ||
(viper.GetString(pconstants.ArgOutput) == constants.OutputFormatJSON)
}
// ServerSettings returns the settings of the steampipe service that this DbClient is connected to
//
// Keep in mind that when connecting to pre-0.21.x servers, the server_settings data is not available. This is expected.
// Code which read server_settings should take this into account.
func (c *DbClient) ServerSettings() *db_common.ServerSettings {
return c.serverSettings
}
// RegisterNotificationListener has an empty implementation
// NOTE: we do not (currently) support notifications from remote connections
func (c *DbClient) RegisterNotificationListener(func(notification *pgconn.Notification)) {}
// Close implements Client
// closes the connection to the database and shuts down the backend
func (c *DbClient) Close(context.Context) error {
log.Printf("[TRACE] DbClient.Close %v", c.userPool)
c.closePools()
// nullify active sessions, since with the closing of the pools
// none of the sessions will be valid anymore
// Acquire mutex to prevent concurrent access to sessions map
c.sessionsMutex.Lock()
c.sessions = nil
c.sessionsMutex.Unlock()
return nil
}
// GetSchemaFromDB retrieves schemas for all steampipe connections (EXCEPT DISABLED CONNECTIONS)
// NOTE: it optimises the schema extraction by extracting schema information for
// connections backed by distinct plugins and then fanning back out.
func (c *DbClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetadata, error) {
log.Printf("[INFO] DbClient GetSchemaFromDB")
mgmtConn, err := c.managementPool.Acquire(ctx)
if err != nil {
return nil, err
}
defer mgmtConn.Release()
// for optimisation purposes, try to load connection state and build a map of schemas to load
// (if we are connected to a remote server running an older CLI,
// this load may fail, in which case bypass the optimisation)
connectionStateMap, err := steampipeconfig.LoadConnectionState(ctx, mgmtConn.Conn(), steampipeconfig.WithWaitUntilLoading())
// NOTE: if we failed to load connection state, this may be because we are connected to an older version of the CLI
// use legacy (v0.19.x) schema loading code
if err != nil {
return c.GetSchemaFromDBLegacy(ctx, mgmtConn)
}
// build a ConnectionSchemaMap object to identify the schemas to load
connectionSchemaMap := steampipeconfig.NewConnectionSchemaMap(ctx, connectionStateMap, c.GetRequiredSessionSearchPath())
if err != nil {
return nil, err
}
// get the unique schema - we use this to limit the schemas we load from the database
schemas := maps.Keys(connectionSchemaMap)
// build a query to retrieve these schemas
query := c.buildSchemasQuery(schemas...)
// build schema metadata from query result
metadata, err := db_common.LoadSchemaMetadata(ctx, mgmtConn.Conn(), query)
if err != nil {
return nil, err
}
// we now need to add in all other schemas which have the same schemas as those we have loaded
for loadedSchema, otherSchemas := range connectionSchemaMap {
// all 'otherSchema's have the same schema as loadedSchema
exemplarSchema, ok := metadata.Schemas[loadedSchema]
if !ok {
// should can happen in the case of a dynamic plugin with no tables - use empty schema
exemplarSchema = make(map[string]db_common.TableSchema)
}
for _, s := range otherSchemas {
metadata.Schemas[s] = exemplarSchema
}
}
return metadata, nil
}
func (c *DbClient) GetSchemaFromDBLegacy(ctx context.Context, conn *pgxpool.Conn) (*db_common.SchemaMetadata, error) {
// build a query to retrieve these schemas
query := c.buildSchemasQueryLegacy()
// build schema metadata from query result
return db_common.LoadSchemaMetadata(ctx, conn.Conn(), query)
}
// refreshDbClient terminates the current connection and opens up a new connection to the service.
func (c *DbClient) ResetPools(ctx context.Context) {
log.Println("[TRACE] db_client.ResetPools start")
defer log.Println("[TRACE] db_client.ResetPools end")
if c.userPool != nil {
c.userPool.Reset()
}
if c.managementPool != nil {
c.managementPool.Reset()
}
}
func (c *DbClient) buildSchemasQuery(schemas ...string) string {
for idx, s := range schemas {
schemas[idx] = fmt.Sprintf("'%s'", s)
}
// build the schemas filter clause
schemaClause := ""
if len(schemas) > 0 {
schemaClause = fmt.Sprintf(`
cols.table_schema in (%s)
OR`, strings.Join(schemas, ","))
}
query := fmt.Sprintf(`
SELECT
table_name,
column_name,
column_default,
is_nullable,
data_type,
udt_name,
table_schema,
(COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment,
(COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment
FROM
information_schema.columns cols
LEFT JOIN
pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema
LEFT JOIN
pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid
WHERE %s
LEFT(cols.table_schema,8) = 'pg_temp_'
`, schemaClause)
return query
}
func (c *DbClient) buildSchemasQueryLegacy() string {
query := `
WITH distinct_schema AS (
SELECT DISTINCT(foreign_table_schema)
FROM
information_schema.foreign_tables
WHERE
foreign_table_schema <> 'steampipe_command'
)
SELECT
table_name,
column_name,
column_default,
is_nullable,
data_type,
udt_name,
table_schema,
(COALESCE(pg_catalog.col_description(c.oid, cols.ordinal_position :: int),'')) as column_comment,
(COALESCE(pg_catalog.obj_description(c.oid),'')) as table_comment
FROM
information_schema.columns cols
LEFT JOIN
pg_catalog.pg_namespace nsp ON nsp.nspname = cols.table_schema
LEFT JOIN
pg_catalog.pg_class c ON c.relname = cols.table_name AND c.relnamespace = nsp.oid
WHERE
cols.table_schema in (select * from distinct_schema)
OR
LEFT(cols.table_schema,8) = 'pg_temp_'
`
return query
}