From 4e13e4e6d1c49b188b7742154a5b461cc51fca40 Mon Sep 17 00:00:00 2001 From: kaidaguerre Date: Fri, 17 Feb 2023 18:01:49 +0000 Subject: [PATCH] When a plugin fails to load, remove connections for that plugin from the connection state file. Closes #3124 --- pkg/db/db_client/db_client_search_path.go | 4 ++-- pkg/db/db_common/client.go | 2 +- pkg/db/db_local/local_db_client.go | 4 ++-- pkg/db/db_local/local_db_client_connections.go | 5 +++++ pkg/steampipeconfig/connection_plugin.go | 10 ++++++++-- pkg/steampipeconfig/refresh_connections_result.go | 14 ++++++++++++++ 6 files changed, 32 insertions(+), 7 deletions(-) diff --git a/pkg/db/db_client/db_client_search_path.go b/pkg/db/db_client/db_client_search_path.go index 361aac7ac..7c3b66412 100644 --- a/pkg/db/db_client/db_client_search_path.go +++ b/pkg/db/db_client/db_client_search_path.go @@ -69,7 +69,7 @@ func (c *DbClient) SetRequiredSessionSearchPath(ctx context.Context) error { requiredSearchPath := viper.GetStringSlice(constants.ArgSearchPath) searchPathPrefix := viper.GetStringSlice(constants.ArgSearchPathPrefix) - searchPath, err := c.ContructSearchPath(ctx, requiredSearchPath, searchPathPrefix) + searchPath, err := c.ConstructSearchPath(ctx, requiredSearchPath, searchPathPrefix) if err != nil { return err } @@ -85,7 +85,7 @@ func (c *DbClient) GetRequiredSessionSearchPath() []string { return c.requiredSessionSearchPath } -func (c *DbClient) ContructSearchPath(ctx context.Context, customSearchPath, searchPathPrefix []string) ([]string, error) { +func (c *DbClient) ConstructSearchPath(ctx context.Context, customSearchPath, searchPathPrefix []string) ([]string, error) { // strip empty elements from search path and prefix customSearchPath = helpers.RemoveFromStringSlice(customSearchPath, "") searchPathPrefix = helpers.RemoveFromStringSlice(searchPathPrefix, "") diff --git a/pkg/db/db_common/client.go b/pkg/db/db_common/client.go index 7507e957d..c65843f9d 100644 --- a/pkg/db/db_common/client.go +++ b/pkg/db/db_common/client.go @@ -20,7 +20,7 @@ type Client interface { GetCurrentSearchPathForDbConnection(context.Context, *sql.Conn) ([]string, error) SetRequiredSessionSearchPath(context.Context) error GetRequiredSessionSearchPath() []string - ContructSearchPath(context.Context, []string, []string) ([]string, error) + ConstructSearchPath(context.Context, []string, []string) ([]string, error) AcquireSession(context.Context) *AcquireSessionResult diff --git a/pkg/db/db_local/local_db_client.go b/pkg/db/db_local/local_db_client.go index 56e6f9dea..668b7ed30 100644 --- a/pkg/db/db_local/local_db_client.go +++ b/pkg/db/db_local/local_db_client.go @@ -159,8 +159,8 @@ func (c *LocalDbClient) GetRequiredSessionSearchPath() []string { return c.client.GetRequiredSessionSearchPath() } -func (c *LocalDbClient) ContructSearchPath(ctx context.Context, requiredSearchPath, searchPathPrefix []string) ([]string, error) { - return c.client.ContructSearchPath(ctx, requiredSearchPath, searchPathPrefix) +func (c *LocalDbClient) ConstructSearchPath(ctx context.Context, requiredSearchPath, searchPathPrefix []string) ([]string, error) { + return c.client.ConstructSearchPath(ctx, requiredSearchPath, searchPathPrefix) } // GetSchemaFromDB for LocalDBClient optimises the schema extraction by extracting schema diff --git a/pkg/db/db_local/local_db_client_connections.go b/pkg/db/db_local/local_db_client_connections.go index 6f8e80728..f0026a5d2 100644 --- a/pkg/db/db_local/local_db_client_connections.go +++ b/pkg/db/db_local/local_db_client_connections.go @@ -38,6 +38,11 @@ func (c *LocalDbClient) refreshConnections(ctx context.Context, forceUpdateConne if res.Error == nil && connectionUpdates.ConnectionStateModified || res.UpdatedConnections { // now serialise the connection state + // NOTE: remove any connection which failed + for c := range res.FailedConnections { + delete(connectionUpdates.RequiredConnectionState, c) + } + // update required connections with the schema mode from the connection state and schema hash from the hash map if err := connectionUpdates.RequiredConnectionState.Save(); err != nil { res.Error = err diff --git a/pkg/steampipeconfig/connection_plugin.go b/pkg/steampipeconfig/connection_plugin.go index 01ee79554..e93592ac3 100644 --- a/pkg/steampipeconfig/connection_plugin.go +++ b/pkg/steampipeconfig/connection_plugin.go @@ -111,8 +111,14 @@ func CreateConnectionPlugins(connectionsToCreate []*modconfig.Connection) (reque } // if there were any failures, display them - for plugin, failure := range getResponse.FailureMap { - res.AddWarning(fmt.Sprintf("failed to start plugin '%s': %s", plugin, failure)) + for failedPlugin, failure := range getResponse.FailureMap { + res.AddWarning(fmt.Sprintf("failed to start plugin '%s': %s", failedPlugin, failure)) + // figure out which connections are provided by any failed plugins + for _, c := range connectionsToCreate { + if c.Plugin == failedPlugin { + res.AddFailedConnection(c.Name, failure) + } + } } // now create or retrieve a connection plugin for each connection diff --git a/pkg/steampipeconfig/refresh_connections_result.go b/pkg/steampipeconfig/refresh_connections_result.go index 2d0ce60e2..7ebf5b741 100644 --- a/pkg/steampipeconfig/refresh_connections_result.go +++ b/pkg/steampipeconfig/refresh_connections_result.go @@ -13,6 +13,7 @@ type RefreshConnectionResult struct { modconfig.ErrorAndWarnings UpdatedConnections bool Updates *ConnectionUpdates + FailedConnections map[string]string } func (r *RefreshConnectionResult) Merge(other *RefreshConnectionResult) { @@ -26,6 +27,11 @@ func (r *RefreshConnectionResult) Merge(other *RefreshConnectionResult) { r.Error = other.Error } r.Warnings = append(r.Warnings, other.Warnings...) + for c, err := range other.FailedConnections { + if _, ok := r.FailedConnections[c]; !ok { + r.AddFailedConnection(c, err) + } + } } func (r *RefreshConnectionResult) String() string { @@ -39,3 +45,11 @@ func (r *RefreshConnectionResult) String() string { op.WriteString(fmt.Sprintf("UpdatedConnections: %v\n", r.UpdatedConnections)) return op.String() } + +func (r *RefreshConnectionResult) AddFailedConnection(c string, failure string) { + if r.FailedConnections == nil { + r.FailedConnections = make(map[string]string) + } + + r.FailedConnections[c] = failure +}