diff --git a/cmd/check.go b/cmd/check.go index e48a929c2..7f0046267 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -143,7 +143,7 @@ func runCheckCmd(cmd *cobra.Command, args []string) { } // create the execution tree - executionTree, err := controlexecute.NewExecutionTree(ctx, w, client, targetName) + executionTree, err := controlexecute.NewExecutionTree(ctx, w, client, targetName, initData.ControlFilterWhereClause) error_helpers.FailOnError(err) // execute controls synchronously (execute returns the number of failures) @@ -191,7 +191,13 @@ func validateCheckArgs(ctx context.Context, cmd *cobra.Command, args []string) b // only 1 of 'share' and 'snapshot' may be set if viper.GetBool(constants.ArgShare) && viper.GetBool(constants.ArgSnapshot) { - error_helpers.ShowError(ctx, fmt.Errorf("only 1 of 'share' and 'snapshot' may be set")) + error_helpers.ShowError(ctx, fmt.Errorf("only 1 of '--%s' and '--%s' may be set", constants.ArgShare, constants.ArgSnapshot)) + return false + } + + // if both '--where' and '--tag' have been used, then it's an error + if viper.IsSet(constants.ArgWhere) && viper.IsSet(constants.ArgTag) { + error_helpers.ShowError(ctx, fmt.Errorf("only 1 of '--%s' and '--%s' may be set", constants.ArgWhere, constants.ArgWhere)) return false } diff --git a/pkg/control/controldisplay/snapshot.go b/pkg/control/controldisplay/snapshot.go index 13a7df588..17723824c 100644 --- a/pkg/control/controldisplay/snapshot.go +++ b/pkg/control/controldisplay/snapshot.go @@ -79,4 +79,5 @@ func PublishSnapshot(ctx context.Context, e *controlexecute.ExecutionTree, shoul fmt.Println(message) } return nil + } diff --git a/pkg/control/controlexecute/execution_tree.go b/pkg/control/controlexecute/execution_tree.go index 5f4924874..16d662b51 100644 --- a/pkg/control/controlexecute/execution_tree.go +++ b/pkg/control/controlexecute/execution_tree.go @@ -4,9 +4,7 @@ import ( "context" "fmt" "log" - "net/url" "sort" - "strings" "time" "github.com/spf13/viper" @@ -39,8 +37,7 @@ type ExecutionTree struct { controlNameFilterMap map[string]bool } -func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, client db_common.Client, arg string) (*ExecutionTree, error) { - // TODO [reports] FAIL IF any resources in the tree have runtime dependencies +func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, client db_common.Client, arg, controlFilterWhereClause string) (*ExecutionTree, error) { // now populate the ExecutionTree executionTree := &ExecutionTree{ Workspace: workspace, @@ -50,7 +47,7 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien // if a "--where" or "--tag" parameter was passed, build a map of control names used to filter the controls to run // create a context with status hooks disabled noStatusCtx := statushooks.DisableStatusHooks(ctx) - err := executionTree.populateControlFilterMap(noStatusCtx) + err := executionTree.populateControlFilterMap(noStatusCtx, controlFilterWhereClause) if err != nil { return nil, err } @@ -138,29 +135,9 @@ func (e *ExecutionTree) waitForActiveRunsToComplete(ctx context.Context, paralle return parallelismLock.Acquire(waitCtx, maxParallelGoRoutines) } -func (e *ExecutionTree) populateControlFilterMap(ctx context.Context) error { - // if both '--where' and '--tag' have been used, then it's an error - if viper.IsSet(constants.ArgWhere) && viper.IsSet(constants.ArgTag) { - return fmt.Errorf("'--%s' and '--%s' cannot be used together", constants.ArgWhere, constants.ArgTag) - } - - controlFilterWhereClause := "" - - if viper.IsSet(constants.ArgTag) { - // if '--tag' args were used, derive the whereClause from them - tags := viper.GetStringSlice(constants.ArgTag) - controlFilterWhereClause = e.generateWhereClauseFromTags(tags) - } else if viper.IsSet(constants.ArgWhere) { - // if a 'where' arg was used, execute this sql to get a list of control names - // use this list to build a name map used to determine whether to run a particular control - controlFilterWhereClause = viper.GetString(constants.ArgWhere) - } - +func (e *ExecutionTree) populateControlFilterMap(ctx context.Context, controlFilterWhereClause string) error { // if we derived or were passed a where clause, run the filter if len(controlFilterWhereClause) > 0 { - // if we have a control filter where clause, we must create the control introspection tables - viper.Set(constants.ArgIntrospection, constants.IntrospectionControl) - log.Println("[TRACE]", "filtering controls with", controlFilterWhereClause) var err error e.controlNameFilterMap, err = e.getControlMapFromWhereClause(ctx, controlFilterWhereClause) @@ -172,35 +149,6 @@ func (e *ExecutionTree) populateControlFilterMap(ctx context.Context) error { return nil } -func (e *ExecutionTree) generateWhereClauseFromTags(tags []string) string { - whereMap := map[string][]string{} - - // 'tags' should be KV Pairs of the form: 'benchmark=pic' or 'cis_level=1' - for _, tag := range tags { - value, _ := url.ParseQuery(tag) - for k, v := range value { - if _, found := whereMap[k]; !found { - whereMap[k] = []string{} - } - whereMap[k] = append(whereMap[k], v...) - } - } - whereComponents := []string{} - for key, values := range whereMap { - thisComponent := []string{} - for _, x := range values { - if len(x) == 0 { - // ignore - continue - } - thisComponent = append(thisComponent, fmt.Sprintf("tags->>'%s'='%s'", key, x)) - } - whereComponents = append(whereComponents, fmt.Sprintf("(%s)", strings.Join(thisComponent, " OR "))) - } - - return strings.Join(whereComponents, " AND ") -} - func (e *ExecutionTree) ShouldIncludeControl(controlName string) bool { if e.controlNameFilterMap == nil { return true diff --git a/pkg/control/init_data.go b/pkg/control/init_data.go index c77fb60cd..dd3120410 100644 --- a/pkg/control/init_data.go +++ b/pkg/control/init_data.go @@ -6,6 +6,8 @@ import ( "github.com/turbot/steampipe/pkg/control/controldisplay" "github.com/turbot/steampipe/pkg/error_helpers" "github.com/turbot/steampipe/pkg/statushooks" + "net/url" + "strings" "github.com/spf13/viper" "github.com/turbot/steampipe/pkg/constants" @@ -15,7 +17,8 @@ import ( type InitData struct { initialisation.InitData - OutputFormatter controldisplay.Formatter + OutputFormatter controldisplay.Formatter + ControlFilterWhereClause string } // NewInitData returns a new InitData object @@ -82,16 +85,64 @@ func NewInitData(ctx context.Context) *InitData { } i.OutputFormatter = formatter + i.setControlFilterClause() return i } +func (i *InitData) setControlFilterClause() { + if viper.IsSet(constants.ArgTag) { + // if '--tag' args were used, derive the whereClause from them + tags := viper.GetStringSlice(constants.ArgTag) + i.ControlFilterWhereClause = generateWhereClauseFromTags(tags) + } else if viper.IsSet(constants.ArgWhere) { + // if a 'where' arg was used, execute this sql to get a list of control names + // use this list to build a name map used to determine whether to run a particular control + i.ControlFilterWhereClause = viper.GetString(constants.ArgWhere) + } + + // if we derived or were passed a where clause, run the filter + if len(i.ControlFilterWhereClause) > 0 { + // if we have a control filter where clause, we must create the control introspection tables + viper.Set(constants.ArgIntrospection, constants.IntrospectionControl) + } +} + +func generateWhereClauseFromTags(tags []string) string { + whereMap := map[string][]string{} + + // 'tags' should be KV Pairs of the form: 'benchmark=pic' or 'cis_level=1' + for _, tag := range tags { + value, _ := url.ParseQuery(tag) + for k, v := range value { + if _, found := whereMap[k]; !found { + whereMap[k] = []string{} + } + whereMap[k] = append(whereMap[k], v...) + } + } + whereComponents := []string{} + for key, values := range whereMap { + thisComponent := []string{} + for _, x := range values { + if len(x) == 0 { + // ignore + continue + } + thisComponent = append(thisComponent, fmt.Sprintf("tags->>'%s'='%s'", key, x)) + } + whereComponents = append(whereComponents, fmt.Sprintf("(%s)", strings.Join(thisComponent, " OR "))) + } + + return strings.Join(whereComponents, " AND ") +} + // register exporters for each of the supported check formats -func (initData *InitData) registerCheckExporters() { +func (i *InitData) registerCheckExporters() { exporters, err := controldisplay.GetExporters() error_helpers.FailOnErrorWithMessage(err, "failed to load exporters") // register all exporters - initData.RegisterExporters(exporters...) + i.RegisterExporters(exporters...) } // parseOutputArg parses the --output flag value and returns the Formatter that can format the data diff --git a/pkg/dashboard/dashboardexecute/check_run.go b/pkg/dashboard/dashboardexecute/check_run.go index 534363643..86698f734 100644 --- a/pkg/dashboard/dashboardexecute/check_run.go +++ b/pkg/dashboard/dashboardexecute/check_run.go @@ -95,7 +95,7 @@ func NewCheckRun(resource modconfig.DashboardLeafNode, parent dashboardtypes.Das // Initialise implements DashboardRunNode func (r *CheckRun) Initialise(ctx context.Context) { // build control execution tree during init, rather than in Execute, so that it is populated when the ExecutionStarted event is sent - executionTree, err := controlexecute.NewExecutionTree(ctx, r.executionTree.workspace, r.executionTree.client, r.DashboardNode.Name()) + executionTree, err := controlexecute.NewExecutionTree(ctx, r.executionTree.workspace, r.executionTree.client, r.DashboardNode.Name(), "") if err != nil { // set the error status on the counter - this will raise counter error event r.SetError(ctx, err)