mirror of
https://github.com/turbot/steampipe.git
synced 2026-03-31 18:00:19 -04:00
Fix issue where service is not shutdown when check is cancelled. Closes #1250
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -113,6 +114,7 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
if initData.client != nil {
|
||||
log.Printf("[TRACE] close client")
|
||||
initData.client.Close()
|
||||
}
|
||||
if initData.workspace != nil {
|
||||
@@ -245,12 +247,12 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *checkInitDa
|
||||
// get a client
|
||||
var client db_common.Client
|
||||
if connectionString := viper.GetString(constants.ArgConnectionString); connectionString != "" {
|
||||
client, err = db_client.NewDbClient(initData.ctx, connectionString)
|
||||
client, err = db_client.NewDbClient(ctx, connectionString)
|
||||
} else {
|
||||
// stop the spinner
|
||||
display.StopSpinner(spinner)
|
||||
// when starting the database, installers may trigger their own spinners
|
||||
client, err = db_local.GetLocalClient(initData.ctx, constants.InvokerCheck)
|
||||
client, err = db_local.GetLocalClient(ctx, constants.InvokerCheck)
|
||||
// resume the spinner
|
||||
display.ResumeSpinner(spinner)
|
||||
}
|
||||
@@ -261,7 +263,7 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *checkInitDa
|
||||
}
|
||||
initData.client = client
|
||||
|
||||
refreshResult := initData.client.RefreshConnectionAndSearchPaths(initData.ctx)
|
||||
refreshResult := initData.client.RefreshConnectionAndSearchPaths(ctx)
|
||||
if refreshResult.Error != nil {
|
||||
initData.result.Error = refreshResult.Error
|
||||
return initData
|
||||
|
||||
@@ -237,7 +237,7 @@ func getQueryInitDataAsync(ctx context.Context, w *workspace.Workspace, initData
|
||||
if err != nil {
|
||||
initData.Result.Error = fmt.Errorf("error acquiring database connection, %s", err.Error())
|
||||
} else {
|
||||
sessionResult.Session.Close()
|
||||
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
|
||||
}
|
||||
|
||||
}()
|
||||
@@ -259,6 +259,7 @@ func startCancelHandler(cancel context.CancelFunc) {
|
||||
signal.Notify(sigIntChannel, os.Interrupt)
|
||||
go func() {
|
||||
<-sigIntChannel
|
||||
log.Println("[TRACE] got SIGINT")
|
||||
// call context cancellation function
|
||||
cancel()
|
||||
// leave the channel open - any subsequent interrupts hits will be ignored
|
||||
|
||||
@@ -230,7 +230,7 @@ func runServiceInForeground(invoker constants.Invoker) {
|
||||
fmt.Print("\r")
|
||||
// if we have received this signal, then the user probably wants to shut down
|
||||
// everything. Shutdowns MUST NOT happen in cancellable contexts
|
||||
count, err := db_local.GetCountOfConnectedClients(context.Background())
|
||||
count, err := db_local.GetCountOfThirdPartyClients(context.Background())
|
||||
if err != nil {
|
||||
// report the error in the off chance that there's one
|
||||
utils.ShowError(err)
|
||||
@@ -398,7 +398,7 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
// check if there are any connected clients to the service
|
||||
connectedClientCount, err := db_local.GetCountOfConnectedClients(cmd.Context())
|
||||
connectedClientCount, err := db_local.GetCountOfThirdPartyClients(cmd.Context())
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
utils.FailOnErrorWithMessage(err, "error during service stop")
|
||||
|
||||
@@ -178,7 +178,10 @@ func (r *ControlRun) Execute(ctx context.Context, client db_common.Client) {
|
||||
}
|
||||
r.Lifecycle.Add("got_session")
|
||||
dbSession := sessionResult.Session
|
||||
defer dbSession.Close()
|
||||
defer func() {
|
||||
// do this in a closure, otherwise the argument will not get evaluated in calltime
|
||||
dbSession.Close(utils.IsContextCancelled(ctx))
|
||||
}()
|
||||
|
||||
// set our status
|
||||
r.runStatus = ControlRunStarted
|
||||
|
||||
@@ -179,7 +179,7 @@ func (r *ResultGroup) Execute(ctx context.Context, client db_common.Client, para
|
||||
continue
|
||||
}
|
||||
|
||||
go func(run *ControlRun) {
|
||||
go func(c context.Context, run *ControlRun) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// if the Execute panic'ed, set it as an error
|
||||
@@ -188,8 +188,8 @@ func (r *ResultGroup) Execute(ctx context.Context, client db_common.Client, para
|
||||
// Release in defer, so that we don't retain the lock even if there's a panic inside
|
||||
parallelismLock.Release(1)
|
||||
}()
|
||||
run.Execute(ctx, client)
|
||||
}(controlRun)
|
||||
run.Execute(c, client)
|
||||
}(ctx, controlRun)
|
||||
}
|
||||
for _, child := range r.Groups {
|
||||
child.Execute(ctx, client, parallelismLock)
|
||||
|
||||
@@ -4,11 +4,15 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/runtime_constants"
|
||||
"github.com/turbot/steampipe/schema"
|
||||
"github.com/turbot/steampipe/steampipeconfig"
|
||||
"golang.org/x/sync/semaphore"
|
||||
@@ -74,6 +78,13 @@ func establishConnection(ctx context.Context, connStr string) (*sql.DB, error) {
|
||||
utils.LogTime("db_client.establishConnection start")
|
||||
defer utils.LogTime("db_client.establishConnection end")
|
||||
|
||||
connConfig, _ := pgx.ParseConfig(connStr)
|
||||
connConfig.RuntimeParams = map[string]string{
|
||||
// set an app name so that we can track connections from this execution
|
||||
"application_name": runtime_constants.PgClientAppName,
|
||||
}
|
||||
connStr = stdlib.RegisterConnConfig(connConfig)
|
||||
|
||||
db, err := sql.Open("pgx", connStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -104,8 +115,10 @@ func (c *DbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallba
|
||||
// Close implements Client
|
||||
// closes the connection to the database and shuts down the backend
|
||||
func (c *DbClient) Close() error {
|
||||
log.Printf("[TRACE] DbClient.Close %v", c.dbClient)
|
||||
if c.dbClient != nil {
|
||||
c.sessionInitWaitGroup.Wait()
|
||||
|
||||
// clear the map - so that we can't reuse it
|
||||
c.sessions = nil
|
||||
return c.dbClient.Close()
|
||||
@@ -133,7 +146,7 @@ func (c *DbClient) RefreshSessions(ctx context.Context) *db_common.AcquireSessio
|
||||
}
|
||||
sessionResult := c.AcquireSession(ctx)
|
||||
if sessionResult.Session != nil {
|
||||
sessionResult.Session.Close()
|
||||
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
|
||||
}
|
||||
return sessionResult
|
||||
}
|
||||
|
||||
@@ -27,7 +27,11 @@ func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner
|
||||
if sessionResult.Error != nil {
|
||||
return nil, sessionResult.Error
|
||||
}
|
||||
defer sessionResult.Session.Close()
|
||||
defer func() {
|
||||
// we need to do this in a closure, otherwise the ctx will be evaluated immediately
|
||||
// and not in call-time
|
||||
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
|
||||
}()
|
||||
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, disableSpinner)
|
||||
}
|
||||
|
||||
@@ -67,7 +71,7 @@ func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner boo
|
||||
}
|
||||
|
||||
// define callback to close session when the async execution is complete
|
||||
closeSessionCallback := func() { sessionResult.Session.Close() }
|
||||
closeSessionCallback := func() { sessionResult.Session.Close(utils.IsContextCancelled(ctx)) }
|
||||
return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback, disableSpinner)
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ func (c *DbClient) getSessionWithRetries(ctx context.Context) (*sql.Conn, uint32
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[TRACE] getSessionWithRetries failed after 10 retries: %s", err)
|
||||
log.Printf("[TRACE] getSessionWithRetries failed after %d retries: %s", retries, err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package db_common
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
@@ -33,9 +35,22 @@ func (s *DatabaseSession) UpdateUsage() {
|
||||
s.LastUsed = time.Now()
|
||||
}
|
||||
|
||||
func (s *DatabaseSession) Close() error {
|
||||
func (s *DatabaseSession) Close(waitForCleanup bool) error {
|
||||
var err error
|
||||
if s.Connection != nil {
|
||||
err := s.Connection.Close()
|
||||
if waitForCleanup {
|
||||
s.Connection.Raw(func(driverConn interface{}) error {
|
||||
conn := driverConn.(*stdlib.Conn)
|
||||
select {
|
||||
case <-time.After(5 * time.Second):
|
||||
return context.DeadlineExceeded
|
||||
case <-conn.Conn().PgConn().CleanupDone():
|
||||
return nil
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
err = s.Connection.Close()
|
||||
s.Connection = nil
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -72,10 +72,13 @@ func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbCli
|
||||
func (c *LocalDbClient) Close() error {
|
||||
log.Printf("[TRACE] close local client %p", c)
|
||||
if c.client != nil {
|
||||
log.Printf("[TRACE] local client not NIL")
|
||||
if err := c.client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] local client close complete")
|
||||
}
|
||||
log.Printf("[TRACE] shutdown local service %v", c.invoker)
|
||||
// no context to pass on - use background
|
||||
// we shouldn't do this in a context that can be cancelled anyway
|
||||
ShutdownService(c.invoker)
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/plugin_manager"
|
||||
"github.com/turbot/steampipe/runtime_constants"
|
||||
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
@@ -44,7 +45,7 @@ func ShutdownService(invoker constants.Invoker) {
|
||||
|
||||
// how many clients are connected
|
||||
// under a fresh context
|
||||
count, _ := GetCountOfConnectedClients(context.Background())
|
||||
count, _ := GetCountOfThirdPartyClients(context.Background())
|
||||
if count > 0 {
|
||||
// there are other clients connected to the database
|
||||
// we can't stop the DB.
|
||||
@@ -68,8 +69,8 @@ func ShutdownService(invoker constants.Invoker) {
|
||||
|
||||
}
|
||||
|
||||
// GetCountOfConnectedClients returns the number of clients currently connected to the service
|
||||
func GetCountOfConnectedClients(ctx context.Context) (i int, e error) {
|
||||
// GetCountOfThirdPartyClients returns the number of connections to the service from other thrid party applications
|
||||
func GetCountOfThirdPartyClients(ctx context.Context) (i int, e error) {
|
||||
utils.LogTime("db_local.GetCountOfConnectedClients start")
|
||||
defer utils.LogTime(fmt.Sprintf("db_local.GetCountOfConnectedClients end:%d", i))
|
||||
|
||||
@@ -81,7 +82,8 @@ func GetCountOfConnectedClients(ctx context.Context) (i int, e error) {
|
||||
|
||||
clientCount := 0
|
||||
// get the total number of connected clients
|
||||
row := rootClient.QueryRow("select count(*) from pg_stat_activity where client_port IS NOT NULL and backend_type='client backend';")
|
||||
// which are not us - determined by the unique application_name client parameter
|
||||
row := rootClient.QueryRow("select count(*) from pg_stat_activity where client_port IS NOT NULL and backend_type='client backend' and application_name != $1;", runtime_constants.PgClientAppName)
|
||||
row.Scan(&clientCount)
|
||||
// clientCount can never be zero, since the client we are using to run the query counts as a client
|
||||
// deduct the open connections in the pool of this client
|
||||
|
||||
2
main.go
2
main.go
@@ -5,10 +5,10 @@ import (
|
||||
"os"
|
||||
|
||||
filehelpers "github.com/turbot/go-kit/files"
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
"github.com/turbot/steampipe/cmd"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
14
runtime_constants/execution_id.go
Normal file
14
runtime_constants/execution_id.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package runtime_constants
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
var (
|
||||
ExecutionID = utils.GetMD5Hash(fmt.Sprintf("%d", time.Now().Nanosecond()))
|
||||
PgClientAppName = fmt.Sprintf("%s_%s", constants.AppName, ExecutionID)
|
||||
)
|
||||
5
runtime_constants/runtime_constants.go
Normal file
5
runtime_constants/runtime_constants.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// The runtime_constants package contains values which
|
||||
// are not constants during compilation, but should remain
|
||||
// constant during the duration of an execution of the binary
|
||||
|
||||
package runtime_constants
|
||||
@@ -1,10 +1,14 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
typeHelpers "github.com/turbot/go-kit/types"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@@ -24,3 +28,110 @@ func DebugDumpViper() {
|
||||
}
|
||||
fmt.Println(strings.Repeat("*", 80))
|
||||
}
|
||||
|
||||
func DebugDumpRows(rows *sql.Rows) {
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
// we do not need to stream because
|
||||
// defer takes care of it!
|
||||
return
|
||||
}
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
// we do not need to stream because
|
||||
// defer takes care of it!
|
||||
return
|
||||
}
|
||||
fmt.Println(cols)
|
||||
fmt.Println("---------------------------------------")
|
||||
for rows.Next() {
|
||||
row, _ := readRow(rows, cols, colTypes)
|
||||
rowAsString, _ := columnValuesAsString(row, colTypes)
|
||||
fmt.Println(rowAsString)
|
||||
}
|
||||
}
|
||||
|
||||
func readRow(rows *sql.Rows, cols []string, colTypes []*sql.ColumnType) ([]interface{}, error) {
|
||||
// slice of interfaces to receive the row data
|
||||
columnValues := make([]interface{}, len(cols))
|
||||
// make a slice of pointers to the result to pass to scan
|
||||
resultPtrs := make([]interface{}, len(cols)) // A temporary interface{} slice
|
||||
for i := range columnValues {
|
||||
resultPtrs[i] = &columnValues[i]
|
||||
}
|
||||
rows.Scan(resultPtrs...)
|
||||
|
||||
return populateRow(columnValues, colTypes), nil
|
||||
}
|
||||
|
||||
func populateRow(columnValues []interface{}, colTypes []*sql.ColumnType) []interface{} {
|
||||
result := make([]interface{}, len(columnValues))
|
||||
for i, columnValue := range columnValues {
|
||||
if columnValue != nil {
|
||||
colType := colTypes[i]
|
||||
dbType := colType.DatabaseTypeName()
|
||||
switch dbType {
|
||||
case "JSON", "JSONB":
|
||||
var val interface{}
|
||||
if err := json.Unmarshal(columnValue.([]byte), &val); err != nil {
|
||||
// what???
|
||||
// TODO how to handle error
|
||||
}
|
||||
result[i] = val
|
||||
default:
|
||||
result[i] = columnValue
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// columnValuesAsString converts a slice of columns into strings
|
||||
func columnValuesAsString(values []interface{}, columns []*sql.ColumnType) ([]string, error) {
|
||||
rowAsString := make([]string, len(columns))
|
||||
for idx, val := range values {
|
||||
val, err := columnValueAsString(val, columns[idx])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rowAsString[idx] = val
|
||||
}
|
||||
return rowAsString, nil
|
||||
}
|
||||
|
||||
// columnValueAsString converts column value to string
|
||||
func columnValueAsString(val interface{}, colType *sql.ColumnType) (result string, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
result = fmt.Sprintf("%v", val)
|
||||
}
|
||||
}()
|
||||
|
||||
if val == nil {
|
||||
return "<null>", nil
|
||||
}
|
||||
|
||||
//log.Printf("[TRACE] ColumnValueAsString type %s", colType.DatabaseTypeName())
|
||||
// possible types for colType are defined in pq/oid/types.go
|
||||
switch colType.DatabaseTypeName() {
|
||||
case "JSON", "JSONB":
|
||||
bytes, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(bytes), nil
|
||||
case "TIMESTAMP", "DATE", "TIME", "INTERVAL":
|
||||
t, ok := val.(time.Time)
|
||||
if ok {
|
||||
return t.Format("2006-01-02 15:04:05"), nil
|
||||
}
|
||||
fallthrough
|
||||
case "NAME":
|
||||
result := string(val.([]uint8))
|
||||
return result, nil
|
||||
|
||||
default:
|
||||
return typeHelpers.ToString(val), nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user