Decouple spinner display code from database and execution layer. Closes #1290

This commit is contained in:
kaidaguerre
2022-01-06 11:54:18 +00:00
committed by GitHub
parent 079e0fc584
commit efbebd99ee
51 changed files with 679 additions and 603 deletions

View File

@@ -10,12 +10,12 @@ import (
"sync"
"time"
"github.com/briandowns/spinner"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/cmdconfig"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/contexthelpers"
"github.com/turbot/steampipe/control"
"github.com/turbot/steampipe/control/controldisplay"
"github.com/turbot/steampipe/control/controlexecute"
@@ -24,6 +24,7 @@ import (
"github.com/turbot/steampipe/db/db_local"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/modinstaller"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/utils"
"github.com/turbot/steampipe/workspace"
)
@@ -88,16 +89,21 @@ You may specify one or more benchmarks or controls to run (separated by a space)
func runCheckCmd(cmd *cobra.Command, args []string) {
utils.LogTime("runCheckCmd start")
initData := &control.InitData{}
// setup a cancel context and start cancel handler
ctx, cancel := context.WithCancel(cmd.Context())
contexthelpers.StartCancelHandler(cancel)
defer func() {
utils.LogTime("runCheckCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
if initData.Client != nil {
log.Printf("[TRACE] close client")
initData.Client.Close()
initData.Client.Close(ctx)
}
if initData.Workspace != nil {
initData.Workspace.Close()
@@ -105,24 +111,22 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
}()
// verify we have an argument
if !validateArgs(cmd, args) {
if !validateArgs(ctx, cmd, args) {
return
}
var spinner *spinner.Spinner
if viper.GetBool(constants.ArgProgress) {
spinner = display.ShowSpinner("Initializing...")
// if progress is disabled, update context to contain a null status hooks object
if !viper.GetBool(constants.ArgProgress) {
statushooks.DisableStatusHooks(ctx)
}
// initialise
initData = initialiseCheck(cmd.Context(), spinner)
display.StopSpinner(spinner)
if shouldExit := handleCheckInitResult(initData); shouldExit {
initData = initialiseCheck(ctx)
if shouldExit := handleCheckInitResult(ctx, initData); shouldExit {
return
}
// pull out useful properties
ctx := initData.Ctx
workspace := initData.Workspace
client := initData.Client
failures := 0
@@ -165,7 +169,7 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
exportWaitGroup.Wait()
if len(exportErrors) > 0 {
utils.ShowError(utils.CombineErrors(exportErrors...))
utils.ShowError(ctx, utils.CombineErrors(exportErrors...))
}
if shouldPrintTiming() {
@@ -176,10 +180,10 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
exitCode = failures
}
func validateArgs(cmd *cobra.Command, args []string) bool {
func validateArgs(ctx context.Context, cmd *cobra.Command, args []string) bool {
if len(args) == 0 {
fmt.Println()
utils.ShowError(fmt.Errorf("you must provide at least one argument"))
utils.ShowError(ctx, fmt.Errorf("you must provide at least one argument"))
fmt.Println()
cmd.Help()
fmt.Println()
@@ -189,11 +193,13 @@ func validateArgs(cmd *cobra.Command, args []string) bool {
return true
}
func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.InitData {
func initialiseCheck(ctx context.Context) *control.InitData {
statushooks.SetStatus(ctx, "Initializing...")
defer statushooks.Done(ctx)
initData := &control.InitData{
Result: &db_common.InitResult{},
}
if viper.GetBool(constants.ArgModInstall) {
opts := &modinstaller.InstallOpts{WorkspacePath: viper.GetString(constants.ArgWorkspaceChDir)}
_, err := modinstaller.InstallWorkspaceDependencies(opts)
@@ -202,9 +208,6 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
return initData
}
}
cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, false)
err := validateOutputFormat()
if err != nil {
initData.Result.Error = err
@@ -217,10 +220,6 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
return initData
}
ctx, cancel := context.WithCancel(ctx)
startCancelHandler(cancel)
initData.Ctx = ctx
// set color schema
err = initialiseColorScheme()
if err != nil {
@@ -228,7 +227,7 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
return initData
}
// load workspace
initData.Workspace, err = loadWorkspacePromptingForVariables(ctx, spinner)
initData.Workspace, err = loadWorkspacePromptingForVariables(ctx)
if err != nil {
if !utils.IsCancelledError(err) {
err = utils.PrefixError(err, "failed to load workspace")
@@ -248,18 +247,14 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
initData.Result.AddWarnings("no controls found in current workspace")
}
display.UpdateSpinnerMessage(spinner, "Connecting to service...")
statushooks.SetStatus(ctx, "Connecting to service...")
// get a client
var client db_common.Client
if connectionString := viper.GetString(constants.ArgConnectionString); connectionString != "" {
client, err = db_client.NewDbClient(ctx, connectionString)
} else {
// stop the spinner
display.StopSpinner(spinner)
// when starting the database, installers may trigger their own spinners
client, err = db_local.GetLocalClient(ctx, constants.InvokerCheck)
// resume the spinner
display.ResumeSpinner(spinner)
}
if err != nil {
@@ -288,13 +283,13 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
return initData
}
func handleCheckInitResult(initData *control.InitData) bool {
func handleCheckInitResult(ctx context.Context, initData *control.InitData) bool {
// if there is an error or cancellation we bomb out
// check for the various kinds of failures
utils.FailOnError(initData.Result.Error)
// cancelled?
if initData.Ctx != nil {
utils.FailOnError(initData.Ctx.Err())
if ctx != nil {
utils.FailOnError(ctx.Err())
}
// if there is a usage warning we display it

View File

@@ -53,11 +53,12 @@ func modInstallCmd() *cobra.Command {
}
func runModInstallCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("cmd.runModInstallCmd")
defer func() {
utils.LogTime("cmd.runModInstallCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
@@ -88,17 +89,18 @@ func modUninstallCmd() *cobra.Command {
}
func runModUninstallCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("cmd.runModInstallCmd")
defer func() {
utils.LogTime("cmd.runModInstallCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
opts := newInstallOpts(cmd, args...)
installData, err := modinstaller.UninstallWorkspaceDependencies(opts)
installData, err := modinstaller.UninstallWorkspaceDependencies(ctx, opts)
utils.FailOnError(err)
fmt.Println(modinstaller.BuildUninstallSummary(installData))
@@ -122,11 +124,12 @@ func modUpdateCmd() *cobra.Command {
}
func runModUpdateCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("cmd.runModUpdateCmd")
defer func() {
utils.LogTime("cmd.runModUpdateCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
@@ -153,11 +156,12 @@ func modListCmd() *cobra.Command {
}
func runModListCmd(cmd *cobra.Command, _ []string) {
ctx := cmd.Context()
utils.LogTime("cmd.runModListCmd")
defer func() {
utils.LogTime("cmd.runModListCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
@@ -186,11 +190,12 @@ func modInitCmd() *cobra.Command {
}
func runModInitCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("cmd.runModInitCmd")
defer func() {
utils.LogTime("cmd.runModInitCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()

View File

@@ -16,6 +16,7 @@ import (
"github.com/turbot/steampipe/ociinstaller/versionfile"
"github.com/turbot/steampipe/plugin"
"github.com/turbot/steampipe/statefile"
"github.com/turbot/steampipe/statusspinner"
"github.com/turbot/steampipe/steampipeconfig"
"github.com/turbot/steampipe/steampipeconfig/modconfig"
"github.com/turbot/steampipe/utils"
@@ -176,11 +177,12 @@ Example:
// exitCode=3 For errors related to loading state, loading version data or an issue contacting the update server.
// exitCode=4 For plugin listing failures
func runPluginInstallCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runPluginInstallCmd install")
defer func() {
utils.LogTime("runPluginInstallCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
@@ -193,7 +195,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
if len(plugins) == 0 {
fmt.Println()
utils.ShowError(fmt.Errorf("you need to provide at least one plugin to install"))
utils.ShowError(ctx, fmt.Errorf("you need to provide at least one plugin to install"))
fmt.Println()
cmd.Help()
fmt.Println()
@@ -204,7 +206,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
// a leading blank line - since we always output multiple lines
fmt.Println()
spinner := display.ShowSpinner("")
statusSpinner := statusspinner.NewStatusSpinner()
for _, p := range plugins {
isPluginExists, _ := plugin.Exists(p)
@@ -217,7 +219,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
})
continue
}
display.UpdateSpinnerMessage(spinner, fmt.Sprintf("Installing plugin: %s", p))
statusSpinner.SetStatus(fmt.Sprintf("Installing plugin: %s", p))
image, err := plugin.Install(cmd.Context(), p)
if err != nil {
msg := ""
@@ -250,9 +252,9 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
})
}
display.StopSpinner(spinner)
statusSpinner.Done()
refreshConnectionsIfNecessary(cmd.Context(), installReports, false)
refreshConnectionsIfNecessary(cmd.Context(), installReports, true)
display.PrintInstallReports(installReports, false)
// a concluding blank line - since we always output multiple lines
@@ -260,11 +262,12 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
}
func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runPluginUpdateCmd install")
defer func() {
utils.LogTime("runPluginUpdateCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
@@ -275,7 +278,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
plugins, err := resolveUpdatePluginsFromArgs(args)
if err != nil {
fmt.Println()
utils.ShowError(err)
utils.ShowError(ctx, err)
fmt.Println()
cmd.Help()
fmt.Println()
@@ -285,7 +288,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
state, err := statefile.LoadState()
if err != nil {
utils.ShowError(fmt.Errorf("could not load state"))
utils.ShowError(ctx, fmt.Errorf("could not load state"))
exitCode = 3
return
}
@@ -293,7 +296,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
// load up the version file data
versionData, err := versionfile.LoadPluginVersionFile()
if err != nil {
utils.ShowError(fmt.Errorf("error loading current plugin data"))
utils.ShowError(ctx, fmt.Errorf("error loading current plugin data"))
exitCode = 3
return
}
@@ -340,14 +343,14 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
return
}
spinner := display.ShowSpinner("Checking for available updates")
statusSpinner := statusspinner.NewStatusSpinner(statusspinner.WithMessage("Checking for available updates"))
reports := plugin.GetUpdateReport(state.InstallationID, runUpdatesFor)
display.StopSpinner(spinner)
statusSpinner.Done()
if len(reports) == 0 {
// this happens if for some reason the update server could not be contacted,
// in which case we get back an empty map
utils.ShowError(fmt.Errorf("there was an issue contacting the update server. Please try later"))
utils.ShowError(ctx, fmt.Errorf("there was an issue contacting the update server. Please try later"))
exitCode = 3
return
}
@@ -363,9 +366,9 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
continue
}
spinner := display.ShowSpinner(fmt.Sprintf("Updating plugin %s...", report.CheckResponse.Name))
statusSpinner.SetStatus(fmt.Sprintf("Updating plugin %s...", report.CheckResponse.Name))
image, err := plugin.Install(cmd.Context(), report.Plugin.Name)
display.StopSpinner(spinner)
statusSpinner.Done()
if err != nil {
msg := ""
if strings.HasSuffix(err.Error(), "not found") {
@@ -398,7 +401,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
})
}
refreshConnectionsIfNecessary(cmd.Context(), updateReports, true)
refreshConnectionsIfNecessary(cmd.Context(), updateReports, false)
display.PrintInstallReports(updateReports, true)
// a concluding blank line - since we always output multiple lines
@@ -421,7 +424,7 @@ func resolveUpdatePluginsFromArgs(args []string) ([]string, error) {
}
// start service if necessary and refresh connections
func refreshConnectionsIfNecessary(ctx context.Context, reports []display.InstallReport, isUpdate bool) error {
func refreshConnectionsIfNecessary(ctx context.Context, reports []display.InstallReport, shouldReload bool) error {
// get count of skipped reports
skipped := 0
for _, report := range reports {
@@ -436,7 +439,7 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal
}
// reload the config, since an installation MUST have created a new config file
if !isUpdate {
if shouldReload {
var cmd = viper.Get(constants.ConfigKeyActiveCommand).(*cobra.Command)
config, err := steampipeconfig.LoadSteampipeConfig(viper.GetString(constants.ArgWorkspaceChDir), cmd.Name())
if err != nil {
@@ -449,7 +452,7 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal
if err != nil {
return err
}
defer client.Close()
defer client.Close(ctx)
res := client.RefreshConnectionAndSearchPaths(ctx)
if res.Error != nil {
return res.Error
@@ -460,24 +463,26 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal
}
func runPluginListCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runPluginListCmd list")
defer func() {
utils.LogTime("runPluginListCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
pluginConnectionMap, err := getPluginConnectionMap(cmd.Context())
if err != nil {
utils.ShowErrorWithMessage(err, "Plugin Listing failed")
utils.ShowErrorWithMessage(ctx, err, "Plugin Listing failed")
exitCode = 4
return
}
list, err := plugin.List(pluginConnectionMap)
if err != nil {
utils.ShowErrorWithMessage(err, "Plugin Listing failed")
utils.ShowErrorWithMessage(ctx, err, "Plugin Listing failed")
exitCode = 4
}
headers := []string{"Name", "Version", "Connections"}
@@ -489,35 +494,37 @@ func runPluginListCmd(cmd *cobra.Command, args []string) {
}
func runPluginUninstallCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runPluginUninstallCmd uninstall")
defer func() {
utils.LogTime("runPluginUninstallCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
exitCode = 1
}
}()
if len(args) == 0 {
fmt.Println()
utils.ShowError(fmt.Errorf("you need to provide at least one plugin to uninstall"))
utils.ShowError(ctx, fmt.Errorf("you need to provide at least one plugin to uninstall"))
fmt.Println()
cmd.Help()
fmt.Println()
exitCode = 2
return
}
connectionMap, err := getPluginConnectionMap(cmd.Context())
connectionMap, err := getPluginConnectionMap(ctx)
if err != nil {
utils.ShowError(err)
utils.ShowError(ctx, err)
exitCode = 4
return
}
for _, p := range args {
if err := plugin.Remove(p, connectionMap); err != nil {
utils.ShowErrorWithMessage(err, fmt.Sprintf("Failed to uninstall plugin '%s'", p))
if err := plugin.Remove(ctx, p, connectionMap); err != nil {
utils.ShowErrorWithMessage(ctx, err, fmt.Sprintf("Failed to uninstall plugin '%s'", p))
}
}
}
@@ -528,7 +535,7 @@ func getPluginConnectionMap(ctx context.Context) (map[string][]modconfig.Connect
if err != nil {
return nil, err
}
defer client.Close()
defer client.Close(ctx)
res := client.RefreshConnectionAndSearchPaths(ctx)
if res.Error != nil {
return nil, res.Error

View File

@@ -33,6 +33,7 @@ func pluginManagerCmd() *cobra.Command {
}
func runPluginManagerCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
logger := createPluginManagerLog()
log.Printf("[INFO] starting plugin manager")
@@ -51,7 +52,7 @@ func runPluginManagerCmd(cmd *cobra.Command, args []string) {
connectionWatcher, err := connectionwatcher.NewConnectionWatcher(pluginManager.SetConnectionConfigMap)
if err != nil {
log.Printf("[WARN] failed to create connection watcher: %s", err.Error())
utils.ShowError(err)
utils.ShowError(ctx, err)
os.Exit(1)
}

View File

@@ -6,10 +6,8 @@ import (
"fmt"
"log"
"os"
"os/signal"
"strings"
"github.com/briandowns/spinner"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/turbot/go-kit/helpers"
@@ -18,6 +16,7 @@ import (
"github.com/turbot/steampipe/interactive"
"github.com/turbot/steampipe/query"
"github.com/turbot/steampipe/query/queryexecute"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/steampipeconfig/modconfig"
"github.com/turbot/steampipe/utils"
"github.com/turbot/steampipe/workspace"
@@ -80,11 +79,12 @@ Examples:
}
func runQueryCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("cmd.runQueryCmd start")
defer func() {
utils.LogTime("cmd.runQueryCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
}
}()
@@ -97,19 +97,11 @@ func runQueryCmd(cmd *cobra.Command, args []string) {
// enable spinner only in interactive mode
interactiveMode := len(args) == 0
cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, interactiveMode)
// set config to indicate whether we are running an interactive query
viper.Set(constants.ConfigKeyInteractive, interactiveMode)
ctx := cmd.Context()
if !interactiveMode {
c, cancel := context.WithCancel(ctx)
startCancelHandler(cancel)
ctx = c
}
// load the workspace
w, err := loadWorkspacePromptingForVariables(ctx, nil)
w, err := loadWorkspacePromptingForVariables(ctx)
utils.FailOnErrorWithMessage(err, "failed to load workspace")
// se we have loaded a workspace - be sure to close it
@@ -119,7 +111,7 @@ func runQueryCmd(cmd *cobra.Command, args []string) {
initData := query.NewInitData(ctx, w, args)
if interactiveMode {
queryexecute.RunInteractiveSession(initData)
queryexecute.RunInteractiveSession(ctx, initData)
} else {
// set global exit code
exitCode = queryexecute.RunBatchSession(ctx, initData)
@@ -144,10 +136,10 @@ func getPipedStdinData() string {
return stdinData
}
func loadWorkspacePromptingForVariables(ctx context.Context, spinner *spinner.Spinner) (*workspace.Workspace, error) {
func loadWorkspacePromptingForVariables(ctx context.Context) (*workspace.Workspace, error) {
workspacePath := viper.GetString(constants.ArgWorkspaceChDir)
w, err := workspace.Load(workspacePath)
w, err := workspace.Load(ctx, workspacePath)
if err == nil {
return w, nil
}
@@ -156,29 +148,13 @@ func loadWorkspacePromptingForVariables(ctx context.Context, spinner *spinner.Sp
if !ok {
return nil, err
}
if spinner != nil {
spinner.Stop()
}
// so we have missing variables - prompt for them
// first hide spinner if it is there
statushooks.Done(ctx)
if err := interactive.PromptForMissingVariables(ctx, missingVariablesError.MissingVariables); err != nil {
log.Printf("[TRACE] Interactive variables prompting returned error %v", err)
return nil, err
}
if spinner != nil {
spinner.Start()
}
// ok we should have all variables now - reload workspace
return workspace.Load(workspacePath)
}
func startCancelHandler(cancel context.CancelFunc) {
sigIntChannel := make(chan os.Signal, 1)
signal.Notify(sigIntChannel, os.Interrupt)
go func() {
<-sigIntChannel
log.Println("[TRACE] got SIGINT")
// call context cancellation function
cancel()
// leave the channel open - any subsequent interrupts hits will be ignored
}()
return workspace.Load(ctx, workspacePath)
}

View File

@@ -4,11 +4,11 @@ import (
"context"
"github.com/spf13/cobra"
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe-plugin-sdk/logging"
"github.com/turbot/steampipe/cmdconfig"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/contexthelpers"
"github.com/turbot/steampipe/report/reportserver"
"github.com/turbot/steampipe/utils"
)
@@ -29,18 +29,17 @@ func reportCmd() *cobra.Command {
}
func runReportCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
logging.LogTime("runReportCmd start")
defer func() {
logging.LogTime("runReportCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
}
}()
cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, false)
ctx, cancel := context.WithCancel(cmd.Context())
startCancelHandler(cancel)
contexthelpers.StartCancelHandler(cancel)
// start db if necessary
//err := db_local.EnsureDbAndStartService(constants.InvokerReport, true)
@@ -53,7 +52,7 @@ func runReportCmd(cmd *cobra.Command, args []string) {
utils.FailOnError(err)
}
defer server.Shutdown()
defer server.Shutdown(ctx)
server.Start()
}

View File

@@ -1,10 +1,14 @@
package cmd
import (
"context"
"fmt"
"log"
"os"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/statusspinner"
"github.com/hashicorp/go-hclog"
"github.com/mattn/go-isatty"
"github.com/spf13/cobra"
@@ -183,6 +187,19 @@ func Execute() int {
utils.LogTime("cmd.root.Execute start")
defer utils.LogTime("cmd.root.Execute end")
rootCmd.Execute()
ctx := createRootContext()
rootCmd.ExecuteContext(ctx)
return exitCode
}
// create the root context - create a status renderer and set as value
func createRootContext() context.Context {
var statusRenderer statushooks.StatusHooks = statushooks.NullHooks
// if the client is a TTY, inject a status spinner
if isatty.IsTerminal(os.Stdout.Fd()) {
statusRenderer = statusspinner.NewStatusSpinner()
}
ctx := statushooks.AddStatusHooksToContext(context.Background(), statusRenderer)
return ctx
}

View File

@@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"github.com/turbot/steampipe/statushooks"
"os"
"os/signal"
"strings"
@@ -127,11 +129,12 @@ func serviceRestartCmd() *cobra.Command {
}
func runServiceStartCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runServiceStartCmd start")
defer func() {
utils.LogTime("runServiceStartCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
if exitCode == 0 {
// there was an error and the exitcode
// was not set to a non-zero value.
@@ -163,7 +166,7 @@ func runServiceStartCmd(cmd *cobra.Command, args []string) {
utils.FailOnError(startResult.Error)
if startResult.Status == db_local.ServiceFailedToStart {
utils.ShowError(fmt.Errorf("steampipe service failed to start"))
utils.ShowError(ctx, fmt.Errorf("steampipe service failed to start"))
return
}
@@ -191,18 +194,18 @@ func runServiceStartCmd(cmd *cobra.Command, args []string) {
err = db_local.RefreshConnectionAndSearchPaths(ctx, invoker)
if err != nil {
db_local.StopServices(false, constants.InvokerService, nil)
db_local.StopServices(ctx, false, constants.InvokerService)
utils.FailOnError(err)
}
printStatus(startResult.DbState, startResult.PluginManagerState)
printStatus(ctx, startResult.DbState, startResult.PluginManagerState)
if viper.GetBool(constants.ArgForeground) {
runServiceInForeground(invoker)
runServiceInForeground(ctx, invoker)
}
}
func runServiceInForeground(invoker constants.Invoker) {
func runServiceInForeground(ctx context.Context, invoker constants.Invoker) {
fmt.Println("Hit Ctrl+C to stop the service")
sigIntChannel := make(chan os.Signal, 1)
@@ -232,7 +235,7 @@ func runServiceInForeground(invoker constants.Invoker) {
count, err := db_local.GetCountOfThirdPartyClients(context.Background())
if err != nil {
// report the error in the off chance that there's one
utils.ShowError(err)
utils.ShowError(ctx, err)
return
}
@@ -246,7 +249,7 @@ func runServiceInForeground(invoker constants.Invoker) {
}
fmt.Println("Stopping Steampipe service.")
db_local.StopServices(false, invoker, nil)
db_local.StopServices(ctx, false, invoker)
fmt.Println("Steampipe service stopped.")
return
}
@@ -254,11 +257,12 @@ func runServiceInForeground(invoker constants.Invoker) {
}
func runServiceRestartCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runServiceRestartCmd start")
defer func() {
utils.LogTime("runServiceRestartCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
if exitCode == 0 {
// there was an error and the exitcode
// was not set to a non-zero value.
@@ -277,7 +281,7 @@ func runServiceRestartCmd(cmd *cobra.Command, args []string) {
}
// stop db
stopStatus, err := db_local.StopServices(viper.GetBool(constants.ArgForce), constants.InvokerService, nil)
stopStatus, err := db_local.StopServices(ctx, viper.GetBool(constants.ArgForce), constants.InvokerService)
utils.FailOnErrorWithMessage(err, "could not stop current instance")
if stopStatus != db_local.ServiceStopped {
fmt.Println(`
@@ -307,16 +311,17 @@ to force a restart.
utils.FailOnError(err)
fmt.Println("Steampipe service restarted.")
printStatus(startResult.DbState, startResult.PluginManagerState)
printStatus(ctx, startResult.DbState, startResult.PluginManagerState)
}
func runServiceStatusCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runServiceStatusCmd status")
defer func() {
utils.LogTime("runServiceStatusCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
}
}()
@@ -331,10 +336,10 @@ func runServiceStatusCmd(cmd *cobra.Command, args []string) {
pmState, pmStateErr := pluginmanager.LoadPluginManagerState()
if dbStateErr != nil || pmStateErr != nil {
utils.ShowError(composeStateError(dbStateErr, pmStateErr))
utils.ShowError(ctx, composeStateError(dbStateErr, pmStateErr))
return
}
printStatus(dbState, pmState)
printStatus(ctx, dbState, pmState)
}
}
@@ -354,19 +359,17 @@ func composeStateError(dbStateErr error, pmStateErr error) error {
}
func runServiceStopCmd(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
utils.LogTime("runServiceStopCmd stop")
stoppedChan := make(chan bool, 1)
var status db_local.StopStatus
var err error
var dbState *db_local.RunningDBInstanceInfo
spinner := display.StartSpinnerAfterDelay("", constants.SpinnerShowTimeout, stoppedChan)
defer func() {
utils.LogTime("runServiceStopCmd end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
if exitCode == 0 {
// there was an error and the exitcode
// was not set to a non-zero value.
@@ -378,20 +381,17 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) {
force := cmdconfig.Viper().GetBool(constants.ArgForce)
if force {
status, err = db_local.StopServices(force, constants.InvokerService, spinner)
status, err = db_local.StopServices(ctx, force, constants.InvokerService)
} else {
dbState, err = db_local.GetState()
if err != nil {
display.StopSpinner(spinner)
utils.FailOnErrorWithMessage(err, "could not stop Steampipe service")
}
if dbState == nil {
display.StopSpinner(spinner)
fmt.Println("Steampipe service is not running.")
return
}
if dbState.Invoker != constants.InvokerService {
display.StopSpinner(spinner)
printRunningImplicit(dbState.Invoker)
return
}
@@ -399,27 +399,22 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) {
// check if there are any connected clients to the service
connectedClientCount, err := db_local.GetCountOfThirdPartyClients(cmd.Context())
if err != nil {
display.StopSpinner(spinner)
utils.FailOnErrorWithMessage(err, "error during service stop")
}
if connectedClientCount > 0 {
display.StopSpinner(spinner)
printClientsConnected()
return
}
status, _ = db_local.StopServices(false, constants.InvokerService, spinner)
status, _ = db_local.StopServices(ctx, false, constants.InvokerService)
}
if err != nil {
display.StopSpinner(spinner)
utils.ShowError(err)
utils.ShowError(ctx, err)
return
}
display.StopSpinner(spinner)
switch status {
case db_local.ServiceStopped:
if dbState != nil {
@@ -451,12 +446,9 @@ func showAllStatus(ctx context.Context) {
var processes []*psutils.Process
var err error
doneFetchingDetailsChan := make(chan bool)
sp := display.StartSpinnerAfterDelay("Getting details", constants.SpinnerShowTimeout, doneFetchingDetailsChan)
statushooks.SetStatus(ctx, "Getting details")
processes, err = db_local.FindAllSteampipePostgresInstances(ctx)
close(doneFetchingDetailsChan)
display.StopSpinner(sp)
statushooks.Done(ctx)
utils.FailOnError(err)
@@ -498,7 +490,7 @@ func getServiceProcessDetails(process *psutils.Process) (string, string, string,
return fmt.Sprintf("%d", process.Pid), installDir, port, listenType
}
func printStatus(dbState *db_local.RunningDBInstanceInfo, pmState *pluginmanager.PluginManagerState) {
func printStatus(ctx context.Context, dbState *db_local.RunningDBInstanceInfo, pmState *pluginmanager.PluginManagerState) {
if dbState == nil && !pmState.Running {
fmt.Println("Service is not running")
return
@@ -553,7 +545,7 @@ To keep the service running after the %s session completes, use %s.
// the service is running, but the plugin_manager is not running and there's no state file
// meaning that it cannot be restarted by the FDW
// it's an ERROR
utils.ShowError(fmt.Errorf(`
utils.ShowError(ctx, fmt.Errorf(`
Service is running, but the Plugin Manager cannot be recovered.
Please use %s to recover the service
`,

View File

@@ -86,7 +86,7 @@ func (w *ConnectionWatcher) handleFileWatcherEvent(e []fsnotify.Event) {
if err != nil {
log.Printf("[WARN] Error creating client to handle updated connection config: %s", err.Error())
}
defer client.Close()
defer client.Close(ctx)
log.Printf("[TRACE] loaded updated config")

View File

@@ -2,7 +2,6 @@ package constants
// viper config keys
const (
ConfigKeyShowInteractiveOutput = "show-interactive-output"
// ConfigKeyDatabaseSearchPath is used to store the search path set in the database config in viper
// the viper value will be set via via a call to getScopedKey in steampipeconfig/steampipeconfig.go
ConfigKeyDatabaseSearchPath = "database.search-path"

21
contexthelpers/cancel.go Normal file
View File

@@ -0,0 +1,21 @@
package contexthelpers
import (
"context"
"log"
"os"
"os/signal"
)
func StartCancelHandler(cancel context.CancelFunc) chan os.Signal {
sigIntChannel := make(chan os.Signal, 1)
signal.Notify(sigIntChannel, os.Interrupt)
go func() {
<-sigIntChannel
log.Println("[TRACE] got SIGINT")
// call context cancellation function
cancel()
// leave the channel open - any subsequent interrupts hits will be ignored
}()
return sigIntChannel
}

View File

@@ -0,0 +1,8 @@
package contexthelpers
//https://medium.com/@matryer/context-keys-in-go-5312346a868d
type ContextKey string
func (c ContextKey) String() string {
return "steampipe context key " + string(c)
}

View File

@@ -13,6 +13,7 @@ import (
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/query/queryresult"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/steampipeconfig/modconfig"
"github.com/turbot/steampipe/utils"
)
@@ -92,8 +93,8 @@ func NewControlRun(control *modconfig.Control, group *ResultGroup, executionTree
return res
}
func (r *ControlRun) skip() {
r.setRunStatus(ControlRunComplete)
func (r *ControlRun) skip(ctx context.Context) {
r.setRunStatus(ctx, ControlRunComplete)
}
// set search path for this control run
@@ -196,14 +197,14 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
r.runStatus = ControlRunStarted
// update the current running control in the Progress renderer
r.executionTree.progress.OnControlStart(control)
defer r.executionTree.progress.OnControlFinish()
r.executionTree.progress.OnControlStart(ctx, control)
defer r.executionTree.progress.OnControlFinish(ctx)
// resolve the control query
r.Lifecycle.Add("query_resolution_start")
query, err := r.resolveControlQuery(control)
if err != nil {
r.SetError(err)
r.SetError(ctx, err)
return
}
r.Lifecycle.Add("query_resolution_finish")
@@ -211,7 +212,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
log.Printf("[TRACE] setting search path %s\n", control.Name())
r.Lifecycle.Add("set_search_path_start")
if err := r.setSearchPath(ctx, dbSession, client); err != nil {
r.SetError(err)
r.SetError(ctx, err)
return
}
r.Lifecycle.Add("set_search_path_finish")
@@ -225,7 +226,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
// NOTE no need to pass an OnComplete callback - we are already closing our session after waiting for results
log.Printf("[TRACE] execute start for, %s\n", control.Name())
r.Lifecycle.Add("query_start")
queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, query, nil, false)
queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, query, nil)
r.Lifecycle.Add("query_finish")
log.Printf("[TRACE] execute finish for, %s\n", control.Name())
@@ -243,7 +244,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
log.Printf("[TRACE] control %s query failed again with plugin connectivity error %s - NOT retrying...", r.Control.Name(), err)
}
}
r.SetError(err)
r.SetError(ctx, err)
return
}
@@ -255,7 +256,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
log.Printf("[TRACE] finish result for, %s\n", control.Name())
}
func (r *ControlRun) SetError(err error) {
func (r *ControlRun) SetError(ctx context.Context, err error) {
if err == nil {
return
}
@@ -263,18 +264,23 @@ func (r *ControlRun) SetError(err error) {
// update error count
r.Summary.Error++
r.setRunStatus(ControlRunError)
r.setRunStatus(ctx, ControlRunError)
}
func (r *ControlRun) GetError() error {
return r.runError
}
// create a context with a deadline, and with status updates disabled (we do not want to show 'loading' results)
func (r *ControlRun) getControlQueryContext(ctx context.Context) context.Context {
// create a context with a deadline
shouldBeDoneBy := time.Now().Add(controlQueryTimeout)
// we don't use this cancel fn because, pgx prematurely cancels the PG connection when this cancel gets called in 'defer'
newCtx, _ := context.WithDeadline(ctx, shouldBeDoneBy)
// disable the status spinner to hide 'loading' results)
newCtx = statushooks.DisableStatusHooks(newCtx)
return newCtx
}
@@ -304,20 +310,20 @@ func (r *ControlRun) waitForResults(ctx context.Context) {
// create a channel to which will be closed when gathering has been done
gatherDoneChan := make(chan string)
go func() {
r.gatherResults()
r.gatherResults(ctx)
close(gatherDoneChan)
}()
select {
// check for cancellation
case <-ctx.Done():
r.SetError(ctx.Err())
r.SetError(ctx, ctx.Err())
case <-gatherDoneChan:
// do nothing
}
}
func (r *ControlRun) gatherResults() {
func (r *ControlRun) gatherResults(ctx context.Context) {
r.Lifecycle.Add("gather_start")
defer func() { r.Lifecycle.Add("gather_finish") }()
for {
@@ -326,14 +332,14 @@ func (r *ControlRun) gatherResults() {
// nil row means control run is complete
if row == nil {
// nil row means we are done
r.setRunStatus(ControlRunComplete)
r.setRunStatus(ctx, ControlRunComplete)
r.createdOrderedResultRows()
return
}
// if the row is in error then we terminate the run
if row.Error != nil {
// set error status and summary
r.SetError(row.Error)
r.SetError(ctx, row.Error)
// update the result group status with our status - this will be passed all the way up the execution tree
r.group.updateSummary(r.Summary)
return
@@ -342,7 +348,7 @@ func (r *ControlRun) gatherResults() {
// so all is ok - create another result row
result, err := NewResultRow(r.Control, row, r.queryResult.ColTypes)
if err != nil {
r.SetError(err)
r.SetError(ctx, err)
return
}
r.addResultRow(result)
@@ -380,7 +386,7 @@ func (r *ControlRun) createdOrderedResultRows() {
}
}
func (r *ControlRun) setRunStatus(status ControlRunStatus) {
func (r *ControlRun) setRunStatus(ctx context.Context, status ControlRunStatus) {
r.stateLock.Lock()
r.runStatus = status
r.stateLock.Unlock()
@@ -388,12 +394,11 @@ func (r *ControlRun) setRunStatus(status ControlRunStatus) {
if r.Finished() {
// update Progress
if status == ControlRunError {
r.executionTree.progress.OnControlError()
r.executionTree.progress.OnControlError(ctx)
} else {
r.executionTree.progress.OnControlComplete()
r.executionTree.progress.OnControlComplete(ctx)
}
// TODO CANCEL QUERY IF NEEDED
r.doneChan <- true
}
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/query/queryresult"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/steampipeconfig/modconfig"
"github.com/turbot/steampipe/workspace"
"golang.org/x/sync/semaphore"
@@ -40,9 +41,10 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien
workspace: workspace,
client: client,
}
// if a "--where" or "--tag" parameter was passed, build a map of control manes used to filter the controls to run
// NOTE: not enabled yet
err := executionTree.populateControlFilterMap(ctx)
// if a "--where" or "--tag" parameter was passed, build a map of control names used to filter the controls to run
// create a context with status hooks disabled
noStatusCtx := statushooks.DisableStatusHooks(ctx)
err := executionTree.populateControlFilterMap(noStatusCtx)
if err != nil {
return nil, err
@@ -55,7 +57,7 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien
}
// build tree of result groups, starting with a synthetic 'root' node
executionTree.Root = NewRootResultGroup(executionTree, rootItem)
executionTree.Root = NewRootResultGroup(ctx, executionTree, rootItem)
// after tree has built, ControlCount will be set - create progress rendered
executionTree.progress = NewControlProgressRenderer(len(executionTree.controlRuns))
@@ -65,7 +67,7 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien
// AddControl checks whether control should be included in the tree
// if so, creates a ControlRun, which is added to the parent group
func (e *ExecutionTree) AddControl(control *modconfig.Control, group *ResultGroup) {
func (e *ExecutionTree) AddControl(ctx context.Context, control *modconfig.Control, group *ResultGroup) {
// note we use short name to determine whether to include a control
if e.ShouldIncludeControl(control.ShortName) {
// create new ControlRun with treeItem as the parent
@@ -81,11 +83,11 @@ func (e *ExecutionTree) Execute(ctx context.Context, client db_common.Client) in
log.Println("[TRACE]", "begin ExecutionTree.Execute")
defer log.Println("[TRACE]", "end ExecutionTree.Execute")
e.StartTime = time.Now()
e.progress.Start()
e.progress.Start(ctx)
defer func() {
e.EndTime = time.Now()
e.progress.Finish()
e.progress.Finish(ctx)
}()
// the number of goroutines parallel to start
@@ -247,7 +249,7 @@ func (e *ExecutionTree) getControlMapFromWhereClause(ctx context.Context, whereC
query = fmt.Sprintf("select resource_name from %s where %s", constants.IntrospectionTableControl, whereClause)
}
res, err := e.client.ExecuteSync(ctx, query, false)
res, err := e.client.ExecuteSync(ctx, query)
if err != nil {
return nil, err
}

View File

@@ -1,16 +1,17 @@
package controlexecute
import (
"context"
"fmt"
"sync"
"github.com/turbot/steampipe/statushooks"
"github.com/spf13/viper"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/steampipeconfig/modconfig"
"github.com/briandowns/spinner"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/utils"
)
@@ -34,16 +35,16 @@ func NewControlProgressRenderer(total int) *ControlProgressRenderer {
}
}
func (p *ControlProgressRenderer) Start() {
func (p *ControlProgressRenderer) Start(ctx context.Context) {
p.updateLock.Lock()
defer p.updateLock.Unlock()
if p.enabled {
p.spinner = display.ShowSpinner("Starting controls...")
statushooks.SetStatus(ctx, "Starting controls...")
}
}
func (p *ControlProgressRenderer) OnControlStart(control *modconfig.Control) {
func (p *ControlProgressRenderer) OnControlStart(ctx context.Context, control *modconfig.Control) {
p.updateLock.Lock()
defer p.updateLock.Unlock()
@@ -54,43 +55,43 @@ func (p *ControlProgressRenderer) OnControlStart(control *modconfig.Control) {
p.pending--
if p.enabled {
display.UpdateSpinnerMessage(p.spinner, p.message())
statushooks.SetStatus(ctx, p.message())
}
}
func (p *ControlProgressRenderer) OnControlFinish() {
func (p *ControlProgressRenderer) OnControlFinish(ctx context.Context) {
p.updateLock.Lock()
defer p.updateLock.Unlock()
// decrement the parallel execution count
p.executing--
if p.enabled {
display.UpdateSpinnerMessage(p.spinner, p.message())
statushooks.SetStatus(ctx, p.message())
}
}
func (p *ControlProgressRenderer) OnControlComplete() {
func (p *ControlProgressRenderer) OnControlComplete(ctx context.Context) {
p.updateLock.Lock()
defer p.updateLock.Unlock()
p.complete++
if p.enabled {
display.UpdateSpinnerMessage(p.spinner, p.message())
statushooks.SetStatus(ctx, p.message())
}
}
func (p *ControlProgressRenderer) OnControlError() {
func (p *ControlProgressRenderer) OnControlError(ctx context.Context) {
p.updateLock.Lock()
defer p.updateLock.Unlock()
p.error++
if p.enabled {
display.UpdateSpinnerMessage(p.spinner, p.message())
statushooks.SetStatus(ctx, p.message())
}
}
func (p *ControlProgressRenderer) Finish() {
func (p *ControlProgressRenderer) Finish(ctx context.Context) {
if p.enabled {
display.StopSpinner(p.spinner)
statushooks.Done(ctx)
}
}

View File

@@ -49,7 +49,7 @@ func NewGroupSummary() *GroupSummary {
}
// NewRootResultGroup creates a ResultGroup to act as the root node of a control execution tree
func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.ModTreeItem) *ResultGroup {
func NewRootResultGroup(ctx context.Context, executionTree *ExecutionTree, rootItems ...modconfig.ModTreeItem) *ResultGroup {
root := &ResultGroup{
GroupId: RootResultGroupName,
Groups: []*ResultGroup{},
@@ -62,10 +62,10 @@ func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.Mod
// if root item is a benchmark, create new result group with root as parent
if control, ok := item.(*modconfig.Control); ok {
// if root item is a control, add control run
executionTree.AddControl(control, root)
executionTree.AddControl(ctx, control, root)
} else {
// create a result group for this item
itemGroup := NewResultGroup(executionTree, item, root)
itemGroup := NewResultGroup(ctx, executionTree, item, root)
root.Groups = append(root.Groups, itemGroup)
}
}
@@ -73,7 +73,7 @@ func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.Mod
}
// NewResultGroup creates a result group from a ModTreeItem
func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem, parent *ResultGroup) *ResultGroup {
func NewResultGroup(ctx context.Context, executionTree *ExecutionTree, treeItem modconfig.ModTreeItem, parent *ResultGroup) *ResultGroup {
// only show qualified group names for controls from dependent mods
groupId := treeItem.Name()
if mod := treeItem.GetMod(); mod != nil && mod.Name() == executionTree.workspace.Mod.Name() {
@@ -96,7 +96,7 @@ func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem
for _, c := range treeItem.GetChildren() {
if benchmark, ok := c.(*modconfig.Benchmark); ok {
// create a result group for this item
benchmarkGroup := NewResultGroup(executionTree, benchmark, group)
benchmarkGroup := NewResultGroup(ctx, executionTree, benchmark, group)
// if the group has any control runs, add to tree
if benchmarkGroup.ControlRunCount() > 0 {
// create a new result group with 'group' as the parent
@@ -104,7 +104,7 @@ func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem
}
}
if control, ok := c.(*modconfig.Control); ok {
executionTree.AddControl(control, group)
executionTree.AddControl(ctx, control, group)
}
}
@@ -180,18 +180,18 @@ func (r *ResultGroup) execute(ctx context.Context, client db_common.Client, para
for _, controlRun := range r.ControlRuns {
if utils.IsContextCancelled(ctx) {
controlRun.SetError(ctx.Err())
controlRun.SetError(ctx, ctx.Err())
continue
}
if viper.GetBool(constants.ArgDryRun) {
controlRun.skip()
controlRun.skip(ctx)
continue
}
err := parallelismLock.Acquire(ctx, 1)
if err != nil {
controlRun.SetError(err)
controlRun.SetError(ctx, err)
continue
}
@@ -199,7 +199,7 @@ func (r *ResultGroup) execute(ctx context.Context, client db_common.Client, para
defer func() {
if r := recover(); r != nil {
// if the Execute panic'ed, set it as an error
run.SetError(helpers.ToError(r))
run.SetError(ctx, helpers.ToError(r))
}
// Release in defer, so that we don't retain the lock even if there's a panic inside
parallelismLock.Release(1)

View File

@@ -1,14 +1,11 @@
package control
import (
"context"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/workspace"
)
type InitData struct {
Ctx context.Context
Workspace *workspace.Workspace
Client db_common.Client
Result *db_common.InitResult

View File

@@ -68,7 +68,7 @@ func NewDbClient(ctx context.Context, connectionString string) (*DbClient, error
client.connectionString = connectionString
if err := client.LoadForeignSchemaNames(ctx); err != nil {
client.Close()
client.Close(ctx)
return nil, err
}
return client, nil
@@ -114,7 +114,7 @@ func (c *DbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallba
// Close implements Client
// closes the connection to the database and shuts down the backend
func (c *DbClient) Close() error {
func (c *DbClient) Close(context.Context) error {
log.Printf("[TRACE] DbClient.Close %v", c.dbClient)
if c.dbClient != nil {
c.sessionInitWaitGroup.Wait()

View File

@@ -8,12 +8,9 @@ import (
"fmt"
"time"
"github.com/briandowns/spinner"
"github.com/turbot/steampipe/cmdconfig"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/query/queryresult"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/utils"
"golang.org/x/text/language"
"golang.org/x/text/message"
@@ -21,7 +18,7 @@ import (
// ExecuteSync implements Client
// execute a query against this client and wait for the result
func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
func (c *DbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) {
// acquire a session
sessionResult := c.AcquireSession(ctx)
if sessionResult.Error != nil {
@@ -32,17 +29,17 @@ func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner
// and not in call-time
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
}()
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, disableSpinner)
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query)
}
// ExecuteSyncInSession implements Client
// execute a query against this client and wait for the result
func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) {
if query == "" {
return &queryresult.SyncQueryResult{}, nil
}
result, err := c.ExecuteInSession(ctx, session, query, nil, disableSpinner)
result, err := c.ExecuteInSession(ctx, session, query, nil)
if err != nil {
return nil, err
}
@@ -61,7 +58,7 @@ func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.
// Execute implements Client
// execute the query in the given Context
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner bool) (*queryresult.Result, error) {
func (c *DbClient) Execute(ctx context.Context, query string) (*queryresult.Result, error) {
// acquire a session
sessionResult := c.AcquireSession(ctx)
if sessionResult.Error != nil {
@@ -70,26 +67,29 @@ func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner boo
// define callback to close session when the async execution is complete
closeSessionCallback := func() { sessionResult.Session.Close(utils.IsContextCancelled(ctx)) }
return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback, disableSpinner)
return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback)
}
// ExecuteInSession implements Client
// execute the query in the given Context using the provided DatabaseSession
// ExecuteInSession assumes no responsibility over the lifecycle of the DatabaseSession - that is the responsibility of the caller
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error) {
func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) {
if query == "" {
return queryresult.NewQueryResult(nil), nil
}
startTime := time.Now()
// channel to flag to spinner that the query has run
var spinner *spinner.Spinner
var tx *sql.Tx
defer func() {
if err != nil {
// stop spinner in case of error
display.StopSpinner(spinner)
statushooks.Done(ctx)
// error - rollback transaction if we have one
if tx != nil {
tx.Rollback()
}
// call the completion callback - if one was provided
if onComplete != nil {
onComplete()
@@ -97,11 +97,7 @@ func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.Data
}
}()
if !disableSpinner && cmdconfig.Viper().GetBool(constants.ConfigKeyShowInteractiveOutput) {
// if `show-interactive-output` is false, the spinner gets created, but is never shown
// so the s.Active() will always come back false . . .
spinner = display.ShowSpinner("Loading results...")
}
statushooks.SetStatus(ctx, "Loading results...")
// start query
var rows *sql.Rows
@@ -122,7 +118,7 @@ func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.Data
// read the rows in a go routine
go func() {
// read in the rows and stream to the query result object
c.readRows(ctx, startTime, rows, result, spinner)
c.readRows(ctx, startTime, rows, result)
if onComplete != nil {
onComplete()
}
@@ -158,11 +154,11 @@ func (c *DbClient) startQuery(ctx context.Context, query string, conn *sql.Conn)
return
}
func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows, result *queryresult.Result, activeSpinner *spinner.Spinner) {
func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows, result *queryresult.Result) {
// defer this, so that these get cleaned up even if there is an unforeseen error
defer func() {
// we are done fetching results. time for display. remove the spinner
display.StopSpinner(activeSpinner)
// we are done fetching results. time for display. clear the status indication
statushooks.Done(ctx)
// close the sql rows object
rows.Close()
if err := rows.Err(); err != nil {
@@ -191,7 +187,7 @@ func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows
continueToNext := true
select {
case <-ctx.Done():
display.UpdateSpinnerMessage(activeSpinner, "Cancelling query")
statushooks.SetStatus(ctx, "Cancelling query")
continueToNext = false
default:
if rowResult, err := readRowContext(ctx, rows, cols, colTypes); err != nil {
@@ -200,9 +196,9 @@ func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows
} else {
result.StreamRow(rowResult)
}
// update the spinner message with the count of rows that have already been fetched
// update the status message with the count of rows that have already been fetched
// this will not show if the spinner is not active
display.UpdateSpinnerMessage(activeSpinner, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount)))
statushooks.SetStatus(ctx, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount)))
rowCount++
}
if !continueToNext {

View File

@@ -18,7 +18,7 @@ import (
func (c *DbClient) GetCurrentSearchPath(ctx context.Context) ([]string, error) {
var currentSearchPath []string
var pathAsString string
rows, err := c.ExecuteSync(ctx, "show search_path", true)
rows, err := c.ExecuteSync(ctx, "show search_path")
if err != nil {
return nil, err
}

View File

@@ -11,7 +11,7 @@ import (
type EnsureSessionStateCallback = func(context.Context, *DatabaseSession) (err error, warnings []string)
type Client interface {
Close() error
Close(ctx context.Context) error
ForeignSchemas() []string
ConnectionMap() *steampipeconfig.ConnectionDataMap
@@ -22,11 +22,11 @@ type Client interface {
AcquireSession(context.Context) *AcquireSessionResult
ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error)
Execute(ctx context.Context, query string, disableSpinner bool) (res *queryresult.Result, err error)
ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error)
Execute(ctx context.Context, query string) (res *queryresult.Result, err error)
ExecuteSyncInSession(ctx context.Context, session *DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error)
ExecuteInSession(ctx context.Context, session *DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error)
ExecuteSyncInSession(ctx context.Context, session *DatabaseSession, query string) (*queryresult.SyncQueryResult, error)
ExecuteInSession(ctx context.Context, session *DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error)
CacheOn(context.Context) error
CacheOff(context.Context) error

View File

@@ -13,7 +13,7 @@ func ExecuteQuery(ctx context.Context, queryString string, client Client) (*quer
defer utils.LogTime("db.ExecuteQuery end")
resultsStreamer := queryresult.NewResultStreamer()
result, err := client.Execute(ctx, queryString, false)
result, err := client.Execute(ctx, queryString)
if err != nil {
return nil, err
}

View File

@@ -19,25 +19,6 @@ import (
// TagColumn is the tag used to specify the column name and type in the introspection tables
const TagColumn = "column"
func UpdateIntrospectionTables(workspaceResources *modconfig.WorkspaceResourceMaps, client Client) error {
utils.LogTime("db.UpdateIntrospectionTables start")
defer utils.LogTime("db.UpdateIntrospectionTables end")
// get the create sql for each table type
clearSql := getClearTablesSql()
// now get sql to populate the tables
insertSql := getTableInsertSql(workspaceResources)
sql := []string{clearSql, insertSql}
// execute the query, passing 'true' to disable the spinner
_, err := client.ExecuteSync(context.Background(), strings.Join(sql, "\n"), true)
if err != nil {
return fmt.Errorf("failed to update introspection tables: %v", err)
}
return nil
}
func CreateIntrospectionTables(ctx context.Context, workspaceResources *modconfig.WorkspaceResourceMaps, session *DatabaseSession) error {
utils.LogTime("db.CreateIntrospectionTables start")
defer utils.LogTime("db.CreateIntrospectionTables end")

View File

@@ -10,15 +10,14 @@ import (
"os/exec"
"sync"
"github.com/briandowns/spinner"
psutils "github.com/shirou/gopsutil/process"
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/filepaths"
"github.com/turbot/steampipe/ociinstaller"
"github.com/turbot/steampipe/ociinstaller/versionfile"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/utils"
)
@@ -41,66 +40,57 @@ func EnsureDBInstalled(ctx context.Context) (err error) {
close(doneChan)
}()
spinner := display.StartSpinnerAfterDelay("", constants.SpinnerShowTimeout, doneChan)
if IsInstalled() {
// check if the FDW need updating, and init the db id required
err := prepareDb(ctx, spinner)
display.StopSpinner(spinner)
err := prepareDb(ctx)
return err
}
log.Println("[TRACE] calling removeRunningInstanceInfo")
err = removeRunningInstanceInfo()
if err != nil && !os.IsNotExist(err) {
display.StopSpinner(spinner)
log.Printf("[TRACE] removeRunningInstanceInfo failed: %v", err)
return fmt.Errorf("Cleanup any Steampipe processes... FAILED!")
}
log.Println("[TRACE] removing previous installation")
display.UpdateSpinnerMessage(spinner, "Prepare database install location...")
statushooks.SetStatus(ctx, "Prepare database install location...")
defer statushooks.Done(ctx)
err = os.RemoveAll(getDatabaseLocation())
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] %v", err)
return fmt.Errorf("Prepare database install location... FAILED!")
}
display.UpdateSpinnerMessage(spinner, "Download & install embedded PostgreSQL database...")
statushooks.SetStatus(ctx, "Download & install embedded PostgreSQL database...")
_, err = ociinstaller.InstallDB(ctx, constants.DefaultEmbeddedPostgresImage, getDatabaseLocation())
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] %v", err)
return fmt.Errorf("Download & install embedded PostgreSQL database... FAILED!")
}
// installFDW takes care of the spinner, since it may need to run independently
_, err = installFDW(ctx, true, spinner)
_, err = installFDW(ctx, true)
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] installFDW failed: %v", err)
return fmt.Errorf("Download & install steampipe-postgres-fdw... FAILED!")
}
// run the database installation
err = runInstall(ctx, true, spinner)
err = runInstall(ctx, true)
if err != nil {
display.StopSpinner(spinner)
return err
}
// write a signature after everything gets done!
// so that we can check for this later on
display.UpdateSpinnerMessage(spinner, "Updating install records...")
statushooks.SetStatus(ctx, "Updating install records...")
err = updateDownloadedBinarySignature()
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] updateDownloadedBinarySignature failed: %v", err)
return fmt.Errorf("Updating install records... FAILED!")
}
display.StopSpinner(spinner)
return nil
}
@@ -137,11 +127,10 @@ func IsInstalled() bool {
}
// prepareDb updates the FDW if needed, and inits the database if required
func prepareDb(ctx context.Context, spinner *spinner.Spinner) error {
func prepareDb(ctx context.Context) error {
// check if FDW needs to be updated
if fdwNeedsUpdate() {
_, err := installFDW(ctx, false, spinner)
spinner.Stop()
_, err := installFDW(ctx, false)
if err != nil {
log.Printf("[TRACE] installFDW failed: %v", err)
return fmt.Errorf("Update steampipe-postgres-fdw... FAILED!")
@@ -153,10 +142,9 @@ func prepareDb(ctx context.Context, spinner *spinner.Spinner) error {
}
if needsInit() {
spinner.Start()
display.UpdateSpinnerMessage(spinner, "Cleanup any Steampipe processes...")
statushooks.SetStatus(ctx, "Cleanup any Steampipe processes...")
killInstanceIfAny(ctx)
if err := runInstall(ctx, false, spinner); err != nil {
if err := runInstall(ctx, false); err != nil {
return err
}
}
@@ -175,15 +163,15 @@ func fdwNeedsUpdate() bool {
return versionInfo.FdwExtension.Version != constants.FdwVersion
}
func installFDW(ctx context.Context, firstSetup bool, spinner *spinner.Spinner) (string, error) {
func installFDW(ctx context.Context, firstSetup bool) (string, error) {
utils.LogTime("db_local.installFDW start")
defer utils.LogTime("db_local.installFDW end")
status, err := GetState()
state, err := GetState()
if err != nil {
return "", err
}
if status != nil {
if state != nil {
defer func() {
if !firstSetup {
// update the signature
@@ -191,7 +179,7 @@ func installFDW(ctx context.Context, firstSetup bool, spinner *spinner.Spinner)
}
}()
}
display.UpdateSpinnerMessage(spinner, fmt.Sprintf("Download & install %s...", constants.Bold("steampipe-postgres-fdw")))
statushooks.SetStatus(ctx, fmt.Sprintf("Download & install %s...", constants.Bold("steampipe-postgres-fdw")))
return ociinstaller.InstallFdw(ctx, constants.DefaultFdwImage, getDatabaseLocation())
}
@@ -203,58 +191,54 @@ func needsInit() bool {
return !helpers.FileExists(getPgHbaConfLocation())
}
func runInstall(ctx context.Context, firstInstall bool, spinner *spinner.Spinner) error {
func runInstall(ctx context.Context, firstInstall bool) error {
utils.LogTime("db_local.runInstall start")
defer utils.LogTime("db_local.runInstall end")
display.UpdateSpinnerMessage(spinner, "Cleaning up...")
statushooks.SetStatus(ctx, "Cleaning up...")
defer statushooks.Done(ctx)
err := utils.RemoveDirectoryContents(getDataLocation())
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] %v", err)
return fmt.Errorf("Prepare database install location... FAILED!")
}
display.UpdateSpinnerMessage(spinner, "Initializing database...")
statushooks.SetStatus(ctx, "Initializing database...")
err = initDatabase()
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] initDatabase failed: %v", err)
return fmt.Errorf("Initializing database... FAILED!")
}
display.UpdateSpinnerMessage(spinner, "Starting database...")
statushooks.SetStatus(ctx, "Starting database...")
port, err := getNextFreePort()
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] getNextFreePort failed: %v", err)
return fmt.Errorf("Starting database... FAILED!")
}
process, err := startServiceForInstall(port)
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] startServiceForInstall failed: %v", err)
return fmt.Errorf("Starting database... FAILED!")
}
display.UpdateSpinnerMessage(spinner, "Connection to database...")
statushooks.SetStatus(ctx, "Connection to database...")
client, err := createMaintenanceClient(ctx, port)
if err != nil {
display.StopSpinner(spinner)
return fmt.Errorf("Connection to database... FAILED!")
}
defer func() {
display.UpdateSpinnerMessage(spinner, "Completing configuration")
statushooks.SetStatus(ctx, "Completing configuration")
client.Close()
doThreeStepPostgresExit(process)
doThreeStepPostgresExit(ctx, process)
}()
display.UpdateSpinnerMessage(spinner, "Generating database passwords...")
statushooks.SetStatus(ctx, "Generating database passwords...")
// generate a password file for use later
_, err = readPasswordFile()
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] readPassword failed: %v", err)
return fmt.Errorf("Generating database passwords... FAILED!")
}
@@ -274,23 +258,21 @@ func runInstall(ctx context.Context, firstInstall bool, spinner *spinner.Spinner
return fmt.Errorf("Invalid database name '%s' - must start with either a lowercase character or an underscore", databaseName)
}
display.UpdateSpinnerMessage(spinner, "Configuring database...")
statushooks.SetStatus(ctx, "Configuring database...")
err = installDatabaseWithPermissions(ctx, databaseName, client)
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] installSteampipeDatabaseAndUser failed: %v", err)
return fmt.Errorf("Configuring database... FAILED!")
}
display.UpdateSpinnerMessage(spinner, "Configuring Steampipe...")
statushooks.SetStatus(ctx, "Configuring Steampipe...")
err = installForeignServer(ctx, client)
if err != nil {
display.StopSpinner(spinner)
log.Printf("[TRACE] installForeignServer failed: %v", err)
return fmt.Errorf("Configuring Steampipe... FAILED!")
}
return err
return nil
}
func resolveDatabaseName() string {

View File

@@ -13,6 +13,7 @@ import (
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/query/queryresult"
"github.com/turbot/steampipe/schema"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/steampipeconfig"
"github.com/turbot/steampipe/utils"
)
@@ -41,13 +42,12 @@ func GetLocalClient(ctx context.Context, invoker constants.Invoker) (db_common.C
client, err := NewLocalClient(ctx, invoker)
if err != nil {
ShutdownService(invoker)
ShutdownService(ctx, invoker)
}
return client, err
}
// NewLocalClient ensures that the database instance is running
// and returns a `Client` to interact with it
// NewLocalClient verifies that the local database instance is running and returns a Client to interact with it
func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbClient, error) {
utils.LogTime("db.NewLocalClient start")
defer utils.LogTime("db.NewLocalClient end")
@@ -69,19 +69,17 @@ func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbCli
// Close implements Client
// close the connection to the database and shuts down the backend
func (c *LocalDbClient) Close() error {
func (c *LocalDbClient) Close(ctx context.Context) error {
log.Printf("[TRACE] close local client %p", c)
if c.client != nil {
log.Printf("[TRACE] local client not NIL")
if err := c.client.Close(); err != nil {
if err := c.client.Close(ctx); err != nil {
return err
}
log.Printf("[TRACE] local client close complete")
}
log.Printf("[TRACE] shutdown local service %v", c.invoker)
// no context to pass on - use background
// we shouldn't do this in a context that can be cancelled anyway
ShutdownService(c.invoker)
ShutdownService(ctx, c.invoker)
return nil
}
@@ -108,23 +106,23 @@ func (c *LocalDbClient) AcquireSession(ctx context.Context) *db_common.AcquireSe
}
// ExecuteSync implements Client
func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSync(ctx, query, disableSpinner)
func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSync(ctx, query)
}
// ExecuteSyncInSession implements Client
func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSyncInSession(ctx, session, query, disableSpinner)
func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) {
return c.client.ExecuteSyncInSession(ctx, session, query)
}
// ExecuteInSession implements Client
func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error) {
return c.client.ExecuteInSession(ctx, session, query, onComplete, disableSpinner)
func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) {
return c.client.ExecuteInSession(ctx, session, query, onComplete)
}
// Execute implements Client
func (c *LocalDbClient) Execute(ctx context.Context, query string, disableSpinner bool) (res *queryresult.Result, err error) {
return c.client.Execute(ctx, query, disableSpinner)
func (c *LocalDbClient) Execute(ctx context.Context, query string) (res *queryresult.Result, err error) {
return c.client.Execute(ctx, query)
}
// CacheOn implements Client
@@ -167,6 +165,9 @@ func (c *LocalDbClient) LoadForeignSchemaNames(ctx context.Context) error {
// local only functions
func (c *LocalDbClient) RefreshConnectionAndSearchPaths(ctx context.Context) *steampipeconfig.RefreshConnectionResult {
// NOTE: disable any status updates - we do not want 'loading' output from any queries
ctx = statushooks.DisableStatusHooks(ctx)
res := c.refreshConnections(ctx)
if res.Error != nil {
return res
@@ -221,7 +222,7 @@ func (c *LocalDbClient) setUserSearchPath(ctx context.Context) ([]string, error)
// get all roles which are a member of steampipe_users
query := fmt.Sprintf(`select usename from pg_user where pg_has_role(usename, '%s', 'member')`, constants.DatabaseUsersRole)
res, err := c.ExecuteSync(context.Background(), query, true)
res, err := c.ExecuteSync(context.Background(), query)
if err != nil {
return nil, err
}

View File

@@ -12,7 +12,7 @@ func RefreshConnectionAndSearchPaths(ctx context.Context, invoker constants.Invo
if err != nil {
return err
}
defer client.Close()
defer client.Close(ctx)
refreshResult := client.RefreshConnectionAndSearchPaths(ctx)
// display any initialisation warnings
refreshResult.ShowWarnings()

View File

@@ -136,7 +136,7 @@ func startDB(ctx context.Context, port int, listen StartListenType, invoker cons
// if there was an error and we started the service, stop it again
if res.Error != nil {
if res.Status == ServiceStarted {
StopServices(false, invoker, nil)
StopServices(ctx, false, invoker)
}
// remove the state file if we are going back with an error
removeRunningInstanceInfo()
@@ -554,7 +554,7 @@ func killInstanceIfAny(ctx context.Context) bool {
for _, process := range processes {
wg.Add(1)
go func(p *psutils.Process) {
doThreeStepPostgresExit(p)
doThreeStepPostgresExit(ctx, p)
wg.Done()
}(process)
}

View File

@@ -9,14 +9,12 @@ import (
"syscall"
"time"
"github.com/briandowns/spinner"
psutils "github.com/shirou/gopsutil/process"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/constants/runtime"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/filepaths"
"github.com/turbot/steampipe/pluginmanager"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/utils"
)
@@ -32,7 +30,7 @@ const (
)
// ShutdownService stops the database instance if the given 'invoker' matches
func ShutdownService(invoker constants.Invoker) {
func ShutdownService(ctx context.Context, invoker constants.Invoker) {
utils.LogTime("db_local.ShutdownService start")
defer utils.LogTime("db_local.ShutdownService end")
@@ -54,18 +52,18 @@ func ShutdownService(invoker constants.Invoker) {
}
// we can shut down the database
stopStatus, err := StopServices(false, invoker, nil)
stopStatus, err := StopServices(ctx, false, invoker)
if err != nil {
utils.ShowError(err)
utils.ShowError(ctx, err)
}
if stopStatus == ServiceStopped {
return
}
// shutdown failed - try to force stop
_, err = StopServices(true, invoker, nil)
_, err = StopServices(ctx, true, invoker)
if err != nil {
utils.ShowError(err)
utils.ShowError(ctx, err)
}
}
@@ -92,7 +90,8 @@ func GetCountOfThirdPartyClients(ctx context.Context) (i int, e error) {
}
// StopServices searches for and stops the running instance. Does nothing if an instance was not found
func StopServices(force bool, invoker constants.Invoker, spinner *spinner.Spinner) (status StopStatus, e error) {
func StopServices(ctx context.Context, force bool, invoker constants.Invoker) (status StopStatus, e error) {
log.Printf("[TRACE] StopDB invoker %s, force %v", invoker, force)
utils.LogTime("db_local.StopDB start")
@@ -108,15 +107,16 @@ func StopServices(force bool, invoker constants.Invoker, spinner *spinner.Spinne
pluginManagerStopError := pluginmanager.Stop()
// stop the DB Service
stopResult, dbStopError := stopDBService(spinner, force)
stopResult, dbStopError := stopDBService(ctx, force)
return stopResult, utils.CombineErrors(dbStopError, pluginManagerStopError)
}
func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) {
func stopDBService(ctx context.Context, force bool) (StopStatus, error) {
if force {
// check if we have a process from another install-dir
display.UpdateSpinnerMessage(spinner, "Checking for running instances...")
statushooks.SetStatus(ctx, "Checking for running instances...")
defer statushooks.Done(ctx)
// do not use a context that can be cancelled
killInstanceIfAny(context.Background())
return ServiceStopped, nil
@@ -139,9 +139,7 @@ func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) {
return ServiceStopFailed, err
}
display.UpdateSpinnerMessage(spinner, "Shutting down...")
err = doThreeStepPostgresExit(process)
err = doThreeStepPostgresExit(ctx, process)
if err != nil {
// we couldn't stop it still.
// timeout
@@ -176,7 +174,7 @@ func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) {
checked that the service can indeed shutdown gracefully,
the sequence is there only as a backup.
**/
func doThreeStepPostgresExit(process *psutils.Process) error {
func doThreeStepPostgresExit(ctx context.Context, process *psutils.Process) error {
utils.LogTime("db_local.doThreeStepPostgresExit start")
defer utils.LogTime("db_local.doThreeStepPostgresExit end")
@@ -191,6 +189,11 @@ func doThreeStepPostgresExit(process *psutils.Process) error {
exitSuccessful = waitForProcessExit(process, 2*time.Second)
if !exitSuccessful {
// process didn't quit
// set status, as this is taking time
statushooks.SetStatus(ctx, "Shutting down...")
defer statushooks.Done(ctx)
// try a SIGINT
err = process.SendSignal(syscall.SIGINT)
if err != nil {

View File

@@ -2,6 +2,7 @@ package display
import (
"bytes"
"context"
"encoding/csv"
"encoding/json"
"fmt"
@@ -20,17 +21,17 @@ import (
)
// ShowOutput :: displays the output using the proper formatter as applicable
func ShowOutput(result *queryresult.Result) {
func ShowOutput(ctx context.Context, result *queryresult.Result) {
output := cmdconfig.Viper().GetString(constants.ArgOutput)
if output == constants.OutputFormatJSON {
displayJSON(result)
displayJSON(ctx, result)
} else if output == constants.OutputFormatCSV {
displayCSV(result)
displayCSV(ctx, result)
} else if output == constants.OutputFormatLine {
displayLine(result)
displayLine(ctx, result)
} else {
// default
displayTable(result)
displayTable(ctx, result)
}
}
@@ -102,7 +103,7 @@ func getColumnSettings(headers []string, rows [][]string) ([]table.ColumnConfig,
return colConfigs, headerRow
}
func displayLine(result *queryresult.Result) {
func displayLine(ctx context.Context, result *queryresult.Result) {
colNames := ColumnNames(result.ColTypes)
maxColNameLength := 0
for _, colName := range colNames {
@@ -158,7 +159,7 @@ func displayLine(result *queryresult.Result) {
// call this function for each row
if err := iterateResults(result, rowFunc); err != nil {
utils.ShowError(err)
utils.ShowError(ctx, err)
return
}
}
@@ -173,7 +174,7 @@ func getTerminalColumnsRequiredForString(str string) int {
return colsRequired
}
func displayJSON(result *queryresult.Result) {
func displayJSON(ctx context.Context, result *queryresult.Result) {
var jsonOutput []map[string]interface{}
// define function to add each row to the JSON output
@@ -188,7 +189,7 @@ func displayJSON(result *queryresult.Result) {
// call this function for each row
if err := iterateResults(result, rowFunc); err != nil {
utils.ShowError(err)
utils.ShowError(ctx, err)
return
}
// display the JSON
@@ -202,7 +203,7 @@ func displayJSON(result *queryresult.Result) {
fmt.Println()
}
func displayCSV(result *queryresult.Result) {
func displayCSV(ctx context.Context, result *queryresult.Result) {
csvWriter := csv.NewWriter(os.Stdout)
csvWriter.Comma = []rune(cmdconfig.Viper().GetString(constants.ArgSeparator))[0]
@@ -219,17 +220,17 @@ func displayCSV(result *queryresult.Result) {
// call this function for each row
if err := iterateResults(result, rowFunc); err != nil {
utils.ShowError(err)
utils.ShowError(ctx, err)
return
}
csvWriter.Flush()
if csvWriter.Error() != nil {
utils.ShowErrorWithMessage(csvWriter.Error(), "unable to print csv")
utils.ShowErrorWithMessage(ctx, csvWriter.Error(), "unable to print csv")
}
}
func displayTable(result *queryresult.Result) {
func displayTable(ctx context.Context, result *queryresult.Result) {
// the buffer to put the output data in
outbuf := bytes.NewBufferString("")
@@ -271,14 +272,14 @@ func displayTable(result *queryresult.Result) {
if err != nil {
// display the error
fmt.Println()
utils.ShowError(err)
utils.ShowError(ctx, err)
fmt.Println()
}
// write out the table to the buffer
t.Render()
// page out the table
ShowPaged(outbuf.String())
ShowPaged(ctx, outbuf.String())
// if timer is turned on
if cmdconfig.Viper().GetBool(constants.ArgTimer) {

View File

@@ -2,6 +2,7 @@ package display
import (
"bufio"
"context"
"fmt"
"os"
"os/exec"
@@ -15,9 +16,9 @@ import (
)
// ShowPaged displays the `content` in a system dependent pager
func ShowPaged(content string) {
func ShowPaged(ctx context.Context, content string) {
if isPagerNeeded(content) && (runtime.GOOS == "darwin" || runtime.GOOS == "linux") {
nixPager(content)
nixPager(ctx, content)
} else {
nullPager(content)
}
@@ -59,11 +60,11 @@ func nullPager(content string) {
fmt.Print(content)
}
func nixPager(content string) {
func nixPager(ctx context.Context, content string) {
if isLessAvailable() {
execPager(exec.Command("less", "-SRXF"), content)
execPager(ctx, exec.Command("less", "-SRXF"), content)
} else if isMoreAvailable() {
execPager(exec.Command("more"), content)
execPager(ctx, exec.Command("more"), content)
} else {
nullPager(content)
}
@@ -79,13 +80,13 @@ func isMoreAvailable() bool {
return err == nil
}
func execPager(cmd *exec.Cmd, content string) {
func execPager(ctx context.Context, cmd *exec.Cmd, content string) {
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Stdin = strings.NewReader(content)
// run the command - it will block until the pager is exited
err := cmd.Run()
if err != nil {
utils.ShowErrorWithMessage(err, "could not display results")
utils.ShowErrorWithMessage(ctx, err, "could not display results")
}
}

View File

@@ -1,124 +0,0 @@
package display
import (
"fmt"
"os"
"strings"
"time"
"github.com/briandowns/spinner"
"github.com/karrick/gows"
"github.com/spf13/viper"
"github.com/turbot/steampipe/constants"
)
//
// spinner format:
// <spinner><space><message><space><dot><dot><dot><cursor>
// 1 1 [.......] 1 1 1 1 1
// We need at least seven characters to show the spinner properly
//
// Not using the (…) character, since it is too small
//
const minSpinnerWidth = 7
func truncateSpinnerMessageToScreen(msg string) string {
if len(strings.TrimSpace(msg)) == 0 {
// if this is a blank message, return it as is
return msg
}
maxCols, _, _ := gows.GetWinSize()
// if the screen is smaller than the minimum spinner width, we cannot truncate
if maxCols < minSpinnerWidth {
return msg
}
availableColumns := maxCols - minSpinnerWidth
if len(msg) > availableColumns {
msg = msg[:availableColumns]
msg = fmt.Sprintf("%s ...", msg)
}
return msg
}
// StartSpinnerAfterDelay shows the spinner with the given `msg` if and only if `cancelStartIf` resolves
// after `delay`.
//
// Example: if delay is 2 seconds and `cancelStartIf` resolves after 2.5 seconds, the spinner
// will show for 0.5 seconds. If `cancelStartIf` resolves after 1.5 seconds, the spinner will
// NOT be shown at all
//
func StartSpinnerAfterDelay(msg string, delay time.Duration, cancelStartIf chan bool) *spinner.Spinner {
if !viper.GetBool(constants.ConfigKeyIsTerminalTTY) {
return nil
}
msg = truncateSpinnerMessageToScreen(msg)
spinner := spinner.New(
spinner.CharSets[14],
100*time.Millisecond,
spinner.WithHiddenCursor(true),
spinner.WithWriter(os.Stdout),
spinner.WithSuffix(fmt.Sprintf(" %s", msg)),
)
go func() {
select {
case <-cancelStartIf:
case <-time.After(delay):
if spinner != nil && !spinner.Active() {
spinner.Start()
}
}
time.Sleep(50 * time.Millisecond)
}()
return spinner
}
// ShowSpinner shows a spinner with the given message
func ShowSpinner(msg string) *spinner.Spinner {
if !viper.GetBool(constants.ConfigKeyIsTerminalTTY) {
return nil
}
msg = truncateSpinnerMessageToScreen(msg)
s := spinner.New(
spinner.CharSets[14],
100*time.Millisecond,
spinner.WithHiddenCursor(true),
spinner.WithWriter(os.Stdout),
spinner.WithSuffix(fmt.Sprintf(" %s", msg)),
)
s.Start()
return s
}
// StopSpinnerWithMessage stops a spinner instance and clears it, after writing `finalMsg`
func StopSpinnerWithMessage(spinner *spinner.Spinner, finalMsg string) {
if spinner != nil {
spinner.FinalMSG = finalMsg
spinner.Stop()
}
}
// StopSpinner stops a spinner instance and clears it
func StopSpinner(spinner *spinner.Spinner) {
if spinner != nil {
spinner.Stop()
}
}
func ResumeSpinner(spinner *spinner.Spinner) {
if spinner != nil && !spinner.Active() {
spinner.Restart()
}
}
// UpdateSpinnerMessage updates the message of the given spinner
func UpdateSpinnerMessage(spinner *spinner.Spinner, newMessage string) {
if spinner != nil {
newMessage = truncateSpinnerMessageToScreen(newMessage)
spinner.Suffix = fmt.Sprintf(" %s", newMessage)
}
}

View File

@@ -3,6 +3,8 @@ package executionlayer
import (
"context"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/report/reportevents"
"github.com/turbot/steampipe/report/reportexecute"
@@ -11,6 +13,9 @@ import (
)
func ExecuteReportNode(ctx context.Context, reportName string, workspace *workspace.Workspace, client db_common.Client) error {
// create context for the report execution
// (for now just disable all status messages - replace with event based? )
reportCtx := statushooks.DisableStatusHooks(ctx)
executionTree, err := reportexecute.NewReportExecutionTree(reportName, client, workspace)
if err != nil {
return err
@@ -19,7 +24,7 @@ func ExecuteReportNode(ctx context.Context, reportName string, workspace *worksp
go func() {
workspace.PublishReportEvent(&reportevents.ExecutionStarted{ReportNode: executionTree.Root})
if err := executionTree.Execute(ctx); err != nil {
if err := executionTree.Execute(reportCtx); err != nil {
if executionTree.Root.GetRunStatus() == reportinterfaces.ReportRunError {
// set error state on the root node
executionTree.Root.SetError(err)

View File

@@ -17,13 +17,14 @@ import (
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/cmdconfig"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/contexthelpers"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/query"
"github.com/turbot/steampipe/query/metaquery"
"github.com/turbot/steampipe/query/queryhistory"
"github.com/turbot/steampipe/query/queryresult"
"github.com/turbot/steampipe/schema"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/steampipeconfig"
"github.com/turbot/steampipe/utils"
"github.com/turbot/steampipe/version"
@@ -54,7 +55,11 @@ type InteractiveClient struct {
// lock while execution is occurring to avoid errors/warnings being shown
executionLock sync.Mutex
schemaMetadata *schema.Metadata
highlighter *Highlighter
highlighter *Highlighter
// status update hooks
statusHook statushooks.StatusHooks
}
func getHighlighter(theme string) *Highlighter {
@@ -65,7 +70,7 @@ func getHighlighter(theme string) *Highlighter {
)
}
func newInteractiveClient(initData *query.InitData, resultsStreamer *queryresult.ResultStreamer) (*InteractiveClient, error) {
func newInteractiveClient(ctx context.Context, initData *query.InitData, resultsStreamer *queryresult.ResultStreamer) (*InteractiveClient, error) {
c := &InteractiveClient{
initData: initData,
resultsStreamer: resultsStreamer,
@@ -75,30 +80,33 @@ func newInteractiveClient(initData *query.InitData, resultsStreamer *queryresult
initResultChan: make(chan *db_common.InitResult, 1),
highlighter: getHighlighter(viper.GetString(constants.ArgTheme)),
}
// asynchronously wait for init to complete
// we start this immediately rather than lazy loading as we want to handle errors asap
go c.readInitDataStream()
go c.readInitDataStream(ctx)
return c, nil
}
// InteractivePrompt starts an interactive prompt and return
func (c *InteractiveClient) InteractivePrompt() {
func (c *InteractiveClient) InteractivePrompt(ctx context.Context) {
// start a cancel handler for the interactive client - this will call activeQueryCancelFunc if it is set
// (registered when we call createQueryContext)
interruptSignalChannel := c.startCancelHandler()
interruptSignalChannel := contexthelpers.StartCancelHandler(c.cancelActiveQueryIfAny)
// create a cancel context for the prompt - this will set c.cancelPrompt
promptCtx := c.createPromptContext()
parentContext := ctx
ctx = c.createPromptContext(parentContext)
defer func() {
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
}
// close up the SIGINT channel so that the receiver goroutine can quit
signal.Stop(interruptSignalChannel)
close(interruptSignalChannel)
// cleanup the init data to ensure any services we started are stopped
c.initData.Cleanup()
c.initData.Cleanup(ctx)
// close the result stream
// this needs to be the last thing we do,
@@ -106,18 +114,19 @@ func (c *InteractiveClient) InteractivePrompt() {
c.resultsStreamer.Close()
}()
fmt.Printf("Welcome to Steampipe v%s\n", version.SteampipeVersion.String())
fmt.Printf("For more information, type %s\n", constants.Bold(".help"))
statushooks.Message(ctx,
fmt.Sprintf("Welcome to Steampipe v%s", version.SteampipeVersion.String()),
fmt.Sprintf("For more information, type %s", constants.Bold(".help")))
// run the prompt in a goroutine, so we can also detect async initialisation errors
promptResultChan := make(chan utils.InteractiveExitStatus, 1)
c.runInteractivePromptAsync(promptCtx, &promptResultChan)
c.runInteractivePromptAsync(ctx, &promptResultChan)
// select results
for {
select {
case initResult := <-c.initResultChan:
c.handleInitResult(promptCtx, initResult)
c.handleInitResult(ctx, initResult)
// if there was an error, handleInitResult will shut down the prompt
// - we must wait for it to shut down and not return immediately
@@ -129,9 +138,9 @@ func (c *InteractiveClient) InteractivePrompt() {
return
}
// create new context
promptCtx = c.createPromptContext()
ctx = c.createPromptContext(parentContext)
// now run it again
c.runInteractivePromptAsync(promptCtx, &promptResultChan)
c.runInteractivePromptAsync(ctx, &promptResultChan)
}
}
}
@@ -184,7 +193,7 @@ func (c *InteractiveClient) handleInitResult(ctx context.Context, initResult *db
c.ClosePrompt(AfterPromptCloseExit)
// add newline to ensure error is not printed at end of current prompt line
fmt.Println()
utils.ShowError(initResult.Error)
utils.ShowError(ctx, initResult.Error)
return
}
@@ -231,7 +240,7 @@ func (c *InteractiveClient) runInteractivePrompt(ctx context.Context) (ret utils
}()
callExecutor := func(line string) {
c.executor(line)
c.executor(ctx, line)
}
completer := func(d prompt.Document) []prompt.Suggest {
return c.queryCompleter(d)
@@ -333,7 +342,7 @@ func (c *InteractiveClient) breakMultilinePrompt(buffer *prompt.Buffer) {
c.interactiveBuffer = []string{}
}
func (c *InteractiveClient) executor(line string) {
func (c *InteractiveClient) executor(ctx context.Context, line string) {
// take an execution lock, so that errors and warnings don't show up while
// we are underway
c.executionLock.Lock()
@@ -347,10 +356,10 @@ func (c *InteractiveClient) executor(line string) {
// we want to store even if we fail to resolve a query
c.interactiveQueryHistory.Push(line)
query, err := c.getQuery(line)
query, err := c.getQuery(ctx, line)
if query == "" {
if err != nil {
utils.ShowError(utils.HandleCancelError(err))
utils.ShowError(ctx, utils.HandleCancelError(err))
}
// restart the prompt
c.restartInteractiveSession()
@@ -358,20 +367,20 @@ func (c *InteractiveClient) executor(line string) {
}
// create a context for the execution of the query
queryContext := c.createQueryContext()
queryContext := c.createQueryContext(ctx)
if metaquery.IsMetaQuery(query) {
if err := c.executeMetaquery(queryContext, query); err != nil {
utils.ShowError(err)
utils.ShowError(ctx, err)
}
// cancel the context
c.cancelActiveQueryIfAny()
} else {
// otherwise execute query
result, err := c.client().Execute(queryContext, query, false)
result, err := c.client().Execute(queryContext, query)
if err != nil {
utils.ShowError(utils.HandleCancelError(err))
utils.ShowError(ctx, utils.HandleCancelError(err))
} else {
c.resultsStreamer.StreamResult(result)
}
@@ -381,7 +390,7 @@ func (c *InteractiveClient) executor(line string) {
c.restartInteractiveSession()
}
func (c *InteractiveClient) getQuery(line string) (string, error) {
func (c *InteractiveClient) getQuery(ctx context.Context, line string) (string, error) {
// if it's an empty line, then we don't need to do anything
if line == "" {
return "", nil
@@ -391,23 +400,20 @@ func (c *InteractiveClient) getQuery(line string) (string, error) {
if !c.isInitialised() {
// create a context used purely to detect cancellation during initialisation
// this will also set c.cancelActiveQuery
queryContext := c.createQueryContext()
queryContext := c.createQueryContext(ctx)
defer func() {
// cancel this context
c.cancelActiveQueryIfAny()
}()
initDoneChan := make(chan bool)
sp := display.StartSpinnerAfterDelay("Initializing...", constants.SpinnerShowTimeout, initDoneChan)
statushooks.SetStatus(ctx, "Initializing...")
// wait for client initialisation to complete
if err := c.waitForInitData(queryContext); err != nil {
err := c.waitForInitData(queryContext)
statushooks.Done(ctx)
if err != nil {
// if it failed, report error and quit
close(initDoneChan)
display.StopSpinner(sp)
return "", err
}
close(initDoneChan)
display.StopSpinner(sp)
}
// push the current line into the buffer
@@ -420,7 +426,7 @@ func (c *InteractiveClient) getQuery(line string) (string, error) {
query, _, err := c.workspace().ResolveQueryAndArgs(queryString)
if err != nil {
// if we fail to resolve, show error but do not return it - we want to stay in the prompt
utils.ShowError(err)
utils.ShowError(ctx, err)
return "", nil
}
isNamedQuery := query != queryString

View File

@@ -2,31 +2,21 @@ package interactive
import (
"context"
"os"
"os/signal"
"syscall"
)
func (c *InteractiveClient) startCancelHandler() chan os.Signal {
interruptSignalChannel := make(chan os.Signal, 10)
signal.Notify(interruptSignalChannel, syscall.SIGINT, syscall.SIGTERM)
go func() {
for range interruptSignalChannel {
c.cancelActiveQueryIfAny()
}
}()
return interruptSignalChannel
}
// create a cancel context for the interactive prompt, and set c.cancelFunc
func (c *InteractiveClient) createPromptContext() context.Context {
ctx, cancel := context.WithCancel(context.Background())
func (c *InteractiveClient) createPromptContext(parentContext context.Context) context.Context {
// ensure previous prompt is cleaned up
if c.cancelPrompt != nil {
c.cancelPrompt()
}
ctx, cancel := context.WithCancel(parentContext)
c.cancelPrompt = cancel
return ctx
}
func (c *InteractiveClient) createQueryContext() context.Context {
ctx, cancel := context.WithCancel(context.Background())
func (c *InteractiveClient) createQueryContext(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
c.cancelActiveQuery = cancel
return ctx
}

View File

@@ -16,11 +16,11 @@ import (
var initTimeout = 40 * time.Second
func (c *InteractiveClient) readInitDataStream() {
func (c *InteractiveClient) readInitDataStream(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
c.interactivePrompt.ClearScreen()
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
}
}()
@@ -40,15 +40,15 @@ func (c *InteractiveClient) readInitDataStream() {
// start the workspace file watcher
if viper.GetBool(constants.ArgWatch) {
// provide an explicit error handler which re-renders the prompt after displaying the error
c.initData.Result.Error = c.initData.Workspace.SetupWatcher(c.initData.Client, c.workspaceWatcherErrorHandler)
c.initData.Result.Error = c.initData.Workspace.SetupWatcher(ctx, c.initData.Client, c.workspaceWatcherErrorHandler)
}
c.initResultChan <- c.initData.Result
}
func (c *InteractiveClient) workspaceWatcherErrorHandler(err error) {
func (c *InteractiveClient) workspaceWatcherErrorHandler(ctx context.Context, err error) {
fmt.Println()
utils.ShowError(err)
utils.ShowError(ctx, err)
c.interactivePrompt.Render()
}

View File

@@ -1,6 +1,8 @@
package interactive
import (
"context"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/db/db_local"
"github.com/turbot/steampipe/query"
@@ -9,19 +11,19 @@ import (
)
// RunInteractivePrompt starts the interactive query prompt
func RunInteractivePrompt(initData *query.InitData) (*queryresult.ResultStreamer, error) {
func RunInteractivePrompt(ctx context.Context, initData *query.InitData) (*queryresult.ResultStreamer, error) {
resultsStreamer := queryresult.NewResultStreamer()
interactiveClient, err := newInteractiveClient(initData, resultsStreamer)
interactiveClient, err := newInteractiveClient(ctx, initData, resultsStreamer)
if err != nil {
utils.ShowErrorWithMessage(err, "interactive client failed to initialize")
utils.ShowErrorWithMessage(ctx, err, "interactive client failed to initialize")
// do not bind shutdown to any cancellable context
db_local.ShutdownService(constants.InvokerQuery)
db_local.ShutdownService(ctx, constants.InvokerQuery)
return nil, err
}
// start the interactive prompt in a go routine
go interactiveClient.InteractivePrompt()
go interactiveClient.InteractivePrompt(ctx)
return resultsStreamer, nil
}

12
main.go
View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"fmt"
"os"
@@ -17,11 +18,12 @@ var Logger hclog.Logger
var exitCode int
func main() {
ctx := context.Background()
utils.LogTime("main start")
exitCode := 0
defer func() {
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
}
utils.LogTime("main end")
utils.DisplayProfileData()
@@ -29,7 +31,7 @@ func main() {
}()
// ensure steampipe is not being run as root
checkRoot()
checkRoot(ctx)
// increase the soft ULIMIT to match the hard limit
err := setULimit()
@@ -60,10 +62,10 @@ func setULimit() error {
// this is to replicate the user security mechanism of out underlying
// postgresql engine.
func checkRoot() {
func checkRoot(ctx context.Context) {
if os.Geteuid() == 0 {
exitCode = 1
utils.ShowError(fmt.Errorf(`Steampipe cannot be run as the "root" user.
utils.ShowError(ctx, fmt.Errorf(`Steampipe cannot be run as the "root" user.
To reduce security risk, use an unprivileged user account instead.`))
os.Exit(exitCode)
}
@@ -79,7 +81,7 @@ To reduce security risk, use an unprivileged user account instead.`))
if os.Geteuid() != os.Getuid() {
exitCode = 1
utils.ShowError(fmt.Errorf("real and effective user IDs must match."))
utils.ShowError(ctx, fmt.Errorf("real and effective user IDs must match."))
os.Exit(exitCode)
}
}

View File

@@ -1,16 +1,18 @@
package modinstaller
import (
"context"
"github.com/turbot/go-kit/helpers"
"github.com/turbot/steampipe/utils"
)
func UninstallWorkspaceDependencies(opts *InstallOpts) (*InstallData, error) {
func UninstallWorkspaceDependencies(ctx context.Context, opts *InstallOpts) (*InstallData, error) {
utils.LogTime("cmd.UninstallWorkspaceDependencies")
defer func() {
utils.LogTime("cmd.UninstallWorkspaceDependencies end")
if r := recover(); r != nil {
utils.ShowError(helpers.ToError(r))
utils.ShowError(ctx, helpers.ToError(r))
}
}()

View File

@@ -7,10 +7,10 @@ import (
"path/filepath"
"strings"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/filepaths"
"github.com/turbot/steampipe/ociinstaller"
"github.com/turbot/steampipe/ociinstaller/versionfile"
"github.com/turbot/steampipe/statushooks"
"github.com/turbot/steampipe/steampipeconfig/modconfig"
"github.com/turbot/steampipe/utils"
)
@@ -22,9 +22,9 @@ const (
)
// Remove removes an installed plugin
func Remove(image string, pluginConnections map[string][]modconfig.Connection) error {
spinner := display.ShowSpinner(fmt.Sprintf("Removing plugin %s", image))
defer display.StopSpinner(spinner)
func Remove(ctx context.Context, image string, pluginConnections map[string][]modconfig.Connection) error {
statushooks.SetStatus(ctx, fmt.Sprintf("Removing plugin %s", image))
defer statushooks.Done(ctx)
fullPluginName := ociinstaller.NewSteampipeImageRef(image).DisplayImageRef()
@@ -60,7 +60,7 @@ func Remove(image string, pluginConnections map[string][]modconfig.Connection) e
connFiles := Unique(files)
if len(connFiles) > 0 {
display.StopSpinner(spinner)
str := []string{fmt.Sprintf("\nUninstalled plugin %s\n\nNote: the following %s %s %s steampipe %s using the '%s' plugin:", image, utils.Pluralize("file", len(connFiles)), utils.Pluralize("has", len(connFiles)), utils.Pluralize("a", len(conns)), utils.Pluralize("connection", len(conns)), image)}
for _, file := range connFiles {
str = append(str, fmt.Sprintf("\n \t* file: %s", file))
@@ -78,7 +78,7 @@ func Remove(image string, pluginConnections map[string][]modconfig.Connection) e
}
}
str = append(str, fmt.Sprintf("\nPlease remove %s to continue using steampipe", utils.Pluralize("it", len(connFiles))))
fmt.Println(strings.Join(str, "\n"))
statushooks.Message(ctx, str...)
fmt.Println()
}
return err

View File

@@ -37,7 +37,7 @@ func NewInitData(ctx context.Context, w *workspace.Workspace, args []string) *In
return i
}
func (i *InitData) Cleanup() {
func (i *InitData) Cleanup(ctx context.Context) {
// cancel any ongoing operation
if i.cancel != nil {
i.cancel()
@@ -50,7 +50,7 @@ func (i *InitData) Cleanup() {
// if a client was initialised, close it
if i.Client != nil {
i.Client.Close()
i.Client.Close(ctx)
}
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/spf13/viper"
"github.com/turbot/steampipe/constants"
"github.com/turbot/steampipe/contexthelpers"
"github.com/turbot/steampipe/db/db_common"
"github.com/turbot/steampipe/display"
"github.com/turbot/steampipe/interactive"
@@ -13,34 +14,36 @@ import (
"github.com/turbot/steampipe/utils"
)
func RunInteractiveSession(initData *query.InitData) {
func RunInteractiveSession(ctx context.Context, initData *query.InitData) {
utils.LogTime("execute.RunInteractiveSession start")
defer utils.LogTime("execute.RunInteractiveSession end")
// the db executor sends result data over resultsStreamer
resultsStreamer, err := interactive.RunInteractivePrompt(initData)
resultsStreamer, err := interactive.RunInteractivePrompt(ctx, initData)
utils.FailOnError(err)
// print the data as it comes
for r := range resultsStreamer.Results {
display.ShowOutput(r)
display.ShowOutput(ctx, r)
// signal to the resultStreamer that we are done with this chunk of the stream
resultsStreamer.AllResultsRead()
}
}
func RunBatchSession(ctx context.Context, initData *query.InitData) int {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
// start cancel handler to intercept interurpts and cancel the context
contexthelpers.StartCancelHandler(cancel)
// wait for init
<-initData.Loaded
if err := initData.Result.Error; err != nil {
utils.FailOnError(err)
}
// ensure we close client
defer func() {
if initData.Client != nil {
initData.Client.Close()
}
}()
defer initData.Cleanup(ctx)
// display any initialisation messages/warnings
initData.Result.DisplayMessages()
@@ -86,7 +89,7 @@ func executeQuery(ctx context.Context, queryString string, client db_common.Clie
// print the data as it comes
for r := range resultsStreamer.Results {
display.ShowOutput(r)
display.ShowOutput(ctx, r)
// signal to the resultStreamer that we are done with this result
resultsStreamer.AllResultsRead()
}

View File

@@ -166,7 +166,7 @@ func (e *ReportExecutionTree) ExecuteNode(ctx context.Context, name string) erro
}
func (e *ReportExecutionTree) executePanelSQL(ctx context.Context, query string) ([][]interface{}, error) {
queryResult, err := e.client.ExecuteSync(ctx, query, true)
queryResult, err := e.client.ExecuteSync(ctx, query)
if err != nil {
return nil, err
}

View File

@@ -46,7 +46,7 @@ type ReportClientInfo struct {
}
func NewServer(ctx context.Context) (*Server, error) {
dbClient, err := db_local.GetLocalClient(ctx, constants.InvokerReport)
var dbClient, err = db_local.GetLocalClient(ctx, constants.InvokerReport)
if err != nil {
return nil, err
}
@@ -57,7 +57,7 @@ func NewServer(ctx context.Context) (*Server, error) {
}
refreshResult.ShowWarnings()
loadedWorkspace, err := workspace.Load(viper.GetString(constants.ArgWorkspaceChDir))
loadedWorkspace, err := workspace.Load(ctx, viper.GetString(constants.ArgWorkspaceChDir))
if err != nil {
return nil, err
}
@@ -78,7 +78,7 @@ func NewServer(ctx context.Context) (*Server, error) {
}
loadedWorkspace.RegisterReportEventHandler(server.HandleWorkspaceUpdate)
err = loadedWorkspace.SetupWatcher(dbClient, nil)
err = loadedWorkspace.SetupWatcher(ctx, dbClient, nil)
return server, err
}
@@ -129,10 +129,10 @@ func (s *Server) Start() {
StartAPI(s.context, s.webSocket)
}
func (s *Server) Shutdown() {
func (s *Server) Shutdown(ctx context.Context) {
// Close the DB client
if s.dbClient != nil {
s.dbClient.Close()
s.dbClient.Close(ctx)
}
if s.webSocket != nil {

42
statushooks/context.go Normal file
View File

@@ -0,0 +1,42 @@
package statushooks
import (
"context"
"github.com/turbot/steampipe/contexthelpers"
)
var (
contextKeyStatusHook = contexthelpers.ContextKey("status_hook")
)
func DisableStatusHooks(ctx context.Context) context.Context {
return AddStatusHooksToContext(ctx, NullHooks)
}
func AddStatusHooksToContext(ctx context.Context, statusHooks StatusHooks) context.Context {
return context.WithValue(ctx, contextKeyStatusHook, statusHooks)
}
func StatusHooksFromContext(ctx context.Context) StatusHooks {
if ctx == nil {
return NullHooks
}
if val, ok := ctx.Value(contextKeyStatusHook).(StatusHooks); ok {
return val
}
// no status hook in context - return null status hook
return NullHooks
}
func SetStatus(ctx context.Context, msg string) {
StatusHooksFromContext(ctx).SetStatus(msg)
}
func Done(ctx context.Context) {
StatusHooksFromContext(ctx).Done()
}
func Message(ctx context.Context, msgs ...string) {
StatusHooksFromContext(ctx).Message(msgs...)
}

View File

@@ -0,0 +1,15 @@
package statushooks
type StatusHooks interface {
SetStatus(string)
Done()
Message(...string)
}
var NullHooks = &NullStatusHook{}
type NullStatusHook struct{}
func (*NullStatusHook) SetStatus(string) {}
func (*NullStatusHook) Done() {}
func (*NullStatusHook) Message(...string) {}

137
statusspinner/spinner.go Normal file
View File

@@ -0,0 +1,137 @@
package statusspinner
import (
"fmt"
"os"
"strings"
"time"
"github.com/briandowns/spinner"
"github.com/karrick/gows"
)
//
// spinner format:
// <spinner><space><message><space><dot><dot><dot><cursor>
// 1 1 [.......] 1 1 1 1 1
// We need at least seven characters to show the spinner properly
//
// Not using the (…) character, since it is too small
//
const minSpinnerWidth = 7
// StatusSpinner is a struct which implements StatusHooks, and uses a spinner to display status messages
type StatusSpinner struct {
spinner *spinner.Spinner
delay time.Duration
cancel chan struct{}
}
type StatusSpinnerOpt func(*StatusSpinner)
func WithMessage(msg string) StatusSpinnerOpt {
return func(s *StatusSpinner) {
s.UpdateSpinnerMessage(msg)
}
}
func WithDelay(delay time.Duration) StatusSpinnerOpt {
return func(s *StatusSpinner) {
s.delay = delay
}
}
func NewStatusSpinner(opts ...StatusSpinnerOpt) *StatusSpinner {
res := &StatusSpinner{}
res.spinner = spinner.New(
spinner.CharSets[14],
100*time.Millisecond,
spinner.WithHiddenCursor(true),
spinner.WithWriter(os.Stdout),
)
for _, opt := range opts {
opt(res)
}
return res
}
// SetStatus implements StatusHooks
func (s *StatusSpinner) SetStatus(msg string) {
s.UpdateSpinnerMessage(msg)
if !s.spinner.Active() {
s.startSpinner()
}
}
func (s *StatusSpinner) startSpinner() {
if s.cancel != nil {
// if there is a cancel channel, we are already waiting for the service to start after a delay
return
}
if s.delay == 0 {
s.spinner.Start()
return
}
s.cancel = make(chan struct{}, 1)
go func() {
select {
case <-s.cancel:
case <-time.After(s.delay):
s.spinner.Start()
s.cancel = nil
}
time.Sleep(50 * time.Millisecond)
}()
}
func (s *StatusSpinner) Message(msgs ...string) {
if s.spinner.Active() {
s.spinner.Stop()
defer s.spinner.Start()
}
for _, msg := range msgs {
fmt.Println(msg)
}
}
// Done implements StatusHooks
func (s *StatusSpinner) Done() {
if s.cancel != nil {
close(s.cancel)
}
s.closeSpinner()
}
// UpdateSpinnerMessage updates the message of the given spinner
func (s *StatusSpinner) UpdateSpinnerMessage(newMessage string) {
newMessage = s.truncateSpinnerMessageToScreen(newMessage)
s.spinner.Suffix = fmt.Sprintf(" %s", newMessage)
}
func (s *StatusSpinner) closeSpinner() {
if s.spinner != nil {
s.spinner.Stop()
}
}
func (s *StatusSpinner) truncateSpinnerMessageToScreen(msg string) string {
if len(strings.TrimSpace(msg)) == 0 {
// if this is a blank message, return it as is
return msg
}
maxCols, _, _ := gows.GetWinSize()
// if the screen is smaller than the minimum spinner width, we cannot truncate
if maxCols < minSpinnerWidth {
return msg
}
availableColumns := maxCols - minSpinnerWidth
if len(msg) > availableColumns {
msg = msg[:availableColumns]
msg = fmt.Sprintf("%s ...", msg)
}
return msg
}

View File

@@ -9,11 +9,12 @@ import (
"github.com/fatih/color"
"github.com/shiena/ansicolor"
"github.com/turbot/steampipe/statushooks"
)
var (
colorErr = color.RedString("Error")
colorWarn = color.YellowString("Warning")
colorErr = color.RedString("Error")
colorWarn = color.YellowString("Warning")
)
func init() {
@@ -34,20 +35,22 @@ func FailOnErrorWithMessage(err error, message string) {
}
}
func ShowError(err error) {
func ShowError(ctx context.Context, err error) {
if err == nil {
return
}
err = HandleCancelError(err)
statushooks.Done(ctx)
fmt.Fprintf(color.Output, "%s: %v\n", colorErr, TransformErrorToSteampipe(err))
}
// ShowErrorWithMessage displays the given error nicely with the given message
func ShowErrorWithMessage(err error, message string) {
func ShowErrorWithMessage(ctx context.Context, err error, message string) {
if err == nil {
return
}
err = HandleCancelError(err)
statushooks.Done(ctx)
fmt.Fprintf(color.Output, "%s: %s - %v\n", colorErr, message, TransformErrorToSteampipe(err))
}

View File

@@ -2,6 +2,7 @@ package workspace
import (
"bufio"
"context"
"fmt"
"log"
"os"
@@ -42,7 +43,7 @@ type Workspace struct {
exclusions []string
// should we load/watch files recursively
listFlag filehelpers.ListFlag
fileWatcherErrorHandler func(error)
fileWatcherErrorHandler func(context.Context, error)
watcherError error
// event handlers
reportEventHandlers []reportevents.ReportEventHandler
@@ -53,7 +54,7 @@ type Workspace struct {
}
// Load creates a Workspace and loads the workspace mod
func Load(workspacePath string) (*Workspace, error) {
func Load(ctx context.Context, workspacePath string) (*Workspace, error) {
utils.LogTime("workspace.Load start")
defer utils.LogTime("workspace.Load end")
@@ -72,7 +73,7 @@ func Load(workspacePath string) (*Workspace, error) {
}
// load the workspace mod
if err := workspace.loadWorkspaceMod(); err != nil {
if err := workspace.loadWorkspaceMod(ctx); err != nil {
return nil, err
}
@@ -101,7 +102,7 @@ func LoadResourceNames(workspacePath string) (*modconfig.WorkspaceResources, err
return workspace.loadWorkspaceResourceName()
}
func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(error)) error {
func (w *Workspace) SetupWatcher(ctx context.Context, client db_common.Client, errorHandler func(context.Context, error)) error {
watcherOptions := &utils.WatcherOptions{
Directories: []string{w.Path},
Include: filehelpers.InclusionsFromExtensions(steampipeconfig.GetModFileExtensions()),
@@ -112,7 +113,7 @@ func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(erro
// decide how to handle them
// OnError: errCallback,
OnChange: func(events []fsnotify.Event) {
w.handleFileWatcherEvent(client, events)
w.handleFileWatcherEvent(ctx, client, events)
},
}
watcher, err := utils.NewWatcher(watcherOptions)
@@ -127,9 +128,9 @@ func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(erro
// after a file watcher event
w.fileWatcherErrorHandler = errorHandler
if w.fileWatcherErrorHandler == nil {
w.fileWatcherErrorHandler = func(err error) {
w.fileWatcherErrorHandler = func(ctx context.Context, err error) {
fmt.Println()
utils.ShowErrorWithMessage(err, "Failed to reload mod from file watcher")
utils.ShowErrorWithMessage(ctx, err, "Failed to reload mod from file watcher")
}
}
@@ -249,11 +250,11 @@ func (w *Workspace) setModfileExists() {
}
}
func (w *Workspace) loadWorkspaceMod() error {
func (w *Workspace) loadWorkspaceMod(ctx context.Context) error {
// clear all resource maps
w.reset()
// load and evaluate all variables
inputVariables, err := w.getAllVariables()
inputVariables, err := w.getAllVariables(ctx)
if err != nil {
return err
}

View File

@@ -22,7 +22,7 @@ func (w *Workspace) RegisterReportEventHandler(handler reportevents.ReportEventH
w.reportEventHandlers = append(w.reportEventHandlers, handler)
}
func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsnotify.Event) {
func (w *Workspace) handleFileWatcherEvent(ctx context.Context, client db_common.Client, events []fsnotify.Event) {
w.loadLock.Lock()
defer w.loadLock.Unlock()
@@ -32,11 +32,11 @@ func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsn
prevResourceMaps := w.GetResourceMaps()
// now reload the workspace
err := w.loadWorkspaceMod()
err := w.loadWorkspaceMod(ctx)
if err != nil {
// check the existing watcher error - if we are already in an error state, do not show error
if w.watcherError == nil {
w.fileWatcherErrorHandler(utils.PrefixError(err, "Failed to reload workspace"))
w.fileWatcherErrorHandler(ctx, utils.PrefixError(err, "Failed to reload workspace"))
}
// now set watcher error to new error
w.watcherError = err
@@ -53,7 +53,7 @@ func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsn
res := client.RefreshSessions(context.Background())
if res.Error != nil || len(res.Warnings) > 0 {
fmt.Println()
utils.ShowErrorWithMessage(res.Error, "error when refreshing session data")
utils.ShowErrorWithMessage(ctx, res.Error, "error when refreshing session data")
utils.ShowWarning(strings.Join(res.Warnings, "\n"))
if w.onFileWatcherEventMessages != nil {
w.onFileWatcherEventMessages()

View File

@@ -1,6 +1,7 @@
package workspace
import (
"context"
"fmt"
"path/filepath"
"strings"
@@ -135,7 +136,7 @@ var testCasesLoadWorkspace = map[string]loadWorkspaceTest{
func TestLoadWorkspace(t *testing.T) {
for name, test := range testCasesLoadWorkspace {
workspacePath, err := filepath.Abs(test.source)
workspace, err := Load(workspacePath)
workspace, err := Load(context.Background(), workspacePath)
if err != nil {
if test.expected != "ERROR" {

View File

@@ -1,6 +1,7 @@
package workspace
import (
"context"
"fmt"
"sort"
"strings"
@@ -14,7 +15,7 @@ import (
"github.com/turbot/steampipe/utils"
)
func (w *Workspace) getAllVariables() (map[string]*modconfig.Variable, error) {
func (w *Workspace) getAllVariables(ctx context.Context) (map[string]*modconfig.Variable, error) {
// build options used to load workspace
runCtx, err := w.getRunContext()
if err != nil {
@@ -40,7 +41,7 @@ func (w *Workspace) getAllVariables() (map[string]*modconfig.Variable, error) {
return nil, err
}
if err := validateVariables(variableMap, inputVariables); err != nil {
if err := validateVariables(ctx, variableMap, inputVariables); err != nil {
return nil, err
}
@@ -73,21 +74,21 @@ func (w *Workspace) getInputVariables(variableMap map[string]*modconfig.Variable
return parsedValues, diags.Err()
}
func validateVariables(variableMap map[string]*modconfig.Variable, variables inputvars.InputValues) error {
func validateVariables(ctx context.Context, variableMap map[string]*modconfig.Variable, variables inputvars.InputValues) error {
diags := inputvars.CheckInputVariables(variableMap, variables)
if diags.HasErrors() {
displayValidationErrors(diags)
displayValidationErrors(ctx, diags)
// return empty error
return modconfig.VariableValidationFailedError{}
}
return nil
}
func displayValidationErrors(diags tfdiags.Diagnostics) {
func displayValidationErrors(ctx context.Context, diags tfdiags.Diagnostics) {
fmt.Println()
for i, diag := range diags {
utils.ShowError(fmt.Errorf("%s", constants.Bold(diag.Description().Summary)))
utils.ShowError(ctx, fmt.Errorf("%s", constants.Bold(diag.Description().Summary)))
fmt.Println(diag.Description().Detail)
if i < len(diags)-1 {
fmt.Println()