From 0732f8ee4fbcd5d045d0e2ccab16bd3749c1416d Mon Sep 17 00:00:00 2001 From: Binaek Sarkar Date: Tue, 21 Dec 2021 18:20:20 +0530 Subject: [PATCH] Fix issue where service is not shutdown when check is cancelled. Closes #1250 --- cmd/check.go | 8 +- cmd/query.go | 3 +- cmd/service.go | 4 +- control/controlexecute/control_run.go | 5 +- control/controlexecute/result_group.go | 6 +- db/db_client/db_client.go | 15 +++- db/db_client/db_client_execute.go | 8 +- db/db_client/db_client_session.go | 2 +- db/db_common/db_session.go | 19 ++++- db/db_local/local_db_client.go | 3 + db/db_local/stop_services.go | 10 ++- main.go | 2 +- runtime_constants/execution_id.go | 14 ++++ runtime_constants/runtime_constants.go | 5 ++ utils/debugdump.go | 111 +++++++++++++++++++++++++ 15 files changed, 194 insertions(+), 21 deletions(-) create mode 100644 runtime_constants/execution_id.go create mode 100644 runtime_constants/runtime_constants.go diff --git a/cmd/check.go b/cmd/check.go index 719593a3d..570ef1193 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "log" "os" "strings" "sync" @@ -113,6 +114,7 @@ func runCheckCmd(cmd *cobra.Command, args []string) { } if initData.client != nil { + log.Printf("[TRACE] close client") initData.client.Close() } if initData.workspace != nil { @@ -245,12 +247,12 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *checkInitDa // get a client var client db_common.Client if connectionString := viper.GetString(constants.ArgConnectionString); connectionString != "" { - client, err = db_client.NewDbClient(initData.ctx, connectionString) + client, err = db_client.NewDbClient(ctx, connectionString) } else { // stop the spinner display.StopSpinner(spinner) // when starting the database, installers may trigger their own spinners - client, err = db_local.GetLocalClient(initData.ctx, constants.InvokerCheck) + client, err = db_local.GetLocalClient(ctx, constants.InvokerCheck) // resume the spinner display.ResumeSpinner(spinner) } @@ -261,7 +263,7 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *checkInitDa } initData.client = client - refreshResult := initData.client.RefreshConnectionAndSearchPaths(initData.ctx) + refreshResult := initData.client.RefreshConnectionAndSearchPaths(ctx) if refreshResult.Error != nil { initData.result.Error = refreshResult.Error return initData diff --git a/cmd/query.go b/cmd/query.go index b73dfb911..4b71c5766 100644 --- a/cmd/query.go +++ b/cmd/query.go @@ -237,7 +237,7 @@ func getQueryInitDataAsync(ctx context.Context, w *workspace.Workspace, initData if err != nil { initData.Result.Error = fmt.Errorf("error acquiring database connection, %s", err.Error()) } else { - sessionResult.Session.Close() + sessionResult.Session.Close(utils.IsContextCancelled(ctx)) } }() @@ -259,6 +259,7 @@ func startCancelHandler(cancel context.CancelFunc) { signal.Notify(sigIntChannel, os.Interrupt) go func() { <-sigIntChannel + log.Println("[TRACE] got SIGINT") // call context cancellation function cancel() // leave the channel open - any subsequent interrupts hits will be ignored diff --git a/cmd/service.go b/cmd/service.go index 35d8699e0..040b00b13 100644 --- a/cmd/service.go +++ b/cmd/service.go @@ -230,7 +230,7 @@ func runServiceInForeground(invoker constants.Invoker) { fmt.Print("\r") // if we have received this signal, then the user probably wants to shut down // everything. Shutdowns MUST NOT happen in cancellable contexts - count, err := db_local.GetCountOfConnectedClients(context.Background()) + count, err := db_local.GetCountOfThirdPartyClients(context.Background()) if err != nil { // report the error in the off chance that there's one utils.ShowError(err) @@ -398,7 +398,7 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) { } // check if there are any connected clients to the service - connectedClientCount, err := db_local.GetCountOfConnectedClients(cmd.Context()) + connectedClientCount, err := db_local.GetCountOfThirdPartyClients(cmd.Context()) if err != nil { display.StopSpinner(spinner) utils.FailOnErrorWithMessage(err, "error during service stop") diff --git a/control/controlexecute/control_run.go b/control/controlexecute/control_run.go index 1de3ed488..9e47d2169 100644 --- a/control/controlexecute/control_run.go +++ b/control/controlexecute/control_run.go @@ -178,7 +178,10 @@ func (r *ControlRun) Execute(ctx context.Context, client db_common.Client) { } r.Lifecycle.Add("got_session") dbSession := sessionResult.Session - defer dbSession.Close() + defer func() { + // do this in a closure, otherwise the argument will not get evaluated in calltime + dbSession.Close(utils.IsContextCancelled(ctx)) + }() // set our status r.runStatus = ControlRunStarted diff --git a/control/controlexecute/result_group.go b/control/controlexecute/result_group.go index 34d7dfe9d..8e9831757 100644 --- a/control/controlexecute/result_group.go +++ b/control/controlexecute/result_group.go @@ -179,7 +179,7 @@ func (r *ResultGroup) Execute(ctx context.Context, client db_common.Client, para continue } - go func(run *ControlRun) { + go func(c context.Context, run *ControlRun) { defer func() { if r := recover(); r != nil { // if the Execute panic'ed, set it as an error @@ -188,8 +188,8 @@ func (r *ResultGroup) Execute(ctx context.Context, client db_common.Client, para // Release in defer, so that we don't retain the lock even if there's a panic inside parallelismLock.Release(1) }() - run.Execute(ctx, client) - }(controlRun) + run.Execute(c, client) + }(ctx, controlRun) } for _, child := range r.Groups { child.Execute(ctx, client, parallelismLock) diff --git a/db/db_client/db_client.go b/db/db_client/db_client.go index f5c06c729..6a44816b5 100644 --- a/db/db_client/db_client.go +++ b/db/db_client/db_client.go @@ -4,11 +4,15 @@ import ( "context" "database/sql" "fmt" + "log" "strings" "sync" + "github.com/jackc/pgx/v4" + "github.com/jackc/pgx/v4/stdlib" "github.com/spf13/viper" "github.com/turbot/steampipe/db/db_common" + "github.com/turbot/steampipe/runtime_constants" "github.com/turbot/steampipe/schema" "github.com/turbot/steampipe/steampipeconfig" "golang.org/x/sync/semaphore" @@ -74,6 +78,13 @@ func establishConnection(ctx context.Context, connStr string) (*sql.DB, error) { utils.LogTime("db_client.establishConnection start") defer utils.LogTime("db_client.establishConnection end") + connConfig, _ := pgx.ParseConfig(connStr) + connConfig.RuntimeParams = map[string]string{ + // set an app name so that we can track connections from this execution + "application_name": runtime_constants.PgClientAppName, + } + connStr = stdlib.RegisterConnConfig(connConfig) + db, err := sql.Open("pgx", connStr) if err != nil { return nil, err @@ -104,8 +115,10 @@ func (c *DbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallba // Close implements Client // closes the connection to the database and shuts down the backend func (c *DbClient) Close() error { + log.Printf("[TRACE] DbClient.Close %v", c.dbClient) if c.dbClient != nil { c.sessionInitWaitGroup.Wait() + // clear the map - so that we can't reuse it c.sessions = nil return c.dbClient.Close() @@ -133,7 +146,7 @@ func (c *DbClient) RefreshSessions(ctx context.Context) *db_common.AcquireSessio } sessionResult := c.AcquireSession(ctx) if sessionResult.Session != nil { - sessionResult.Session.Close() + sessionResult.Session.Close(utils.IsContextCancelled(ctx)) } return sessionResult } diff --git a/db/db_client/db_client_execute.go b/db/db_client/db_client_execute.go index d5f3f9b61..8ae26d113 100644 --- a/db/db_client/db_client_execute.go +++ b/db/db_client/db_client_execute.go @@ -27,7 +27,11 @@ func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner if sessionResult.Error != nil { return nil, sessionResult.Error } - defer sessionResult.Session.Close() + defer func() { + // we need to do this in a closure, otherwise the ctx will be evaluated immediately + // and not in call-time + sessionResult.Session.Close(utils.IsContextCancelled(ctx)) + }() return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, disableSpinner) } @@ -67,7 +71,7 @@ func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner boo } // define callback to close session when the async execution is complete - closeSessionCallback := func() { sessionResult.Session.Close() } + closeSessionCallback := func() { sessionResult.Session.Close(utils.IsContextCancelled(ctx)) } return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback, disableSpinner) } diff --git a/db/db_client/db_client_session.go b/db/db_client/db_client_session.go index c52605a3e..05f0e1b1a 100644 --- a/db/db_client/db_client_session.go +++ b/db/db_client/db_client_session.go @@ -131,7 +131,7 @@ func (c *DbClient) getSessionWithRetries(ctx context.Context) (*sql.Conn, uint32 }) if err != nil { - log.Printf("[TRACE] getSessionWithRetries failed after 10 retries: %s", err) + log.Printf("[TRACE] getSessionWithRetries failed after %d retries: %s", retries, err) return nil, 0, err } diff --git a/db/db_common/db_session.go b/db/db_common/db_session.go index 1f0023f17..5f6081da1 100644 --- a/db/db_common/db_session.go +++ b/db/db_common/db_session.go @@ -1,9 +1,11 @@ package db_common import ( + "context" "database/sql" "time" + "github.com/jackc/pgx/v4/stdlib" "github.com/turbot/steampipe/utils" ) @@ -33,9 +35,22 @@ func (s *DatabaseSession) UpdateUsage() { s.LastUsed = time.Now() } -func (s *DatabaseSession) Close() error { +func (s *DatabaseSession) Close(waitForCleanup bool) error { + var err error if s.Connection != nil { - err := s.Connection.Close() + if waitForCleanup { + s.Connection.Raw(func(driverConn interface{}) error { + conn := driverConn.(*stdlib.Conn) + select { + case <-time.After(5 * time.Second): + return context.DeadlineExceeded + case <-conn.Conn().PgConn().CleanupDone(): + return nil + } + }) + } + + err = s.Connection.Close() s.Connection = nil return err } diff --git a/db/db_local/local_db_client.go b/db/db_local/local_db_client.go index 921e4f4b0..6cfa0616c 100644 --- a/db/db_local/local_db_client.go +++ b/db/db_local/local_db_client.go @@ -72,10 +72,13 @@ func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbCli func (c *LocalDbClient) Close() error { log.Printf("[TRACE] close local client %p", c) if c.client != nil { + log.Printf("[TRACE] local client not NIL") if err := c.client.Close(); err != nil { return err } + log.Printf("[TRACE] local client close complete") } + log.Printf("[TRACE] shutdown local service %v", c.invoker) // no context to pass on - use background // we shouldn't do this in a context that can be cancelled anyway ShutdownService(c.invoker) diff --git a/db/db_local/stop_services.go b/db/db_local/stop_services.go index dffaf17f6..c890cfd3b 100644 --- a/db/db_local/stop_services.go +++ b/db/db_local/stop_services.go @@ -14,6 +14,7 @@ import ( "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/plugin_manager" + "github.com/turbot/steampipe/runtime_constants" "github.com/turbot/steampipe/utils" ) @@ -44,7 +45,7 @@ func ShutdownService(invoker constants.Invoker) { // how many clients are connected // under a fresh context - count, _ := GetCountOfConnectedClients(context.Background()) + count, _ := GetCountOfThirdPartyClients(context.Background()) if count > 0 { // there are other clients connected to the database // we can't stop the DB. @@ -68,8 +69,8 @@ func ShutdownService(invoker constants.Invoker) { } -// GetCountOfConnectedClients returns the number of clients currently connected to the service -func GetCountOfConnectedClients(ctx context.Context) (i int, e error) { +// GetCountOfThirdPartyClients returns the number of connections to the service from other thrid party applications +func GetCountOfThirdPartyClients(ctx context.Context) (i int, e error) { utils.LogTime("db_local.GetCountOfConnectedClients start") defer utils.LogTime(fmt.Sprintf("db_local.GetCountOfConnectedClients end:%d", i)) @@ -81,7 +82,8 @@ func GetCountOfConnectedClients(ctx context.Context) (i int, e error) { clientCount := 0 // get the total number of connected clients - row := rootClient.QueryRow("select count(*) from pg_stat_activity where client_port IS NOT NULL and backend_type='client backend';") + // which are not us - determined by the unique application_name client parameter + row := rootClient.QueryRow("select count(*) from pg_stat_activity where client_port IS NOT NULL and backend_type='client backend' and application_name != $1;", runtime_constants.PgClientAppName) row.Scan(&clientCount) // clientCount can never be zero, since the client we are using to run the query counts as a client // deduct the open connections in the pool of this client diff --git a/main.go b/main.go index f1e1b8d17..20d42c16d 100644 --- a/main.go +++ b/main.go @@ -5,10 +5,10 @@ import ( "os" filehelpers "github.com/turbot/go-kit/files" + "github.com/turbot/go-kit/helpers" "github.com/hashicorp/go-hclog" _ "github.com/jackc/pgx/v4/stdlib" - "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe/cmd" "github.com/turbot/steampipe/utils" ) diff --git a/runtime_constants/execution_id.go b/runtime_constants/execution_id.go new file mode 100644 index 000000000..e42ad778f --- /dev/null +++ b/runtime_constants/execution_id.go @@ -0,0 +1,14 @@ +package runtime_constants + +import ( + "fmt" + "time" + + "github.com/turbot/steampipe/constants" + "github.com/turbot/steampipe/utils" +) + +var ( + ExecutionID = utils.GetMD5Hash(fmt.Sprintf("%d", time.Now().Nanosecond())) + PgClientAppName = fmt.Sprintf("%s_%s", constants.AppName, ExecutionID) +) diff --git a/runtime_constants/runtime_constants.go b/runtime_constants/runtime_constants.go new file mode 100644 index 000000000..6c3cc5cf4 --- /dev/null +++ b/runtime_constants/runtime_constants.go @@ -0,0 +1,5 @@ +// The runtime_constants package contains values which +// are not constants during compilation, but should remain +// constant during the duration of an execution of the binary + +package runtime_constants diff --git a/utils/debugdump.go b/utils/debugdump.go index 68a9c9070..19925b2b5 100644 --- a/utils/debugdump.go +++ b/utils/debugdump.go @@ -1,10 +1,14 @@ package utils import ( + "database/sql" "encoding/json" "fmt" "os" "strings" + "time" + + typeHelpers "github.com/turbot/go-kit/types" "github.com/spf13/viper" ) @@ -24,3 +28,110 @@ func DebugDumpViper() { } fmt.Println(strings.Repeat("*", 80)) } + +func DebugDumpRows(rows *sql.Rows) { + colTypes, err := rows.ColumnTypes() + if err != nil { + // we do not need to stream because + // defer takes care of it! + return + } + cols, err := rows.Columns() + if err != nil { + // we do not need to stream because + // defer takes care of it! + return + } + fmt.Println(cols) + fmt.Println("---------------------------------------") + for rows.Next() { + row, _ := readRow(rows, cols, colTypes) + rowAsString, _ := columnValuesAsString(row, colTypes) + fmt.Println(rowAsString) + } +} + +func readRow(rows *sql.Rows, cols []string, colTypes []*sql.ColumnType) ([]interface{}, error) { + // slice of interfaces to receive the row data + columnValues := make([]interface{}, len(cols)) + // make a slice of pointers to the result to pass to scan + resultPtrs := make([]interface{}, len(cols)) // A temporary interface{} slice + for i := range columnValues { + resultPtrs[i] = &columnValues[i] + } + rows.Scan(resultPtrs...) + + return populateRow(columnValues, colTypes), nil +} + +func populateRow(columnValues []interface{}, colTypes []*sql.ColumnType) []interface{} { + result := make([]interface{}, len(columnValues)) + for i, columnValue := range columnValues { + if columnValue != nil { + colType := colTypes[i] + dbType := colType.DatabaseTypeName() + switch dbType { + case "JSON", "JSONB": + var val interface{} + if err := json.Unmarshal(columnValue.([]byte), &val); err != nil { + // what??? + // TODO how to handle error + } + result[i] = val + default: + result[i] = columnValue + } + } + } + return result +} + +// columnValuesAsString converts a slice of columns into strings +func columnValuesAsString(values []interface{}, columns []*sql.ColumnType) ([]string, error) { + rowAsString := make([]string, len(columns)) + for idx, val := range values { + val, err := columnValueAsString(val, columns[idx]) + if err != nil { + return nil, err + } + rowAsString[idx] = val + } + return rowAsString, nil +} + +// columnValueAsString converts column value to string +func columnValueAsString(val interface{}, colType *sql.ColumnType) (result string, err error) { + defer func() { + if r := recover(); r != nil { + result = fmt.Sprintf("%v", val) + } + }() + + if val == nil { + return "", nil + } + + //log.Printf("[TRACE] ColumnValueAsString type %s", colType.DatabaseTypeName()) + // possible types for colType are defined in pq/oid/types.go + switch colType.DatabaseTypeName() { + case "JSON", "JSONB": + bytes, err := json.Marshal(val) + if err != nil { + return "", err + } + return string(bytes), nil + case "TIMESTAMP", "DATE", "TIME", "INTERVAL": + t, ok := val.(time.Time) + if ok { + return t.Format("2006-01-02 15:04:05"), nil + } + fallthrough + case "NAME": + result := string(val.([]uint8)) + return result, nil + + default: + return typeHelpers.ToString(val), nil + } + +}