mirror of
https://github.com/turbot/steampipe.git
synced 2026-02-23 08:00:51 -05:00
154 lines
4.1 KiB
Go
154 lines
4.1 KiB
Go
package db_client
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v4/stdlib"
|
|
|
|
"github.com/sethvargo/go-retry"
|
|
"github.com/turbot/steampipe/pkg/db/db_common"
|
|
"github.com/turbot/steampipe/pkg/utils"
|
|
)
|
|
|
|
func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common.AcquireSessionResult) {
|
|
sessionResult = &db_common.AcquireSessionResult{}
|
|
c.sessionInitWaitGroup.Add(1)
|
|
defer c.sessionInitWaitGroup.Done()
|
|
|
|
defer func() {
|
|
if sessionResult != nil && sessionResult.Session != nil {
|
|
sessionResult.Session.UpdateUsage()
|
|
|
|
// fail safe - if there is no database connection, ensure we return an error
|
|
// NOTE: this should not be necessary but an occasional crash is occurring with a nil connection
|
|
if sessionResult.Session.Connection == nil && sessionResult.Error == nil {
|
|
sessionResult.Error = fmt.Errorf("nil database connection being returned from AcquireSession but no error was raised")
|
|
}
|
|
}
|
|
}()
|
|
|
|
// reload foreign schema names in case they changed based on a connection watcher event
|
|
if err := c.LoadForeignSchemaNames(ctx); err != nil {
|
|
sessionResult.Error = err
|
|
return
|
|
}
|
|
|
|
// get a database connection and query its backend pid
|
|
// note - this will retry if the connection is bad
|
|
databaseConnection, backendPid, err := c.getDatabaseConnectionWithRetries(ctx)
|
|
if err != nil {
|
|
sessionResult.Error = err
|
|
return sessionResult
|
|
}
|
|
|
|
c.sessionsMutex.Lock()
|
|
session, found := c.sessions[backendPid]
|
|
if !found {
|
|
session = db_common.NewDBSession(backendPid)
|
|
session.LifeCycle.Add("created")
|
|
}
|
|
// we get a new *sql.Conn everytime. USE IT!
|
|
session.Connection = databaseConnection
|
|
sessionResult.Session = session
|
|
c.sessionsMutex.Unlock()
|
|
|
|
// make sure that we close the acquired session, in case of error
|
|
defer func() {
|
|
if sessionResult.Error != nil && databaseConnection != nil {
|
|
databaseConnection.Close()
|
|
}
|
|
}()
|
|
|
|
// if there is no ensure session function, we are done
|
|
if c.ensureSessionFunc == nil {
|
|
return sessionResult
|
|
}
|
|
|
|
// update required session search path if needed
|
|
err = c.ensureSessionSearchPath(ctx, session)
|
|
if err != nil {
|
|
sessionResult.Error = err
|
|
return sessionResult
|
|
}
|
|
|
|
if !session.Initialized {
|
|
session.LifeCycle.Add("queued_for_init")
|
|
|
|
err := c.parallelSessionInitLock.Acquire(ctx, 1)
|
|
if err != nil {
|
|
sessionResult.Error = err
|
|
return sessionResult
|
|
}
|
|
c.sessionInitWaitGroup.Add(1)
|
|
|
|
session.LifeCycle.Add("init_start")
|
|
err, warnings := c.ensureSessionFunc(ctx, session)
|
|
session.LifeCycle.Add("init_finish")
|
|
sessionResult.Warnings = warnings
|
|
c.sessionInitWaitGroup.Done()
|
|
c.parallelSessionInitLock.Release(1)
|
|
if err != nil {
|
|
sessionResult.Error = err
|
|
return sessionResult
|
|
}
|
|
|
|
// if there is no error, mark session as initialized
|
|
session.Initialized = true
|
|
}
|
|
|
|
// now write back to the map
|
|
c.sessionsMutex.Lock()
|
|
c.sessions[backendPid] = session
|
|
c.sessionsMutex.Unlock()
|
|
|
|
return sessionResult
|
|
}
|
|
|
|
func (c *DbClient) getDatabaseConnectionWithRetries(ctx context.Context) (*sql.Conn, uint32, error) {
|
|
backoff, err := retry.NewFibonacci(100 * time.Millisecond)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
|
|
var databaseConnection *sql.Conn
|
|
var backendPid uint32
|
|
|
|
retries := 0
|
|
const getSessionMaxRetries = 10
|
|
err = retry.Do(ctx, retry.WithMaxRetries(getSessionMaxRetries, backoff), func(retryLocalCtx context.Context) (e error) {
|
|
if utils.IsContextCancelled(retryLocalCtx) {
|
|
return retryLocalCtx.Err()
|
|
}
|
|
// get a database connection from the pool
|
|
databaseConnection, err = c.dbClient.Conn(retryLocalCtx)
|
|
if err != nil {
|
|
if databaseConnection != nil {
|
|
databaseConnection.Close()
|
|
}
|
|
retries++
|
|
return retry.RetryableError(err)
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
log.Printf("[TRACE] getDatabaseConnectionWithRetries failed after %d retries: %s", retries, err)
|
|
return nil, 0, err
|
|
}
|
|
|
|
if retries > 0 {
|
|
log.Printf("[TRACE] getDatabaseConnectionWithRetries succeeded after %d retries", retries)
|
|
}
|
|
|
|
databaseConnection.Raw(func(driverConn interface{}) error {
|
|
backendPid = driverConn.(*stdlib.Conn).Conn().PgConn().PID()
|
|
return nil
|
|
})
|
|
|
|
return databaseConnection, uint32(backendPid), nil
|
|
}
|