Remove usage of prepared statements - instead excute sql directly. Wrap with executions in LeafRuns to support runtime dependency resolution. Closes #2789. #2772

This commit is contained in:
kaidaguerre
2022-11-23 14:11:56 +00:00
committed by GitHub
parent 398334b8f0
commit fe6365b1ef
40 changed files with 398 additions and 408 deletions

View File

@@ -187,11 +187,11 @@ func executeSnapshotQuery(initData *query.InitData, ctx context.Context) int {
var queryNames = utils.SortedMapKeys(initData.Queries)
if len(queryNames) > 0 {
for i, name := range queryNames {
queryString := initData.Queries[name]
for _, name := range queryNames {
resolvedQuery := initData.Queries[name]
// if a manual query is being run (i.e. not a named query), convert into a query and add to workspace
// this is to allow us to use existing dashboard execution code
queryProvider, existingResource := ensureQueryResource(name, queryString, i, len(queryNames), initData.Workspace)
queryProvider, existingResource := ensureQueryResource(name, resolvedQuery, initData.Workspace)
// we need to pass the embedded initData to GenerateSnapshot
baseInitData := &initData.InitData
@@ -274,9 +274,9 @@ func snapshotToQueryResult(snap *dashboardtypes.SteampipeSnapshot, name string)
return res, nil
}
// convert the given command line query intos a query resource and add to workspace
// convert the given command line query into a query resource and add to workspace
// this is to allow us to use existing dashboard execution code
func ensureQueryResource(name string, query string, queryIdx, queryCount int, w *workspace.Workspace) (queryProvider modconfig.HclResource, existingResource bool) {
func ensureQueryResource(name string, resolvedQuery *modconfig.ResolvedQuery, w *workspace.Workspace) (queryProvider modconfig.HclResource, existingResource bool) {
// is this an existing resource?
if parsedName, err := modconfig.ParseResourceName(name); err == nil {
if resource, found := modconfig.GetResource(w, parsedName); found {
@@ -287,9 +287,10 @@ func ensureQueryResource(name string, query string, queryIdx, queryCount int, w
// build name
shortName := "command_line_query"
// create the query
// this is NOT a named query - create the query using RawSql
q := modconfig.NewQuery(&hcl.Block{}, w.Mod, shortName).(*modconfig.Query)
q.SQL = utils.ToStringPointer(query)
q.SQL = utils.ToStringPointer(resolvedQuery.RawSQL)
q.SetArgs(resolvedQuery.QueryArgs())
// add empty metadata
q.SetMetadata(&modconfig.ResourceMetadata{})

View File

@@ -265,7 +265,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
}()
// resolve the control query
query, err := r.resolveControlQuery(control)
resolvedQuery, err := r.resolveControlQuery(control)
if err != nil {
r.setError(ctx, err)
return
@@ -282,7 +282,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
// execute the control query
// NOTE no need to pass an OnComplete callback - we are already closing our session after waiting for results
log.Printf("[TRACE] execute start for, %s\n", control.Name())
queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, query, nil)
queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, nil, resolvedQuery.ExecuteSQL, resolvedQuery.Args...)
log.Printf("[TRACE] execute finish for, %s\n", control.Name())
if err != nil {
@@ -334,15 +334,12 @@ func (r *ControlRun) getControlQueryContext(ctx context.Context) context.Context
return newCtx
}
func (r *ControlRun) resolveControlQuery(control *modconfig.Control) (string, error) {
func (r *ControlRun) resolveControlQuery(control *modconfig.Control) (*modconfig.ResolvedQuery, error) {
resolvedQuery, err := r.Tree.Workspace.ResolveQueryFromQueryProvider(control, nil)
if err != nil {
return "", fmt.Errorf(`cannot run %s - failed to resolve query "%s": %s`, control.Name(), typehelpers.SafeString(control.SQL), err.Error())
return nil, fmt.Errorf(`cannot run %s - failed to resolve query "%s": %s`, control.Name(), typehelpers.SafeString(control.SQL), err.Error())
}
if resolvedQuery.ExecuteSQL == "" {
return "", fmt.Errorf(`cannot run %s - failed to resolve query "%s"`, control.Name(), typehelpers.SafeString(control.SQL))
}
return resolvedQuery.ExecuteSQL, nil
return resolvedQuery, nil
}
func (r *ControlRun) waitForResults(ctx context.Context) {

View File

@@ -206,19 +206,19 @@ func (e *ExecutionTree) getExecutionRootFromArg(arg string) (modconfig.ModTreeIt
// This is used to implement the 'where' control filtering
func (e *ExecutionTree) getControlMapFromWhereClause(ctx context.Context, whereClause string) (map[string]bool, error) {
// query may either be a 'where' clause, or a named query
query, _, err := e.Workspace.ResolveQueryAndArgsFromSQLString(whereClause)
resolvedQuery, _, err := e.Workspace.ResolveQueryAndArgsFromSQLString(whereClause)
if err != nil {
return nil, err
}
// did we in fact resolve a named query, or just return the 'name' as the query
isNamedQuery := query != whereClause
isNamedQuery := resolvedQuery.ExecuteSQL != whereClause
// if the query is NOT a named query, we need to construct a full query by adding a select
if !isNamedQuery {
query = fmt.Sprintf("select resource_name from %s where %s", constants.IntrospectionTableControl, whereClause)
resolvedQuery.ExecuteSQL = fmt.Sprintf("select resource_name from %s where %s", constants.IntrospectionTableControl, whereClause)
}
res, err := e.client.ExecuteSync(ctx, query)
res, err := e.client.ExecuteSync(ctx, resolvedQuery.ExecuteSQL, resolvedQuery.Args...)
if err != nil {
return nil, err
}

View File

@@ -2,9 +2,8 @@ package dashboardexecute
import (
"context"
"encoding/json"
"fmt"
typehelpers "github.com/turbot/go-kit/types"
"github.com/turbot/steampipe/pkg/type_conversion"
"log"
"strconv"
"sync"
@@ -26,7 +25,7 @@ type LeafRun struct {
Type string `cty:"type" hcl:"type" column:"type,text" json:"display_type,omitempty"`
Display string `cty:"display" hcl:"display" json:"display,omitempty"`
RawSQL string `json:"sql,omitempty"`
Args []string `json:"args,omitempty"`
Args []any `json:"args,omitempty"`
Params []*modconfig.ParamDef `json:"params,omitempty"`
Data *dashboardtypes.LeafData `json:"data,omitempty"`
ErrorString string `json:"error,omitempty"`
@@ -46,6 +45,7 @@ type LeafRun struct {
childComplete chan dashboardtypes.DashboardNodeRun
withValues map[string]*dashboardtypes.LeafData
withValueMutex sync.Mutex
withRuns []*LeafRun
}
func (r *LeafRun) AsTreeNode() *dashboardtypes.SnapshotTreeNode {
@@ -76,7 +76,8 @@ func NewLeafRun(resource modconfig.DashboardLeafNode, parent dashboardtypes.Dash
executionTree: executionTree,
parent: parent,
runtimeDependencies: make(map[string]*ResolvedRuntimeDependency),
withValues: make(map[string]*dashboardtypes.LeafData),
withValues: make(map[string]*dashboardtypes.LeafData),
}
parsedName, err := modconfig.ParseResourceName(resource.Name())
@@ -99,10 +100,29 @@ func NewLeafRun(resource modconfig.DashboardLeafNode, parent dashboardtypes.Dash
executionTree.runs[r.Name] = r
// if we have children (nodes/edges), create runs for them
if children := resource.GetChildren(); len(children) > 0 {
children := resource.GetChildren()
if len(children) > 0 {
// create the child runs
return r.createChildRuns(children, executionTree)
err := r.createChildRuns(children, executionTree)
if err != nil {
return nil, err
}
}
// if we have with blocks, create runs for them
withBlocks := r.DashboardNode.(modconfig.QueryProvider).GetWiths()
if len(withBlocks) > 0 {
// create the child runs
err := r.createWithRuns(withBlocks, executionTree)
if err != nil {
return nil, err
}
}
// create buffered child complete chan
if childCount := len(children) + len(withBlocks); childCount > 0 {
r.childComplete = make(chan dashboardtypes.DashboardNodeRun, childCount)
}
return r, nil
}
@@ -128,18 +148,19 @@ func (r *LeafRun) addRuntimeDependencies() {
r.runtimeDependencies[name] = NewResolvedRuntimeDependency(dep, getValueFunc)
}
// if the parent is a leaf run, we must be a node or an edge, inherit our parent runtime dependencies
if parentLeafRun, ok := r.parent.(*LeafRun); ok {
for name, dep := range parentLeafRun.runtimeDependencies {
if _, ok := r.runtimeDependencies[name]; !ok {
r.runtimeDependencies[name] = dep
// NOTE: UNLESS we are a 'with' run
if _, isWith := r.DashboardNode.(*modconfig.DashboardWith); !isWith {
if parentLeafRun, ok := r.parent.(*LeafRun); ok {
for name, dep := range parentLeafRun.runtimeDependencies {
if _, ok := r.runtimeDependencies[name]; !ok {
r.runtimeDependencies[name] = dep
}
}
}
}
}
func (r *LeafRun) createChildRuns(children []modconfig.ModTreeItem, executionTree *DashboardExecutionTree) (*LeafRun, error) {
// create buffered child complete chan
r.childComplete = make(chan dashboardtypes.DashboardNodeRun, len(children))
func (r *LeafRun) createChildRuns(children []modconfig.ModTreeItem, executionTree *DashboardExecutionTree) error {
r.children = make([]dashboardtypes.DashboardNodeRun, len(children))
var errors []error
@@ -153,7 +174,20 @@ func (r *LeafRun) createChildRuns(children []modconfig.ModTreeItem, executionTre
}
r.children[i] = childRun
}
return r, error_helpers.CombineErrors(errors...)
return error_helpers.CombineErrors(errors...)
}
func (r *LeafRun) createWithRuns(withs []*modconfig.DashboardWith, executionTree *DashboardExecutionTree) error {
r.withRuns = make([]*LeafRun, len(withs))
for i, w := range withs {
withRun, err := NewLeafRun(w, r, executionTree)
if err != nil {
return err
}
r.withRuns[i] = withRun
}
return nil
}
// if we have a query provider which requires execution OR we have children, set status to ready
@@ -181,16 +215,12 @@ func (r *LeafRun) Execute(ctx context.Context) {
// to get here, we must be a query provider
// start all `with` blocks
for _, w := range r.DashboardNode.(modconfig.QueryProvider).GetWiths() {
queryResult, err := r.executionTree.client.ExecuteSync(ctx, typehelpers.SafeString(w.SQL))
if err != nil {
r.SetError(ctx, err)
return
}
withResult := dashboardtypes.NewLeafData(queryResult)
r.setWithValue(w.UnqualifiedName, withResult)
if len(r.withRuns) > 0 {
r.executeWithRuns(ctx)
}
// TODO KAI handle error in with block
// if there are any unresolved runtime dependencies, wait for them
if len(r.runtimeDependencies) > 0 {
if err := r.waitForRuntimeDependencies(ctx); err != nil {
@@ -287,6 +317,15 @@ func (r *LeafRun) ChildrenComplete() bool {
return true
}
func (r *LeafRun) withComplete() bool {
for _, w := range r.withRuns {
if !w.RunComplete() {
return false
}
}
return true
}
// IsSnapshotPanel implements SnapshotPanel
func (*LeafRun) IsSnapshotPanel() {}
@@ -399,21 +438,22 @@ func (r *LeafRun) buildRuntimeDependencyArgs() (*modconfig.QueryArgs, error) {
// build map of default params
for _, dep := range r.runtimeDependencies {
// format the arg value as a postgres string
formattedVal, err := type_conversion.GoToPostgresString(dep.value)
// TACTICAL
// format the arg value as a JSON string
jsonBytes, err := json.Marshal(dep.value)
valStr := string(jsonBytes)
if err != nil {
return nil, err
}
if dep.dependency.ArgName != nil {
res.ArgMap[*dep.dependency.ArgName] = formattedVal
res.ArgMap[*dep.dependency.ArgName] = valStr
} else {
if dep.dependency.ArgIndex == nil {
return nil, fmt.Errorf("invalid runtime dependency - both ArgName and ArgIndex are nil ")
}
// now add at correct index
res.ArgList[*dep.dependency.ArgIndex] = &formattedVal
res.ArgList[*dep.dependency.ArgIndex] = &valStr
}
}
return res, nil
@@ -423,19 +463,8 @@ func (r *LeafRun) buildRuntimeDependencyArgs() (*modconfig.QueryArgs, error) {
func (r *LeafRun) executeQuery(ctx context.Context) {
log.Printf("[TRACE] LeafRun '%s' SQL resolved, executing", r.DashboardNode.Name())
queryResult, err := r.executionTree.client.ExecuteSync(ctx, r.executeSQL)
queryResult, err := r.executionTree.client.ExecuteSync(ctx, r.executeSQL, r.Args...)
if err != nil {
query := r.DashboardNode.(modconfig.QueryProvider).GetQuery()
if query != nil {
queryName := query.Name()
// get the query and any prepared statement error from the workspace
preparedStatementFailure := r.executionTree.workspace.GetPreparedStatementCreationFailure(queryName)
if preparedStatementFailure != nil {
declRange := preparedStatementFailure.Query.GetDeclRange()
preparedStatementError := preparedStatementFailure.Error
err = error_helpers.EnrichPreparedStatementError(err, queryName, preparedStatementError, declRange)
}
}
log.Printf("[TRACE] LeafRun '%s' query failed: %s", r.DashboardNode.Name(), err.Error())
// set the error status on the counter - this will raise counter error event
r.SetError(ctx, err)
@@ -450,6 +479,40 @@ func (r *LeafRun) executeQuery(ctx context.Context) {
r.SetComplete(ctx)
}
// if this leaf run has with runs), execute them
func (r *LeafRun) executeWithRuns(ctx context.Context) {
for _, w := range r.withRuns {
go w.Execute(ctx)
}
// wait for children to complete
var errors []error
for !r.withComplete() {
log.Printf("[TRACE] run %s waiting for with runs", r.Name)
completeChild := <-r.childComplete
log.Printf("[TRACE] run %s got with complete", r.Name)
if completeChild.GetRunStatus() == dashboardtypes.DashboardRunError {
errors = append(errors, completeChild.GetError())
}
// fall through to recheck ChildrenComplete
}
log.Printf("[TRACE] run %s ALL children complete", r.Name)
// so all with runs have completed - check for errors
err := error_helpers.CombineErrors(errors...)
if err == nil {
r.setWithData()
} else {
r.SetError(ctx, err)
}
}
func (r *LeafRun) setWithData() {
for _, w := range r.withRuns {
r.setWithValue(w.DashboardNode.GetUnqualifiedName(), w.Data)
}
}
// if this leaf run has children (nodes/edges), execute them
func (r *LeafRun) executeChildren(ctx context.Context) {
for _, c := range r.children {
@@ -585,7 +648,6 @@ func columnValuesFromRows(column string, rows []map[string]interface{}) (any, er
}
return res, nil
}
func (r *LeafRun) setWithValue(name string, result *dashboardtypes.LeafData) {
r.withValueMutex.Lock()
defer r.withValueMutex.Unlock()

View File

@@ -26,7 +26,7 @@ import (
// ExecuteSync implements Client
// execute a query against this client and wait for the result
func (c *DbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) {
func (c *DbClient) ExecuteSync(ctx context.Context, query string, args ...any) (*queryresult.SyncQueryResult, error) {
// acquire a session
sessionResult := c.AcquireSession(ctx)
if sessionResult.Error != nil {
@@ -42,17 +42,17 @@ func (c *DbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.
// and not in call-time
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
}()
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query)
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, args...)
}
// ExecuteSyncInSession implements Client
// execute a query against this client and wait for the result
func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) {
func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, args ...any) (*queryresult.SyncQueryResult, error) {
if query == "" {
return &queryresult.SyncQueryResult{}, nil
}
result, err := c.ExecuteInSession(ctx, session, query, nil)
result, err := c.ExecuteInSession(ctx, session, nil, query, args...)
if err != nil {
return nil, error_helpers.WrapError(err)
}
@@ -79,7 +79,7 @@ func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.
// Execute implements Client
// execute the query in the given Context
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
func (c *DbClient) Execute(ctx context.Context, query string) (*queryresult.Result, error) {
func (c *DbClient) Execute(ctx context.Context, query string, args ...any) (*queryresult.Result, error) {
// acquire a session
sessionResult := c.AcquireSession(ctx)
if sessionResult.Error != nil {
@@ -91,14 +91,14 @@ func (c *DbClient) Execute(ctx context.Context, query string) (*queryresult.Resu
// define callback to close session when the async execution is complete
closeSessionCallback := func() { sessionResult.Session.Close(utils.IsContextCancelled(ctx)) }
return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback)
return c.ExecuteInSession(ctx, sessionResult.Session, closeSessionCallback, query, args...)
}
// ExecuteInSession implements Client
// execute the query in the given Context using the provided DatabaseSession
// ExecuteInSession assumes no responsibility over the lifecycle of the DatabaseSession - that is the responsibility of the caller
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) {
func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, onComplete func(), query string, args ...any) (res *queryresult.Result, err error) {
if query == "" {
return queryresult.NewResult(nil), nil
}
@@ -138,7 +138,7 @@ func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.Data
// start query
var rows pgx.Rows
rows, err = c.startQuery(ctxExecute, query, session.Connection)
rows, err = c.startQuery(ctxExecute, session.Connection, query, args...)
if err != nil {
return
}
@@ -250,12 +250,12 @@ func (c *DbClient) updateScanMetadataMaxId(ctx context.Context, session *db_comm
// run query in a goroutine, so we can check for cancellation
// in case the client becomes unresponsive and does not respect context cancellation
func (c *DbClient) startQuery(ctx context.Context, query string, conn *pgxpool.Conn) (rows pgx.Rows, err error) {
func (c *DbClient) startQuery(ctx context.Context, conn *pgxpool.Conn, query string, args ...any) (rows pgx.Rows, err error) {
doneChan := make(chan bool)
go func() {
// start asynchronous query
rows, err = conn.Query(ctx, query)
rows, err = conn.Query(ctx, query, args...)
close(doneChan)
}()

View File

@@ -24,11 +24,11 @@ type Client interface {
AcquireSession(context.Context) *AcquireSessionResult
ExecuteSync(context.Context, string) (*queryresult.SyncQueryResult, error)
Execute(context.Context, string) (*queryresult.Result, error)
ExecuteSync(context.Context, string, ...any) (*queryresult.SyncQueryResult, error)
Execute(context.Context, string, ...any) (*queryresult.Result, error)
ExecuteSyncInSession(context.Context, *DatabaseSession, string) (*queryresult.SyncQueryResult, error)
ExecuteInSession(context.Context, *DatabaseSession, string, func()) (*queryresult.Result, error)
ExecuteSyncInSession(context.Context, *DatabaseSession, string, ...any) (*queryresult.SyncQueryResult, error)
ExecuteInSession(context.Context, *DatabaseSession, func(), string, ...any) (*queryresult.Result, error)
CacheOn(context.Context) error
CacheOff(context.Context) error

View File

@@ -8,12 +8,12 @@ import (
)
// ExecuteQuery executes a single query. If shutdownAfterCompletion is true, shutdown the client after completion
func ExecuteQuery(ctx context.Context, queryString string, client Client) (*queryresult.ResultStreamer, error) {
func ExecuteQuery(ctx context.Context, client Client, queryString string, args ...any) (*queryresult.ResultStreamer, error) {
utils.LogTime("db.ExecuteQuery start")
defer utils.LogTime("db.ExecuteQuery end")
resultsStreamer := queryresult.NewResultStreamer()
result, err := client.Execute(ctx, queryString)
result, err := client.Execute(ctx, queryString, args...)
if err != nil {
return nil, err
}

View File

@@ -6,7 +6,6 @@ import (
"github.com/spf13/viper"
"github.com/turbot/steampipe/pkg/constants"
"github.com/turbot/steampipe/pkg/error_helpers"
"github.com/turbot/steampipe/pkg/steampipeconfig"
)
type InitResult struct {
@@ -54,9 +53,3 @@ func (r *InitResult) DisplayMessages() {
r.DisplayMessage(context.Background(), m)
}
}
func (r *InitResult) AddPreparedStatementFailures(preparedStatementFailures map[string]*steampipeconfig.PreparedStatementFailure) {
for _, failure := range preparedStatementFailures {
r.AddWarnings(failure.String())
}
}

View File

@@ -1,86 +0,0 @@
package db_common
import (
"context"
"fmt"
"golang.org/x/exp/maps"
"log"
"strings"
"github.com/jackc/pgx/v4"
typehelpers "github.com/turbot/go-kit/types"
"github.com/turbot/steampipe/pkg/steampipeconfig/modconfig"
"github.com/turbot/steampipe/pkg/utils"
)
type PrepareStatementFailures struct {
Failures map[string]error
Error error
}
func NewPrepareStatementFailures() *PrepareStatementFailures {
return &PrepareStatementFailures{Failures: make(map[string]error)}
}
func CreatePreparedStatements(ctx context.Context, resourceMaps *modconfig.ResourceMaps, conn *pgx.Conn, combineSql bool) (error, *PrepareStatementFailures) {
log.Printf("[TRACE] CreatePreparedStatements")
utils.LogTime("db.CreatePreparedStatements start")
defer utils.LogTime("db.CreatePreparedStatements end")
// first get the SQL to create all prepared statements
sqlMap := GetPreparedStatementsSQL(resourceMaps)
if len(sqlMap) == 0 {
return nil, nil
}
// map of prepared statement failures, keyed by query name
failureMap := NewPrepareStatementFailures()
if combineSql {
sql := strings.Join(maps.Values(sqlMap), ";\n")
if _, err := conn.Exec(ctx, sql); err != nil {
failureMap.Error = err
}
} else {
for name, sql := range sqlMap {
if _, err := conn.Exec(ctx, sql); err != nil {
failureMap.Failures[name] = err
}
}
}
// return context error - this enables calling code to respond to cancellation
return ctx.Err(), failureMap
}
func GetPreparedStatementsSQL(resourceMaps *modconfig.ResourceMaps) map[string]string {
// make map of resource name to create SQL
sqlMap := make(map[string]string)
for _, queryProvider := range resourceMaps.QueryProviders() {
if createSQL := getPreparedStatementCreateSql(queryProvider); createSQL != nil {
sqlMap[queryProvider.Name()] = *createSQL
}
}
return sqlMap
}
func getPreparedStatementCreateSql(queryProvider modconfig.QueryProvider) *string {
// the query is a prepared statement if it defines its own sql and has parameters or (positional) arguments
if !modconfig.QueryProviderIsParameterised(queryProvider) {
return nil
}
// if the query provider has params, is MUST define SQL
// remove trailing semicolons from sql as this breaks the prepare statement
rawSql := cleanPreparedStatementCreateSQL(typehelpers.SafeString(queryProvider.GetSQL()))
preparedStatementName := queryProvider.GetPreparedStatementName()
createSQL := fmt.Sprintf("PREPARE %s AS (\n%s\n)", preparedStatementName, rawSql)
return &createSQL
}
func cleanPreparedStatementCreateSQL(query string) string {
rawSql := strings.TrimRight(strings.TrimSpace(query), ";")
return rawSql
}

View File

@@ -106,23 +106,23 @@ func (c *LocalDbClient) AcquireSession(ctx context.Context) *db_common.AcquireSe
}
// ExecuteSync implements Client
func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSync(ctx, query)
func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string, args ...any) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSync(ctx, query, args...)
}
// ExecuteSyncInSession implements Client
func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSyncInSession(ctx, session, query)
func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, args ...any) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSyncInSession(ctx, session, query, args...)
}
// ExecuteInSession implements Client
func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) {
return c.client.ExecuteInSession(ctx, session, query, onComplete)
func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, onComplete func(), query string, args ...any) (res *queryresult.Result, err error) {
return c.client.ExecuteInSession(ctx, session, onComplete, query, args...)
}
// Execute implements Client
func (c *LocalDbClient) Execute(ctx context.Context, query string) (res *queryresult.Result, err error) {
return c.client.Execute(ctx, query)
func (c *LocalDbClient) Execute(ctx context.Context, query string, args ...any) (res *queryresult.Result, err error) {
return c.client.Execute(ctx, query, args...)
}
// CacheOn implements Client

View File

@@ -114,14 +114,7 @@ func (i *InitData) Init(ctx context.Context, invoker constants.Invoker) (res *In
sessionDataSource := workspace.NewSessionDataSource(i.Workspace, i.PreparedStatementSource)
// define db connection callback function
ensureSessionData := func(ctx context.Context, conn *pgx.Conn) error {
// if we are connecting to Steampipecloud, combine prepared statement sql when creating prepared statements
// to optimise performance
combineSql := cloudMetadata != nil
err, preparedStatementFailures := workspace.EnsureSessionData(ctx, sessionDataSource, conn, combineSql)
i.Workspace.HandlePreparedStatementFailures(preparedStatementFailures)
return err
return workspace.EnsureSessionData(ctx, sessionDataSource, conn)
}
// get a client
@@ -155,8 +148,6 @@ func (i *InitData) Init(ctx context.Context, invoker constants.Invoker) (res *In
}
// add refresh connection warnings
i.Result.AddWarnings(refreshResult.Warnings...)
// add warnings from prepared statement creation
i.Result.AddPreparedStatementFailures(i.Workspace.GetPreparedStatementFailures())
return
}

View File

@@ -388,8 +388,8 @@ func (c *InteractiveClient) executor(ctx context.Context, line string) {
line = strings.TrimSpace(line)
query := c.getQuery(ctx, line)
if query == "" {
resolvedQuery := c.getQuery(ctx, line)
if resolvedQuery == nil {
// we failed to resolve a query, or are in the middle of a multi-line entry
// restart the prompt, DO NOT clear the interactive buffer
c.restartInteractiveSession()
@@ -401,8 +401,8 @@ func (c *InteractiveClient) executor(ctx context.Context, line string) {
// create a context for the execution of the query
queryCtx := c.createQueryContext(ctx)
if metaquery.IsMetaQuery(query) {
if err := c.executeMetaquery(queryCtx, query); err != nil {
if metaquery.IsMetaQuery(resolvedQuery.ExecuteSQL) {
if err := c.executeMetaquery(queryCtx, resolvedQuery.ExecuteSQL); err != nil {
error_helpers.ShowError(ctx, err)
}
// cancel the context
@@ -411,7 +411,7 @@ func (c *InteractiveClient) executor(ctx context.Context, line string) {
} else {
// otherwise execute query
t := time.Now()
result, err := c.client().Execute(queryCtx, query)
result, err := c.client().Execute(queryCtx, resolvedQuery.ExecuteSQL, resolvedQuery.Args...)
if err != nil {
error_helpers.ShowError(ctx, error_helpers.HandleCancelError(err))
// if timing flag is enabled, show the time taken for the query to fail
@@ -427,10 +427,10 @@ func (c *InteractiveClient) executor(ctx context.Context, line string) {
c.restartInteractiveSession()
}
func (c *InteractiveClient) getQuery(ctx context.Context, line string) string {
func (c *InteractiveClient) getQuery(ctx context.Context, line string) *modconfig.ResolvedQuery {
// if it's an empty line, then we don't need to do anything
if line == "" {
return ""
return nil
}
// store the history (the raw line which was entered)
@@ -463,7 +463,7 @@ func (c *InteractiveClient) getQuery(ctx context.Context, line string) string {
// clear the interactive buffer
c.interactiveBuffer = nil
// error will have been handled elsewhere
return ""
return nil
}
}
@@ -474,7 +474,7 @@ func (c *InteractiveClient) getQuery(ctx context.Context, line string) string {
queryString := strings.Join(c.interactiveBuffer, "\n")
// in case of a named query call with params, parse the where clause
query, queryProvider, err := c.workspace().ResolveQueryAndArgsFromSQLString(queryString)
resolvedQuery, queryProvider, err := c.workspace().ResolveQueryAndArgsFromSQLString(queryString)
if err != nil {
// if we fail to resolve:
// - show error but do not return it so we stay in the prompt
@@ -482,7 +482,7 @@ func (c *InteractiveClient) getQuery(ctx context.Context, line string) string {
// - clear interactive buffer
c.interactiveBuffer = nil
error_helpers.ShowError(ctx, err)
return ""
return nil
}
isNamedQuery := queryProvider != nil
@@ -493,7 +493,7 @@ func (c *InteractiveClient) getQuery(ctx context.Context, line string) string {
// is we are not executing, do not store history
historyEntry = ""
// do not clear interactive buffer
return ""
return nil
}
// so we need to execute
@@ -503,18 +503,18 @@ func (c *InteractiveClient) getQuery(ctx context.Context, line string) string {
// what are we executing?
// if the line is ONLY a semicolon, do nothing and restart interactive session
if strings.TrimSpace(query) == ";" {
if strings.TrimSpace(resolvedQuery.ExecuteSQL) == ";" {
// do not store in history
historyEntry = ""
c.restartInteractiveSession()
return ""
return nil
}
// if this is a multiline query, update history entry
if !isNamedQuery && len(strings.Split(query, "\n")) > 1 {
historyEntry = query
if !isNamedQuery && len(strings.Split(resolvedQuery.ExecuteSQL, "\n")) > 1 {
historyEntry = resolvedQuery.ExecuteSQL
}
return query
return resolvedQuery
}
func (c *InteractiveClient) executeMetaquery(ctx context.Context, query string) error {

View File

@@ -7,6 +7,7 @@ import (
"github.com/turbot/steampipe/pkg/constants"
"github.com/turbot/steampipe/pkg/export"
"github.com/turbot/steampipe/pkg/initialisation"
"github.com/turbot/steampipe/pkg/steampipeconfig/modconfig"
"github.com/turbot/steampipe/pkg/workspace"
)
@@ -14,8 +15,8 @@ type InitData struct {
initialisation.InitData
cancelInitialisation context.CancelFunc
Loaded chan struct{}
// map of query name to query (key is the query text for command line queries)
Queries map[string]string
// map of query name to resolved query (key is the query text for command line queries)
Queries map[string]*modconfig.ResolvedQuery
}
// NewInitData returns a new InitData object
@@ -90,7 +91,7 @@ func (i *InitData) init(ctx context.Context, w *workspace.Workspace, args []stri
// set max DB connections to 1
viper.Set(constants.ArgMaxParallel, 1)
// convert the query or sql file arg into an array of executable queries - check names queries in the current workspace
queries, preparedStatementSource, err := w.GetQueriesFromArgs(args)
resolvedQueries, preparedStatementSource, err := w.GetQueriesFromArgs(args)
if err != nil {
i.Result.Error = err
return
@@ -99,7 +100,7 @@ func (i *InitData) init(ctx context.Context, w *workspace.Workspace, args []stri
ctx, cancel := context.WithCancel(ctx)
// and store it
i.cancelInitialisation = cancel
i.Queries = queries
i.Queries = resolvedQueries
i.PreparedStatementSource = preparedStatementSource
// and call base init

View File

@@ -14,6 +14,7 @@ import (
"github.com/turbot/steampipe/pkg/error_helpers"
"github.com/turbot/steampipe/pkg/interactive"
"github.com/turbot/steampipe/pkg/query"
"github.com/turbot/steampipe/pkg/steampipeconfig/modconfig"
"github.com/turbot/steampipe/pkg/utils"
)
@@ -65,11 +66,11 @@ func executeQueries(ctx context.Context, initData *query.InitData) int {
t := time.Now()
// build ordered list of queries
// (ordered for testing repeatability)
var queryNames []string = utils.SortedMapKeys(initData.Queries)
var queryNames = utils.SortedMapKeys(initData.Queries)
for i, name := range queryNames {
q := initData.Queries[name]
if err := executeQuery(ctx, q, initData.Client); err != nil {
if err := executeQuery(ctx, initData.Client, q); err != nil {
failures++
error_helpers.ShowWarning(fmt.Sprintf("executeQueries: query %d of %d failed: %v", i+1, len(queryNames), error_helpers.DecodePgError(err)))
// if timing flag is enabled, show the time taken for the query to fail
@@ -87,12 +88,12 @@ func executeQueries(ctx context.Context, initData *query.InitData) int {
return failures
}
func executeQuery(ctx context.Context, queryString string, client db_common.Client) error {
func executeQuery(ctx context.Context, client db_common.Client, resolvedQuery *modconfig.ResolvedQuery) error {
utils.LogTime("query.execute.executeQuery start")
defer utils.LogTime("query.execute.executeQuery end")
// the db executor sends result data over resultsStreamer
resultsStreamer, err := db_common.ExecuteQuery(ctx, queryString, client)
resultsStreamer, err := db_common.ExecuteQuery(ctx, client, resolvedQuery.ExecuteSQL, resolvedQuery.Args...)
if err != nil {
return err
}

View File

@@ -320,10 +320,10 @@ func (c *Control) GetPreparedStatementName() string {
return c.PreparedStatementName
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (c *Control) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// GetResolvedQuery implements QueryProvider
func (c *Control) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return c.getPreparedStatementExecuteSQL(c, runtimeArgs)
return c.getResolvedQuery(c, runtimeArgs)
}
// GetWidth implements DashboardLeafNode

View File

@@ -266,9 +266,9 @@ func (c *DashboardCard) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (c *DashboardCard) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (c *DashboardCard) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return c.getPreparedStatementExecuteSQL(c, runtimeArgs)
return c.getResolvedQuery(c, runtimeArgs)
}
func (c *DashboardCard) setBaseProperties(resourceMapProvider ResourceMapsProvider) {

View File

@@ -289,9 +289,9 @@ func (c *DashboardChart) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (c *DashboardChart) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (c *DashboardChart) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return c.getPreparedStatementExecuteSQL(c, runtimeArgs)
return c.getResolvedQuery(c, runtimeArgs)
}
func (c *DashboardChart) setBaseProperties(resourceMapProvider ResourceMapsProvider) {

View File

@@ -244,9 +244,9 @@ func (e *DashboardEdge) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (e *DashboardEdge) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (e *DashboardEdge) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return e.getPreparedStatementExecuteSQL(e, runtimeArgs)
return e.getResolvedQuery(e, runtimeArgs)
}
// IsSnapshotPanel implements SnapshotPanel

View File

@@ -274,9 +274,9 @@ func (f *DashboardFlow) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (f *DashboardFlow) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (f *DashboardFlow) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return f.getPreparedStatementExecuteSQL(f, runtimeArgs)
return f.getResolvedQuery(f, runtimeArgs)
}
// GetEdges implements EdgeAndNodeProvider

View File

@@ -278,9 +278,9 @@ func (g *DashboardGraph) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (g *DashboardGraph) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (g *DashboardGraph) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return g.getPreparedStatementExecuteSQL(g, runtimeArgs)
return g.getResolvedQuery(g, runtimeArgs)
}
// GetEdges implements EdgeAndNodeProvider

View File

@@ -274,9 +274,9 @@ func (h *DashboardHierarchy) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (h *DashboardHierarchy) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (h *DashboardHierarchy) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return h.getPreparedStatementExecuteSQL(h, runtimeArgs)
return h.getResolvedQuery(h, runtimeArgs)
}
// GetEdges implements EdgeAndNodeProvider

View File

@@ -248,9 +248,9 @@ func (i *DashboardImage) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (i *DashboardImage) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (i *DashboardImage) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return i.getPreparedStatementExecuteSQL(i, runtimeArgs)
return i.getResolvedQuery(i, runtimeArgs)
}
func (i *DashboardImage) setBaseProperties(resourceMapProvider ResourceMapsProvider) {

View File

@@ -301,9 +301,9 @@ func (i *DashboardInput) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (i *DashboardInput) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (i *DashboardInput) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return i.getPreparedStatementExecuteSQL(i, runtimeArgs)
return i.getResolvedQuery(i, runtimeArgs)
}
// DependsOnInput returns whether this input has a runtime dependency on the given input

View File

@@ -244,9 +244,9 @@ func (n *DashboardNode) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (n *DashboardNode) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (n *DashboardNode) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return n.getPreparedStatementExecuteSQL(n, runtimeArgs)
return n.getResolvedQuery(n, runtimeArgs)
}
// IsSnapshotPanel implements SnapshotPanel

View File

@@ -286,9 +286,9 @@ func (t *DashboardTable) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (t *DashboardTable) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (t *DashboardTable) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return t.getPreparedStatementExecuteSQL(t, runtimeArgs)
return t.getResolvedQuery(t, runtimeArgs)
}
func (t *DashboardTable) setBaseProperties(resourceMapProvider ResourceMapsProvider) {

View File

@@ -211,15 +211,30 @@ func (e *DashboardWith) GetPreparedStatementName() string {
return e.PreparedStatementName
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (e *DashboardWith) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// GetResolvedQuery implements QueryProvider
func (e *DashboardWith) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return e.getPreparedStatementExecuteSQL(e, runtimeArgs)
return e.getResolvedQuery(e, runtimeArgs)
}
// IsSnapshotPanel implements SnapshotPanel
func (*DashboardWith) IsSnapshotPanel() {}
// GetWidth implements DashboardLeafNode
func (*DashboardWith) GetWidth() int {
return 0
}
// GetDisplay implements DashboardLeafNode
func (*DashboardWith) GetDisplay() string {
return ""
}
// GetType implements DashboardLeafNode
func (*DashboardWith) GetType() string {
return ""
}
func (e *DashboardWith) setBaseProperties(resourceMapProvider ResourceMapsProvider) {
// not all base properties are stored in the evalContext
// (e.g. resource metadata and runtime dependencies are not stores)

View File

@@ -68,7 +68,7 @@ type QueryProvider interface {
GetMod() *Mod
GetDescription() string
GetPreparedStatementName() string
GetPreparedStatementExecuteSQL(*QueryArgs) (*ResolvedQuery, error)
GetResolvedQuery(*QueryArgs) (*ResolvedQuery, error)
// implemented by QueryProviderBase
AddRuntimeDependencies([]*RuntimeDependency)
GetRuntimeDependencies() map[string]*RuntimeDependency

View File

@@ -254,9 +254,9 @@ func (q *Query) GetPreparedStatementName() string {
}
// GetPreparedStatementExecuteSQL implements QueryProvider
func (q *Query) GetPreparedStatementExecuteSQL(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
func (q *Query) GetResolvedQuery(runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
// defer to base
return q.getPreparedStatementExecuteSQL(q, runtimeArgs)
return q.getResolvedQuery(q, runtimeArgs)
}
// AddParent implements ModTreeItem

View File

@@ -48,6 +48,16 @@ func (q *QueryArgs) ArgsStringList() []string {
return argsStringList
}
// TODO RENAME
// SafeArgsList convert ArgLists into list of strings but return as an interface slice
func (q *QueryArgs) SafeArgsList() []any {
var argsStringList = make([]any, len(q.ArgList))
for i, a := range q.ArgList {
argsStringList[i] = typehelpers.SafeString(a)
}
return argsStringList
}
func NewQueryArgs() *QueryArgs {
return &QueryArgs{
ArgMap: make(map[string]string),

View File

@@ -1,11 +1,11 @@
package modconfig
import (
"encoding/json"
"fmt"
"log"
"strings"
typehelpers "github.com/turbot/go-kit/types"
"github.com/turbot/steampipe/pkg/utils"
)
@@ -24,12 +24,13 @@ func MergeArgs(queryProvider QueryProvider, runtimeArgs *QueryArgs) (*QueryArgs,
return baseArgs.Merge(runtimeArgs, queryProvider)
}
// ResolveArgsAsString resolves the argument values,
// ResolveArgs resolves the argument values,
// falling back on defaults from param definitions in the source (if present)
// it returns the arg values as a csv string which can be used in a prepared statement invocation
// (the arg values and param defaults will already have been converted to postgres format)
func ResolveArgsAsString(source QueryProvider, runtimeArgs *QueryArgs) (string, []string, error) {
var paramStrs, missingParams []string
func ResolveArgs(source QueryProvider, runtimeArgs *QueryArgs) ([]any, error) {
var paramVals []any
var missingParams []string
var err error
// validate args
if runtimeArgs == nil {
@@ -43,56 +44,71 @@ func ResolveArgsAsString(source QueryProvider, runtimeArgs *QueryArgs) (string,
}
mergedArgs, err := sourceArgs.Merge(runtimeArgs, source)
if err != nil {
return "", nil, err
return nil, err
}
if len(mergedArgs.ArgMap) > 0 {
// do params contain named params?
paramStrs, missingParams, err = resolveNamedParameters(source, mergedArgs)
paramVals, missingParams, err = resolveNamedParameters(source, mergedArgs)
} else {
// resolve as positional parameters
// (or fall back to defaults if no positional params are present)
paramStrs, missingParams, err = resolvePositionalParameters(source, mergedArgs)
paramVals, missingParams, err = resolvePositionalParameters(source, mergedArgs)
}
if err != nil {
return "", nil, err
return nil, err
}
// did we resolve them all?
if len(missingParams) > 0 {
// a better error will be constructed by the calling code
return "", nil, fmt.Errorf("%s", strings.Join(missingParams, ","))
return nil, fmt.Errorf("%s", strings.Join(missingParams, ","))
}
// are there any params?
if len(paramStrs) == 0 {
return "", nil, nil
if len(paramVals) == 0 {
return nil, nil
}
// success!
return fmt.Sprintf("(%s)", strings.Join(paramStrs, ",")), paramStrs, nil
return paramVals, nil
}
func resolveNamedParameters(queryProvider QueryProvider, args *QueryArgs) (argStrs []string, missingParams []string, err error) {
func resolveNamedParameters(queryProvider QueryProvider, args *QueryArgs) (argVals []any, missingParams []string, err error) {
// if query params contains both positional and named params, error out
params := queryProvider.GetParams()
argStrs = make([]string, len(params))
argVals = make([]any, len(params))
// iterate through each param def and resolve the value
// build a map of which args have been matched (used to validate all args have param defs)
argsWithParamDef := make(map[string]bool)
for i, param := range params {
// first set default
defaultValue := typehelpers.SafeString(param.Default)
var defaultValue any = nil
if param.Default == nil {
defaultValue = ""
} else {
if param.Default != nil {
err := json.Unmarshal([]byte(*param.Default), &defaultValue)
if err != nil {
return nil, nil, err
}
}
}
// can we resolve a value for this param?
if val, ok := args.ArgMap[param.Name]; ok {
argStrs[i] = val
// convert from json
var argVal any
err := json.Unmarshal([]byte(val), &argVal)
if err != nil {
return nil, nil, err
}
argVals[i] = argVal
argsWithParamDef[param.Name] = true
} else if defaultValue != "" {
} else if defaultValue != nil {
// is there a default
argStrs[i] = defaultValue
argVals[i] = defaultValue
} else {
// no value provided and no default defined - add to missing list
missingParams = append(missingParams, param.Name)
@@ -106,12 +122,11 @@ func resolveNamedParameters(queryProvider QueryProvider, args *QueryArgs) (argSt
}
}
return argStrs, missingParams, nil
return argVals, missingParams, nil
}
func resolvePositionalParameters(queryProvider QueryProvider, args *QueryArgs) (argStrs []string, missingParams []string, err error) {
func resolvePositionalParameters(queryProvider QueryProvider, args *QueryArgs) (argValues []any, missingParams []string, err error) {
// if query params contains both positional and named params, error out
// if there are param defs - we must be able to resolve all params
// if there are MORE defs than provided parameters, all remaining defs MUST provide a default
params := queryProvider.GetParams()
@@ -120,8 +135,8 @@ func resolvePositionalParameters(queryProvider QueryProvider, args *QueryArgs) (
if len(params) == 0 {
// no params defined, so we return as many args as are provided
// (convert from *string to string)
argStrs = args.ArgsStringList()
return argStrs, nil, nil
argValues = args.SafeArgsList()
return argValues, nil, nil
}
// so there are param definitions - use these to populate argStrs
@@ -138,37 +153,35 @@ func resolvePositionalParameters(queryProvider QueryProvider, args *QueryArgs) (
return
}
argStrs = make([]string, len(params))
// so there are param definitions - use these to populate argStrs
argValues = make([]any, len(params))
for i, param := range params {
// first set default
defaultValue := typehelpers.SafeString(param.Default)
var defaultValue any = nil
if param.Default != nil {
err := json.Unmarshal([]byte(*param.Default), &defaultValue)
if err != nil {
return nil, nil, err
}
}
if i < len(args.ArgList) && args.ArgList[i] != nil {
argStrs[i] = typehelpers.SafeString(args.ArgList[i])
} else if defaultValue != "" {
// convert from json
var argVal any
err := json.Unmarshal([]byte(*args.ArgList[i]), &argVal)
if err != nil {
return nil, nil, err
}
argValues[i] = argVal
} else if defaultValue != nil {
// so we have run out of provided params - is there a default?
argStrs[i] = defaultValue
argValues[i] = defaultValue
} else {
// no value provided and no default defined - add to missing list
missingParams = append(missingParams, param.Name)
}
}
return argStrs, missingParams, nil
}
// QueryProviderIsParameterised returns whether the query provider has a parameterised query
// the query is parameterised if either there are any param defintions, or any positional arguments passed,
// or it has runtime dependencies (which must be args)
func QueryProviderIsParameterised(queryProvider QueryProvider) bool {
// no sql, NOT parameterised
if queryProvider.GetSQL() == nil {
return false
}
args := queryProvider.GetArgs()
params := queryProvider.GetParams()
runtimeDependencies := queryProvider.GetRuntimeDependencies()
return args != nil || len(params) > 0 || len(runtimeDependencies) > 0
return argValues, missingParams, nil
}

View File

@@ -2,19 +2,17 @@ package modconfig
import (
"fmt"
"github.com/hashicorp/hcl/v2"
"log"
"strings"
"github.com/hashicorp/hcl/v2"
typehelpers "github.com/turbot/go-kit/types"
"github.com/turbot/steampipe/pkg/constants"
"github.com/turbot/steampipe/pkg/utils"
)
type QueryProviderBase struct {
runtimeDependencies map[string]*RuntimeDependency
Withs []*DashboardWith
withs []*DashboardWith
}
// VerifyQuery returns an error if neither sql or query are set
@@ -64,16 +62,20 @@ func (b *QueryProviderBase) buildPreparedStatementPrefix(modName string) string
}
// return the SQLs to run the query as a prepared statement
func (b *QueryProviderBase) getPreparedStatementExecuteSQL(queryProvider QueryProvider, runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
argsString, argsArray, err := ResolveArgsAsString(queryProvider, runtimeArgs)
func (b *QueryProviderBase) getResolvedQuery(queryProvider QueryProvider, runtimeArgs *QueryArgs) (*ResolvedQuery, error) {
argsArray, err := ResolveArgs(queryProvider, runtimeArgs)
if err != nil {
return nil, fmt.Errorf("failed to resolve args for %s: %s", queryProvider.Name(), err.Error())
}
executeString := fmt.Sprintf("execute %s%s", queryProvider.GetPreparedStatementName(), argsString)
log.Printf("[TRACE] GetPreparedStatementExecuteSQL source: %s, sql: %s, args: %s", queryProvider.Name(), executeString, runtimeArgs)
sql := typehelpers.SafeString(queryProvider.GetSQL())
// we expect there to be sql on the query provider, NOT a Query
if sql == "" {
return nil, fmt.Errorf("getResolvedQuery faiuled - no sql set for '%s'", queryProvider.Name())
}
return &ResolvedQuery{
ExecuteSQL: executeString,
RawSQL: typehelpers.SafeString(queryProvider.GetSQL()),
ExecuteSQL: sql,
RawSQL: sql,
Args: argsArray,
Params: queryProvider.GetParams(),
}, nil
@@ -129,11 +131,11 @@ func (*QueryProviderBase) GetDescription() string {
}
func (b *QueryProviderBase) AddWith(with *DashboardWith) {
b.Withs = append(b.Withs, with)
b.withs = append(b.withs, with)
}
func (b *QueryProviderBase) GetWith(name string) (*DashboardWith, bool) {
for _, w := range b.Withs {
for _, w := range b.withs {
if w.UnqualifiedName == name {
return w, true
}
@@ -141,5 +143,5 @@ func (b *QueryProviderBase) GetWith(name string) (*DashboardWith, bool) {
return nil, false
}
func (b *QueryProviderBase) GetWiths() []*DashboardWith {
return b.Withs
return b.withs
}

View File

@@ -1,9 +1,29 @@
package modconfig
import (
"encoding/json"
)
// ResolvedQuery contains the execute SQL, raw SQL and args string used to execute a query
type ResolvedQuery struct {
ExecuteSQL string
RawSQL string
Args []string
Args []any
Params []*ParamDef
}
func (r ResolvedQuery) QueryArgs() *QueryArgs {
res := NewQueryArgs()
res.ArgList = make([]*string, len(r.Args))
for i, a := range r.Args {
// TODO TACTICAL check/fix
jsonBytes, err := json.Marshal(a)
argStr := string(jsonBytes)
if err != nil {
res.ArgList[i] = &argStr
}
}
return res
}

View File

@@ -2,6 +2,7 @@ package parse
import (
"fmt"
"github.com/turbot/steampipe/pkg/type_conversion"
"reflect"
"strings"
@@ -11,7 +12,6 @@ import (
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/pkg/steampipeconfig/modconfig"
"github.com/turbot/steampipe/pkg/steampipeconfig/modconfig/var_config"
"github.com/turbot/steampipe/pkg/type_conversion"
"github.com/turbot/steampipe/pkg/utils"
)
@@ -300,8 +300,8 @@ func decodeParam(block *hcl.Block, parseCtx *ModParseContext, parentName string)
diags = append(diags, moreDiags...)
if !moreDiags.HasErrors() {
// convert the raw default into a postgres representation
if valStr, err := type_conversion.CtyToPostgresString(v); err == nil {
// convert the raw default into a string representation
if valStr, err := type_conversion.CtyToJSON(v); err == nil {
def.Default = utils.ToStringPointer(valStr)
} else {
diags = append(diags, &hcl.Diagnostic{
@@ -402,6 +402,10 @@ func decodeQueryProviderBlocks(block *hcl.Block, content *hcl.BodyContent, resou
with, withRes := decodeQueryProvider(block, parseCtx)
res.Merge(withRes)
queryProvider.AddWith(with.(*modconfig.DashboardWith))
// TACTICAL
// populate metadata for with block
handleModDecodeResult(with, withRes, block, parseCtx)
}
}

View File

@@ -78,8 +78,8 @@ func ctyTupleToArgArray(attr *hcl.Attribute, val cty.Value) ([]*string, []*modco
runtimeDependencies = append(runtimeDependencies, runtimeDependency)
} else {
// decode the value into a postgres compatible
valStr, err := type_conversion.CtyToPostgresString(v)
// decode the value into a json representation
valStr, err := type_conversion.CtyToJSON(v)
if err != nil {
err := fmt.Errorf("invalid value provided for arg #%d: %v", idx, err)
return nil, nil, err
@@ -112,8 +112,8 @@ func ctyObjectToArgMap(attr *hcl.Attribute, val cty.Value, evalCtx *hcl.EvalCont
}
runtimeDependencies = append(runtimeDependencies, runtimeDependency)
} else {
// decode the value into a postgres compatible
valStr, err := type_conversion.CtyToPostgresString(v)
// decode the value into a json representation
valStr, err := type_conversion.CtyToJSON(v)
if err != nil {
err := fmt.Errorf("invalid value provided for param '%s': %v", key, err)
return nil, nil, err

View File

@@ -137,7 +137,7 @@ func parseArg(v string) (string, error) {
if diags.HasErrors() {
return "", plugin.DiagsToError("bad arg syntax", diags)
}
return type_conversion.CtyToPostgresString(val)
return type_conversion.CtyToJSON(val)
}
func parseNamedArgs(argsList []string) (map[string]string, error) {

View File

@@ -12,6 +12,7 @@ order by
EOQ
}
dashboard "dashboard_named_args" {
title = "dashboard with named arguments"
@@ -22,9 +23,10 @@ dashboard "dashboard_named_args" {
}
table {
sql = "select $1"
sql = "select * from aws_account where arn in ($1)"
with "w1" {
sql = "select * from aws_account"
}
args = {
"with_val" = with.w1.rows[*].arn

View File

@@ -1,5 +1,19 @@
query "q1"{
title ="Q1"
description = "THIS IS QUERY 1"
sql = "select 1"
control "query_params_with_defaults_and_partial_named_args" {
title = "Control to test query param functionality with defaults(and some named args passed in query)"
query = query.query_params_with_no_defaults
args = {
"p1" = "command_parameter_1"
}
}
query "query_params_with_no_defaults"{
description = "query 1 - 3 params with no defaults"
sql = "select $1::text[]"
param "p1"{
description = "First parameter"
default = ["c","d"]
}
}

View File

@@ -10,12 +10,12 @@ import (
// EnsureSessionData determines whether session scoped data (introspection tables and prepared statements)
// exists for this session, and if not, creates it
func EnsureSessionData(ctx context.Context, source *SessionDataSource, conn *pgx.Conn, combineSql bool) (error, *db_common.PrepareStatementFailures) {
func EnsureSessionData(ctx context.Context, source *SessionDataSource, conn *pgx.Conn) (error) {
utils.LogTime("workspace.EnsureSessionData start")
defer utils.LogTime("workspace.EnsureSessionData end")
if conn == nil {
return errors.New("nil conn passed to EnsureSessionData"), nil
return errors.New("nil conn passed to EnsureSessionData")
}
// check for introspection tables
@@ -25,19 +25,14 @@ func EnsureSessionData(ctx context.Context, source *SessionDataSource, conn *pgx
var count int
err := row.Scan(&count)
if err != nil {
return err, nil
return err
}
var preparedStatementFailures *db_common.PrepareStatementFailures
if count == 0 {
err, preparedStatementFailures = db_common.CreatePreparedStatements(ctx, source.PreparedStatementSource(), conn, combineSql)
if err != nil {
return err, preparedStatementFailures
}
err = db_common.CreateIntrospectionTables(ctx, source.IntrospectionTableSource(), conn)
if err != nil {
return err, preparedStatementFailures
return err
}
}
return nil, preparedStatementFailures
return nil
}

View File

@@ -59,9 +59,6 @@ type Workspace struct {
dashboardEventChan chan dashboardevents.DashboardEvent
// count of workspace changed events - used to ignore first event
changeEventCount int
// avoid concurrent map access when multiple db connections may try to access preparedStatementFailures
preparedStatementFailureLock sync.Mutex
preparedStatementFailures map[string]*steampipeconfig.PreparedStatementFailure
}
// Load creates a Workspace and loads the workspace mod
@@ -201,43 +198,6 @@ func (w *Workspace) ModfileExists() bool {
return len(w.modFilePath) > 0
}
func (w *Workspace) HandlePreparedStatementFailures(failures *db_common.PrepareStatementFailures) {
if failures == nil {
return
}
// avoid concurrent map access when multiple db connections may try to access preparedStatementFailures
w.preparedStatementFailureLock.Lock()
defer w.preparedStatementFailureLock.Unlock()
// replace the map of failures with the current map
w.preparedStatementFailures = make(map[string]*steampipeconfig.PreparedStatementFailure)
for queryName, err := range failures.Failures {
if query, ok := w.GetQueryProvider(queryName); ok {
w.preparedStatementFailures[queryName] = &steampipeconfig.PreparedStatementFailure{
Query: query,
Error: err,
}
}
}
if failures.Error != nil {
w.preparedStatementFailures["preparedStatementGlobalError"] = &steampipeconfig.PreparedStatementFailure{
Error: failures.Error,
}
}
}
// GetPreparedStatementCreationFailure looks for a prepared statement error for the given query and if found,
// returns the query and the prepared statement creation error (if any)
func (w *Workspace) GetPreparedStatementCreationFailure(queryName string) *steampipeconfig.PreparedStatementFailure {
return w.preparedStatementFailures[queryName]
}
func (w *Workspace) GetPreparedStatementFailures() map[string]*steampipeconfig.PreparedStatementFailure {
return w.preparedStatementFailures
}
// check whether the workspace contains a modfile
// this will determine whether we load files recursively, and create pseudo resources for sql files
func (w *Workspace) setModfileExists() {

View File

@@ -18,26 +18,26 @@ import (
// GetQueriesFromArgs retrieves queries from args
//
// For each arg check if it is a named query or a file, before falling back to treating it as sql
func (w *Workspace) GetQueriesFromArgs(args []string) (map[string]string, *modconfig.ResourceMaps, error) {
func (w *Workspace) GetQueriesFromArgs(args []string) (map[string]*modconfig.ResolvedQuery, *modconfig.ResourceMaps, error) {
utils.LogTime("execute.GetQueriesFromArgs start")
defer utils.LogTime("execute.GetQueriesFromArgs end")
var queries = make(map[string]string)
var queries = make(map[string]*modconfig.ResolvedQuery)
var queryProviders []modconfig.QueryProvider
// build map of just the required prepared statement providers
for _, arg := range args {
query, queryProvider, err := w.ResolveQueryAndArgsFromSQLString(arg)
resolvedQuery, queryProvider, err := w.ResolveQueryAndArgsFromSQLString(arg)
if err != nil {
return nil, nil, err
}
if len(query) > 0 {
if len(resolvedQuery.ExecuteSQL) > 0 {
// default name to the query text
queryName := query
queryName := resolvedQuery.ExecuteSQL
if queryProvider != nil {
queryName = queryProvider.Name()
queryProviders = append(queryProviders, queryProvider)
}
queries[queryName] = query
queries[queryName] = resolvedQuery
}
}
@@ -49,7 +49,7 @@ func (w *Workspace) GetQueriesFromArgs(args []string) (map[string]string, *modco
}
// ResolveQueryAndArgsFromSQLString attempts to resolve 'arg' to a query and query args
func (w *Workspace) ResolveQueryAndArgsFromSQLString(sqlString string) (string, modconfig.QueryProvider, error) {
func (w *Workspace) ResolveQueryAndArgsFromSQLString(sqlString string) (*modconfig.ResolvedQuery, modconfig.QueryProvider, error) {
var args = &modconfig.QueryArgs{}
var err error
@@ -58,7 +58,7 @@ func (w *Workspace) ResolveQueryAndArgsFromSQLString(sqlString string) (string,
// if this looks like a named query provider invocation, parse the sql string for arguments
resource, args, err := w.extractQueryProviderFromQueryString(sqlString)
if err != nil {
return "", nil, err
return nil, nil, err
}
if resource != nil {
@@ -67,32 +67,32 @@ func (w *Workspace) ResolveQueryAndArgsFromSQLString(sqlString string) (string,
// resolve the query for the query provider and return it
resolvedQuery, err := w.ResolveQueryFromQueryProvider(resource, args)
if err != nil {
return "", nil, err
return nil, nil, err
}
log.Printf("[TRACE] resolved query: %s", sqlString)
return resolvedQuery.ExecuteSQL, resource, nil
return resolvedQuery, resource, nil
}
// 2) is this a file
fileQuery, fileExists, err := w.getQueryFromFile(sqlString)
if fileExists {
if err != nil {
return "", nil, fmt.Errorf("ResolveQueryAndArgsFromSQLString failed: error opening file '%s': %v", sqlString, err)
return nil, nil, fmt.Errorf("ResolveQueryAndArgsFromSQLString failed: error opening file '%s': %v", sqlString, err)
}
if len(fileQuery) == 0 {
if fileQuery == nil {
error_helpers.ShowWarning(fmt.Sprintf("file '%s' does not contain any data", sqlString))
// (just return the empty string - it will be filtered above)
// (just return the empty query - it will be filtered above)
}
return fileQuery, nil, nil
}
// 3) so we have not managed to resolve this - if it looks like a named query or control, return an error
if name, isResource := queryLooksLikeExecutableResource(sqlString); isResource {
return "", nil, fmt.Errorf("'%s' not found in %s (%s)", name, w.Mod.Name(), w.Path)
return nil, nil, fmt.Errorf("'%s' not found in %s (%s)", name, w.Mod.Name(), w.Path)
}
// 4) just use the query string as is and assume it is valid SQL
return sqlString, nil, nil
return &modconfig.ResolvedQuery{RawSQL: sqlString, ExecuteSQL: sqlString}, nil, nil
}
// ResolveQueryFromQueryProvider resolves the query for the given QueryProvider
@@ -108,9 +108,6 @@ func (w *Workspace) ResolveQueryFromQueryProvider(queryProvider modconfig.QueryP
query := queryProvider.GetQuery()
sql := queryProvider.GetSQL()
if query == nil && sql == nil {
return nil, fmt.Errorf("%s does not define either a 'sql' property or a 'query' property\n", queryProvider.Name())
}
params := queryProvider.GetParams()
// merge the base args with the runtime args
@@ -128,54 +125,52 @@ func (w *Workspace) ResolveQueryFromQueryProvider(queryProvider modconfig.QueryP
return w.ResolveQueryFromQueryProvider(query, runtimeArgs)
}
// if the control has SQL set, use that
if sql != nil {
queryProviderSQL := typehelpers.SafeString(sql)
log.Printf("[TRACE] control defines inline SQL")
// if the SQL refers to a named query, this is the same as if the 'Query' property is set
if namedQueryProvider, ok := w.GetQueryProvider(queryProviderSQL); ok {
// in this case, it is NOT valid for the query provider to define its own Param definitions
if params != nil {
return nil, fmt.Errorf("%s has an 'SQL' property which refers to %s, so it cannot define 'param' blocks", queryProvider.Name(), namedQueryProvider.Name())
}
return w.ResolveQueryFromQueryProvider(namedQueryProvider, runtimeArgs)
}
// so the sql is NOT a named query
// determine whether there are any params - there may either be param defs, OR positional args
// if there are NO params OR list args, use the control SQL as is
if !modconfig.QueryProviderIsParameterised(queryProvider) {
return &modconfig.ResolvedQuery{ExecuteSQL: queryProviderSQL, RawSQL: queryProviderSQL}, nil
}
// must have sql is there is no query
if sql == nil {
return nil, fmt.Errorf("%s does not define either a 'sql' property or a 'query' property\n", queryProvider.Name())
}
// so the control defines SQL and has params - it is a prepared statement
return queryProvider.GetPreparedStatementExecuteSQL(runtimeArgs)
queryProviderSQL := typehelpers.SafeString(sql)
log.Printf("[TRACE] control defines inline SQL")
// if the SQL refers to a named query, this is the same as if the 'Query' property is set
if namedQueryProvider, ok := w.GetQueryProvider(queryProviderSQL); ok {
// in this case, it is NOT valid for the query provider to define its own Param definitions
if params != nil {
return nil, fmt.Errorf("%s has an 'SQL' property which refers to %s, so it cannot define 'param' blocks", queryProvider.Name(), namedQueryProvider.Name())
}
return w.ResolveQueryFromQueryProvider(namedQueryProvider, runtimeArgs)
}
// so the sql is NOT a named query
return queryProvider.GetResolvedQuery(runtimeArgs)
}
// try to treat the input string as a file name and if it exists, return its contents
func (w *Workspace) getQueryFromFile(input string) (string, bool, error) {
func (w *Workspace) getQueryFromFile(input string) (*modconfig.ResolvedQuery, bool, error) {
// get absolute filename
path, err := filepath.Abs(input)
if err != nil {
return "", false, nil
return nil, false, nil
}
// does it exist?
if _, err := os.Stat(path); err != nil {
// if this gives any error, return not exist. we may get a not found or a path too long for example
return "", false, nil
return nil, false, nil
}
// read file
fileBytes, err := os.ReadFile(path)
if err != nil {
return "", true, err
return nil, true, err
}
return string(fileBytes), true, nil
res := &modconfig.ResolvedQuery{
RawSQL: string(fileBytes),
ExecuteSQL: string(fileBytes),
}
return res, true, nil
}
// does the input look like a resource which can be executed as a query