Fix issue where service is not shutdown when check is cancelled. Closes #1250

This commit is contained in:
Binaek Sarkar
2021-12-21 18:20:20 +05:30
committed by GitHub
parent 0bc29f1517
commit 0732f8ee4f
15 changed files with 194 additions and 21 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"log"
"os"
"strings"
"sync"
@@ -113,6 +114,7 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
}
if initData.client != nil {
log.Printf("[TRACE] close client")
initData.client.Close()
}
if initData.workspace != nil {
@@ -245,12 +247,12 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *checkInitDa
// get a client
var client db_common.Client
if connectionString := viper.GetString(constants.ArgConnectionString); connectionString != "" {
client, err = db_client.NewDbClient(initData.ctx, connectionString)
client, err = db_client.NewDbClient(ctx, connectionString)
} else {
// stop the spinner
display.StopSpinner(spinner)
// when starting the database, installers may trigger their own spinners
client, err = db_local.GetLocalClient(initData.ctx, constants.InvokerCheck)
client, err = db_local.GetLocalClient(ctx, constants.InvokerCheck)
// resume the spinner
display.ResumeSpinner(spinner)
}
@@ -261,7 +263,7 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *checkInitDa
}
initData.client = client
refreshResult := initData.client.RefreshConnectionAndSearchPaths(initData.ctx)
refreshResult := initData.client.RefreshConnectionAndSearchPaths(ctx)
if refreshResult.Error != nil {
initData.result.Error = refreshResult.Error
return initData

View File

@@ -237,7 +237,7 @@ func getQueryInitDataAsync(ctx context.Context, w *workspace.Workspace, initData
if err != nil {
initData.Result.Error = fmt.Errorf("error acquiring database connection, %s", err.Error())
} else {
sessionResult.Session.Close()
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
}
}()
@@ -259,6 +259,7 @@ func startCancelHandler(cancel context.CancelFunc) {
signal.Notify(sigIntChannel, os.Interrupt)
go func() {
<-sigIntChannel
log.Println("[TRACE] got SIGINT")
// call context cancellation function
cancel()
// leave the channel open - any subsequent interrupts hits will be ignored

View File

@@ -230,7 +230,7 @@ func runServiceInForeground(invoker constants.Invoker) {
fmt.Print("\r")
// if we have received this signal, then the user probably wants to shut down
// everything. Shutdowns MUST NOT happen in cancellable contexts
count, err := db_local.GetCountOfConnectedClients(context.Background())
count, err := db_local.GetCountOfThirdPartyClients(context.Background())
if err != nil {
// report the error in the off chance that there's one
utils.ShowError(err)
@@ -398,7 +398,7 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) {
}
// check if there are any connected clients to the service
connectedClientCount, err := db_local.GetCountOfConnectedClients(cmd.Context())
connectedClientCount, err := db_local.GetCountOfThirdPartyClients(cmd.Context())
if err != nil {
display.StopSpinner(spinner)
utils.FailOnErrorWithMessage(err, "error during service stop")

View File

@@ -178,7 +178,10 @@ func (r *ControlRun) Execute(ctx context.Context, client db_common.Client) {
}
r.Lifecycle.Add("got_session")
dbSession := sessionResult.Session
defer dbSession.Close()
defer func() {
// do this in a closure, otherwise the argument will not get evaluated in calltime
dbSession.Close(utils.IsContextCancelled(ctx))
}()
// set our status
r.runStatus = ControlRunStarted

View File

@@ -179,7 +179,7 @@ func (r *ResultGroup) Execute(ctx context.Context, client db_common.Client, para
continue
}
go func(run *ControlRun) {
go func(c context.Context, run *ControlRun) {
defer func() {
if r := recover(); r != nil {
// if the Execute panic'ed, set it as an error
@@ -188,8 +188,8 @@ func (r *ResultGroup) Execute(ctx context.Context, client db_common.Client, para
// Release in defer, so that we don't retain the lock even if there's a panic inside
parallelismLock.Release(1)
}()
run.Execute(ctx, client)
}(controlRun)
run.Execute(c, client)
}(ctx, controlRun)
}
for _, child := range r.Groups {
child.Execute(ctx, client, parallelismLock)

View File

@@ -4,11 +4,15 @@ import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"sync"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/stdlib"
"github.com/spf13/viper"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/runtime_constants"
"github.com/turbot/steampipe/schema"
"github.com/turbot/steampipe/steampipeconfig"
"golang.org/x/sync/semaphore"
@@ -74,6 +78,13 @@ func establishConnection(ctx context.Context, connStr string) (*sql.DB, error) {
utils.LogTime("db_client.establishConnection start")
defer utils.LogTime("db_client.establishConnection end")
connConfig, _ := pgx.ParseConfig(connStr)
connConfig.RuntimeParams = map[string]string{
// set an app name so that we can track connections from this execution
"application_name": runtime_constants.PgClientAppName,
}
connStr = stdlib.RegisterConnConfig(connConfig)
db, err := sql.Open("pgx", connStr)
if err != nil {
return nil, err
@@ -104,8 +115,10 @@ func (c *DbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallba
// Close implements Client
// closes the connection to the database and shuts down the backend
func (c *DbClient) Close() error {
log.Printf("[TRACE] DbClient.Close %v", c.dbClient)
if c.dbClient != nil {
c.sessionInitWaitGroup.Wait()
// clear the map - so that we can't reuse it
c.sessions = nil
return c.dbClient.Close()
@@ -133,7 +146,7 @@ func (c *DbClient) RefreshSessions(ctx context.Context) *db_common.AcquireSessio
}
sessionResult := c.AcquireSession(ctx)
if sessionResult.Session != nil {
sessionResult.Session.Close()
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
}
return sessionResult
}

View File

@@ -27,7 +27,11 @@ func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner
if sessionResult.Error != nil {
return nil, sessionResult.Error
}
defer sessionResult.Session.Close()
defer func() {
// we need to do this in a closure, otherwise the ctx will be evaluated immediately
// and not in call-time
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
}()
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, disableSpinner)
}
@@ -67,7 +71,7 @@ func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner boo
}
// define callback to close session when the async execution is complete
closeSessionCallback := func() { sessionResult.Session.Close() }
closeSessionCallback := func() { sessionResult.Session.Close(utils.IsContextCancelled(ctx)) }
return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback, disableSpinner)
}

View File

@@ -131,7 +131,7 @@ func (c *DbClient) getSessionWithRetries(ctx context.Context) (*sql.Conn, uint32
})
if err != nil {
log.Printf("[TRACE] getSessionWithRetries failed after 10 retries: %s", err)
log.Printf("[TRACE] getSessionWithRetries failed after %d retries: %s", retries, err)
return nil, 0, err
}

View File

@@ -1,9 +1,11 @@
package db_common
import (
"context"
"database/sql"
"time"
"github.com/jackc/pgx/v4/stdlib"
"github.com/turbot/steampipe/utils"
)
@@ -33,9 +35,22 @@ func (s *DatabaseSession) UpdateUsage() {
s.LastUsed = time.Now()
}
func (s *DatabaseSession) Close() error {
func (s *DatabaseSession) Close(waitForCleanup bool) error {
var err error
if s.Connection != nil {
err := s.Connection.Close()
if waitForCleanup {
s.Connection.Raw(func(driverConn interface{}) error {
conn := driverConn.(*stdlib.Conn)
select {
case <-time.After(5 * time.Second):
return context.DeadlineExceeded
case <-conn.Conn().PgConn().CleanupDone():
return nil
}
})
}
err = s.Connection.Close()
s.Connection = nil
return err
}

View File

@@ -72,10 +72,13 @@ func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbCli
func (c *LocalDbClient) Close() error {
log.Printf("[TRACE] close local client %p", c)
if c.client != nil {
log.Printf("[TRACE] local client not NIL")
if err := c.client.Close(); err != nil {
return err
}
log.Printf("[TRACE] local client close complete")
}
log.Printf("[TRACE] shutdown local service %v", c.invoker)
// no context to pass on - use background
// we shouldn't do this in a context that can be cancelled anyway
ShutdownService(c.invoker)

View File

@@ -14,6 +14,7 @@ import (
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/plugin_manager"
"github.com/turbot/steampipe/runtime_constants"
"github.com/turbot/steampipe/utils"
)
@@ -44,7 +45,7 @@ func ShutdownService(invoker constants.Invoker) {
// how many clients are connected
// under a fresh context
count, _ := GetCountOfConnectedClients(context.Background())
count, _ := GetCountOfThirdPartyClients(context.Background())
if count > 0 {
// there are other clients connected to the database
// we can't stop the DB.
@@ -68,8 +69,8 @@ func ShutdownService(invoker constants.Invoker) {
}
// GetCountOfConnectedClients returns the number of clients currently connected to the service
func GetCountOfConnectedClients(ctx context.Context) (i int, e error) {
// GetCountOfThirdPartyClients returns the number of connections to the service from other thrid party applications
func GetCountOfThirdPartyClients(ctx context.Context) (i int, e error) {
utils.LogTime("db_local.GetCountOfConnectedClients start")
defer utils.LogTime(fmt.Sprintf("db_local.GetCountOfConnectedClients end:%d", i))
@@ -81,7 +82,8 @@ func GetCountOfConnectedClients(ctx context.Context) (i int, e error) {
clientCount := 0
// get the total number of connected clients
row := rootClient.QueryRow("select count(*) from pg_stat_activity where client_port IS NOT NULL and backend_type='client backend';")
// which are not us - determined by the unique application_name client parameter
row := rootClient.QueryRow("select count(*) from pg_stat_activity where client_port IS NOT NULL and backend_type='client backend' and application_name != $1;", runtime_constants.PgClientAppName)
row.Scan(&clientCount)
// clientCount can never be zero, since the client we are using to run the query counts as a client
// deduct the open connections in the pool of this client

View File

@@ -5,10 +5,10 @@ import (
"os"
filehelpers "github.com/turbot/go-kit/files"
"github.com/turbot/go-kit/helpers"
"github.com/hashicorp/go-hclog"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/cmd"
"github.com/turbot/steampipe/utils"
)

View File

@@ -0,0 +1,14 @@
package runtime_constants
import (
"fmt"
"time"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/utils"
)
var (
ExecutionID = utils.GetMD5Hash(fmt.Sprintf("%d", time.Now().Nanosecond()))
PgClientAppName = fmt.Sprintf("%s_%s", constants.AppName, ExecutionID)
)

View File

@@ -0,0 +1,5 @@
// The runtime_constants package contains values which
// are not constants during compilation, but should remain
// constant during the duration of an execution of the binary
package runtime_constants

View File

@@ -1,10 +1,14 @@
package utils
import (
"database/sql"
"encoding/json"
"fmt"
"os"
"strings"
"time"
typeHelpers "github.com/turbot/go-kit/types"
"github.com/spf13/viper"
)
@@ -24,3 +28,110 @@ func DebugDumpViper() {
}
fmt.Println(strings.Repeat("*", 80))
}
func DebugDumpRows(rows *sql.Rows) {
colTypes, err := rows.ColumnTypes()
if err != nil {
// we do not need to stream because
// defer takes care of it!
return
}
cols, err := rows.Columns()
if err != nil {
// we do not need to stream because
// defer takes care of it!
return
}
fmt.Println(cols)
fmt.Println("---------------------------------------")
for rows.Next() {
row, _ := readRow(rows, cols, colTypes)
rowAsString, _ := columnValuesAsString(row, colTypes)
fmt.Println(rowAsString)
}
}
func readRow(rows *sql.Rows, cols []string, colTypes []*sql.ColumnType) ([]interface{}, error) {
// slice of interfaces to receive the row data
columnValues := make([]interface{}, len(cols))
// make a slice of pointers to the result to pass to scan
resultPtrs := make([]interface{}, len(cols)) // A temporary interface{} slice
for i := range columnValues {
resultPtrs[i] = &columnValues[i]
}
rows.Scan(resultPtrs...)
return populateRow(columnValues, colTypes), nil
}
func populateRow(columnValues []interface{}, colTypes []*sql.ColumnType) []interface{} {
result := make([]interface{}, len(columnValues))
for i, columnValue := range columnValues {
if columnValue != nil {
colType := colTypes[i]
dbType := colType.DatabaseTypeName()
switch dbType {
case "JSON", "JSONB":
var val interface{}
if err := json.Unmarshal(columnValue.([]byte), &val); err != nil {
// what???
// TODO how to handle error
}
result[i] = val
default:
result[i] = columnValue
}
}
}
return result
}
// columnValuesAsString converts a slice of columns into strings
func columnValuesAsString(values []interface{}, columns []*sql.ColumnType) ([]string, error) {
rowAsString := make([]string, len(columns))
for idx, val := range values {
val, err := columnValueAsString(val, columns[idx])
if err != nil {
return nil, err
}
rowAsString[idx] = val
}
return rowAsString, nil
}
// columnValueAsString converts column value to string
func columnValueAsString(val interface{}, colType *sql.ColumnType) (result string, err error) {
defer func() {
if r := recover(); r != nil {
result = fmt.Sprintf("%v", val)
}
}()
if val == nil {
return "<null>", nil
}
//log.Printf("[TRACE] ColumnValueAsString type %s", colType.DatabaseTypeName())
// possible types for colType are defined in pq/oid/types.go
switch colType.DatabaseTypeName() {
case "JSON", "JSONB":
bytes, err := json.Marshal(val)
if err != nil {
return "", err
}
return string(bytes), nil
case "TIMESTAMP", "DATE", "TIME", "INTERVAL":
t, ok := val.(time.Time)
if ok {
return t.Format("2006-01-02 15:04:05"), nil
}
fallthrough
case "NAME":
result := string(val.([]uint8))
return result, nil
default:
return typeHelpers.ToString(val), nil
}
}