diff --git a/cmd/check.go b/cmd/check.go index b0f68c9d0..4e28304ea 100644 --- a/cmd/check.go +++ b/cmd/check.go @@ -10,12 +10,12 @@ import ( "sync" "time" - "github.com/briandowns/spinner" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe/cmdconfig" "github.com/turbot/steampipe/constants" + "github.com/turbot/steampipe/contexthelpers" "github.com/turbot/steampipe/control" "github.com/turbot/steampipe/control/controldisplay" "github.com/turbot/steampipe/control/controlexecute" @@ -24,6 +24,7 @@ import ( "github.com/turbot/steampipe/db/db_local" "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/modinstaller" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/utils" "github.com/turbot/steampipe/workspace" ) @@ -88,16 +89,21 @@ You may specify one or more benchmarks or controls to run (separated by a space) func runCheckCmd(cmd *cobra.Command, args []string) { utils.LogTime("runCheckCmd start") initData := &control.InitData{} + + // setup a cancel context and start cancel handler + ctx, cancel := context.WithCancel(cmd.Context()) + contexthelpers.StartCancelHandler(cancel) + defer func() { utils.LogTime("runCheckCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } if initData.Client != nil { log.Printf("[TRACE] close client") - initData.Client.Close() + initData.Client.Close(ctx) } if initData.Workspace != nil { initData.Workspace.Close() @@ -105,24 +111,22 @@ func runCheckCmd(cmd *cobra.Command, args []string) { }() // verify we have an argument - if !validateArgs(cmd, args) { + if !validateArgs(ctx, cmd, args) { return } - var spinner *spinner.Spinner - if viper.GetBool(constants.ArgProgress) { - spinner = display.ShowSpinner("Initializing...") + // if progress is disabled, update context to contain a null status hooks object + if !viper.GetBool(constants.ArgProgress) { + statushooks.DisableStatusHooks(ctx) } // initialise - initData = initialiseCheck(cmd.Context(), spinner) - display.StopSpinner(spinner) - if shouldExit := handleCheckInitResult(initData); shouldExit { + initData = initialiseCheck(ctx) + if shouldExit := handleCheckInitResult(ctx, initData); shouldExit { return } // pull out useful properties - ctx := initData.Ctx workspace := initData.Workspace client := initData.Client failures := 0 @@ -165,7 +169,7 @@ func runCheckCmd(cmd *cobra.Command, args []string) { exportWaitGroup.Wait() if len(exportErrors) > 0 { - utils.ShowError(utils.CombineErrors(exportErrors...)) + utils.ShowError(ctx, utils.CombineErrors(exportErrors...)) } if shouldPrintTiming() { @@ -176,10 +180,10 @@ func runCheckCmd(cmd *cobra.Command, args []string) { exitCode = failures } -func validateArgs(cmd *cobra.Command, args []string) bool { +func validateArgs(ctx context.Context, cmd *cobra.Command, args []string) bool { if len(args) == 0 { fmt.Println() - utils.ShowError(fmt.Errorf("you must provide at least one argument")) + utils.ShowError(ctx, fmt.Errorf("you must provide at least one argument")) fmt.Println() cmd.Help() fmt.Println() @@ -189,11 +193,13 @@ func validateArgs(cmd *cobra.Command, args []string) bool { return true } -func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.InitData { +func initialiseCheck(ctx context.Context) *control.InitData { + statushooks.SetStatus(ctx, "Initializing...") + defer statushooks.Done(ctx) + initData := &control.InitData{ Result: &db_common.InitResult{}, } - if viper.GetBool(constants.ArgModInstall) { opts := &modinstaller.InstallOpts{WorkspacePath: viper.GetString(constants.ArgWorkspaceChDir)} _, err := modinstaller.InstallWorkspaceDependencies(opts) @@ -202,9 +208,6 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini return initData } } - - cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, false) - err := validateOutputFormat() if err != nil { initData.Result.Error = err @@ -217,10 +220,6 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini return initData } - ctx, cancel := context.WithCancel(ctx) - startCancelHandler(cancel) - initData.Ctx = ctx - // set color schema err = initialiseColorScheme() if err != nil { @@ -228,7 +227,7 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini return initData } // load workspace - initData.Workspace, err = loadWorkspacePromptingForVariables(ctx, spinner) + initData.Workspace, err = loadWorkspacePromptingForVariables(ctx) if err != nil { if !utils.IsCancelledError(err) { err = utils.PrefixError(err, "failed to load workspace") @@ -248,18 +247,14 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini initData.Result.AddWarnings("no controls found in current workspace") } - display.UpdateSpinnerMessage(spinner, "Connecting to service...") + statushooks.SetStatus(ctx, "Connecting to service...") // get a client var client db_common.Client if connectionString := viper.GetString(constants.ArgConnectionString); connectionString != "" { client, err = db_client.NewDbClient(ctx, connectionString) } else { - // stop the spinner - display.StopSpinner(spinner) // when starting the database, installers may trigger their own spinners client, err = db_local.GetLocalClient(ctx, constants.InvokerCheck) - // resume the spinner - display.ResumeSpinner(spinner) } if err != nil { @@ -288,13 +283,13 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini return initData } -func handleCheckInitResult(initData *control.InitData) bool { +func handleCheckInitResult(ctx context.Context, initData *control.InitData) bool { // if there is an error or cancellation we bomb out // check for the various kinds of failures utils.FailOnError(initData.Result.Error) // cancelled? - if initData.Ctx != nil { - utils.FailOnError(initData.Ctx.Err()) + if ctx != nil { + utils.FailOnError(ctx.Err()) } // if there is a usage warning we display it diff --git a/cmd/mod.go b/cmd/mod.go index f382e0fa9..d9fa41c52 100644 --- a/cmd/mod.go +++ b/cmd/mod.go @@ -53,11 +53,12 @@ func modInstallCmd() *cobra.Command { } func runModInstallCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("cmd.runModInstallCmd") defer func() { utils.LogTime("cmd.runModInstallCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() @@ -88,17 +89,18 @@ func modUninstallCmd() *cobra.Command { } func runModUninstallCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("cmd.runModInstallCmd") defer func() { utils.LogTime("cmd.runModInstallCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() opts := newInstallOpts(cmd, args...) - installData, err := modinstaller.UninstallWorkspaceDependencies(opts) + installData, err := modinstaller.UninstallWorkspaceDependencies(ctx, opts) utils.FailOnError(err) fmt.Println(modinstaller.BuildUninstallSummary(installData)) @@ -122,11 +124,12 @@ func modUpdateCmd() *cobra.Command { } func runModUpdateCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("cmd.runModUpdateCmd") defer func() { utils.LogTime("cmd.runModUpdateCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() @@ -153,11 +156,12 @@ func modListCmd() *cobra.Command { } func runModListCmd(cmd *cobra.Command, _ []string) { + ctx := cmd.Context() utils.LogTime("cmd.runModListCmd") defer func() { utils.LogTime("cmd.runModListCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() @@ -186,11 +190,12 @@ func modInitCmd() *cobra.Command { } func runModInitCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("cmd.runModInitCmd") defer func() { utils.LogTime("cmd.runModInitCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() diff --git a/cmd/plugin.go b/cmd/plugin.go index c35937301..10f21cccd 100644 --- a/cmd/plugin.go +++ b/cmd/plugin.go @@ -16,6 +16,7 @@ import ( "github.com/turbot/steampipe/ociinstaller/versionfile" "github.com/turbot/steampipe/plugin" "github.com/turbot/steampipe/statefile" + "github.com/turbot/steampipe/statusspinner" "github.com/turbot/steampipe/steampipeconfig" "github.com/turbot/steampipe/steampipeconfig/modconfig" "github.com/turbot/steampipe/utils" @@ -176,11 +177,12 @@ Example: // exitCode=3 For errors related to loading state, loading version data or an issue contacting the update server. // exitCode=4 For plugin listing failures func runPluginInstallCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runPluginInstallCmd install") defer func() { utils.LogTime("runPluginInstallCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() @@ -193,7 +195,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) { if len(plugins) == 0 { fmt.Println() - utils.ShowError(fmt.Errorf("you need to provide at least one plugin to install")) + utils.ShowError(ctx, fmt.Errorf("you need to provide at least one plugin to install")) fmt.Println() cmd.Help() fmt.Println() @@ -204,7 +206,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) { // a leading blank line - since we always output multiple lines fmt.Println() - spinner := display.ShowSpinner("") + statusSpinner := statusspinner.NewStatusSpinner() for _, p := range plugins { isPluginExists, _ := plugin.Exists(p) @@ -217,7 +219,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) { }) continue } - display.UpdateSpinnerMessage(spinner, fmt.Sprintf("Installing plugin: %s", p)) + statusSpinner.SetStatus(fmt.Sprintf("Installing plugin: %s", p)) image, err := plugin.Install(cmd.Context(), p) if err != nil { msg := "" @@ -250,9 +252,9 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) { }) } - display.StopSpinner(spinner) + statusSpinner.Done() - refreshConnectionsIfNecessary(cmd.Context(), installReports, false) + refreshConnectionsIfNecessary(cmd.Context(), installReports, true) display.PrintInstallReports(installReports, false) // a concluding blank line - since we always output multiple lines @@ -260,11 +262,12 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) { } func runPluginUpdateCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runPluginUpdateCmd install") defer func() { utils.LogTime("runPluginUpdateCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() @@ -275,7 +278,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) { plugins, err := resolveUpdatePluginsFromArgs(args) if err != nil { fmt.Println() - utils.ShowError(err) + utils.ShowError(ctx, err) fmt.Println() cmd.Help() fmt.Println() @@ -285,7 +288,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) { state, err := statefile.LoadState() if err != nil { - utils.ShowError(fmt.Errorf("could not load state")) + utils.ShowError(ctx, fmt.Errorf("could not load state")) exitCode = 3 return } @@ -293,7 +296,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) { // load up the version file data versionData, err := versionfile.LoadPluginVersionFile() if err != nil { - utils.ShowError(fmt.Errorf("error loading current plugin data")) + utils.ShowError(ctx, fmt.Errorf("error loading current plugin data")) exitCode = 3 return } @@ -340,14 +343,14 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) { return } - spinner := display.ShowSpinner("Checking for available updates") + statusSpinner := statusspinner.NewStatusSpinner(statusspinner.WithMessage("Checking for available updates")) reports := plugin.GetUpdateReport(state.InstallationID, runUpdatesFor) - display.StopSpinner(spinner) + statusSpinner.Done() if len(reports) == 0 { // this happens if for some reason the update server could not be contacted, // in which case we get back an empty map - utils.ShowError(fmt.Errorf("there was an issue contacting the update server. Please try later")) + utils.ShowError(ctx, fmt.Errorf("there was an issue contacting the update server. Please try later")) exitCode = 3 return } @@ -363,9 +366,9 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) { continue } - spinner := display.ShowSpinner(fmt.Sprintf("Updating plugin %s...", report.CheckResponse.Name)) + statusSpinner.SetStatus(fmt.Sprintf("Updating plugin %s...", report.CheckResponse.Name)) image, err := plugin.Install(cmd.Context(), report.Plugin.Name) - display.StopSpinner(spinner) + statusSpinner.Done() if err != nil { msg := "" if strings.HasSuffix(err.Error(), "not found") { @@ -398,7 +401,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) { }) } - refreshConnectionsIfNecessary(cmd.Context(), updateReports, true) + refreshConnectionsIfNecessary(cmd.Context(), updateReports, false) display.PrintInstallReports(updateReports, true) // a concluding blank line - since we always output multiple lines @@ -421,7 +424,7 @@ func resolveUpdatePluginsFromArgs(args []string) ([]string, error) { } // start service if necessary and refresh connections -func refreshConnectionsIfNecessary(ctx context.Context, reports []display.InstallReport, isUpdate bool) error { +func refreshConnectionsIfNecessary(ctx context.Context, reports []display.InstallReport, shouldReload bool) error { // get count of skipped reports skipped := 0 for _, report := range reports { @@ -436,7 +439,7 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal } // reload the config, since an installation MUST have created a new config file - if !isUpdate { + if shouldReload { var cmd = viper.Get(constants.ConfigKeyActiveCommand).(*cobra.Command) config, err := steampipeconfig.LoadSteampipeConfig(viper.GetString(constants.ArgWorkspaceChDir), cmd.Name()) if err != nil { @@ -449,7 +452,7 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal if err != nil { return err } - defer client.Close() + defer client.Close(ctx) res := client.RefreshConnectionAndSearchPaths(ctx) if res.Error != nil { return res.Error @@ -460,24 +463,26 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal } func runPluginListCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runPluginListCmd list") defer func() { utils.LogTime("runPluginListCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() + pluginConnectionMap, err := getPluginConnectionMap(cmd.Context()) if err != nil { - utils.ShowErrorWithMessage(err, "Plugin Listing failed") + utils.ShowErrorWithMessage(ctx, err, "Plugin Listing failed") exitCode = 4 return } list, err := plugin.List(pluginConnectionMap) if err != nil { - utils.ShowErrorWithMessage(err, "Plugin Listing failed") + utils.ShowErrorWithMessage(ctx, err, "Plugin Listing failed") exitCode = 4 } headers := []string{"Name", "Version", "Connections"} @@ -489,35 +494,37 @@ func runPluginListCmd(cmd *cobra.Command, args []string) { } func runPluginUninstallCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runPluginUninstallCmd uninstall") defer func() { utils.LogTime("runPluginUninstallCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) exitCode = 1 } }() if len(args) == 0 { fmt.Println() - utils.ShowError(fmt.Errorf("you need to provide at least one plugin to uninstall")) + utils.ShowError(ctx, fmt.Errorf("you need to provide at least one plugin to uninstall")) fmt.Println() cmd.Help() fmt.Println() exitCode = 2 return } - connectionMap, err := getPluginConnectionMap(cmd.Context()) + + connectionMap, err := getPluginConnectionMap(ctx) if err != nil { - utils.ShowError(err) + utils.ShowError(ctx, err) exitCode = 4 return } for _, p := range args { - if err := plugin.Remove(p, connectionMap); err != nil { - utils.ShowErrorWithMessage(err, fmt.Sprintf("Failed to uninstall plugin '%s'", p)) + if err := plugin.Remove(ctx, p, connectionMap); err != nil { + utils.ShowErrorWithMessage(ctx, err, fmt.Sprintf("Failed to uninstall plugin '%s'", p)) } } } @@ -528,7 +535,7 @@ func getPluginConnectionMap(ctx context.Context) (map[string][]modconfig.Connect if err != nil { return nil, err } - defer client.Close() + defer client.Close(ctx) res := client.RefreshConnectionAndSearchPaths(ctx) if res.Error != nil { return nil, res.Error diff --git a/cmd/plugin_manager.go b/cmd/plugin_manager.go index 5cd1c50e8..c0cdf126e 100644 --- a/cmd/plugin_manager.go +++ b/cmd/plugin_manager.go @@ -33,6 +33,7 @@ func pluginManagerCmd() *cobra.Command { } func runPluginManagerCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() logger := createPluginManagerLog() log.Printf("[INFO] starting plugin manager") @@ -51,7 +52,7 @@ func runPluginManagerCmd(cmd *cobra.Command, args []string) { connectionWatcher, err := connectionwatcher.NewConnectionWatcher(pluginManager.SetConnectionConfigMap) if err != nil { log.Printf("[WARN] failed to create connection watcher: %s", err.Error()) - utils.ShowError(err) + utils.ShowError(ctx, err) os.Exit(1) } diff --git a/cmd/query.go b/cmd/query.go index 4835907a6..5c182705b 100644 --- a/cmd/query.go +++ b/cmd/query.go @@ -6,10 +6,8 @@ import ( "fmt" "log" "os" - "os/signal" "strings" - "github.com/briandowns/spinner" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/turbot/go-kit/helpers" @@ -18,6 +16,7 @@ import ( "github.com/turbot/steampipe/interactive" "github.com/turbot/steampipe/query" "github.com/turbot/steampipe/query/queryexecute" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/steampipeconfig/modconfig" "github.com/turbot/steampipe/utils" "github.com/turbot/steampipe/workspace" @@ -80,11 +79,12 @@ Examples: } func runQueryCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("cmd.runQueryCmd start") defer func() { utils.LogTime("cmd.runQueryCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) } }() @@ -97,19 +97,11 @@ func runQueryCmd(cmd *cobra.Command, args []string) { // enable spinner only in interactive mode interactiveMode := len(args) == 0 - cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, interactiveMode) // set config to indicate whether we are running an interactive query viper.Set(constants.ConfigKeyInteractive, interactiveMode) - ctx := cmd.Context() - if !interactiveMode { - c, cancel := context.WithCancel(ctx) - startCancelHandler(cancel) - ctx = c - } - // load the workspace - w, err := loadWorkspacePromptingForVariables(ctx, nil) + w, err := loadWorkspacePromptingForVariables(ctx) utils.FailOnErrorWithMessage(err, "failed to load workspace") // se we have loaded a workspace - be sure to close it @@ -119,7 +111,7 @@ func runQueryCmd(cmd *cobra.Command, args []string) { initData := query.NewInitData(ctx, w, args) if interactiveMode { - queryexecute.RunInteractiveSession(initData) + queryexecute.RunInteractiveSession(ctx, initData) } else { // set global exit code exitCode = queryexecute.RunBatchSession(ctx, initData) @@ -144,10 +136,10 @@ func getPipedStdinData() string { return stdinData } -func loadWorkspacePromptingForVariables(ctx context.Context, spinner *spinner.Spinner) (*workspace.Workspace, error) { +func loadWorkspacePromptingForVariables(ctx context.Context) (*workspace.Workspace, error) { workspacePath := viper.GetString(constants.ArgWorkspaceChDir) - w, err := workspace.Load(workspacePath) + w, err := workspace.Load(ctx, workspacePath) if err == nil { return w, nil } @@ -156,29 +148,13 @@ func loadWorkspacePromptingForVariables(ctx context.Context, spinner *spinner.Sp if !ok { return nil, err } - if spinner != nil { - spinner.Stop() - } // so we have missing variables - prompt for them + // first hide spinner if it is there + statushooks.Done(ctx) if err := interactive.PromptForMissingVariables(ctx, missingVariablesError.MissingVariables); err != nil { log.Printf("[TRACE] Interactive variables prompting returned error %v", err) return nil, err } - if spinner != nil { - spinner.Start() - } // ok we should have all variables now - reload workspace - return workspace.Load(workspacePath) -} - -func startCancelHandler(cancel context.CancelFunc) { - sigIntChannel := make(chan os.Signal, 1) - signal.Notify(sigIntChannel, os.Interrupt) - go func() { - <-sigIntChannel - log.Println("[TRACE] got SIGINT") - // call context cancellation function - cancel() - // leave the channel open - any subsequent interrupts hits will be ignored - }() + return workspace.Load(ctx, workspacePath) } diff --git a/cmd/report.go b/cmd/report.go index 190da8b98..40d295e79 100644 --- a/cmd/report.go +++ b/cmd/report.go @@ -4,11 +4,11 @@ import ( "context" "github.com/spf13/cobra" - "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe-plugin-sdk/logging" "github.com/turbot/steampipe/cmdconfig" "github.com/turbot/steampipe/constants" + "github.com/turbot/steampipe/contexthelpers" "github.com/turbot/steampipe/report/reportserver" "github.com/turbot/steampipe/utils" ) @@ -29,18 +29,17 @@ func reportCmd() *cobra.Command { } func runReportCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() logging.LogTime("runReportCmd start") defer func() { logging.LogTime("runReportCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) } }() - cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, false) - ctx, cancel := context.WithCancel(cmd.Context()) - startCancelHandler(cancel) + contexthelpers.StartCancelHandler(cancel) // start db if necessary //err := db_local.EnsureDbAndStartService(constants.InvokerReport, true) @@ -53,7 +52,7 @@ func runReportCmd(cmd *cobra.Command, args []string) { utils.FailOnError(err) } - defer server.Shutdown() + defer server.Shutdown(ctx) server.Start() } diff --git a/cmd/root.go b/cmd/root.go index 0ea115520..45c915e96 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,10 +1,14 @@ package cmd import ( + "context" "fmt" "log" "os" + "github.com/turbot/steampipe/statushooks" + "github.com/turbot/steampipe/statusspinner" + "github.com/hashicorp/go-hclog" "github.com/mattn/go-isatty" "github.com/spf13/cobra" @@ -183,6 +187,19 @@ func Execute() int { utils.LogTime("cmd.root.Execute start") defer utils.LogTime("cmd.root.Execute end") - rootCmd.Execute() + ctx := createRootContext() + rootCmd.ExecuteContext(ctx) return exitCode } + +// create the root context - create a status renderer and set as value +func createRootContext() context.Context { + var statusRenderer statushooks.StatusHooks = statushooks.NullHooks + // if the client is a TTY, inject a status spinner + if isatty.IsTerminal(os.Stdout.Fd()) { + statusRenderer = statusspinner.NewStatusSpinner() + } + + ctx := statushooks.AddStatusHooksToContext(context.Background(), statusRenderer) + return ctx +} diff --git a/cmd/service.go b/cmd/service.go index 12234e676..77e21ea7f 100644 --- a/cmd/service.go +++ b/cmd/service.go @@ -5,6 +5,8 @@ import ( "errors" "fmt" + "github.com/turbot/steampipe/statushooks" + "os" "os/signal" "strings" @@ -127,11 +129,12 @@ func serviceRestartCmd() *cobra.Command { } func runServiceStartCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runServiceStartCmd start") defer func() { utils.LogTime("runServiceStartCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) if exitCode == 0 { // there was an error and the exitcode // was not set to a non-zero value. @@ -163,7 +166,7 @@ func runServiceStartCmd(cmd *cobra.Command, args []string) { utils.FailOnError(startResult.Error) if startResult.Status == db_local.ServiceFailedToStart { - utils.ShowError(fmt.Errorf("steampipe service failed to start")) + utils.ShowError(ctx, fmt.Errorf("steampipe service failed to start")) return } @@ -191,18 +194,18 @@ func runServiceStartCmd(cmd *cobra.Command, args []string) { err = db_local.RefreshConnectionAndSearchPaths(ctx, invoker) if err != nil { - db_local.StopServices(false, constants.InvokerService, nil) + db_local.StopServices(ctx, false, constants.InvokerService) utils.FailOnError(err) } - printStatus(startResult.DbState, startResult.PluginManagerState) + printStatus(ctx, startResult.DbState, startResult.PluginManagerState) if viper.GetBool(constants.ArgForeground) { - runServiceInForeground(invoker) + runServiceInForeground(ctx, invoker) } } -func runServiceInForeground(invoker constants.Invoker) { +func runServiceInForeground(ctx context.Context, invoker constants.Invoker) { fmt.Println("Hit Ctrl+C to stop the service") sigIntChannel := make(chan os.Signal, 1) @@ -232,7 +235,7 @@ func runServiceInForeground(invoker constants.Invoker) { count, err := db_local.GetCountOfThirdPartyClients(context.Background()) if err != nil { // report the error in the off chance that there's one - utils.ShowError(err) + utils.ShowError(ctx, err) return } @@ -246,7 +249,7 @@ func runServiceInForeground(invoker constants.Invoker) { } fmt.Println("Stopping Steampipe service.") - db_local.StopServices(false, invoker, nil) + db_local.StopServices(ctx, false, invoker) fmt.Println("Steampipe service stopped.") return } @@ -254,11 +257,12 @@ func runServiceInForeground(invoker constants.Invoker) { } func runServiceRestartCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runServiceRestartCmd start") defer func() { utils.LogTime("runServiceRestartCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) if exitCode == 0 { // there was an error and the exitcode // was not set to a non-zero value. @@ -277,7 +281,7 @@ func runServiceRestartCmd(cmd *cobra.Command, args []string) { } // stop db - stopStatus, err := db_local.StopServices(viper.GetBool(constants.ArgForce), constants.InvokerService, nil) + stopStatus, err := db_local.StopServices(ctx, viper.GetBool(constants.ArgForce), constants.InvokerService) utils.FailOnErrorWithMessage(err, "could not stop current instance") if stopStatus != db_local.ServiceStopped { fmt.Println(` @@ -307,16 +311,17 @@ to force a restart. utils.FailOnError(err) fmt.Println("Steampipe service restarted.") - printStatus(startResult.DbState, startResult.PluginManagerState) + printStatus(ctx, startResult.DbState, startResult.PluginManagerState) } func runServiceStatusCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runServiceStatusCmd status") defer func() { utils.LogTime("runServiceStatusCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) } }() @@ -331,10 +336,10 @@ func runServiceStatusCmd(cmd *cobra.Command, args []string) { pmState, pmStateErr := pluginmanager.LoadPluginManagerState() if dbStateErr != nil || pmStateErr != nil { - utils.ShowError(composeStateError(dbStateErr, pmStateErr)) + utils.ShowError(ctx, composeStateError(dbStateErr, pmStateErr)) return } - printStatus(dbState, pmState) + printStatus(ctx, dbState, pmState) } } @@ -354,19 +359,17 @@ func composeStateError(dbStateErr error, pmStateErr error) error { } func runServiceStopCmd(cmd *cobra.Command, args []string) { + ctx := cmd.Context() utils.LogTime("runServiceStopCmd stop") - stoppedChan := make(chan bool, 1) var status db_local.StopStatus var err error var dbState *db_local.RunningDBInstanceInfo - spinner := display.StartSpinnerAfterDelay("", constants.SpinnerShowTimeout, stoppedChan) - defer func() { utils.LogTime("runServiceStopCmd end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) if exitCode == 0 { // there was an error and the exitcode // was not set to a non-zero value. @@ -378,20 +381,17 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) { force := cmdconfig.Viper().GetBool(constants.ArgForce) if force { - status, err = db_local.StopServices(force, constants.InvokerService, spinner) + status, err = db_local.StopServices(ctx, force, constants.InvokerService) } else { dbState, err = db_local.GetState() if err != nil { - display.StopSpinner(spinner) utils.FailOnErrorWithMessage(err, "could not stop Steampipe service") } if dbState == nil { - display.StopSpinner(spinner) fmt.Println("Steampipe service is not running.") return } if dbState.Invoker != constants.InvokerService { - display.StopSpinner(spinner) printRunningImplicit(dbState.Invoker) return } @@ -399,27 +399,22 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) { // check if there are any connected clients to the service connectedClientCount, err := db_local.GetCountOfThirdPartyClients(cmd.Context()) if err != nil { - display.StopSpinner(spinner) utils.FailOnErrorWithMessage(err, "error during service stop") } if connectedClientCount > 0 { - display.StopSpinner(spinner) printClientsConnected() return } - status, _ = db_local.StopServices(false, constants.InvokerService, spinner) + status, _ = db_local.StopServices(ctx, false, constants.InvokerService) } if err != nil { - display.StopSpinner(spinner) - utils.ShowError(err) + utils.ShowError(ctx, err) return } - display.StopSpinner(spinner) - switch status { case db_local.ServiceStopped: if dbState != nil { @@ -451,12 +446,9 @@ func showAllStatus(ctx context.Context) { var processes []*psutils.Process var err error - doneFetchingDetailsChan := make(chan bool) - sp := display.StartSpinnerAfterDelay("Getting details", constants.SpinnerShowTimeout, doneFetchingDetailsChan) - + statushooks.SetStatus(ctx, "Getting details") processes, err = db_local.FindAllSteampipePostgresInstances(ctx) - close(doneFetchingDetailsChan) - display.StopSpinner(sp) + statushooks.Done(ctx) utils.FailOnError(err) @@ -498,7 +490,7 @@ func getServiceProcessDetails(process *psutils.Process) (string, string, string, return fmt.Sprintf("%d", process.Pid), installDir, port, listenType } -func printStatus(dbState *db_local.RunningDBInstanceInfo, pmState *pluginmanager.PluginManagerState) { +func printStatus(ctx context.Context, dbState *db_local.RunningDBInstanceInfo, pmState *pluginmanager.PluginManagerState) { if dbState == nil && !pmState.Running { fmt.Println("Service is not running") return @@ -553,7 +545,7 @@ To keep the service running after the %s session completes, use %s. // the service is running, but the plugin_manager is not running and there's no state file // meaning that it cannot be restarted by the FDW // it's an ERROR - utils.ShowError(fmt.Errorf(` + utils.ShowError(ctx, fmt.Errorf(` Service is running, but the Plugin Manager cannot be recovered. Please use %s to recover the service `, diff --git a/connectionwatcher/connection_watcher.go b/connectionwatcher/connection_watcher.go index 14a4504d7..6a55c31db 100644 --- a/connectionwatcher/connection_watcher.go +++ b/connectionwatcher/connection_watcher.go @@ -86,7 +86,7 @@ func (w *ConnectionWatcher) handleFileWatcherEvent(e []fsnotify.Event) { if err != nil { log.Printf("[WARN] Error creating client to handle updated connection config: %s", err.Error()) } - defer client.Close() + defer client.Close(ctx) log.Printf("[TRACE] loaded updated config") diff --git a/constants/config_keys.go b/constants/config_keys.go index 85920bffd..9d22b4618 100644 --- a/constants/config_keys.go +++ b/constants/config_keys.go @@ -2,7 +2,6 @@ package constants // viper config keys const ( - ConfigKeyShowInteractiveOutput = "show-interactive-output" // ConfigKeyDatabaseSearchPath is used to store the search path set in the database config in viper // the viper value will be set via via a call to getScopedKey in steampipeconfig/steampipeconfig.go ConfigKeyDatabaseSearchPath = "database.search-path" diff --git a/contexthelpers/cancel.go b/contexthelpers/cancel.go new file mode 100644 index 000000000..6aa7ffe22 --- /dev/null +++ b/contexthelpers/cancel.go @@ -0,0 +1,21 @@ +package contexthelpers + +import ( + "context" + "log" + "os" + "os/signal" +) + +func StartCancelHandler(cancel context.CancelFunc) chan os.Signal { + sigIntChannel := make(chan os.Signal, 1) + signal.Notify(sigIntChannel, os.Interrupt) + go func() { + <-sigIntChannel + log.Println("[TRACE] got SIGINT") + // call context cancellation function + cancel() + // leave the channel open - any subsequent interrupts hits will be ignored + }() + return sigIntChannel +} diff --git a/contexthelpers/context_key.go b/contexthelpers/context_key.go new file mode 100644 index 000000000..8d833e083 --- /dev/null +++ b/contexthelpers/context_key.go @@ -0,0 +1,8 @@ +package contexthelpers + +//https://medium.com/@matryer/context-keys-in-go-5312346a868d +type ContextKey string + +func (c ContextKey) String() string { + return "steampipe context key " + string(c) +} diff --git a/control/controlexecute/control_run.go b/control/controlexecute/control_run.go index 1fc4666f9..451f39b08 100644 --- a/control/controlexecute/control_run.go +++ b/control/controlexecute/control_run.go @@ -13,6 +13,7 @@ import ( "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/db/db_common" "github.com/turbot/steampipe/query/queryresult" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/steampipeconfig/modconfig" "github.com/turbot/steampipe/utils" ) @@ -92,8 +93,8 @@ func NewControlRun(control *modconfig.Control, group *ResultGroup, executionTree return res } -func (r *ControlRun) skip() { - r.setRunStatus(ControlRunComplete) +func (r *ControlRun) skip(ctx context.Context) { + r.setRunStatus(ctx, ControlRunComplete) } // set search path for this control run @@ -196,14 +197,14 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) { r.runStatus = ControlRunStarted // update the current running control in the Progress renderer - r.executionTree.progress.OnControlStart(control) - defer r.executionTree.progress.OnControlFinish() + r.executionTree.progress.OnControlStart(ctx, control) + defer r.executionTree.progress.OnControlFinish(ctx) // resolve the control query r.Lifecycle.Add("query_resolution_start") query, err := r.resolveControlQuery(control) if err != nil { - r.SetError(err) + r.SetError(ctx, err) return } r.Lifecycle.Add("query_resolution_finish") @@ -211,7 +212,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) { log.Printf("[TRACE] setting search path %s\n", control.Name()) r.Lifecycle.Add("set_search_path_start") if err := r.setSearchPath(ctx, dbSession, client); err != nil { - r.SetError(err) + r.SetError(ctx, err) return } r.Lifecycle.Add("set_search_path_finish") @@ -225,7 +226,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) { // NOTE no need to pass an OnComplete callback - we are already closing our session after waiting for results log.Printf("[TRACE] execute start for, %s\n", control.Name()) r.Lifecycle.Add("query_start") - queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, query, nil, false) + queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, query, nil) r.Lifecycle.Add("query_finish") log.Printf("[TRACE] execute finish for, %s\n", control.Name()) @@ -243,7 +244,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) { log.Printf("[TRACE] control %s query failed again with plugin connectivity error %s - NOT retrying...", r.Control.Name(), err) } } - r.SetError(err) + r.SetError(ctx, err) return } @@ -255,7 +256,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) { log.Printf("[TRACE] finish result for, %s\n", control.Name()) } -func (r *ControlRun) SetError(err error) { +func (r *ControlRun) SetError(ctx context.Context, err error) { if err == nil { return } @@ -263,18 +264,23 @@ func (r *ControlRun) SetError(err error) { // update error count r.Summary.Error++ - r.setRunStatus(ControlRunError) + r.setRunStatus(ctx, ControlRunError) } func (r *ControlRun) GetError() error { return r.runError } +// create a context with a deadline, and with status updates disabled (we do not want to show 'loading' results) func (r *ControlRun) getControlQueryContext(ctx context.Context) context.Context { // create a context with a deadline shouldBeDoneBy := time.Now().Add(controlQueryTimeout) // we don't use this cancel fn because, pgx prematurely cancels the PG connection when this cancel gets called in 'defer' newCtx, _ := context.WithDeadline(ctx, shouldBeDoneBy) + + // disable the status spinner to hide 'loading' results) + newCtx = statushooks.DisableStatusHooks(newCtx) + return newCtx } @@ -304,20 +310,20 @@ func (r *ControlRun) waitForResults(ctx context.Context) { // create a channel to which will be closed when gathering has been done gatherDoneChan := make(chan string) go func() { - r.gatherResults() + r.gatherResults(ctx) close(gatherDoneChan) }() select { // check for cancellation case <-ctx.Done(): - r.SetError(ctx.Err()) + r.SetError(ctx, ctx.Err()) case <-gatherDoneChan: // do nothing } } -func (r *ControlRun) gatherResults() { +func (r *ControlRun) gatherResults(ctx context.Context) { r.Lifecycle.Add("gather_start") defer func() { r.Lifecycle.Add("gather_finish") }() for { @@ -326,14 +332,14 @@ func (r *ControlRun) gatherResults() { // nil row means control run is complete if row == nil { // nil row means we are done - r.setRunStatus(ControlRunComplete) + r.setRunStatus(ctx, ControlRunComplete) r.createdOrderedResultRows() return } // if the row is in error then we terminate the run if row.Error != nil { // set error status and summary - r.SetError(row.Error) + r.SetError(ctx, row.Error) // update the result group status with our status - this will be passed all the way up the execution tree r.group.updateSummary(r.Summary) return @@ -342,7 +348,7 @@ func (r *ControlRun) gatherResults() { // so all is ok - create another result row result, err := NewResultRow(r.Control, row, r.queryResult.ColTypes) if err != nil { - r.SetError(err) + r.SetError(ctx, err) return } r.addResultRow(result) @@ -380,7 +386,7 @@ func (r *ControlRun) createdOrderedResultRows() { } } -func (r *ControlRun) setRunStatus(status ControlRunStatus) { +func (r *ControlRun) setRunStatus(ctx context.Context, status ControlRunStatus) { r.stateLock.Lock() r.runStatus = status r.stateLock.Unlock() @@ -388,12 +394,11 @@ func (r *ControlRun) setRunStatus(status ControlRunStatus) { if r.Finished() { // update Progress if status == ControlRunError { - r.executionTree.progress.OnControlError() + r.executionTree.progress.OnControlError(ctx) } else { - r.executionTree.progress.OnControlComplete() + r.executionTree.progress.OnControlComplete(ctx) } - // TODO CANCEL QUERY IF NEEDED r.doneChan <- true } } diff --git a/control/controlexecute/execution_tree.go b/control/controlexecute/execution_tree.go index de0fe6ae6..e36e74082 100644 --- a/control/controlexecute/execution_tree.go +++ b/control/controlexecute/execution_tree.go @@ -12,6 +12,7 @@ import ( "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/db/db_common" "github.com/turbot/steampipe/query/queryresult" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/steampipeconfig/modconfig" "github.com/turbot/steampipe/workspace" "golang.org/x/sync/semaphore" @@ -40,9 +41,10 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien workspace: workspace, client: client, } - // if a "--where" or "--tag" parameter was passed, build a map of control manes used to filter the controls to run - // NOTE: not enabled yet - err := executionTree.populateControlFilterMap(ctx) + // if a "--where" or "--tag" parameter was passed, build a map of control names used to filter the controls to run + // create a context with status hooks disabled + noStatusCtx := statushooks.DisableStatusHooks(ctx) + err := executionTree.populateControlFilterMap(noStatusCtx) if err != nil { return nil, err @@ -55,7 +57,7 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien } // build tree of result groups, starting with a synthetic 'root' node - executionTree.Root = NewRootResultGroup(executionTree, rootItem) + executionTree.Root = NewRootResultGroup(ctx, executionTree, rootItem) // after tree has built, ControlCount will be set - create progress rendered executionTree.progress = NewControlProgressRenderer(len(executionTree.controlRuns)) @@ -65,7 +67,7 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien // AddControl checks whether control should be included in the tree // if so, creates a ControlRun, which is added to the parent group -func (e *ExecutionTree) AddControl(control *modconfig.Control, group *ResultGroup) { +func (e *ExecutionTree) AddControl(ctx context.Context, control *modconfig.Control, group *ResultGroup) { // note we use short name to determine whether to include a control if e.ShouldIncludeControl(control.ShortName) { // create new ControlRun with treeItem as the parent @@ -81,11 +83,11 @@ func (e *ExecutionTree) Execute(ctx context.Context, client db_common.Client) in log.Println("[TRACE]", "begin ExecutionTree.Execute") defer log.Println("[TRACE]", "end ExecutionTree.Execute") e.StartTime = time.Now() - e.progress.Start() + e.progress.Start(ctx) defer func() { e.EndTime = time.Now() - e.progress.Finish() + e.progress.Finish(ctx) }() // the number of goroutines parallel to start @@ -247,7 +249,7 @@ func (e *ExecutionTree) getControlMapFromWhereClause(ctx context.Context, whereC query = fmt.Sprintf("select resource_name from %s where %s", constants.IntrospectionTableControl, whereClause) } - res, err := e.client.ExecuteSync(ctx, query, false) + res, err := e.client.ExecuteSync(ctx, query) if err != nil { return nil, err } diff --git a/control/controlexecute/progress.go b/control/controlexecute/progress.go index b217decc5..37d77f01d 100644 --- a/control/controlexecute/progress.go +++ b/control/controlexecute/progress.go @@ -1,16 +1,17 @@ package controlexecute import ( + "context" "fmt" "sync" + "github.com/turbot/steampipe/statushooks" + "github.com/spf13/viper" "github.com/turbot/steampipe/constants" - "github.com/turbot/steampipe/steampipeconfig/modconfig" "github.com/briandowns/spinner" - "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/utils" ) @@ -34,16 +35,16 @@ func NewControlProgressRenderer(total int) *ControlProgressRenderer { } } -func (p *ControlProgressRenderer) Start() { +func (p *ControlProgressRenderer) Start(ctx context.Context) { p.updateLock.Lock() defer p.updateLock.Unlock() if p.enabled { - p.spinner = display.ShowSpinner("Starting controls...") + statushooks.SetStatus(ctx, "Starting controls...") } } -func (p *ControlProgressRenderer) OnControlStart(control *modconfig.Control) { +func (p *ControlProgressRenderer) OnControlStart(ctx context.Context, control *modconfig.Control) { p.updateLock.Lock() defer p.updateLock.Unlock() @@ -54,43 +55,43 @@ func (p *ControlProgressRenderer) OnControlStart(control *modconfig.Control) { p.pending-- if p.enabled { - display.UpdateSpinnerMessage(p.spinner, p.message()) + statushooks.SetStatus(ctx, p.message()) } } -func (p *ControlProgressRenderer) OnControlFinish() { +func (p *ControlProgressRenderer) OnControlFinish(ctx context.Context) { p.updateLock.Lock() defer p.updateLock.Unlock() // decrement the parallel execution count p.executing-- if p.enabled { - display.UpdateSpinnerMessage(p.spinner, p.message()) + statushooks.SetStatus(ctx, p.message()) } } -func (p *ControlProgressRenderer) OnControlComplete() { +func (p *ControlProgressRenderer) OnControlComplete(ctx context.Context) { p.updateLock.Lock() defer p.updateLock.Unlock() p.complete++ if p.enabled { - display.UpdateSpinnerMessage(p.spinner, p.message()) + statushooks.SetStatus(ctx, p.message()) } } -func (p *ControlProgressRenderer) OnControlError() { +func (p *ControlProgressRenderer) OnControlError(ctx context.Context) { p.updateLock.Lock() defer p.updateLock.Unlock() p.error++ if p.enabled { - display.UpdateSpinnerMessage(p.spinner, p.message()) + statushooks.SetStatus(ctx, p.message()) } } -func (p *ControlProgressRenderer) Finish() { +func (p *ControlProgressRenderer) Finish(ctx context.Context) { if p.enabled { - display.StopSpinner(p.spinner) + statushooks.Done(ctx) } } diff --git a/control/controlexecute/result_group.go b/control/controlexecute/result_group.go index 2a82be3b8..7ffaa3ede 100644 --- a/control/controlexecute/result_group.go +++ b/control/controlexecute/result_group.go @@ -49,7 +49,7 @@ func NewGroupSummary() *GroupSummary { } // NewRootResultGroup creates a ResultGroup to act as the root node of a control execution tree -func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.ModTreeItem) *ResultGroup { +func NewRootResultGroup(ctx context.Context, executionTree *ExecutionTree, rootItems ...modconfig.ModTreeItem) *ResultGroup { root := &ResultGroup{ GroupId: RootResultGroupName, Groups: []*ResultGroup{}, @@ -62,10 +62,10 @@ func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.Mod // if root item is a benchmark, create new result group with root as parent if control, ok := item.(*modconfig.Control); ok { // if root item is a control, add control run - executionTree.AddControl(control, root) + executionTree.AddControl(ctx, control, root) } else { // create a result group for this item - itemGroup := NewResultGroup(executionTree, item, root) + itemGroup := NewResultGroup(ctx, executionTree, item, root) root.Groups = append(root.Groups, itemGroup) } } @@ -73,7 +73,7 @@ func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.Mod } // NewResultGroup creates a result group from a ModTreeItem -func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem, parent *ResultGroup) *ResultGroup { +func NewResultGroup(ctx context.Context, executionTree *ExecutionTree, treeItem modconfig.ModTreeItem, parent *ResultGroup) *ResultGroup { // only show qualified group names for controls from dependent mods groupId := treeItem.Name() if mod := treeItem.GetMod(); mod != nil && mod.Name() == executionTree.workspace.Mod.Name() { @@ -96,7 +96,7 @@ func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem for _, c := range treeItem.GetChildren() { if benchmark, ok := c.(*modconfig.Benchmark); ok { // create a result group for this item - benchmarkGroup := NewResultGroup(executionTree, benchmark, group) + benchmarkGroup := NewResultGroup(ctx, executionTree, benchmark, group) // if the group has any control runs, add to tree if benchmarkGroup.ControlRunCount() > 0 { // create a new result group with 'group' as the parent @@ -104,7 +104,7 @@ func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem } } if control, ok := c.(*modconfig.Control); ok { - executionTree.AddControl(control, group) + executionTree.AddControl(ctx, control, group) } } @@ -180,18 +180,18 @@ func (r *ResultGroup) execute(ctx context.Context, client db_common.Client, para for _, controlRun := range r.ControlRuns { if utils.IsContextCancelled(ctx) { - controlRun.SetError(ctx.Err()) + controlRun.SetError(ctx, ctx.Err()) continue } if viper.GetBool(constants.ArgDryRun) { - controlRun.skip() + controlRun.skip(ctx) continue } err := parallelismLock.Acquire(ctx, 1) if err != nil { - controlRun.SetError(err) + controlRun.SetError(ctx, err) continue } @@ -199,7 +199,7 @@ func (r *ResultGroup) execute(ctx context.Context, client db_common.Client, para defer func() { if r := recover(); r != nil { // if the Execute panic'ed, set it as an error - run.SetError(helpers.ToError(r)) + run.SetError(ctx, helpers.ToError(r)) } // Release in defer, so that we don't retain the lock even if there's a panic inside parallelismLock.Release(1) diff --git a/control/init_data.go b/control/init_data.go index 6b23833a3..e2e2b4b72 100644 --- a/control/init_data.go +++ b/control/init_data.go @@ -1,14 +1,11 @@ package control import ( - "context" - "github.com/turbot/steampipe/db/db_common" "github.com/turbot/steampipe/workspace" ) type InitData struct { - Ctx context.Context Workspace *workspace.Workspace Client db_common.Client Result *db_common.InitResult diff --git a/db/db_client/db_client.go b/db/db_client/db_client.go index d7b08b930..b3562cc76 100644 --- a/db/db_client/db_client.go +++ b/db/db_client/db_client.go @@ -68,7 +68,7 @@ func NewDbClient(ctx context.Context, connectionString string) (*DbClient, error client.connectionString = connectionString if err := client.LoadForeignSchemaNames(ctx); err != nil { - client.Close() + client.Close(ctx) return nil, err } return client, nil @@ -114,7 +114,7 @@ func (c *DbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallba // Close implements Client // closes the connection to the database and shuts down the backend -func (c *DbClient) Close() error { +func (c *DbClient) Close(context.Context) error { log.Printf("[TRACE] DbClient.Close %v", c.dbClient) if c.dbClient != nil { c.sessionInitWaitGroup.Wait() diff --git a/db/db_client/db_client_execute.go b/db/db_client/db_client_execute.go index 400931c4d..53364f226 100644 --- a/db/db_client/db_client_execute.go +++ b/db/db_client/db_client_execute.go @@ -8,12 +8,9 @@ import ( "fmt" "time" - "github.com/briandowns/spinner" - "github.com/turbot/steampipe/cmdconfig" - "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/db/db_common" - "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/query/queryresult" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/utils" "golang.org/x/text/language" "golang.org/x/text/message" @@ -21,7 +18,7 @@ import ( // ExecuteSync implements Client // execute a query against this client and wait for the result -func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) { +func (c *DbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) { // acquire a session sessionResult := c.AcquireSession(ctx) if sessionResult.Error != nil { @@ -32,17 +29,17 @@ func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner // and not in call-time sessionResult.Session.Close(utils.IsContextCancelled(ctx)) }() - return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, disableSpinner) + return c.ExecuteSyncInSession(ctx, sessionResult.Session, query) } // ExecuteSyncInSession implements Client // execute a query against this client and wait for the result -func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) { +func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) { if query == "" { return &queryresult.SyncQueryResult{}, nil } - result, err := c.ExecuteInSession(ctx, session, query, nil, disableSpinner) + result, err := c.ExecuteInSession(ctx, session, query, nil) if err != nil { return nil, err } @@ -61,7 +58,7 @@ func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common. // Execute implements Client // execute the query in the given Context // NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication -func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner bool) (*queryresult.Result, error) { +func (c *DbClient) Execute(ctx context.Context, query string) (*queryresult.Result, error) { // acquire a session sessionResult := c.AcquireSession(ctx) if sessionResult.Error != nil { @@ -70,26 +67,29 @@ func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner boo // define callback to close session when the async execution is complete closeSessionCallback := func() { sessionResult.Session.Close(utils.IsContextCancelled(ctx)) } - return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback, disableSpinner) + return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback) } // ExecuteInSession implements Client // execute the query in the given Context using the provided DatabaseSession // ExecuteInSession assumes no responsibility over the lifecycle of the DatabaseSession - that is the responsibility of the caller // NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication -func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error) { +func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) { if query == "" { return queryresult.NewQueryResult(nil), nil } startTime := time.Now() - // channel to flag to spinner that the query has run - var spinner *spinner.Spinner + var tx *sql.Tx defer func() { if err != nil { // stop spinner in case of error - display.StopSpinner(spinner) + statushooks.Done(ctx) + // error - rollback transaction if we have one + if tx != nil { + tx.Rollback() + } // call the completion callback - if one was provided if onComplete != nil { onComplete() @@ -97,11 +97,7 @@ func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.Data } }() - if !disableSpinner && cmdconfig.Viper().GetBool(constants.ConfigKeyShowInteractiveOutput) { - // if `show-interactive-output` is false, the spinner gets created, but is never shown - // so the s.Active() will always come back false . . . - spinner = display.ShowSpinner("Loading results...") - } + statushooks.SetStatus(ctx, "Loading results...") // start query var rows *sql.Rows @@ -122,7 +118,7 @@ func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.Data // read the rows in a go routine go func() { // read in the rows and stream to the query result object - c.readRows(ctx, startTime, rows, result, spinner) + c.readRows(ctx, startTime, rows, result) if onComplete != nil { onComplete() } @@ -158,11 +154,11 @@ func (c *DbClient) startQuery(ctx context.Context, query string, conn *sql.Conn) return } -func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows, result *queryresult.Result, activeSpinner *spinner.Spinner) { +func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows, result *queryresult.Result) { // defer this, so that these get cleaned up even if there is an unforeseen error defer func() { - // we are done fetching results. time for display. remove the spinner - display.StopSpinner(activeSpinner) + // we are done fetching results. time for display. clear the status indication + statushooks.Done(ctx) // close the sql rows object rows.Close() if err := rows.Err(); err != nil { @@ -191,7 +187,7 @@ func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows continueToNext := true select { case <-ctx.Done(): - display.UpdateSpinnerMessage(activeSpinner, "Cancelling query") + statushooks.SetStatus(ctx, "Cancelling query") continueToNext = false default: if rowResult, err := readRowContext(ctx, rows, cols, colTypes); err != nil { @@ -200,9 +196,9 @@ func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows } else { result.StreamRow(rowResult) } - // update the spinner message with the count of rows that have already been fetched + // update the status message with the count of rows that have already been fetched // this will not show if the spinner is not active - display.UpdateSpinnerMessage(activeSpinner, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount))) + statushooks.SetStatus(ctx, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount))) rowCount++ } if !continueToNext { diff --git a/db/db_client/db_client_search_path.go b/db/db_client/db_client_search_path.go index 105e7e263..7ac6abf29 100644 --- a/db/db_client/db_client_search_path.go +++ b/db/db_client/db_client_search_path.go @@ -18,7 +18,7 @@ import ( func (c *DbClient) GetCurrentSearchPath(ctx context.Context) ([]string, error) { var currentSearchPath []string var pathAsString string - rows, err := c.ExecuteSync(ctx, "show search_path", true) + rows, err := c.ExecuteSync(ctx, "show search_path") if err != nil { return nil, err } diff --git a/db/db_common/client.go b/db/db_common/client.go index 5c51ad340..fe07aeaeb 100644 --- a/db/db_common/client.go +++ b/db/db_common/client.go @@ -11,7 +11,7 @@ import ( type EnsureSessionStateCallback = func(context.Context, *DatabaseSession) (err error, warnings []string) type Client interface { - Close() error + Close(ctx context.Context) error ForeignSchemas() []string ConnectionMap() *steampipeconfig.ConnectionDataMap @@ -22,11 +22,11 @@ type Client interface { AcquireSession(context.Context) *AcquireSessionResult - ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) - Execute(ctx context.Context, query string, disableSpinner bool) (res *queryresult.Result, err error) + ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) + Execute(ctx context.Context, query string) (res *queryresult.Result, err error) - ExecuteSyncInSession(ctx context.Context, session *DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) - ExecuteInSession(ctx context.Context, session *DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error) + ExecuteSyncInSession(ctx context.Context, session *DatabaseSession, query string) (*queryresult.SyncQueryResult, error) + ExecuteInSession(ctx context.Context, session *DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) CacheOn(context.Context) error CacheOff(context.Context) error diff --git a/db/db_common/execute.go b/db/db_common/execute.go index 7a5138a73..4140de3fc 100644 --- a/db/db_common/execute.go +++ b/db/db_common/execute.go @@ -13,7 +13,7 @@ func ExecuteQuery(ctx context.Context, queryString string, client Client) (*quer defer utils.LogTime("db.ExecuteQuery end") resultsStreamer := queryresult.NewResultStreamer() - result, err := client.Execute(ctx, queryString, false) + result, err := client.Execute(ctx, queryString) if err != nil { return nil, err } diff --git a/db/db_common/introspection_tables.go b/db/db_common/introspection_tables.go index c9c56982e..605452eba 100644 --- a/db/db_common/introspection_tables.go +++ b/db/db_common/introspection_tables.go @@ -19,25 +19,6 @@ import ( // TagColumn is the tag used to specify the column name and type in the introspection tables const TagColumn = "column" -func UpdateIntrospectionTables(workspaceResources *modconfig.WorkspaceResourceMaps, client Client) error { - utils.LogTime("db.UpdateIntrospectionTables start") - defer utils.LogTime("db.UpdateIntrospectionTables end") - - // get the create sql for each table type - clearSql := getClearTablesSql() - - // now get sql to populate the tables - insertSql := getTableInsertSql(workspaceResources) - - sql := []string{clearSql, insertSql} - // execute the query, passing 'true' to disable the spinner - _, err := client.ExecuteSync(context.Background(), strings.Join(sql, "\n"), true) - if err != nil { - return fmt.Errorf("failed to update introspection tables: %v", err) - } - return nil -} - func CreateIntrospectionTables(ctx context.Context, workspaceResources *modconfig.WorkspaceResourceMaps, session *DatabaseSession) error { utils.LogTime("db.CreateIntrospectionTables start") defer utils.LogTime("db.CreateIntrospectionTables end") diff --git a/db/db_local/install.go b/db/db_local/install.go index 84e46fd8b..c690d27fd 100644 --- a/db/db_local/install.go +++ b/db/db_local/install.go @@ -10,15 +10,14 @@ import ( "os/exec" "sync" - "github.com/briandowns/spinner" psutils "github.com/shirou/gopsutil/process" "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/db/db_common" - "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/filepaths" "github.com/turbot/steampipe/ociinstaller" "github.com/turbot/steampipe/ociinstaller/versionfile" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/utils" ) @@ -41,66 +40,57 @@ func EnsureDBInstalled(ctx context.Context) (err error) { close(doneChan) }() - spinner := display.StartSpinnerAfterDelay("", constants.SpinnerShowTimeout, doneChan) - if IsInstalled() { // check if the FDW need updating, and init the db id required - err := prepareDb(ctx, spinner) - display.StopSpinner(spinner) + err := prepareDb(ctx) return err } log.Println("[TRACE] calling removeRunningInstanceInfo") err = removeRunningInstanceInfo() if err != nil && !os.IsNotExist(err) { - display.StopSpinner(spinner) log.Printf("[TRACE] removeRunningInstanceInfo failed: %v", err) return fmt.Errorf("Cleanup any Steampipe processes... FAILED!") } log.Println("[TRACE] removing previous installation") - display.UpdateSpinnerMessage(spinner, "Prepare database install location...") + statushooks.SetStatus(ctx, "Prepare database install location...") + defer statushooks.Done(ctx) + err = os.RemoveAll(getDatabaseLocation()) if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] %v", err) return fmt.Errorf("Prepare database install location... FAILED!") } - display.UpdateSpinnerMessage(spinner, "Download & install embedded PostgreSQL database...") + statushooks.SetStatus(ctx, "Download & install embedded PostgreSQL database...") _, err = ociinstaller.InstallDB(ctx, constants.DefaultEmbeddedPostgresImage, getDatabaseLocation()) if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] %v", err) return fmt.Errorf("Download & install embedded PostgreSQL database... FAILED!") } - // installFDW takes care of the spinner, since it may need to run independently - _, err = installFDW(ctx, true, spinner) + _, err = installFDW(ctx, true) if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] installFDW failed: %v", err) return fmt.Errorf("Download & install steampipe-postgres-fdw... FAILED!") } // run the database installation - err = runInstall(ctx, true, spinner) + err = runInstall(ctx, true) if err != nil { - display.StopSpinner(spinner) return err } // write a signature after everything gets done! // so that we can check for this later on - display.UpdateSpinnerMessage(spinner, "Updating install records...") + statushooks.SetStatus(ctx, "Updating install records...") err = updateDownloadedBinarySignature() if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] updateDownloadedBinarySignature failed: %v", err) return fmt.Errorf("Updating install records... FAILED!") } - display.StopSpinner(spinner) return nil } @@ -137,11 +127,10 @@ func IsInstalled() bool { } // prepareDb updates the FDW if needed, and inits the database if required -func prepareDb(ctx context.Context, spinner *spinner.Spinner) error { +func prepareDb(ctx context.Context) error { // check if FDW needs to be updated if fdwNeedsUpdate() { - _, err := installFDW(ctx, false, spinner) - spinner.Stop() + _, err := installFDW(ctx, false) if err != nil { log.Printf("[TRACE] installFDW failed: %v", err) return fmt.Errorf("Update steampipe-postgres-fdw... FAILED!") @@ -153,10 +142,9 @@ func prepareDb(ctx context.Context, spinner *spinner.Spinner) error { } if needsInit() { - spinner.Start() - display.UpdateSpinnerMessage(spinner, "Cleanup any Steampipe processes...") + statushooks.SetStatus(ctx, "Cleanup any Steampipe processes...") killInstanceIfAny(ctx) - if err := runInstall(ctx, false, spinner); err != nil { + if err := runInstall(ctx, false); err != nil { return err } } @@ -175,15 +163,15 @@ func fdwNeedsUpdate() bool { return versionInfo.FdwExtension.Version != constants.FdwVersion } -func installFDW(ctx context.Context, firstSetup bool, spinner *spinner.Spinner) (string, error) { +func installFDW(ctx context.Context, firstSetup bool) (string, error) { utils.LogTime("db_local.installFDW start") defer utils.LogTime("db_local.installFDW end") - status, err := GetState() + state, err := GetState() if err != nil { return "", err } - if status != nil { + if state != nil { defer func() { if !firstSetup { // update the signature @@ -191,7 +179,7 @@ func installFDW(ctx context.Context, firstSetup bool, spinner *spinner.Spinner) } }() } - display.UpdateSpinnerMessage(spinner, fmt.Sprintf("Download & install %s...", constants.Bold("steampipe-postgres-fdw"))) + statushooks.SetStatus(ctx, fmt.Sprintf("Download & install %s...", constants.Bold("steampipe-postgres-fdw"))) return ociinstaller.InstallFdw(ctx, constants.DefaultFdwImage, getDatabaseLocation()) } @@ -203,58 +191,54 @@ func needsInit() bool { return !helpers.FileExists(getPgHbaConfLocation()) } -func runInstall(ctx context.Context, firstInstall bool, spinner *spinner.Spinner) error { +func runInstall(ctx context.Context, firstInstall bool) error { utils.LogTime("db_local.runInstall start") defer utils.LogTime("db_local.runInstall end") - display.UpdateSpinnerMessage(spinner, "Cleaning up...") + statushooks.SetStatus(ctx, "Cleaning up...") + defer statushooks.Done(ctx) + err := utils.RemoveDirectoryContents(getDataLocation()) if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] %v", err) return fmt.Errorf("Prepare database install location... FAILED!") } - display.UpdateSpinnerMessage(spinner, "Initializing database...") + statushooks.SetStatus(ctx, "Initializing database...") err = initDatabase() if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] initDatabase failed: %v", err) return fmt.Errorf("Initializing database... FAILED!") } - display.UpdateSpinnerMessage(spinner, "Starting database...") + statushooks.SetStatus(ctx, "Starting database...") port, err := getNextFreePort() if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] getNextFreePort failed: %v", err) return fmt.Errorf("Starting database... FAILED!") } process, err := startServiceForInstall(port) if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] startServiceForInstall failed: %v", err) return fmt.Errorf("Starting database... FAILED!") } - display.UpdateSpinnerMessage(spinner, "Connection to database...") + statushooks.SetStatus(ctx, "Connection to database...") client, err := createMaintenanceClient(ctx, port) if err != nil { - display.StopSpinner(spinner) return fmt.Errorf("Connection to database... FAILED!") } defer func() { - display.UpdateSpinnerMessage(spinner, "Completing configuration") + statushooks.SetStatus(ctx, "Completing configuration") client.Close() - doThreeStepPostgresExit(process) + doThreeStepPostgresExit(ctx, process) }() - display.UpdateSpinnerMessage(spinner, "Generating database passwords...") + statushooks.SetStatus(ctx, "Generating database passwords...") // generate a password file for use later _, err = readPasswordFile() if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] readPassword failed: %v", err) return fmt.Errorf("Generating database passwords... FAILED!") } @@ -274,23 +258,21 @@ func runInstall(ctx context.Context, firstInstall bool, spinner *spinner.Spinner return fmt.Errorf("Invalid database name '%s' - must start with either a lowercase character or an underscore", databaseName) } - display.UpdateSpinnerMessage(spinner, "Configuring database...") + statushooks.SetStatus(ctx, "Configuring database...") err = installDatabaseWithPermissions(ctx, databaseName, client) if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] installSteampipeDatabaseAndUser failed: %v", err) return fmt.Errorf("Configuring database... FAILED!") } - display.UpdateSpinnerMessage(spinner, "Configuring Steampipe...") + statushooks.SetStatus(ctx, "Configuring Steampipe...") err = installForeignServer(ctx, client) if err != nil { - display.StopSpinner(spinner) log.Printf("[TRACE] installForeignServer failed: %v", err) return fmt.Errorf("Configuring Steampipe... FAILED!") } - return err + return nil } func resolveDatabaseName() string { diff --git a/db/db_local/local_db_client.go b/db/db_local/local_db_client.go index 6cfa0616c..7d6b8da6b 100644 --- a/db/db_local/local_db_client.go +++ b/db/db_local/local_db_client.go @@ -13,6 +13,7 @@ import ( "github.com/turbot/steampipe/db/db_common" "github.com/turbot/steampipe/query/queryresult" "github.com/turbot/steampipe/schema" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/steampipeconfig" "github.com/turbot/steampipe/utils" ) @@ -41,13 +42,12 @@ func GetLocalClient(ctx context.Context, invoker constants.Invoker) (db_common.C client, err := NewLocalClient(ctx, invoker) if err != nil { - ShutdownService(invoker) + ShutdownService(ctx, invoker) } return client, err } -// NewLocalClient ensures that the database instance is running -// and returns a `Client` to interact with it +// NewLocalClient verifies that the local database instance is running and returns a Client to interact with it func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbClient, error) { utils.LogTime("db.NewLocalClient start") defer utils.LogTime("db.NewLocalClient end") @@ -69,19 +69,17 @@ func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbCli // Close implements Client // close the connection to the database and shuts down the backend -func (c *LocalDbClient) Close() error { +func (c *LocalDbClient) Close(ctx context.Context) error { log.Printf("[TRACE] close local client %p", c) if c.client != nil { log.Printf("[TRACE] local client not NIL") - if err := c.client.Close(); err != nil { + if err := c.client.Close(ctx); err != nil { return err } log.Printf("[TRACE] local client close complete") } log.Printf("[TRACE] shutdown local service %v", c.invoker) - // no context to pass on - use background - // we shouldn't do this in a context that can be cancelled anyway - ShutdownService(c.invoker) + ShutdownService(ctx, c.invoker) return nil } @@ -108,23 +106,23 @@ func (c *LocalDbClient) AcquireSession(ctx context.Context) *db_common.AcquireSe } // ExecuteSync implements Client -func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) { - return c.client.ExecuteSync(ctx, query, disableSpinner) +func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) { + return c.client.ExecuteSync(ctx, query) } // ExecuteSyncInSession implements Client -func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) { - return c.client.ExecuteSyncInSession(ctx, session, query, disableSpinner) +func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) { + return c.client.ExecuteSyncInSession(ctx, session, query) } // ExecuteInSession implements Client -func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error) { - return c.client.ExecuteInSession(ctx, session, query, onComplete, disableSpinner) +func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) { + return c.client.ExecuteInSession(ctx, session, query, onComplete) } // Execute implements Client -func (c *LocalDbClient) Execute(ctx context.Context, query string, disableSpinner bool) (res *queryresult.Result, err error) { - return c.client.Execute(ctx, query, disableSpinner) +func (c *LocalDbClient) Execute(ctx context.Context, query string) (res *queryresult.Result, err error) { + return c.client.Execute(ctx, query) } // CacheOn implements Client @@ -167,6 +165,9 @@ func (c *LocalDbClient) LoadForeignSchemaNames(ctx context.Context) error { // local only functions func (c *LocalDbClient) RefreshConnectionAndSearchPaths(ctx context.Context) *steampipeconfig.RefreshConnectionResult { + // NOTE: disable any status updates - we do not want 'loading' output from any queries + ctx = statushooks.DisableStatusHooks(ctx) + res := c.refreshConnections(ctx) if res.Error != nil { return res @@ -221,7 +222,7 @@ func (c *LocalDbClient) setUserSearchPath(ctx context.Context) ([]string, error) // get all roles which are a member of steampipe_users query := fmt.Sprintf(`select usename from pg_user where pg_has_role(usename, '%s', 'member')`, constants.DatabaseUsersRole) - res, err := c.ExecuteSync(context.Background(), query, true) + res, err := c.ExecuteSync(context.Background(), query) if err != nil { return nil, err } diff --git a/db/db_local/refresh_connections.go b/db/db_local/refresh_connections.go index ca975e58a..187773974 100644 --- a/db/db_local/refresh_connections.go +++ b/db/db_local/refresh_connections.go @@ -12,7 +12,7 @@ func RefreshConnectionAndSearchPaths(ctx context.Context, invoker constants.Invo if err != nil { return err } - defer client.Close() + defer client.Close(ctx) refreshResult := client.RefreshConnectionAndSearchPaths(ctx) // display any initialisation warnings refreshResult.ShowWarnings() diff --git a/db/db_local/start_services.go b/db/db_local/start_services.go index bdb837726..b356ece36 100644 --- a/db/db_local/start_services.go +++ b/db/db_local/start_services.go @@ -136,7 +136,7 @@ func startDB(ctx context.Context, port int, listen StartListenType, invoker cons // if there was an error and we started the service, stop it again if res.Error != nil { if res.Status == ServiceStarted { - StopServices(false, invoker, nil) + StopServices(ctx, false, invoker) } // remove the state file if we are going back with an error removeRunningInstanceInfo() @@ -554,7 +554,7 @@ func killInstanceIfAny(ctx context.Context) bool { for _, process := range processes { wg.Add(1) go func(p *psutils.Process) { - doThreeStepPostgresExit(p) + doThreeStepPostgresExit(ctx, p) wg.Done() }(process) } diff --git a/db/db_local/stop_services.go b/db/db_local/stop_services.go index 427ae3c46..bf96f954a 100644 --- a/db/db_local/stop_services.go +++ b/db/db_local/stop_services.go @@ -9,14 +9,12 @@ import ( "syscall" "time" - "github.com/briandowns/spinner" psutils "github.com/shirou/gopsutil/process" "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/constants/runtime" - "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/filepaths" "github.com/turbot/steampipe/pluginmanager" - + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/utils" ) @@ -32,7 +30,7 @@ const ( ) // ShutdownService stops the database instance if the given 'invoker' matches -func ShutdownService(invoker constants.Invoker) { +func ShutdownService(ctx context.Context, invoker constants.Invoker) { utils.LogTime("db_local.ShutdownService start") defer utils.LogTime("db_local.ShutdownService end") @@ -54,18 +52,18 @@ func ShutdownService(invoker constants.Invoker) { } // we can shut down the database - stopStatus, err := StopServices(false, invoker, nil) + stopStatus, err := StopServices(ctx, false, invoker) if err != nil { - utils.ShowError(err) + utils.ShowError(ctx, err) } if stopStatus == ServiceStopped { return } // shutdown failed - try to force stop - _, err = StopServices(true, invoker, nil) + _, err = StopServices(ctx, true, invoker) if err != nil { - utils.ShowError(err) + utils.ShowError(ctx, err) } } @@ -92,7 +90,8 @@ func GetCountOfThirdPartyClients(ctx context.Context) (i int, e error) { } // StopServices searches for and stops the running instance. Does nothing if an instance was not found -func StopServices(force bool, invoker constants.Invoker, spinner *spinner.Spinner) (status StopStatus, e error) { +func StopServices(ctx context.Context, force bool, invoker constants.Invoker) (status StopStatus, e error) { + log.Printf("[TRACE] StopDB invoker %s, force %v", invoker, force) utils.LogTime("db_local.StopDB start") @@ -108,15 +107,16 @@ func StopServices(force bool, invoker constants.Invoker, spinner *spinner.Spinne pluginManagerStopError := pluginmanager.Stop() // stop the DB Service - stopResult, dbStopError := stopDBService(spinner, force) + stopResult, dbStopError := stopDBService(ctx, force) return stopResult, utils.CombineErrors(dbStopError, pluginManagerStopError) } -func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) { +func stopDBService(ctx context.Context, force bool) (StopStatus, error) { if force { // check if we have a process from another install-dir - display.UpdateSpinnerMessage(spinner, "Checking for running instances...") + statushooks.SetStatus(ctx, "Checking for running instances...") + defer statushooks.Done(ctx) // do not use a context that can be cancelled killInstanceIfAny(context.Background()) return ServiceStopped, nil @@ -139,9 +139,7 @@ func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) { return ServiceStopFailed, err } - display.UpdateSpinnerMessage(spinner, "Shutting down...") - - err = doThreeStepPostgresExit(process) + err = doThreeStepPostgresExit(ctx, process) if err != nil { // we couldn't stop it still. // timeout @@ -176,7 +174,7 @@ func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) { checked that the service can indeed shutdown gracefully, the sequence is there only as a backup. **/ -func doThreeStepPostgresExit(process *psutils.Process) error { +func doThreeStepPostgresExit(ctx context.Context, process *psutils.Process) error { utils.LogTime("db_local.doThreeStepPostgresExit start") defer utils.LogTime("db_local.doThreeStepPostgresExit end") @@ -191,6 +189,11 @@ func doThreeStepPostgresExit(process *psutils.Process) error { exitSuccessful = waitForProcessExit(process, 2*time.Second) if !exitSuccessful { // process didn't quit + + // set status, as this is taking time + statushooks.SetStatus(ctx, "Shutting down...") + defer statushooks.Done(ctx) + // try a SIGINT err = process.SendSignal(syscall.SIGINT) if err != nil { diff --git a/display/display.go b/display/display.go index c4f31c9c2..61b57addc 100644 --- a/display/display.go +++ b/display/display.go @@ -2,6 +2,7 @@ package display import ( "bytes" + "context" "encoding/csv" "encoding/json" "fmt" @@ -20,17 +21,17 @@ import ( ) // ShowOutput :: displays the output using the proper formatter as applicable -func ShowOutput(result *queryresult.Result) { +func ShowOutput(ctx context.Context, result *queryresult.Result) { output := cmdconfig.Viper().GetString(constants.ArgOutput) if output == constants.OutputFormatJSON { - displayJSON(result) + displayJSON(ctx, result) } else if output == constants.OutputFormatCSV { - displayCSV(result) + displayCSV(ctx, result) } else if output == constants.OutputFormatLine { - displayLine(result) + displayLine(ctx, result) } else { // default - displayTable(result) + displayTable(ctx, result) } } @@ -102,7 +103,7 @@ func getColumnSettings(headers []string, rows [][]string) ([]table.ColumnConfig, return colConfigs, headerRow } -func displayLine(result *queryresult.Result) { +func displayLine(ctx context.Context, result *queryresult.Result) { colNames := ColumnNames(result.ColTypes) maxColNameLength := 0 for _, colName := range colNames { @@ -158,7 +159,7 @@ func displayLine(result *queryresult.Result) { // call this function for each row if err := iterateResults(result, rowFunc); err != nil { - utils.ShowError(err) + utils.ShowError(ctx, err) return } } @@ -173,7 +174,7 @@ func getTerminalColumnsRequiredForString(str string) int { return colsRequired } -func displayJSON(result *queryresult.Result) { +func displayJSON(ctx context.Context, result *queryresult.Result) { var jsonOutput []map[string]interface{} // define function to add each row to the JSON output @@ -188,7 +189,7 @@ func displayJSON(result *queryresult.Result) { // call this function for each row if err := iterateResults(result, rowFunc); err != nil { - utils.ShowError(err) + utils.ShowError(ctx, err) return } // display the JSON @@ -202,7 +203,7 @@ func displayJSON(result *queryresult.Result) { fmt.Println() } -func displayCSV(result *queryresult.Result) { +func displayCSV(ctx context.Context, result *queryresult.Result) { csvWriter := csv.NewWriter(os.Stdout) csvWriter.Comma = []rune(cmdconfig.Viper().GetString(constants.ArgSeparator))[0] @@ -219,17 +220,17 @@ func displayCSV(result *queryresult.Result) { // call this function for each row if err := iterateResults(result, rowFunc); err != nil { - utils.ShowError(err) + utils.ShowError(ctx, err) return } csvWriter.Flush() if csvWriter.Error() != nil { - utils.ShowErrorWithMessage(csvWriter.Error(), "unable to print csv") + utils.ShowErrorWithMessage(ctx, csvWriter.Error(), "unable to print csv") } } -func displayTable(result *queryresult.Result) { +func displayTable(ctx context.Context, result *queryresult.Result) { // the buffer to put the output data in outbuf := bytes.NewBufferString("") @@ -271,14 +272,14 @@ func displayTable(result *queryresult.Result) { if err != nil { // display the error fmt.Println() - utils.ShowError(err) + utils.ShowError(ctx, err) fmt.Println() } // write out the table to the buffer t.Render() // page out the table - ShowPaged(outbuf.String()) + ShowPaged(ctx, outbuf.String()) // if timer is turned on if cmdconfig.Viper().GetBool(constants.ArgTimer) { diff --git a/display/pager.go b/display/pager.go index 8dabc51e1..d91767abe 100644 --- a/display/pager.go +++ b/display/pager.go @@ -2,6 +2,7 @@ package display import ( "bufio" + "context" "fmt" "os" "os/exec" @@ -15,9 +16,9 @@ import ( ) // ShowPaged displays the `content` in a system dependent pager -func ShowPaged(content string) { +func ShowPaged(ctx context.Context, content string) { if isPagerNeeded(content) && (runtime.GOOS == "darwin" || runtime.GOOS == "linux") { - nixPager(content) + nixPager(ctx, content) } else { nullPager(content) } @@ -59,11 +60,11 @@ func nullPager(content string) { fmt.Print(content) } -func nixPager(content string) { +func nixPager(ctx context.Context, content string) { if isLessAvailable() { - execPager(exec.Command("less", "-SRXF"), content) + execPager(ctx, exec.Command("less", "-SRXF"), content) } else if isMoreAvailable() { - execPager(exec.Command("more"), content) + execPager(ctx, exec.Command("more"), content) } else { nullPager(content) } @@ -79,13 +80,13 @@ func isMoreAvailable() bool { return err == nil } -func execPager(cmd *exec.Cmd, content string) { +func execPager(ctx context.Context, cmd *exec.Cmd, content string) { cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Stdin = strings.NewReader(content) // run the command - it will block until the pager is exited err := cmd.Run() if err != nil { - utils.ShowErrorWithMessage(err, "could not display results") + utils.ShowErrorWithMessage(ctx, err, "could not display results") } } diff --git a/display/spinner.go b/display/spinner.go deleted file mode 100644 index 76e454cc3..000000000 --- a/display/spinner.go +++ /dev/null @@ -1,124 +0,0 @@ -package display - -import ( - "fmt" - "os" - "strings" - "time" - - "github.com/briandowns/spinner" - "github.com/karrick/gows" - "github.com/spf13/viper" - "github.com/turbot/steampipe/constants" -) - -// -// spinner format: -// -// 1 1 [.......] 1 1 1 1 1 -// We need at least seven characters to show the spinner properly -// -// Not using the (…) character, since it is too small -// -const minSpinnerWidth = 7 - -func truncateSpinnerMessageToScreen(msg string) string { - if len(strings.TrimSpace(msg)) == 0 { - // if this is a blank message, return it as is - return msg - } - - maxCols, _, _ := gows.GetWinSize() - // if the screen is smaller than the minimum spinner width, we cannot truncate - if maxCols < minSpinnerWidth { - return msg - } - availableColumns := maxCols - minSpinnerWidth - if len(msg) > availableColumns { - msg = msg[:availableColumns] - msg = fmt.Sprintf("%s ...", msg) - } - return msg -} - -// StartSpinnerAfterDelay shows the spinner with the given `msg` if and only if `cancelStartIf` resolves -// after `delay`. -// -// Example: if delay is 2 seconds and `cancelStartIf` resolves after 2.5 seconds, the spinner -// will show for 0.5 seconds. If `cancelStartIf` resolves after 1.5 seconds, the spinner will -// NOT be shown at all -// -func StartSpinnerAfterDelay(msg string, delay time.Duration, cancelStartIf chan bool) *spinner.Spinner { - if !viper.GetBool(constants.ConfigKeyIsTerminalTTY) { - return nil - } - - msg = truncateSpinnerMessageToScreen(msg) - spinner := spinner.New( - spinner.CharSets[14], - 100*time.Millisecond, - spinner.WithHiddenCursor(true), - spinner.WithWriter(os.Stdout), - spinner.WithSuffix(fmt.Sprintf(" %s", msg)), - ) - - go func() { - select { - case <-cancelStartIf: - case <-time.After(delay): - if spinner != nil && !spinner.Active() { - spinner.Start() - } - } - time.Sleep(50 * time.Millisecond) - }() - - return spinner -} - -// ShowSpinner shows a spinner with the given message -func ShowSpinner(msg string) *spinner.Spinner { - if !viper.GetBool(constants.ConfigKeyIsTerminalTTY) { - return nil - } - - msg = truncateSpinnerMessageToScreen(msg) - s := spinner.New( - spinner.CharSets[14], - 100*time.Millisecond, - spinner.WithHiddenCursor(true), - spinner.WithWriter(os.Stdout), - spinner.WithSuffix(fmt.Sprintf(" %s", msg)), - ) - s.Start() - return s -} - -// StopSpinnerWithMessage stops a spinner instance and clears it, after writing `finalMsg` -func StopSpinnerWithMessage(spinner *spinner.Spinner, finalMsg string) { - if spinner != nil { - spinner.FinalMSG = finalMsg - spinner.Stop() - } -} - -// StopSpinner stops a spinner instance and clears it -func StopSpinner(spinner *spinner.Spinner) { - if spinner != nil { - spinner.Stop() - } -} - -func ResumeSpinner(spinner *spinner.Spinner) { - if spinner != nil && !spinner.Active() { - spinner.Restart() - } -} - -// UpdateSpinnerMessage updates the message of the given spinner -func UpdateSpinnerMessage(spinner *spinner.Spinner, newMessage string) { - if spinner != nil { - newMessage = truncateSpinnerMessageToScreen(newMessage) - spinner.Suffix = fmt.Sprintf(" %s", newMessage) - } -} diff --git a/executionlayer/execute.go b/executionlayer/execute.go index fa0c07fed..7015ad981 100644 --- a/executionlayer/execute.go +++ b/executionlayer/execute.go @@ -3,6 +3,8 @@ package executionlayer import ( "context" + "github.com/turbot/steampipe/statushooks" + "github.com/turbot/steampipe/db/db_common" "github.com/turbot/steampipe/report/reportevents" "github.com/turbot/steampipe/report/reportexecute" @@ -11,6 +13,9 @@ import ( ) func ExecuteReportNode(ctx context.Context, reportName string, workspace *workspace.Workspace, client db_common.Client) error { + // create context for the report execution + // (for now just disable all status messages - replace with event based? ) + reportCtx := statushooks.DisableStatusHooks(ctx) executionTree, err := reportexecute.NewReportExecutionTree(reportName, client, workspace) if err != nil { return err @@ -19,7 +24,7 @@ func ExecuteReportNode(ctx context.Context, reportName string, workspace *worksp go func() { workspace.PublishReportEvent(&reportevents.ExecutionStarted{ReportNode: executionTree.Root}) - if err := executionTree.Execute(ctx); err != nil { + if err := executionTree.Execute(reportCtx); err != nil { if executionTree.Root.GetRunStatus() == reportinterfaces.ReportRunError { // set error state on the root node executionTree.Root.SetError(err) diff --git a/interactive/interactive_client.go b/interactive/interactive_client.go index 415bf5599..b48204459 100644 --- a/interactive/interactive_client.go +++ b/interactive/interactive_client.go @@ -17,13 +17,14 @@ import ( "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe/cmdconfig" "github.com/turbot/steampipe/constants" + "github.com/turbot/steampipe/contexthelpers" "github.com/turbot/steampipe/db/db_common" - "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/query" "github.com/turbot/steampipe/query/metaquery" "github.com/turbot/steampipe/query/queryhistory" "github.com/turbot/steampipe/query/queryresult" "github.com/turbot/steampipe/schema" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/steampipeconfig" "github.com/turbot/steampipe/utils" "github.com/turbot/steampipe/version" @@ -54,7 +55,11 @@ type InteractiveClient struct { // lock while execution is occurring to avoid errors/warnings being shown executionLock sync.Mutex schemaMetadata *schema.Metadata - highlighter *Highlighter + + highlighter *Highlighter + + // status update hooks + statusHook statushooks.StatusHooks } func getHighlighter(theme string) *Highlighter { @@ -65,7 +70,7 @@ func getHighlighter(theme string) *Highlighter { ) } -func newInteractiveClient(initData *query.InitData, resultsStreamer *queryresult.ResultStreamer) (*InteractiveClient, error) { +func newInteractiveClient(ctx context.Context, initData *query.InitData, resultsStreamer *queryresult.ResultStreamer) (*InteractiveClient, error) { c := &InteractiveClient{ initData: initData, resultsStreamer: resultsStreamer, @@ -75,30 +80,33 @@ func newInteractiveClient(initData *query.InitData, resultsStreamer *queryresult initResultChan: make(chan *db_common.InitResult, 1), highlighter: getHighlighter(viper.GetString(constants.ArgTheme)), } + // asynchronously wait for init to complete // we start this immediately rather than lazy loading as we want to handle errors asap - go c.readInitDataStream() + go c.readInitDataStream(ctx) return c, nil } // InteractivePrompt starts an interactive prompt and return -func (c *InteractiveClient) InteractivePrompt() { +func (c *InteractiveClient) InteractivePrompt(ctx context.Context) { // start a cancel handler for the interactive client - this will call activeQueryCancelFunc if it is set // (registered when we call createQueryContext) - interruptSignalChannel := c.startCancelHandler() + interruptSignalChannel := contexthelpers.StartCancelHandler(c.cancelActiveQueryIfAny) + // create a cancel context for the prompt - this will set c.cancelPrompt - promptCtx := c.createPromptContext() + parentContext := ctx + ctx = c.createPromptContext(parentContext) defer func() { if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) } // close up the SIGINT channel so that the receiver goroutine can quit signal.Stop(interruptSignalChannel) close(interruptSignalChannel) // cleanup the init data to ensure any services we started are stopped - c.initData.Cleanup() + c.initData.Cleanup(ctx) // close the result stream // this needs to be the last thing we do, @@ -106,18 +114,19 @@ func (c *InteractiveClient) InteractivePrompt() { c.resultsStreamer.Close() }() - fmt.Printf("Welcome to Steampipe v%s\n", version.SteampipeVersion.String()) - fmt.Printf("For more information, type %s\n", constants.Bold(".help")) + statushooks.Message(ctx, + fmt.Sprintf("Welcome to Steampipe v%s", version.SteampipeVersion.String()), + fmt.Sprintf("For more information, type %s", constants.Bold(".help"))) // run the prompt in a goroutine, so we can also detect async initialisation errors promptResultChan := make(chan utils.InteractiveExitStatus, 1) - c.runInteractivePromptAsync(promptCtx, &promptResultChan) + c.runInteractivePromptAsync(ctx, &promptResultChan) // select results for { select { case initResult := <-c.initResultChan: - c.handleInitResult(promptCtx, initResult) + c.handleInitResult(ctx, initResult) // if there was an error, handleInitResult will shut down the prompt // - we must wait for it to shut down and not return immediately @@ -129,9 +138,9 @@ func (c *InteractiveClient) InteractivePrompt() { return } // create new context - promptCtx = c.createPromptContext() + ctx = c.createPromptContext(parentContext) // now run it again - c.runInteractivePromptAsync(promptCtx, &promptResultChan) + c.runInteractivePromptAsync(ctx, &promptResultChan) } } } @@ -184,7 +193,7 @@ func (c *InteractiveClient) handleInitResult(ctx context.Context, initResult *db c.ClosePrompt(AfterPromptCloseExit) // add newline to ensure error is not printed at end of current prompt line fmt.Println() - utils.ShowError(initResult.Error) + utils.ShowError(ctx, initResult.Error) return } @@ -231,7 +240,7 @@ func (c *InteractiveClient) runInteractivePrompt(ctx context.Context) (ret utils }() callExecutor := func(line string) { - c.executor(line) + c.executor(ctx, line) } completer := func(d prompt.Document) []prompt.Suggest { return c.queryCompleter(d) @@ -333,7 +342,7 @@ func (c *InteractiveClient) breakMultilinePrompt(buffer *prompt.Buffer) { c.interactiveBuffer = []string{} } -func (c *InteractiveClient) executor(line string) { +func (c *InteractiveClient) executor(ctx context.Context, line string) { // take an execution lock, so that errors and warnings don't show up while // we are underway c.executionLock.Lock() @@ -347,10 +356,10 @@ func (c *InteractiveClient) executor(line string) { // we want to store even if we fail to resolve a query c.interactiveQueryHistory.Push(line) - query, err := c.getQuery(line) + query, err := c.getQuery(ctx, line) if query == "" { if err != nil { - utils.ShowError(utils.HandleCancelError(err)) + utils.ShowError(ctx, utils.HandleCancelError(err)) } // restart the prompt c.restartInteractiveSession() @@ -358,20 +367,20 @@ func (c *InteractiveClient) executor(line string) { } // create a context for the execution of the query - queryContext := c.createQueryContext() + queryContext := c.createQueryContext(ctx) if metaquery.IsMetaQuery(query) { if err := c.executeMetaquery(queryContext, query); err != nil { - utils.ShowError(err) + utils.ShowError(ctx, err) } // cancel the context c.cancelActiveQueryIfAny() } else { // otherwise execute query - result, err := c.client().Execute(queryContext, query, false) + result, err := c.client().Execute(queryContext, query) if err != nil { - utils.ShowError(utils.HandleCancelError(err)) + utils.ShowError(ctx, utils.HandleCancelError(err)) } else { c.resultsStreamer.StreamResult(result) } @@ -381,7 +390,7 @@ func (c *InteractiveClient) executor(line string) { c.restartInteractiveSession() } -func (c *InteractiveClient) getQuery(line string) (string, error) { +func (c *InteractiveClient) getQuery(ctx context.Context, line string) (string, error) { // if it's an empty line, then we don't need to do anything if line == "" { return "", nil @@ -391,23 +400,20 @@ func (c *InteractiveClient) getQuery(line string) (string, error) { if !c.isInitialised() { // create a context used purely to detect cancellation during initialisation // this will also set c.cancelActiveQuery - queryContext := c.createQueryContext() + queryContext := c.createQueryContext(ctx) defer func() { // cancel this context c.cancelActiveQueryIfAny() }() - initDoneChan := make(chan bool) - sp := display.StartSpinnerAfterDelay("Initializing...", constants.SpinnerShowTimeout, initDoneChan) + statushooks.SetStatus(ctx, "Initializing...") // wait for client initialisation to complete - if err := c.waitForInitData(queryContext); err != nil { + err := c.waitForInitData(queryContext) + statushooks.Done(ctx) + if err != nil { // if it failed, report error and quit - close(initDoneChan) - display.StopSpinner(sp) return "", err } - close(initDoneChan) - display.StopSpinner(sp) } // push the current line into the buffer @@ -420,7 +426,7 @@ func (c *InteractiveClient) getQuery(line string) (string, error) { query, _, err := c.workspace().ResolveQueryAndArgs(queryString) if err != nil { // if we fail to resolve, show error but do not return it - we want to stay in the prompt - utils.ShowError(err) + utils.ShowError(ctx, err) return "", nil } isNamedQuery := query != queryString diff --git a/interactive/interactive_client_cancel.go b/interactive/interactive_client_cancel.go index 9357a1837..35831d750 100644 --- a/interactive/interactive_client_cancel.go +++ b/interactive/interactive_client_cancel.go @@ -2,31 +2,21 @@ package interactive import ( "context" - "os" - "os/signal" - "syscall" ) -func (c *InteractiveClient) startCancelHandler() chan os.Signal { - interruptSignalChannel := make(chan os.Signal, 10) - signal.Notify(interruptSignalChannel, syscall.SIGINT, syscall.SIGTERM) - go func() { - for range interruptSignalChannel { - c.cancelActiveQueryIfAny() - } - }() - return interruptSignalChannel -} - // create a cancel context for the interactive prompt, and set c.cancelFunc -func (c *InteractiveClient) createPromptContext() context.Context { - ctx, cancel := context.WithCancel(context.Background()) +func (c *InteractiveClient) createPromptContext(parentContext context.Context) context.Context { + // ensure previous prompt is cleaned up + if c.cancelPrompt != nil { + c.cancelPrompt() + } + ctx, cancel := context.WithCancel(parentContext) c.cancelPrompt = cancel return ctx } -func (c *InteractiveClient) createQueryContext() context.Context { - ctx, cancel := context.WithCancel(context.Background()) +func (c *InteractiveClient) createQueryContext(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) c.cancelActiveQuery = cancel return ctx } diff --git a/interactive/interactive_client_init.go b/interactive/interactive_client_init.go index 1c7e57cdb..9f7266833 100644 --- a/interactive/interactive_client_init.go +++ b/interactive/interactive_client_init.go @@ -16,11 +16,11 @@ import ( var initTimeout = 40 * time.Second -func (c *InteractiveClient) readInitDataStream() { +func (c *InteractiveClient) readInitDataStream(ctx context.Context) { defer func() { if r := recover(); r != nil { c.interactivePrompt.ClearScreen() - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) } }() @@ -40,15 +40,15 @@ func (c *InteractiveClient) readInitDataStream() { // start the workspace file watcher if viper.GetBool(constants.ArgWatch) { // provide an explicit error handler which re-renders the prompt after displaying the error - c.initData.Result.Error = c.initData.Workspace.SetupWatcher(c.initData.Client, c.workspaceWatcherErrorHandler) + c.initData.Result.Error = c.initData.Workspace.SetupWatcher(ctx, c.initData.Client, c.workspaceWatcherErrorHandler) } c.initResultChan <- c.initData.Result } -func (c *InteractiveClient) workspaceWatcherErrorHandler(err error) { +func (c *InteractiveClient) workspaceWatcherErrorHandler(ctx context.Context, err error) { fmt.Println() - utils.ShowError(err) + utils.ShowError(ctx, err) c.interactivePrompt.Render() } diff --git a/interactive/run.go b/interactive/run.go index b610e1731..6ca3f9521 100644 --- a/interactive/run.go +++ b/interactive/run.go @@ -1,6 +1,8 @@ package interactive import ( + "context" + "github.com/turbot/steampipe/constants" "github.com/turbot/steampipe/db/db_local" "github.com/turbot/steampipe/query" @@ -9,19 +11,19 @@ import ( ) // RunInteractivePrompt starts the interactive query prompt -func RunInteractivePrompt(initData *query.InitData) (*queryresult.ResultStreamer, error) { +func RunInteractivePrompt(ctx context.Context, initData *query.InitData) (*queryresult.ResultStreamer, error) { resultsStreamer := queryresult.NewResultStreamer() - interactiveClient, err := newInteractiveClient(initData, resultsStreamer) + interactiveClient, err := newInteractiveClient(ctx, initData, resultsStreamer) if err != nil { - utils.ShowErrorWithMessage(err, "interactive client failed to initialize") + utils.ShowErrorWithMessage(ctx, err, "interactive client failed to initialize") // do not bind shutdown to any cancellable context - db_local.ShutdownService(constants.InvokerQuery) + db_local.ShutdownService(ctx, constants.InvokerQuery) return nil, err } // start the interactive prompt in a go routine - go interactiveClient.InteractivePrompt() + go interactiveClient.InteractivePrompt(ctx) return resultsStreamer, nil } diff --git a/main.go b/main.go index 20d42c16d..ac838f243 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" @@ -17,11 +18,12 @@ var Logger hclog.Logger var exitCode int func main() { + ctx := context.Background() utils.LogTime("main start") exitCode := 0 defer func() { if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) } utils.LogTime("main end") utils.DisplayProfileData() @@ -29,7 +31,7 @@ func main() { }() // ensure steampipe is not being run as root - checkRoot() + checkRoot(ctx) // increase the soft ULIMIT to match the hard limit err := setULimit() @@ -60,10 +62,10 @@ func setULimit() error { // this is to replicate the user security mechanism of out underlying // postgresql engine. -func checkRoot() { +func checkRoot(ctx context.Context) { if os.Geteuid() == 0 { exitCode = 1 - utils.ShowError(fmt.Errorf(`Steampipe cannot be run as the "root" user. + utils.ShowError(ctx, fmt.Errorf(`Steampipe cannot be run as the "root" user. To reduce security risk, use an unprivileged user account instead.`)) os.Exit(exitCode) } @@ -79,7 +81,7 @@ To reduce security risk, use an unprivileged user account instead.`)) if os.Geteuid() != os.Getuid() { exitCode = 1 - utils.ShowError(fmt.Errorf("real and effective user IDs must match.")) + utils.ShowError(ctx, fmt.Errorf("real and effective user IDs must match.")) os.Exit(exitCode) } } diff --git a/modinstaller/uninstall.go b/modinstaller/uninstall.go index 381aec744..9f8ea6cbf 100644 --- a/modinstaller/uninstall.go +++ b/modinstaller/uninstall.go @@ -1,16 +1,18 @@ package modinstaller import ( + "context" + "github.com/turbot/go-kit/helpers" "github.com/turbot/steampipe/utils" ) -func UninstallWorkspaceDependencies(opts *InstallOpts) (*InstallData, error) { +func UninstallWorkspaceDependencies(ctx context.Context, opts *InstallOpts) (*InstallData, error) { utils.LogTime("cmd.UninstallWorkspaceDependencies") defer func() { utils.LogTime("cmd.UninstallWorkspaceDependencies end") if r := recover(); r != nil { - utils.ShowError(helpers.ToError(r)) + utils.ShowError(ctx, helpers.ToError(r)) } }() diff --git a/plugin/actions.go b/plugin/actions.go index 2b9e45d20..b2e80cdb3 100644 --- a/plugin/actions.go +++ b/plugin/actions.go @@ -7,10 +7,10 @@ import ( "path/filepath" "strings" - "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/filepaths" "github.com/turbot/steampipe/ociinstaller" "github.com/turbot/steampipe/ociinstaller/versionfile" + "github.com/turbot/steampipe/statushooks" "github.com/turbot/steampipe/steampipeconfig/modconfig" "github.com/turbot/steampipe/utils" ) @@ -22,9 +22,9 @@ const ( ) // Remove removes an installed plugin -func Remove(image string, pluginConnections map[string][]modconfig.Connection) error { - spinner := display.ShowSpinner(fmt.Sprintf("Removing plugin %s", image)) - defer display.StopSpinner(spinner) +func Remove(ctx context.Context, image string, pluginConnections map[string][]modconfig.Connection) error { + statushooks.SetStatus(ctx, fmt.Sprintf("Removing plugin %s", image)) + defer statushooks.Done(ctx) fullPluginName := ociinstaller.NewSteampipeImageRef(image).DisplayImageRef() @@ -60,7 +60,7 @@ func Remove(image string, pluginConnections map[string][]modconfig.Connection) e connFiles := Unique(files) if len(connFiles) > 0 { - display.StopSpinner(spinner) + str := []string{fmt.Sprintf("\nUninstalled plugin %s\n\nNote: the following %s %s %s steampipe %s using the '%s' plugin:", image, utils.Pluralize("file", len(connFiles)), utils.Pluralize("has", len(connFiles)), utils.Pluralize("a", len(conns)), utils.Pluralize("connection", len(conns)), image)} for _, file := range connFiles { str = append(str, fmt.Sprintf("\n \t* file: %s", file)) @@ -78,7 +78,7 @@ func Remove(image string, pluginConnections map[string][]modconfig.Connection) e } } str = append(str, fmt.Sprintf("\nPlease remove %s to continue using steampipe", utils.Pluralize("it", len(connFiles)))) - fmt.Println(strings.Join(str, "\n")) + statushooks.Message(ctx, str...) fmt.Println() } return err diff --git a/query/init_data.go b/query/init_data.go index 2e35375c2..6841241db 100644 --- a/query/init_data.go +++ b/query/init_data.go @@ -37,7 +37,7 @@ func NewInitData(ctx context.Context, w *workspace.Workspace, args []string) *In return i } -func (i *InitData) Cleanup() { +func (i *InitData) Cleanup(ctx context.Context) { // cancel any ongoing operation if i.cancel != nil { i.cancel() @@ -50,7 +50,7 @@ func (i *InitData) Cleanup() { // if a client was initialised, close it if i.Client != nil { - i.Client.Close() + i.Client.Close(ctx) } } diff --git a/query/queryexecute/execute.go b/query/queryexecute/execute.go index 3f59bee7a..13742bb5d 100644 --- a/query/queryexecute/execute.go +++ b/query/queryexecute/execute.go @@ -6,6 +6,7 @@ import ( "github.com/spf13/viper" "github.com/turbot/steampipe/constants" + "github.com/turbot/steampipe/contexthelpers" "github.com/turbot/steampipe/db/db_common" "github.com/turbot/steampipe/display" "github.com/turbot/steampipe/interactive" @@ -13,34 +14,36 @@ import ( "github.com/turbot/steampipe/utils" ) -func RunInteractiveSession(initData *query.InitData) { +func RunInteractiveSession(ctx context.Context, initData *query.InitData) { utils.LogTime("execute.RunInteractiveSession start") defer utils.LogTime("execute.RunInteractiveSession end") // the db executor sends result data over resultsStreamer - resultsStreamer, err := interactive.RunInteractivePrompt(initData) + resultsStreamer, err := interactive.RunInteractivePrompt(ctx, initData) utils.FailOnError(err) // print the data as it comes for r := range resultsStreamer.Results { - display.ShowOutput(r) + display.ShowOutput(ctx, r) // signal to the resultStreamer that we are done with this chunk of the stream resultsStreamer.AllResultsRead() } } func RunBatchSession(ctx context.Context, initData *query.InitData) int { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + + // start cancel handler to intercept interurpts and cancel the context + contexthelpers.StartCancelHandler(cancel) + // wait for init <-initData.Loaded if err := initData.Result.Error; err != nil { utils.FailOnError(err) } // ensure we close client - defer func() { - if initData.Client != nil { - initData.Client.Close() - } - }() + defer initData.Cleanup(ctx) // display any initialisation messages/warnings initData.Result.DisplayMessages() @@ -86,7 +89,7 @@ func executeQuery(ctx context.Context, queryString string, client db_common.Clie // print the data as it comes for r := range resultsStreamer.Results { - display.ShowOutput(r) + display.ShowOutput(ctx, r) // signal to the resultStreamer that we are done with this result resultsStreamer.AllResultsRead() } diff --git a/report/reportexecute/report_execution_tree.go b/report/reportexecute/report_execution_tree.go index 1b3070412..514b46447 100644 --- a/report/reportexecute/report_execution_tree.go +++ b/report/reportexecute/report_execution_tree.go @@ -166,7 +166,7 @@ func (e *ReportExecutionTree) ExecuteNode(ctx context.Context, name string) erro } func (e *ReportExecutionTree) executePanelSQL(ctx context.Context, query string) ([][]interface{}, error) { - queryResult, err := e.client.ExecuteSync(ctx, query, true) + queryResult, err := e.client.ExecuteSync(ctx, query) if err != nil { return nil, err } diff --git a/report/reportserver/server.go b/report/reportserver/server.go index 4148f6fe9..59fb0a1c8 100644 --- a/report/reportserver/server.go +++ b/report/reportserver/server.go @@ -46,7 +46,7 @@ type ReportClientInfo struct { } func NewServer(ctx context.Context) (*Server, error) { - dbClient, err := db_local.GetLocalClient(ctx, constants.InvokerReport) + var dbClient, err = db_local.GetLocalClient(ctx, constants.InvokerReport) if err != nil { return nil, err } @@ -57,7 +57,7 @@ func NewServer(ctx context.Context) (*Server, error) { } refreshResult.ShowWarnings() - loadedWorkspace, err := workspace.Load(viper.GetString(constants.ArgWorkspaceChDir)) + loadedWorkspace, err := workspace.Load(ctx, viper.GetString(constants.ArgWorkspaceChDir)) if err != nil { return nil, err } @@ -78,7 +78,7 @@ func NewServer(ctx context.Context) (*Server, error) { } loadedWorkspace.RegisterReportEventHandler(server.HandleWorkspaceUpdate) - err = loadedWorkspace.SetupWatcher(dbClient, nil) + err = loadedWorkspace.SetupWatcher(ctx, dbClient, nil) return server, err } @@ -129,10 +129,10 @@ func (s *Server) Start() { StartAPI(s.context, s.webSocket) } -func (s *Server) Shutdown() { +func (s *Server) Shutdown(ctx context.Context) { // Close the DB client if s.dbClient != nil { - s.dbClient.Close() + s.dbClient.Close(ctx) } if s.webSocket != nil { diff --git a/statushooks/context.go b/statushooks/context.go new file mode 100644 index 000000000..eb695e51a --- /dev/null +++ b/statushooks/context.go @@ -0,0 +1,42 @@ +package statushooks + +import ( + "context" + + "github.com/turbot/steampipe/contexthelpers" +) + +var ( + contextKeyStatusHook = contexthelpers.ContextKey("status_hook") +) + +func DisableStatusHooks(ctx context.Context) context.Context { + return AddStatusHooksToContext(ctx, NullHooks) +} + +func AddStatusHooksToContext(ctx context.Context, statusHooks StatusHooks) context.Context { + return context.WithValue(ctx, contextKeyStatusHook, statusHooks) +} + +func StatusHooksFromContext(ctx context.Context) StatusHooks { + if ctx == nil { + return NullHooks + } + if val, ok := ctx.Value(contextKeyStatusHook).(StatusHooks); ok { + return val + } + // no status hook in context - return null status hook + return NullHooks +} + +func SetStatus(ctx context.Context, msg string) { + StatusHooksFromContext(ctx).SetStatus(msg) +} + +func Done(ctx context.Context) { + StatusHooksFromContext(ctx).Done() +} + +func Message(ctx context.Context, msgs ...string) { + StatusHooksFromContext(ctx).Message(msgs...) +} diff --git a/statushooks/status_hooks.go b/statushooks/status_hooks.go new file mode 100644 index 000000000..12dfaf4aa --- /dev/null +++ b/statushooks/status_hooks.go @@ -0,0 +1,15 @@ +package statushooks + +type StatusHooks interface { + SetStatus(string) + Done() + Message(...string) +} + +var NullHooks = &NullStatusHook{} + +type NullStatusHook struct{} + +func (*NullStatusHook) SetStatus(string) {} +func (*NullStatusHook) Done() {} +func (*NullStatusHook) Message(...string) {} diff --git a/statusspinner/spinner.go b/statusspinner/spinner.go new file mode 100644 index 000000000..ad953567e --- /dev/null +++ b/statusspinner/spinner.go @@ -0,0 +1,137 @@ +package statusspinner + +import ( + "fmt" + "os" + "strings" + "time" + + "github.com/briandowns/spinner" + "github.com/karrick/gows" +) + +// +// spinner format: +// +// 1 1 [.......] 1 1 1 1 1 +// We need at least seven characters to show the spinner properly +// +// Not using the (…) character, since it is too small +// +const minSpinnerWidth = 7 + +// StatusSpinner is a struct which implements StatusHooks, and uses a spinner to display status messages +type StatusSpinner struct { + spinner *spinner.Spinner + delay time.Duration + cancel chan struct{} +} + +type StatusSpinnerOpt func(*StatusSpinner) + +func WithMessage(msg string) StatusSpinnerOpt { + return func(s *StatusSpinner) { + s.UpdateSpinnerMessage(msg) + } +} + +func WithDelay(delay time.Duration) StatusSpinnerOpt { + return func(s *StatusSpinner) { + s.delay = delay + } +} + +func NewStatusSpinner(opts ...StatusSpinnerOpt) *StatusSpinner { + res := &StatusSpinner{} + + res.spinner = spinner.New( + spinner.CharSets[14], + 100*time.Millisecond, + spinner.WithHiddenCursor(true), + spinner.WithWriter(os.Stdout), + ) + for _, opt := range opts { + opt(res) + } + + return res +} + +// SetStatus implements StatusHooks +func (s *StatusSpinner) SetStatus(msg string) { + s.UpdateSpinnerMessage(msg) + if !s.spinner.Active() { + s.startSpinner() + } +} + +func (s *StatusSpinner) startSpinner() { + if s.cancel != nil { + // if there is a cancel channel, we are already waiting for the service to start after a delay + return + } + if s.delay == 0 { + s.spinner.Start() + return + } + + s.cancel = make(chan struct{}, 1) + go func() { + select { + case <-s.cancel: + case <-time.After(s.delay): + s.spinner.Start() + s.cancel = nil + } + time.Sleep(50 * time.Millisecond) + }() +} + +func (s *StatusSpinner) Message(msgs ...string) { + if s.spinner.Active() { + s.spinner.Stop() + defer s.spinner.Start() + } + for _, msg := range msgs { + fmt.Println(msg) + } +} + +// Done implements StatusHooks +func (s *StatusSpinner) Done() { + if s.cancel != nil { + close(s.cancel) + } + s.closeSpinner() +} + +// UpdateSpinnerMessage updates the message of the given spinner +func (s *StatusSpinner) UpdateSpinnerMessage(newMessage string) { + newMessage = s.truncateSpinnerMessageToScreen(newMessage) + s.spinner.Suffix = fmt.Sprintf(" %s", newMessage) +} + +func (s *StatusSpinner) closeSpinner() { + if s.spinner != nil { + s.spinner.Stop() + } +} + +func (s *StatusSpinner) truncateSpinnerMessageToScreen(msg string) string { + if len(strings.TrimSpace(msg)) == 0 { + // if this is a blank message, return it as is + return msg + } + + maxCols, _, _ := gows.GetWinSize() + // if the screen is smaller than the minimum spinner width, we cannot truncate + if maxCols < minSpinnerWidth { + return msg + } + availableColumns := maxCols - minSpinnerWidth + if len(msg) > availableColumns { + msg = msg[:availableColumns] + msg = fmt.Sprintf("%s ...", msg) + } + return msg +} diff --git a/utils/errors.go b/utils/errors.go index 9c3fb4420..3aa105adc 100644 --- a/utils/errors.go +++ b/utils/errors.go @@ -9,11 +9,12 @@ import ( "github.com/fatih/color" "github.com/shiena/ansicolor" + "github.com/turbot/steampipe/statushooks" ) var ( - colorErr = color.RedString("Error") - colorWarn = color.YellowString("Warning") + colorErr = color.RedString("Error") + colorWarn = color.YellowString("Warning") ) func init() { @@ -34,20 +35,22 @@ func FailOnErrorWithMessage(err error, message string) { } } -func ShowError(err error) { +func ShowError(ctx context.Context, err error) { if err == nil { return } err = HandleCancelError(err) + statushooks.Done(ctx) fmt.Fprintf(color.Output, "%s: %v\n", colorErr, TransformErrorToSteampipe(err)) } // ShowErrorWithMessage displays the given error nicely with the given message -func ShowErrorWithMessage(err error, message string) { +func ShowErrorWithMessage(ctx context.Context, err error, message string) { if err == nil { return } err = HandleCancelError(err) + statushooks.Done(ctx) fmt.Fprintf(color.Output, "%s: %s - %v\n", colorErr, message, TransformErrorToSteampipe(err)) } diff --git a/workspace/workspace.go b/workspace/workspace.go index 9cc53caf0..22677df9b 100644 --- a/workspace/workspace.go +++ b/workspace/workspace.go @@ -2,6 +2,7 @@ package workspace import ( "bufio" + "context" "fmt" "log" "os" @@ -42,7 +43,7 @@ type Workspace struct { exclusions []string // should we load/watch files recursively listFlag filehelpers.ListFlag - fileWatcherErrorHandler func(error) + fileWatcherErrorHandler func(context.Context, error) watcherError error // event handlers reportEventHandlers []reportevents.ReportEventHandler @@ -53,7 +54,7 @@ type Workspace struct { } // Load creates a Workspace and loads the workspace mod -func Load(workspacePath string) (*Workspace, error) { +func Load(ctx context.Context, workspacePath string) (*Workspace, error) { utils.LogTime("workspace.Load start") defer utils.LogTime("workspace.Load end") @@ -72,7 +73,7 @@ func Load(workspacePath string) (*Workspace, error) { } // load the workspace mod - if err := workspace.loadWorkspaceMod(); err != nil { + if err := workspace.loadWorkspaceMod(ctx); err != nil { return nil, err } @@ -101,7 +102,7 @@ func LoadResourceNames(workspacePath string) (*modconfig.WorkspaceResources, err return workspace.loadWorkspaceResourceName() } -func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(error)) error { +func (w *Workspace) SetupWatcher(ctx context.Context, client db_common.Client, errorHandler func(context.Context, error)) error { watcherOptions := &utils.WatcherOptions{ Directories: []string{w.Path}, Include: filehelpers.InclusionsFromExtensions(steampipeconfig.GetModFileExtensions()), @@ -112,7 +113,7 @@ func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(erro // decide how to handle them // OnError: errCallback, OnChange: func(events []fsnotify.Event) { - w.handleFileWatcherEvent(client, events) + w.handleFileWatcherEvent(ctx, client, events) }, } watcher, err := utils.NewWatcher(watcherOptions) @@ -127,9 +128,9 @@ func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(erro // after a file watcher event w.fileWatcherErrorHandler = errorHandler if w.fileWatcherErrorHandler == nil { - w.fileWatcherErrorHandler = func(err error) { + w.fileWatcherErrorHandler = func(ctx context.Context, err error) { fmt.Println() - utils.ShowErrorWithMessage(err, "Failed to reload mod from file watcher") + utils.ShowErrorWithMessage(ctx, err, "Failed to reload mod from file watcher") } } @@ -249,11 +250,11 @@ func (w *Workspace) setModfileExists() { } } -func (w *Workspace) loadWorkspaceMod() error { +func (w *Workspace) loadWorkspaceMod(ctx context.Context) error { // clear all resource maps w.reset() // load and evaluate all variables - inputVariables, err := w.getAllVariables() + inputVariables, err := w.getAllVariables(ctx) if err != nil { return err } diff --git a/workspace/workspace_events.go b/workspace/workspace_events.go index 9dcb9a722..2b77560f5 100644 --- a/workspace/workspace_events.go +++ b/workspace/workspace_events.go @@ -22,7 +22,7 @@ func (w *Workspace) RegisterReportEventHandler(handler reportevents.ReportEventH w.reportEventHandlers = append(w.reportEventHandlers, handler) } -func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsnotify.Event) { +func (w *Workspace) handleFileWatcherEvent(ctx context.Context, client db_common.Client, events []fsnotify.Event) { w.loadLock.Lock() defer w.loadLock.Unlock() @@ -32,11 +32,11 @@ func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsn prevResourceMaps := w.GetResourceMaps() // now reload the workspace - err := w.loadWorkspaceMod() + err := w.loadWorkspaceMod(ctx) if err != nil { // check the existing watcher error - if we are already in an error state, do not show error if w.watcherError == nil { - w.fileWatcherErrorHandler(utils.PrefixError(err, "Failed to reload workspace")) + w.fileWatcherErrorHandler(ctx, utils.PrefixError(err, "Failed to reload workspace")) } // now set watcher error to new error w.watcherError = err @@ -53,7 +53,7 @@ func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsn res := client.RefreshSessions(context.Background()) if res.Error != nil || len(res.Warnings) > 0 { fmt.Println() - utils.ShowErrorWithMessage(res.Error, "error when refreshing session data") + utils.ShowErrorWithMessage(ctx, res.Error, "error when refreshing session data") utils.ShowWarning(strings.Join(res.Warnings, "\n")) if w.onFileWatcherEventMessages != nil { w.onFileWatcherEventMessages() diff --git a/workspace/workspace_test.go b/workspace/workspace_test.go index baf1f7ddb..684f9ba55 100644 --- a/workspace/workspace_test.go +++ b/workspace/workspace_test.go @@ -1,6 +1,7 @@ package workspace import ( + "context" "fmt" "path/filepath" "strings" @@ -135,7 +136,7 @@ var testCasesLoadWorkspace = map[string]loadWorkspaceTest{ func TestLoadWorkspace(t *testing.T) { for name, test := range testCasesLoadWorkspace { workspacePath, err := filepath.Abs(test.source) - workspace, err := Load(workspacePath) + workspace, err := Load(context.Background(), workspacePath) if err != nil { if test.expected != "ERROR" { diff --git a/workspace/workspace_variables.go b/workspace/workspace_variables.go index b3e435922..99cbe40e1 100644 --- a/workspace/workspace_variables.go +++ b/workspace/workspace_variables.go @@ -1,6 +1,7 @@ package workspace import ( + "context" "fmt" "sort" "strings" @@ -14,7 +15,7 @@ import ( "github.com/turbot/steampipe/utils" ) -func (w *Workspace) getAllVariables() (map[string]*modconfig.Variable, error) { +func (w *Workspace) getAllVariables(ctx context.Context) (map[string]*modconfig.Variable, error) { // build options used to load workspace runCtx, err := w.getRunContext() if err != nil { @@ -40,7 +41,7 @@ func (w *Workspace) getAllVariables() (map[string]*modconfig.Variable, error) { return nil, err } - if err := validateVariables(variableMap, inputVariables); err != nil { + if err := validateVariables(ctx, variableMap, inputVariables); err != nil { return nil, err } @@ -73,21 +74,21 @@ func (w *Workspace) getInputVariables(variableMap map[string]*modconfig.Variable return parsedValues, diags.Err() } -func validateVariables(variableMap map[string]*modconfig.Variable, variables inputvars.InputValues) error { +func validateVariables(ctx context.Context, variableMap map[string]*modconfig.Variable, variables inputvars.InputValues) error { diags := inputvars.CheckInputVariables(variableMap, variables) if diags.HasErrors() { - displayValidationErrors(diags) + displayValidationErrors(ctx, diags) // return empty error return modconfig.VariableValidationFailedError{} } return nil } -func displayValidationErrors(diags tfdiags.Diagnostics) { +func displayValidationErrors(ctx context.Context, diags tfdiags.Diagnostics) { fmt.Println() for i, diag := range diags { - utils.ShowError(fmt.Errorf("%s", constants.Bold(diag.Description().Summary))) + utils.ShowError(ctx, fmt.Errorf("%s", constants.Bold(diag.Description().Summary))) fmt.Println(diag.Description().Detail) if i < len(diags)-1 { fmt.Println()