Use single Steampipe Postgres notification channel (#3191)

This commit is contained in:
kaidaguerre
2023-03-07 17:34:41 +00:00
committed by GitHub
parent 16fd991ee1
commit 075fafec09
16 changed files with 115 additions and 61 deletions

View File

@@ -2,6 +2,8 @@ package cmd
import ( import (
"fmt" "fmt"
"github.com/spf13/viper"
"github.com/turbot/steampipe/pkg/constants/runtime"
"log" "log"
"os" "os"
"path/filepath" "path/filepath"
@@ -27,7 +29,7 @@ func pluginManagerCmd() *cobra.Command {
Run: runPluginManagerCmd, Run: runPluginManagerCmd,
Hidden: true, Hidden: true,
} }
cmdconfig.OnCmd(cmd) cmdconfig.OnCmd(cmd).AddStringFlag(constants.ArgAppName, "", "The app name to use for database connections")
return cmd return cmd
} }
@@ -44,6 +46,12 @@ func runPluginManagerCmd(cmd *cobra.Command, args []string) {
os.Exit(1) os.Exit(1)
} }
// the CLI will pass the Postgress AppName to use for db clients - this is to ensure the CLI does not hold up
// shutting down the DB because of connections we have open (but will close)
if viper.IsSet(constants.ArgAppName) {
runtime.PgClientAppName = viper.GetString(constants.ArgAppName)
}
configMap := connectionwatcher.NewConnectionConfigMap(steampipeConfig.Connections) configMap := connectionwatcher.NewConnectionConfigMap(steampipeConfig.Connections)
log.Printf("[TRACE] loaded config map: %s", strings.Join(steampipeConfig.ConnectionNames(), ",")) log.Printf("[TRACE] loaded config map: %s", strings.Join(steampipeConfig.ConnectionNames(), ","))

View File

@@ -57,8 +57,8 @@ const (
ArgModLocation = "mod-location" ArgModLocation = "mod-location"
ArgSnapshotLocation = "snapshot-location" ArgSnapshotLocation = "snapshot-location"
ArgSnapshotTitle = "snapshot-title" ArgSnapshotTitle = "snapshot-title"
ArgDatabaseStartTimeout = "database-start-timeout" ArgDatabaseStartTimeout = "database-start-timeout"
ArgAppName = "app-name"
) )
// metaquery mode arguments // metaquery mode arguments

View File

@@ -1,5 +1,5 @@
package constants package constants
const ( const (
NotificationConnectionUpdate = "connection_update" PostgresNotificationChannel = "steampipe_notification"
) )

View File

