From 0da4feb8941765ea25bd7d918aa28a6fd386bc8d Mon Sep 17 00:00:00 2001 From: kaidaguerre Date: Mon, 12 Apr 2021 15:32:48 +0100 Subject: [PATCH] Add search-path and search-path-prefix arguments to query command. Closes #358 --- cmd/plugin.go | 7 +- cmd/query.go | 6 +- cmd/service.go | 4 +- cmdconfig/viper.go | 2 +- constants/args.go | 16 ++--- constants/config_keys.go | 5 +- db/client.go | 132 +++++++++-------------------------- db/client_connections.go | 45 ++++++------ db/client_execute.go | 2 +- db/client_search_path.go | 144 +++++++++++++++++++++++++++++++++++++++ db/install.go | 2 +- db/query.go | 14 +--- db/start.go | 7 +- db/stop.go | 2 +- 14 files changed, 228 insertions(+), 160 deletions(-) create mode 100644 db/client_search_path.go diff --git a/cmd/plugin.go b/cmd/plugin.go index 34200493d..6a24caa8e 100644 --- a/cmd/plugin.go +++ b/cmd/plugin.go @@ -444,13 +444,14 @@ func refreshConnectionsIfNecessary(reports []display.InstallReport, isUpdate boo defer func() { db.Shutdown(client, db.InvokerPlugin) }() } - client, err = db.GetClient(false) + // TODO i think we can pass true here and not refresh below + client, err = db.NewClient(false) if err != nil { return err } // refresh connections - if err = client.RefreshConnections(); err != nil { + if _, err = client.RefreshConnections(); err != nil { return err } @@ -532,7 +533,7 @@ func getPluginConnectionMap() (map[string][]string, error) { defer func() { db.Shutdown(client, db.InvokerPlugin) }() } - client, err = db.GetClient(true) + client, err = db.NewClient(true) if err != nil { return nil, fmt.Errorf("Could not connect with steampipe service") } diff --git a/cmd/query.go b/cmd/query.go index 49106a5d3..f9efa77f6 100644 --- a/cmd/query.go +++ b/cmd/query.go @@ -48,7 +48,9 @@ Examples: AddStringFlag(constants.ArgSeparator, "", ",", "Separator string for csv output"). AddStringFlag(constants.ArgOutput, "", "table", "Output format: line, csv, json or table"). AddBoolFlag(constants.ArgTimer, "", false, "Turn on the timer which reports query time."). - AddStringSliceFlag(constants.ArgSqlFile, "", nil, "Specifies one or more sql files to execute.") + AddStringSliceFlag(constants.ArgSqlFile, "", nil, "Specifies one or more sql files to execute."). + AddStringSliceFlag(constants.ArgSearchPath, "", []string{}, "Set a custom search_path for the steampipe user for a query session (comma-separated)"). + AddStringSliceFlag(constants.ArgSearchPathPrefix, "", []string{}, "Set a prefix to the current search path for a query session (comma-separated)") return cmd } @@ -99,7 +101,7 @@ func getQueries(args []string) ([]string, error) { func runQuery(queryString string) { // set the flag to not show spinner showSpinner := queryString == "" - cmdconfig.Viper().Set(constants.ShowInteractiveOutputConfigKey, showSpinner) + cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, showSpinner) // the db executor sends result data over resultsStreamer resultsStreamer, err := db.ExecuteQuery(queryString) diff --git a/cmd/service.go b/cmd/service.go index b5b16c7c2..92ee3c398 100644 --- a/cmd/service.go +++ b/cmd/service.go @@ -7,9 +7,9 @@ import ( "strings" "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe-plugin-sdk/logging" - "github.com/turbot/steampipe/cmdconfig" "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/db" @@ -184,7 +184,7 @@ func runServiceRestartCmd(cmd *cobra.Command, args []string) { return } - stopStatus, err := db.StopDB(cmdconfig.Viper().GetBool(constants.ArgForce), db.InvokerService) + stopStatus, err := db.StopDB(viper.GetBool(constants.ArgForce), db.InvokerService) if err != nil { utils.ShowErrorWithMessage(err, "could not stop current instance") diff --git a/cmdconfig/viper.go b/cmdconfig/viper.go index 815dfd165..87d8b2152 100644 --- a/cmdconfig/viper.go +++ b/cmdconfig/viper.go @@ -14,7 +14,7 @@ import ( func InitViper() { v := viper.GetViper() // set defaults - v.Set(constants.ShowInteractiveOutputConfigKey, true) + v.Set(constants.ConfigKeyShowInteractiveOutput, true) if installDir, isSet := os.LookupEnv("STEAMPIPE_INSTALL_DIR"); isSet { v.SetDefault(constants.ArgInstallDir, installDir) diff --git a/constants/args.go b/constants/args.go index 96a41437d..198b7a327 100644 --- a/constants/args.go +++ b/constants/args.go @@ -18,16 +18,12 @@ const ( ArgListenAddress = "database-listen" ArgSearchPath = "search-path" ArgSearchPathPrefix = "search-path-prefix" - // search path set in the database config - ArgServiceSearchPath = "database.search-path" - // search path set in the terminal config - ArgSearchPathTerminal = "terminal.search-path" - ArgInvoker = "invoker" - ArgRefresh = "refresh" - ArgLogLevel = "log-level" - ArgUpdateCheck = "update-check" - ArgInstallDir = "install-dir" - ArgSqlFile = "sql-file" + ArgInvoker = "invoker" + ArgRefresh = "refresh" + ArgLogLevel = "log-level" + ArgUpdateCheck = "update-check" + ArgInstallDir = "install-dir" + ArgSqlFile = "sql-file" ) /// metaquery mode arguments diff --git a/constants/config_keys.go b/constants/config_keys.go index cbfb6233c..ba73f6686 100644 --- a/constants/config_keys.go +++ b/constants/config_keys.go @@ -1,6 +1,7 @@ package constants +// viper config keys const ( - // ShowInteractiveOutputConfigKey :: viper key - ShowInteractiveOutputConfigKey = "show-interactive-output" + ConfigKeyShowInteractiveOutput = "show-interactive-output" + ConfigKeyDatabaseSearchPath = "database.search-path" ) diff --git a/db/client.go b/db/client.go index 01fea91c9..3a46d8ee6 100644 --- a/db/client.go +++ b/db/client.go @@ -4,12 +4,8 @@ import ( "database/sql" "fmt" "log" - "sort" - "strings" "time" - "github.com/spf13/viper" - "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/schema" "github.com/turbot/steampipe/steampipeconfig" @@ -30,9 +26,9 @@ func (c *Client) close() { } } -// GetClient ensures that the database instance is running +// NewClient ensures that the database instance is running // and returns a `Client` to interact with it -func GetClient(autoRefreshConnections bool) (*Client, error) { +func NewClient(autoRefreshConnections bool) (*Client, error) { db, err := createSteampipeDbClient() if err != nil { return nil, err @@ -41,19 +37,29 @@ func GetClient(autoRefreshConnections bool) (*Client, error) { client.dbClient = db client.loadSchema() + var updatedConnections bool if autoRefreshConnections { - client.RefreshConnections() - refreshFunctions() + if updatedConnections, err = client.RefreshConnections(); err != nil { + client.close() + return nil, err + } + if err := refreshFunctions(); err != nil { + client.close() + return nil, err + } } - // load the connection state and cache it! - connectionMap, err := steampipeconfig.GetConnectionState(client.schemaMetadata.GetSchemas()) - if err != nil { - return nil, err - } - client.connectionMap = &connectionMap - if err := client.setClientSearchPath(); err != nil { - utils.ShowError(err) + // if we did NOT update connections, initialise the connection map and search path + if !updatedConnections { + // load the connection state and cache it! + connectionMap, err := steampipeconfig.GetConnectionState(client.schemaMetadata.GetSchemas()) + if err != nil { + return nil, err + } + client.connectionMap = &connectionMap + if err := client.setClientSearchPath(); err != nil { + utils.ShowError(err) + } } return client, nil @@ -63,6 +69,17 @@ func createSteampipeDbClient() (*sql.DB, error) { return createDbClient(constants.DatabaseName, constants.DatabaseUser) } +// close and reopen db client +func (c *Client) refreshDbClient() error { + c.dbClient.Close() + db, err := createSteampipeDbClient() + if err != nil { + return err + } + c.dbClient = db + return nil +} + func createSteampipeRootDbClient() (*sql.DB, error) { return createDbClient(constants.DatabaseName, constants.DatabaseSuperUser) } @@ -71,89 +88,6 @@ func createPostgresDbClient() (*sql.DB, error) { return createDbClient("postgres", constants.DatabaseSuperUser) } -// set the search path for this client -func (c *Client) setClientSearchPath() error { - var searchPath []string - - if viper.IsSet(constants.ArgSearchPath) { - searchPath = viper.GetStringSlice(constants.ArgSearchPath) - // add 'internal' schema as last schema in the search path - searchPath = append(searchPath, constants.FunctionSchema) - } else { - // so no search path was set in config - build a search path from the connection schemas - searchPath = c.getDefaultSearchPath(searchPath) - } - - // if a prefix was specified, add it - if viper.IsSet(constants.ArgSearchPathPrefix) { - prefixedSearchPath := viper.GetStringSlice(constants.ArgSearchPathPrefix) - for _, p := range searchPath { - if !helpers.StringSliceContains(prefixedSearchPath, p) { - prefixedSearchPath = append(prefixedSearchPath, p) - } - } - searchPath = prefixedSearchPath - } - - // escape the schema - for idx, path := range searchPath { - searchPath[idx] = PgEscapeName(path) - } - q := fmt.Sprintf("set search_path to %s", strings.Join(searchPath, ",")) - _, err := c.ExecuteSync(q) - - if err != nil { - return err - } - - c.schemaMetadata.SearchPath = searchPath - return nil -} - -// set the search path for the db service (by setting it on the steampipe user) -func (c *Client) setServiceSearchPath() error { - // set the search_path to the available foreign schemas - // or the one set by the user in config - var searchPath []string - - // since this is the service starting up, use the ArgServiceSearchPath config - // (this is the value specified in the database config) - if viper.IsSet(constants.ArgServiceSearchPath) { - searchPath = viper.GetStringSlice(constants.ArgServiceSearchPath) - } else { - // so no search path was set in config - build a search path from the connection schemas - searchPath = c.getDefaultSearchPath(searchPath) - } - - // escape the schema names - for idx, path := range searchPath { - searchPath[idx] = PgEscapeName(path) - } - - log.Println("[TRACE] setting service search path to", searchPath) - query := fmt.Sprintf( - "alter user %s set search_path to %s;", - constants.DatabaseUser, - strings.Join(searchPath, ","), - ) - _, err := c.ExecuteSync(query) - return err -} - -// build default search path from the connection schemas, bookended with public and internal -func (c *Client) getDefaultSearchPath(searchPath []string) []string { - searchPath = c.schemaMetadata.GetSchemas() - sort.Strings(searchPath) - // add the 'public' schema as the first schema in the search_path. This makes it - // easier for users to build and work with their own tables, and since it's normally - // empty, doesn't make using steampipe tables any more difficult. - searchPath = append([]string{"public"}, searchPath...) - // add 'internal' schema as last schema in the search path - searchPath = append(searchPath, constants.FunctionSchema) - - return searchPath -} - func createDbClient(dbname string, username string) (*sql.DB, error) { log.Println("[TRACE] createDbClient") diff --git a/db/client_connections.go b/db/client_connections.go index ba6a7b3aa..2c8cd64e3 100644 --- a/db/client_connections.go +++ b/db/client_connections.go @@ -15,7 +15,7 @@ import ( // RefreshConnections :: load required connections from config // and update the database schema and search path to reflect the required connections // return whether any changes have been mde -func (c *Client) RefreshConnections() error { +func (c *Client) RefreshConnections() (bool, error) { // load required connection from globab config requiredConnections := steampipeconfig.Config.Connections @@ -25,14 +25,14 @@ func (c *Client) RefreshConnections() error { // refresh the connection state file - the removes any connections which do not exist in the list of current schema updates, err := steampipeconfig.GetConnectionsToUpdate(schemas, requiredConnections) if err != nil { - return err + return false, err } - log.Printf("[TRACE] updates: %+v\n", updates) + log.Printf("[TRACE] RefreshConnections, updates: %+v\n", updates) missingCount := len(updates.MissingPlugins) if missingCount > 0 { // if any plugins are missing, error for now but we could prompt for an install - return fmt.Errorf("%d %s referenced in the connection config not installed: \n %v", + return false, fmt.Errorf("%d %s referenced in the connection config not installed: \n %v", missingCount, utils.Pluralize("plugin", missingCount), strings.Join(updates.MissingPlugins, "\n ")) @@ -50,7 +50,7 @@ func (c *Client) RefreshConnections() error { } }() // in query, this can only start when in interactive - if cmdconfig.Viper().GetBool(constants.ShowInteractiveOutputConfigKey) { + if cmdconfig.Viper().GetBool(constants.ConfigKeyShowInteractiveOutput) { spin := utils.ShowSpinner("Refreshing connections...") defer utils.StopSpinner(spin) } @@ -58,7 +58,7 @@ func (c *Client) RefreshConnections() error { // first instantiate connection plugins for all updates connectionPlugins, err := getConnectionPlugins(updates.Update) if err != nil { - return err + return false, err } // find any plugins which use a newer sdk version than steampipe. validationFailures, validatedUpdates, validatedPlugins := steampipeconfig.ValidatePlugins(updates.Update, connectionPlugins) @@ -71,37 +71,38 @@ func (c *Client) RefreshConnections() error { } for c := range updates.Delete { - log.Printf("[TRACE] delete %s\n ", c) + log.Printf("[TRACE] delete connection %s\n ", c) connectionQueries = append(connectionQueries, deleteConnectionQuery(c)...) } - connectionsToUpdate := len(connectionQueries) > 0 - if connectionsToUpdate { - // execute the connection queries - if err = executeConnectionQueries(connectionQueries, updates); err != nil { - return err - } - } else { - log.Println("[DEBUG] no connections to update") + if len(connectionQueries) == 0 { + log.Println("[TRACE] no connections to update") + return false, nil } - // reload the database schemas, since they have changed - // otherwise we wouldn't be here + // execute the connection queries + if err = executeConnectionQueries(connectionQueries, updates); err != nil { + return false, err + } + + // so there ARE connections to update + + // reload the database schemas, since they have changed - otherwise we wouldn't be here log.Println("[TRACE] reloading schema") c.loadSchema() - // set the search path with the updates + // update the service and client search paths (as long as they have NOT been explicitly set) log.Println("[TRACE] setting search path") c.setServiceSearchPath() c.setClientSearchPath() - // tell client to refresh schemas, connection map and set the search path + // finally update the connection map if err = c.updateConnectionMap(); err != nil { - return err + return false, err } - // indicate whether we have updated connections - return nil + return true, nil + } func (c *Client) updateConnectionMap() error { diff --git a/db/client_execute.go b/db/client_execute.go index 1232e6172..97b091253 100644 --- a/db/client_execute.go +++ b/db/client_execute.go @@ -41,7 +41,7 @@ func (c *Client) executeQuery(query string, countStream bool) (*results.QueryRes // start spinner after a short delay var spinner *spinner.Spinner - if cmdconfig.Viper().GetBool(constants.ShowInteractiveOutputConfigKey) { + if cmdconfig.Viper().GetBool(constants.ConfigKeyShowInteractiveOutput) { // if showspinner is false, the spinner gets created, but is never shown // so the s.Active() will always come back false . . . spinner = utils.StartSpinnerAfterDelay("Loading results...", constants.SpinnerShowTimeout, queryDone) diff --git a/db/client_search_path.go b/db/client_search_path.go new file mode 100644 index 000000000..4911e571e --- /dev/null +++ b/db/client_search_path.go @@ -0,0 +1,144 @@ +package db + +import ( + "fmt" + "log" + "sort" + "strings" + + "github.com/spf13/viper" + "github.com/turbot/go-kit/helpers" + "github.com/turbot/steampipe/constants" +) + +// set the search path for this client +// if either a search-path or search-path-prefix is set in config, set the search path +func (c *Client) setClientSearchPath() error { + searchPath := viper.GetStringSlice(constants.ArgSearchPath) + searchPathPrefix := viper.GetStringSlice(constants.ArgSearchPathPrefix) + + // HACK reopen db client so we take into account recent changes to service search path + if err := c.refreshDbClient(); err != nil { + return err + } + // if neither search-path or search-path-prefix are set in config, we do not need to do anything + // - just fall back to the service search path + if len(searchPath) == 0 && len(searchPathPrefix) == 0 { + return nil + } + + // if a search path was passed, add 'internal' to the end + if len(searchPath) > 0 { + // add 'internal' schema as last schema in the search path + searchPath = append(searchPath, constants.FunctionSchema) + } else { + // so no search path was set in config + // in this case we need to load the existing service search path + searchPath, _ = c.getCurrentSearchPath() + } + + // add in the prefix if present + searchPath = c.addSearchPathPrefix(searchPathPrefix, searchPath) + + // escape the schema + searchPath = escapeSearchPath(searchPath) + + // now construct and execute the query + q := fmt.Sprintf("set search_path to %s", strings.Join(searchPath, ",")) + _, err := c.ExecuteSync(q) + if err != nil { + return err + } + + // store search path on the client + c.schemaMetadata.SearchPath = searchPath + return nil +} + +// set the search path for the db service (by setting it on the steampipe user) +func (c *Client) setServiceSearchPath() error { + var searchPath []string + + // is there a service search path in the config? + // check ConfigKeyDatabaseSearchPath config (this is the value specified in the database config) + if viper.IsSet(constants.ConfigKeyDatabaseSearchPath) { + searchPath = viper.GetStringSlice(constants.ConfigKeyDatabaseSearchPath) + // add 'internal' schema as last schema in the search path + searchPath = append(searchPath, constants.FunctionSchema) + } else { + // no config set - set service search path to default + searchPath = c.getDefaultSearchPath() + } + + // escape the schema names + searchPath = escapeSearchPath(searchPath) + + log.Println("[TRACE] setting service search path to", searchPath) + + // now construct and execute the query + query := fmt.Sprintf( + "alter user %s set search_path to %s;", + constants.DatabaseUser, + strings.Join(searchPath, ","), + ) + _, err := c.ExecuteSync(query) + return err +} + +func (c *Client) addSearchPathPrefix(searchPathPrefix []string, searchPath []string) []string { + if len(searchPathPrefix) > 0 { + prefixedSearchPath := searchPathPrefix + for _, p := range searchPath { + if !helpers.StringSliceContains(prefixedSearchPath, p) { + prefixedSearchPath = append(prefixedSearchPath, p) + } + } + searchPath = prefixedSearchPath + } + return searchPath +} + +// build default search path from the connection schemas, bookended with public and internal +func (c *Client) getDefaultSearchPath() []string { + searchPath := c.schemaMetadata.GetSchemas() + sort.Strings(searchPath) + // add the 'public' schema as the first schema in the search_path. This makes it + // easier for users to build and work with their own tables, and since it's normally + // empty, doesn't make using steampipe tables any more difficult. + searchPath = append([]string{"public"}, searchPath...) + // add 'internal' schema as last schema in the search path + searchPath = append(searchPath, constants.FunctionSchema) + + return searchPath +} + +// query the database to get the current search path +func (c *Client) getCurrentSearchPath() ([]string, error) { + var currentSearchPath []string + var pathAsString string + row := c.dbClient.QueryRow("show search_path") + if row.Err() != nil { + return nil, row.Err() + } + err := row.Scan(&pathAsString) + if err != nil { + return nil, err + } + currentSearchPath = strings.Split(pathAsString, ",") + // unescape search path + for idx, p := range currentSearchPath { + p = strings.Join(strings.Split(p, "\""), "") + p = strings.TrimSpace(p) + currentSearchPath[idx] = p + } + return currentSearchPath, nil +} + +// apply postgres escaping to search path and remove whitespace +func escapeSearchPath(searchPath []string) []string { + res := make([]string, len(searchPath)) + for idx, path := range searchPath { + res[idx] = PgEscapeName(strings.TrimSpace(path)) + } + return res +} diff --git a/db/install.go b/db/install.go index cfde6860b..cb302f682 100644 --- a/db/install.go +++ b/db/install.go @@ -261,7 +261,7 @@ func StartService(invoker Invoker) { break } if time.Since(startedAt) > constants.SpinnerShowTimeout && !spinnerShown { - if cmdconfig.Viper().GetBool(constants.ShowInteractiveOutputConfigKey) { + if cmdconfig.Viper().GetBool(constants.ConfigKeyShowInteractiveOutput) { s := utils.ShowSpinner("Waiting for database to start...") defer utils.StopSpinner(s) } diff --git a/db/query.go b/db/query.go index 90ff2a9d8..53568119e 100644 --- a/db/query.go +++ b/db/query.go @@ -34,21 +34,9 @@ func ExecuteQuery(queryString string) (*results.ResultStreamer, error) { StartService(InvokerQuery) } - client, err := GetClient(false) + client, err := NewClient(true) utils.FailOnErrorWithMessage(err, "client failed to initialize") - // refresh connections - err = client.RefreshConnections() - if err != nil { - // shutdown the service if something went wrong!!! - Shutdown(client, InvokerQuery) - return nil, fmt.Errorf("failed to refresh connections: %v", err.Error()) - } - if err = refreshFunctions(); err != nil { - // shutdown the service if something went wrong!!! - Shutdown(client, InvokerQuery) - return nil, fmt.Errorf("failed to add functions: %v", err) - } resultsStreamer := results.NewResultStreamer() // this is a callback to close the db et-al. when things get done - no matter the mode diff --git a/db/start.go b/db/start.go index b528062a9..c93162e17 100644 --- a/db/start.go +++ b/db/start.go @@ -194,7 +194,8 @@ func StartDB(port int, listen StartListenType, invoker Invoker) (StartResult, er return ServiceStarted, err } - client, err := GetClient(false) + // TODO ADD COMMENT EXPLAINING WHY WE ARE NOT AUTO-REFRESHING + client, err := NewClient(false) if err != nil { return ServiceFailedToStart, handleStartFailure(err) } @@ -212,7 +213,7 @@ func StartDB(port int, listen StartListenType, invoker Invoker) (StartResult, er // refresh plugin connections - ensure db schemas are in sync with connection config // NOTE: refresh defaults to true but will be set to false if this service start command has been invoked by a query command if cmdconfig.Viper().GetBool(constants.ArgRefresh) { - if err = client.RefreshConnections(); err != nil { + if _, err = client.RefreshConnections(); err != nil { return ServiceStarted, err } if err = refreshFunctions(); err != nil { @@ -254,7 +255,7 @@ func handleStartFailure(err error) error { return fmt.Errorf("Another Steampipe service is already running. Use %s to kill all running instances before continuing.", constants.Bold("steampipe service stop --force")) } - // there was nothing to kill.9 + // there was nothing to kill. // this is some other problem that we are not accounting for return err } diff --git a/db/stop.go b/db/stop.go index ca87e52f7..1459c8e01 100644 --- a/db/stop.go +++ b/db/stop.go @@ -121,7 +121,7 @@ func StopDB(force bool, invoker Invoker) (StopStatus, error) { break } if time.Since(signalSentAt) > constants.SpinnerShowTimeout && !spinnerShown { - if cmdconfig.Viper().GetBool(constants.ShowInteractiveOutputConfigKey) { + if cmdconfig.Viper().GetBool(constants.ConfigKeyShowInteractiveOutput) { s := utils.ShowSpinner("Shutting down...") defer utils.StopSpinner(s) spinnerShown = true