@@ -69,14 +69,14 @@ type CreateDbOptions struct {
DatabaseName, Username string DatabaseName, Username string
} }
// createLocalDbClient connects and returns a connection to the given database using // CreateLocalDbConnection connects and returns a connection to the given database using
// the provided username // the provided username
// if the database is not provided (empty), it connects to the default database in the service // if the database is not provided (empty), it connects to the default database in the service
// that was created during installation. // that was created during installation.
// NOTE: no session data callback is used - no sesison data will be present // NOTE: no session data callback is used - no sesison data will be present
func createLocalDbClient(ctx context.Context, opts *CreateDbOptions) (*pgx.Conn, error) { func CreateLocalDbConnection(ctx context.Context, opts *CreateDbOptions) (*pgx.Conn, error) {
utils.LogTime("db.createLocalDbClient start") utils.LogTime("db.CreateLocalDbConnection start")
defer utils.LogTime("db.createLocalDbClient end") defer utils.LogTime("db.CreateLocalDbConnection end")
psqlInfo, err := getLocalSteampipeConnectionString(opts) psqlInfo, err := getLocalSteampipeConnectionString(opts)
if err != nil { if err != nil {

View File

@@ -8,7 +8,7 @@ import (
) )
func executeSqlAsRoot(ctx context.Context, statements ...string) ([]pgconn.CommandTag, error) { func executeSqlAsRoot(ctx context.Context, statements ...string) ([]pgconn.CommandTag, error) {
rootClient, err := createLocalDbClient(ctx, &CreateDbOptions{Username: constants.DatabaseSuperUser}) rootClient, err := CreateLocalDbConnection(ctx, &CreateDbOptions{Username: constants.DatabaseSuperUser})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -67,7 +67,7 @@ func NewLocalClient(ctx context.Context, invoker constants.Invoker, onConnection
} }
// Close implements Client // Close implements Client
// close the connection to the database and shuts down the backend // close the connection to the database and shuts down the backend if we are the last connection
func (c *LocalDbClient) Close(ctx context.Context) error { func (c *LocalDbClient) Close(ctx context.Context) error {
log.Printf("[TRACE] close local client %p", c) log.Printf("[TRACE] close local client %p", c)
if c.client != nil { if c.client != nil {

View File

@@ -108,7 +108,7 @@ func (c *LocalDbClient) executeConnectionUpdateQueries(ctx context.Context, conn
defer utils.LogTime("db.executeConnectionUpdateQueries start") defer utils.LogTime("db.executeConnectionUpdateQueries start")
res := &steampipeconfig.RefreshConnectionResult{} res := &steampipeconfig.RefreshConnectionResult{}
rootClient, err := createLocalDbClient(ctx, &CreateDbOptions{Username: constants.DatabaseSuperUser}) rootClient, err := CreateLocalDbConnection(ctx, &CreateDbOptions{Username: constants.DatabaseSuperUser})
if err != nil { if err != nil {
res.Error = err res.Error = err
return res return res
@@ -178,7 +178,8 @@ func executeUpdateQueries(ctx context.Context, rootClient *pgx.Conn, failures []
for _, failure := range failures { for _, failure := range failures {
log.Printf("[TRACE] remove schema for connection failing validation connection %s, plugin Name %s\n ", failure.ConnectionName, failure.Plugin) log.Printf("[TRACE] remove schema for connection failing validation connection %s, plugin Name %s\n ", failure.ConnectionName, failure.Plugin)
if failure.ShouldDropIfExists { if failure.ShouldDropIfExists {
statements := []string{"lock table pg_namespace;", statements := []string{
"lock table pg_namespace;",
getDeleteConnectionQuery(failure.ConnectionName), getDeleteConnectionQuery(failure.ConnectionName),
} }
_, err := executeSqlInTransaction(ctx, rootClient, statements...) _, err := executeSqlInTransaction(ctx, rootClient, statements...)

View File

@@ -79,7 +79,7 @@ func StartServices(ctx context.Context, port int, listen StartListenType, invoke
if res.DbState == nil { if res.DbState == nil {
res = startDB(ctx, port, listen, invoker) res = startDB(ctx, port, listen, invoker)
} else { } else {
rootClient, err := createLocalDbClient(ctx, &CreateDbOptions{DatabaseName: res.DbState.Database, Username: constants.DatabaseSuperUser}) rootClient, err := CreateLocalDbConnection(ctx, &CreateDbOptions{DatabaseName: res.DbState.Database, Username: constants.DatabaseSuperUser})
if err != nil { if err != nil {
res.Error = err res.Error = err
res.Status = ServiceFailedToStart res.Status = ServiceFailedToStart
@@ -221,7 +221,7 @@ func startDB(ctx context.Context, port int, listen StartListenType, invoker cons
} }
func ensureService(ctx context.Context, databaseName string) error { func ensureService(ctx context.Context, databaseName string) error {
rootClient, err := createLocalDbClient(ctx, &CreateDbOptions{DatabaseName: databaseName, Username: constants.DatabaseSuperUser}) rootClient, err := CreateLocalDbConnection(ctx, &CreateDbOptions{DatabaseName: databaseName, Username: constants.DatabaseSuperUser})
if err != nil { if err != nil {
return err return err
} }
@@ -426,7 +426,7 @@ func traceoutServiceLogs(logChannel chan string, stopLogStreamFn func()) {
} }
func setServicePassword(ctx context.Context, password string) error { func setServicePassword(ctx context.Context, password string) error {
connection, err := createLocalDbClient(ctx, &CreateDbOptions{DatabaseName: "postgres", Username: constants.DatabaseSuperUser}) connection, err := CreateLocalDbConnection(ctx, &CreateDbOptions{DatabaseName: "postgres", Username: constants.DatabaseSuperUser})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -99,7 +99,7 @@ func GetClientCount(ctx context.Context) (*ClientCount, error) {
utils.LogTime("db_local.GetClientCount start") utils.LogTime("db_local.GetClientCount start")
defer utils.LogTime(fmt.Sprintf("db_local.GetClientCount end")) defer utils.LogTime(fmt.Sprintf("db_local.GetClientCount end"))
rootClient, err := createLocalDbClient(ctx, &CreateDbOptions{Username: constants.DatabaseSuperUser}) rootClient, err := CreateLocalDbConnection(ctx, &CreateDbOptions{Username: constants.DatabaseSuperUser})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -16,13 +16,14 @@ import (
"github.com/alecthomas/chroma/lexers" "github.com/alecthomas/chroma/lexers"
"github.com/alecthomas/chroma/styles" "github.com/alecthomas/chroma/styles"
"github.com/c-bata/go-prompt" "github.com/c-bata/go-prompt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/turbot/go-kit/helpers" "github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/pkg/cmdconfig" "github.com/turbot/steampipe/pkg/cmdconfig"
"github.com/turbot/steampipe/pkg/constants" "github.com/turbot/steampipe/pkg/constants"
"github.com/turbot/steampipe/pkg/db/db_common" "github.com/turbot/steampipe/pkg/db/db_common"
"github.com/turbot/steampipe/pkg/db/db_local"
"github.com/turbot/steampipe/pkg/display" "github.com/turbot/steampipe/pkg/display"
"github.com/turbot/steampipe/pkg/error_helpers" "github.com/turbot/steampipe/pkg/error_helpers"
"github.com/turbot/steampipe/pkg/query" "github.com/turbot/steampipe/pkg/query"
@@ -60,6 +61,7 @@ type InteractiveClient struct {
// this is tied to a context which remaing valid throughout the life of the // this is tied to a context which remaing valid throughout the life of the
// interactive session // interactive session
cancelNotificationListener context.CancelFunc cancelNotificationListener context.CancelFunc
// channel used internally to pass the initialisation result // channel used internally to pass the initialisation result
initResultChan chan *db_common.InitResult initResultChan chan *db_common.InitResult
// flag set when initialisation is complete (with or without errors) // flag set when initialisation is complete (with or without errors)
@@ -629,58 +631,72 @@ func (c *InteractiveClient) startCancelHandler() chan bool {
func (c *InteractiveClient) listenToPgNotifications(ctx context.Context) error { func (c *InteractiveClient) listenToPgNotifications(ctx context.Context) error {
log.Printf("[TRACE] InteractiveClient listenToPgNotifications") log.Printf("[TRACE] InteractiveClient listenToPgNotifications")
conn, err := c.getNotificationConnection(ctx)
for ctx.Err() == nil { for ctx.Err() == nil {
conn, err := c.getNotificationConnection(ctx)
if err != nil { if err != nil {
return err return err
} }
log.Printf("[TRACE] Wait for notification") log.Printf("[TRACE] Wait for notification")
notification, err := conn.Conn().WaitForNotification(ctx) notification, err := conn.WaitForNotification(ctx)
if err != nil && !error_helpers.IsContextCancelledError(err) { if err != nil && !error_helpers.IsContextCancelledError(err) {
log.Printf("[INFO] Error waiting for notification: %s", err) log.Printf("[INFO] Error waiting for notification: %s", err)
} }
conn.Release()
if notification != nil { if notification != nil {
c.handleConnectionUpdateNotification(ctx, notification) c.handlePostgresNotification(ctx, notification)
} }
log.Printf("[TRACE] Handled notification")
} }
log.Printf("[TRACE] InteractiveClient listenToPgNotifications DONE") conn.Close(ctx)
log.Printf("[TRACE] InteractiveClient listenToPgNotifications DONE")
return nil return nil
} }
func (c *InteractiveClient) getNotificationConnection(ctx context.Context) (*pgxpool.Conn, error) { func (c *InteractiveClient) getNotificationConnection(ctx context.Context) (*pgx.Conn, error) {
sessionResult := c.client().AcquireSession(ctx) conn, err := db_local.CreateLocalDbConnection(ctx, &db_local.CreateDbOptions{Username: constants.DatabaseUser})
if err != nil {
if sessionResult.Error != nil { return nil, err
return nil, fmt.Errorf("error acquiring database connection to listenToPgNotifications to notifications, %s", sessionResult.Error.Error())
} }
conn := sessionResult.Session.Connection listenSql := fmt.Sprintf("listen %s", constants.PostgresNotificationChannel)
_, err = conn.Exec(context.Background(), listenSql)
listenSql := fmt.Sprintf("listen %s", constants.NotificationConnectionUpdate)
_, err := conn.Exec(context.Background(), listenSql)
if err != nil { if err != nil {
log.Printf("[INFO] Error listening to schema channel: %s", err) log.Printf("[INFO] Error listening to schema channel: %s", err)
conn.Release() conn.Close(ctx)
return nil, err return nil, err
} }
return conn, nil return conn, nil
} }
func (c *InteractiveClient) handleConnectionUpdateNotification(ctx context.Context, notification *pgconn.Notification) { func (c *InteractiveClient) handlePostgresNotification(ctx context.Context, notification *pgconn.Notification) {
if notification == nil { if notification == nil {
return return
} }
log.Printf("[TRACE] handleConnectionUpdateNotification: %s", notification.Payload) log.Printf("[TRACE] handleConnectionUpdateNotification: %s", notification.Payload)
n := &steampipeconfig.ConnectionUpdateNotification{} n := &steampipeconfig.PostgresNotification{}
err := json.Unmarshal([]byte(notification.Payload), n) err := json.Unmarshal([]byte(notification.Payload), n)
if err != nil { if err != nil {
log.Printf("[INFO] Error unmarshalling notification: %s", err) log.Printf("[INFO] Error unmarshalling notification: %s", err)
return return
} }
switch n.Type {
case steampipeconfig.PgNotificationSchemaUpdate:
// unmarshal the notification again, into the correct type
schemaUpdateNotification := &steampipeconfig.SchemaUpdateNotification{}
if err := json.Unmarshal([]byte(notification.Payload), schemaUpdateNotification); err != nil {
log.Printf("[INFO] Error unmarshalling notification: %s", err)
return
}
c.handleConnectionUpdateNotification(ctx, schemaUpdateNotification)
}
}
func (c *InteractiveClient) handleConnectionUpdateNotification(ctx context.Context, notification *steampipeconfig.SchemaUpdateNotification) {
// at present, we do not actually use the payload, we just do a brute force reload
// as an optimization we could look at the updates and only reload the required schemas
// reload the connection data map // reload the connection data map
// first load foreign schema names // first load foreign schema names

View File

@@ -108,6 +108,9 @@ func (i *InitData) init(parentCtx context.Context, args []string) {
i.Result.AddWarnings(errAndWarnings.Warnings...) i.Result.AddWarnings(errAndWarnings.Warnings...)
i.Workspace = w i.Workspace = w
// set max DB connections to 1
viper.Set(constants.ArgMaxParallel, 1)
statushooks.SetStatus(ctx, "Resolving arguments") statushooks.SetStatus(ctx, "Resolving arguments")
// convert the query or sql file arg into an array of executable queries - check names queries in the current workspace // convert the query or sql file arg into an array of executable queries - check names queries in the current workspace

View File

@@ -1,17 +0,0 @@
package steampipeconfig
import (
"golang.org/x/exp/maps"
)
type ConnectionUpdateNotification struct {
Update []string
Delete []string
}
func NewConnectionUpdateNotification(updates *ConnectionUpdates) *ConnectionUpdateNotification {
return &ConnectionUpdateNotification{
Update: maps.Keys(updates.Update),
Delete: maps.Keys(updates.Delete),
}
}

View File

@@ -2,6 +2,7 @@ package steampipeconfig
import ( import (
"fmt" "fmt"
"golang.org/x/exp/maps"
"log" "log"
"sort" "sort"
"strings" "strings"
@@ -211,6 +212,12 @@ func (u *ConnectionUpdates) String() string {
return op.String() return op.String()
} }
func (u *ConnectionUpdates) AsNotification() *SchemaUpdateNotification {
return NewSchemaUpdateNotification(
maps.Keys(u.Update),
maps.Keys(u.Delete))
}
func getSchemaHashesForDynamicSchemas(requiredConnectionData ConnectionDataMap, connectionState ConnectionDataMap) (map[string]string, map[string]*ConnectionPlugin, error) { func getSchemaHashesForDynamicSchemas(requiredConnectionData ConnectionDataMap, connectionState ConnectionDataMap) (map[string]string, map[string]*ConnectionPlugin, error) {
log.Printf("[TRACE] getSchemaHashesForDynamicSchemas") log.Printf("[TRACE] getSchemaHashesForDynamicSchemas")
// for every required connection, check the connection state to determine whether the schema mode is 'dynamic' // for every required connection, check the connection state to determine whether the schema mode is 'dynamic'

View File

@@ -0,0 +1,30 @@
package steampipeconfig
const PostgresNotificationStructVersion = 20230306
type PostgresNotificationType int
const (
PgNotificationSchemaUpdate PostgresNotificationType = iota + 1
)
type PostgresNotification struct {
StructVersion int
Type PostgresNotificationType
}
type SchemaUpdateNotification struct {
StructVersion int
Type PostgresNotificationType
Update []string
Delete []string
}
func NewSchemaUpdateNotification(update, delete []string) *SchemaUpdateNotification {
return &SchemaUpdateNotification{
StructVersion: PostgresNotificationStructVersion,
Type: PgNotificationSchemaUpdate,
Update: update,
Delete: delete,
}
}

View File

@@ -10,6 +10,8 @@ import (
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/turbot/steampipe-plugin-sdk/v5/logging" "github.com/turbot/steampipe-plugin-sdk/v5/logging"
"github.com/turbot/steampipe/pkg/constants"
"github.com/turbot/steampipe/pkg/constants/runtime"
"github.com/turbot/steampipe/pkg/filepaths" "github.com/turbot/steampipe/pkg/filepaths"
pb "github.com/turbot/steampipe/pluginmanager_service/grpc/proto" pb "github.com/turbot/steampipe/pluginmanager_service/grpc/proto"
pluginshared "github.com/turbot/steampipe/pluginmanager_service/grpc/shared" pluginshared "github.com/turbot/steampipe/pluginmanager_service/grpc/shared"
@@ -42,7 +44,10 @@ func StartNewInstance(steampipeExecutablePath string) error {
func start(steampipeExecutablePath string) error { func start(steampipeExecutablePath string) error {
// note: we assume the install dir has been assigned to file_paths.SteampipeDir // note: we assume the install dir has been assigned to file_paths.SteampipeDir
// - this is done both by the FDW and Steampipe // - this is done both by the FDW and Steampipe
pluginManagerCmd := exec.Command(steampipeExecutablePath, "plugin-manager", "--install-dir", filepaths.SteampipeDir) pluginManagerCmd := exec.Command(steampipeExecutablePath,
"plugin-manager",
"--"+constants.ArgInstallDir, filepaths.SteampipeDir,
"--"+constants.ArgAppName, runtime.PgClientAppName)
// set attributes on the command to ensure the process is not shutdown when its parent terminates // set attributes on the command to ensure the process is not shutdown when its parent terminates
pluginManagerCmd.SysProcAttr = &syscall.SysProcAttr{ pluginManagerCmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true, Setpgid: true,

View File

@@ -158,14 +158,12 @@ func (m *PluginManager) OnSchemaChanged(refreshResult *steampipeconfig.RefreshCo
client, err := db_local.NewLocalClient(ctx, constants.InvokerConnectionWatcher, nil) client, err := db_local.NewLocalClient(ctx, constants.InvokerConnectionWatcher, nil)
if err != nil { if err != nil {
log.Printf("[TRACE] error creating client to handle updated connection config: %s", err.Error()) log.Printf("[TRACE] error creating client to handle updated connection config: %s", err.Error())
return
} }
defer client.Close(ctx) defer client.Close(ctx)
notification := steampipeconfig.NewConnectionUpdateNotification(refreshResult.Updates)
if err != nil { notification := refreshResult.Updates.AsNotification()
log.Printf("[WARN] Error sending notification: %s", err) m.notifySchemaChange(notification, client)
} else {
m.notifySchemaChange(notification, client)
}
} }
func (m *PluginManager) Shutdown(req *proto.ShutdownRequest) (resp *proto.ShutdownResponse, err error) { func (m *PluginManager) Shutdown(req *proto.ShutdownRequest) (resp *proto.ShutdownResponse, err error) {
@@ -736,19 +734,22 @@ func (m *PluginManager) updateConnectionSchema(ctx context.Context, connection s
} }
// also send a postgres notification // also send a postgres notification
m.notifySchemaChange(&steampipeconfig.ConnectionUpdateNotification{Update: []string{connection}}, client) notification := steampipeconfig.NewSchemaUpdateNotification([]string{connection}, nil)
m.notifySchemaChange(notification, client)
} }
// send a postgres notification that the schema has chganged // send a postgres notification that the schema has chganged
func (m *PluginManager) notifySchemaChange(notification *steampipeconfig.ConnectionUpdateNotification, client *db_local.LocalDbClient) { func (m *PluginManager) notifySchemaChange(notification any, client *db_local.LocalDbClient) {
notificationBytes, err := json.Marshal(notification) notificationBytes, err := json.Marshal(notification)
if err != nil { if err != nil {
log.Printf("[WARN] Error marshalling schema change notification notification: %s", err) log.Printf("[TRACE] error marshalling Postgres notification: %s", err.Error())
return return
} }
log.Printf("[WARN] Send update notification") log.Printf("[WARN] Send update notification")
sql := fmt.Sprintf("select pg_notify('%s', $1)", constants.NotificationConnectionUpdate) sql := fmt.Sprintf("select pg_notify('%s', $1)", constants.PostgresNotificationChannel)
_, err = client.ExecuteSync(context.Background(), sql, notificationBytes) _, err = client.ExecuteSync(context.Background(), sql, notificationBytes)
if err != nil { if err != nil {
log.Printf("[WARN] Error sending notification: %s", err) log.Printf("[WARN] Error sending notification: %s", err)