mirror of
https://github.com/turbot/steampipe.git
synced 2026-03-24 11:00:34 -04:00
Decouple spinner display code from database and execution layer. Closes #1290
This commit is contained in:
59
cmd/check.go
59
cmd/check.go
@@ -10,12 +10,12 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
"github.com/turbot/steampipe/cmdconfig"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/contexthelpers"
|
||||
"github.com/turbot/steampipe/control"
|
||||
"github.com/turbot/steampipe/control/controldisplay"
|
||||
"github.com/turbot/steampipe/control/controlexecute"
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/turbot/steampipe/db/db_local"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/modinstaller"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
"github.com/turbot/steampipe/workspace"
|
||||
)
|
||||
@@ -88,16 +89,21 @@ You may specify one or more benchmarks or controls to run (separated by a space)
|
||||
func runCheckCmd(cmd *cobra.Command, args []string) {
|
||||
utils.LogTime("runCheckCmd start")
|
||||
initData := &control.InitData{}
|
||||
|
||||
// setup a cancel context and start cancel handler
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
contexthelpers.StartCancelHandler(cancel)
|
||||
|
||||
defer func() {
|
||||
utils.LogTime("runCheckCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
|
||||
if initData.Client != nil {
|
||||
log.Printf("[TRACE] close client")
|
||||
initData.Client.Close()
|
||||
initData.Client.Close(ctx)
|
||||
}
|
||||
if initData.Workspace != nil {
|
||||
initData.Workspace.Close()
|
||||
@@ -105,24 +111,22 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
|
||||
}()
|
||||
|
||||
// verify we have an argument
|
||||
if !validateArgs(cmd, args) {
|
||||
if !validateArgs(ctx, cmd, args) {
|
||||
return
|
||||
}
|
||||
|
||||
var spinner *spinner.Spinner
|
||||
if viper.GetBool(constants.ArgProgress) {
|
||||
spinner = display.ShowSpinner("Initializing...")
|
||||
// if progress is disabled, update context to contain a null status hooks object
|
||||
if !viper.GetBool(constants.ArgProgress) {
|
||||
statushooks.DisableStatusHooks(ctx)
|
||||
}
|
||||
|
||||
// initialise
|
||||
initData = initialiseCheck(cmd.Context(), spinner)
|
||||
display.StopSpinner(spinner)
|
||||
if shouldExit := handleCheckInitResult(initData); shouldExit {
|
||||
initData = initialiseCheck(ctx)
|
||||
if shouldExit := handleCheckInitResult(ctx, initData); shouldExit {
|
||||
return
|
||||
}
|
||||
|
||||
// pull out useful properties
|
||||
ctx := initData.Ctx
|
||||
workspace := initData.Workspace
|
||||
client := initData.Client
|
||||
failures := 0
|
||||
@@ -165,7 +169,7 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
|
||||
exportWaitGroup.Wait()
|
||||
|
||||
if len(exportErrors) > 0 {
|
||||
utils.ShowError(utils.CombineErrors(exportErrors...))
|
||||
utils.ShowError(ctx, utils.CombineErrors(exportErrors...))
|
||||
}
|
||||
|
||||
if shouldPrintTiming() {
|
||||
@@ -176,10 +180,10 @@ func runCheckCmd(cmd *cobra.Command, args []string) {
|
||||
exitCode = failures
|
||||
}
|
||||
|
||||
func validateArgs(cmd *cobra.Command, args []string) bool {
|
||||
func validateArgs(ctx context.Context, cmd *cobra.Command, args []string) bool {
|
||||
if len(args) == 0 {
|
||||
fmt.Println()
|
||||
utils.ShowError(fmt.Errorf("you must provide at least one argument"))
|
||||
utils.ShowError(ctx, fmt.Errorf("you must provide at least one argument"))
|
||||
fmt.Println()
|
||||
cmd.Help()
|
||||
fmt.Println()
|
||||
@@ -189,11 +193,13 @@ func validateArgs(cmd *cobra.Command, args []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.InitData {
|
||||
func initialiseCheck(ctx context.Context) *control.InitData {
|
||||
statushooks.SetStatus(ctx, "Initializing...")
|
||||
defer statushooks.Done(ctx)
|
||||
|
||||
initData := &control.InitData{
|
||||
Result: &db_common.InitResult{},
|
||||
}
|
||||
|
||||
if viper.GetBool(constants.ArgModInstall) {
|
||||
opts := &modinstaller.InstallOpts{WorkspacePath: viper.GetString(constants.ArgWorkspaceChDir)}
|
||||
_, err := modinstaller.InstallWorkspaceDependencies(opts)
|
||||
@@ -202,9 +208,6 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
|
||||
return initData
|
||||
}
|
||||
}
|
||||
|
||||
cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, false)
|
||||
|
||||
err := validateOutputFormat()
|
||||
if err != nil {
|
||||
initData.Result.Error = err
|
||||
@@ -217,10 +220,6 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
|
||||
return initData
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
startCancelHandler(cancel)
|
||||
initData.Ctx = ctx
|
||||
|
||||
// set color schema
|
||||
err = initialiseColorScheme()
|
||||
if err != nil {
|
||||
@@ -228,7 +227,7 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
|
||||
return initData
|
||||
}
|
||||
// load workspace
|
||||
initData.Workspace, err = loadWorkspacePromptingForVariables(ctx, spinner)
|
||||
initData.Workspace, err = loadWorkspacePromptingForVariables(ctx)
|
||||
if err != nil {
|
||||
if !utils.IsCancelledError(err) {
|
||||
err = utils.PrefixError(err, "failed to load workspace")
|
||||
@@ -248,18 +247,14 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
|
||||
initData.Result.AddWarnings("no controls found in current workspace")
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Connecting to service...")
|
||||
statushooks.SetStatus(ctx, "Connecting to service...")
|
||||
// get a client
|
||||
var client db_common.Client
|
||||
if connectionString := viper.GetString(constants.ArgConnectionString); connectionString != "" {
|
||||
client, err = db_client.NewDbClient(ctx, connectionString)
|
||||
} else {
|
||||
// stop the spinner
|
||||
display.StopSpinner(spinner)
|
||||
// when starting the database, installers may trigger their own spinners
|
||||
client, err = db_local.GetLocalClient(ctx, constants.InvokerCheck)
|
||||
// resume the spinner
|
||||
display.ResumeSpinner(spinner)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -288,13 +283,13 @@ func initialiseCheck(ctx context.Context, spinner *spinner.Spinner) *control.Ini
|
||||
return initData
|
||||
}
|
||||
|
||||
func handleCheckInitResult(initData *control.InitData) bool {
|
||||
func handleCheckInitResult(ctx context.Context, initData *control.InitData) bool {
|
||||
// if there is an error or cancellation we bomb out
|
||||
// check for the various kinds of failures
|
||||
utils.FailOnError(initData.Result.Error)
|
||||
// cancelled?
|
||||
if initData.Ctx != nil {
|
||||
utils.FailOnError(initData.Ctx.Err())
|
||||
if ctx != nil {
|
||||
utils.FailOnError(ctx.Err())
|
||||
}
|
||||
|
||||
// if there is a usage warning we display it
|
||||
|
||||
17
cmd/mod.go
17
cmd/mod.go
@@ -53,11 +53,12 @@ func modInstallCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runModInstallCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("cmd.runModInstallCmd")
|
||||
defer func() {
|
||||
utils.LogTime("cmd.runModInstallCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
@@ -88,17 +89,18 @@ func modUninstallCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runModUninstallCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("cmd.runModInstallCmd")
|
||||
defer func() {
|
||||
utils.LogTime("cmd.runModInstallCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
|
||||
opts := newInstallOpts(cmd, args...)
|
||||
installData, err := modinstaller.UninstallWorkspaceDependencies(opts)
|
||||
installData, err := modinstaller.UninstallWorkspaceDependencies(ctx, opts)
|
||||
utils.FailOnError(err)
|
||||
|
||||
fmt.Println(modinstaller.BuildUninstallSummary(installData))
|
||||
@@ -122,11 +124,12 @@ func modUpdateCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runModUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("cmd.runModUpdateCmd")
|
||||
defer func() {
|
||||
utils.LogTime("cmd.runModUpdateCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
@@ -153,11 +156,12 @@ func modListCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runModListCmd(cmd *cobra.Command, _ []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("cmd.runModListCmd")
|
||||
defer func() {
|
||||
utils.LogTime("cmd.runModListCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
@@ -186,11 +190,12 @@ func modInitCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runModInitCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("cmd.runModInitCmd")
|
||||
defer func() {
|
||||
utils.LogTime("cmd.runModInitCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/turbot/steampipe/ociinstaller/versionfile"
|
||||
"github.com/turbot/steampipe/plugin"
|
||||
"github.com/turbot/steampipe/statefile"
|
||||
"github.com/turbot/steampipe/statusspinner"
|
||||
"github.com/turbot/steampipe/steampipeconfig"
|
||||
"github.com/turbot/steampipe/steampipeconfig/modconfig"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
@@ -176,11 +177,12 @@ Example:
|
||||
// exitCode=3 For errors related to loading state, loading version data or an issue contacting the update server.
|
||||
// exitCode=4 For plugin listing failures
|
||||
func runPluginInstallCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runPluginInstallCmd install")
|
||||
defer func() {
|
||||
utils.LogTime("runPluginInstallCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
@@ -193,7 +195,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
|
||||
|
||||
if len(plugins) == 0 {
|
||||
fmt.Println()
|
||||
utils.ShowError(fmt.Errorf("you need to provide at least one plugin to install"))
|
||||
utils.ShowError(ctx, fmt.Errorf("you need to provide at least one plugin to install"))
|
||||
fmt.Println()
|
||||
cmd.Help()
|
||||
fmt.Println()
|
||||
@@ -204,7 +206,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
|
||||
// a leading blank line - since we always output multiple lines
|
||||
fmt.Println()
|
||||
|
||||
spinner := display.ShowSpinner("")
|
||||
statusSpinner := statusspinner.NewStatusSpinner()
|
||||
|
||||
for _, p := range plugins {
|
||||
isPluginExists, _ := plugin.Exists(p)
|
||||
@@ -217,7 +219,7 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
|
||||
})
|
||||
continue
|
||||
}
|
||||
display.UpdateSpinnerMessage(spinner, fmt.Sprintf("Installing plugin: %s", p))
|
||||
statusSpinner.SetStatus(fmt.Sprintf("Installing plugin: %s", p))
|
||||
image, err := plugin.Install(cmd.Context(), p)
|
||||
if err != nil {
|
||||
msg := ""
|
||||
@@ -250,9 +252,9 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
|
||||
})
|
||||
}
|
||||
|
||||
display.StopSpinner(spinner)
|
||||
statusSpinner.Done()
|
||||
|
||||
refreshConnectionsIfNecessary(cmd.Context(), installReports, false)
|
||||
refreshConnectionsIfNecessary(cmd.Context(), installReports, true)
|
||||
display.PrintInstallReports(installReports, false)
|
||||
|
||||
// a concluding blank line - since we always output multiple lines
|
||||
@@ -260,11 +262,12 @@ func runPluginInstallCmd(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runPluginUpdateCmd install")
|
||||
defer func() {
|
||||
utils.LogTime("runPluginUpdateCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
@@ -275,7 +278,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
plugins, err := resolveUpdatePluginsFromArgs(args)
|
||||
if err != nil {
|
||||
fmt.Println()
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
fmt.Println()
|
||||
cmd.Help()
|
||||
fmt.Println()
|
||||
@@ -285,7 +288,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
|
||||
state, err := statefile.LoadState()
|
||||
if err != nil {
|
||||
utils.ShowError(fmt.Errorf("could not load state"))
|
||||
utils.ShowError(ctx, fmt.Errorf("could not load state"))
|
||||
exitCode = 3
|
||||
return
|
||||
}
|
||||
@@ -293,7 +296,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
// load up the version file data
|
||||
versionData, err := versionfile.LoadPluginVersionFile()
|
||||
if err != nil {
|
||||
utils.ShowError(fmt.Errorf("error loading current plugin data"))
|
||||
utils.ShowError(ctx, fmt.Errorf("error loading current plugin data"))
|
||||
exitCode = 3
|
||||
return
|
||||
}
|
||||
@@ -340,14 +343,14 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
return
|
||||
}
|
||||
|
||||
spinner := display.ShowSpinner("Checking for available updates")
|
||||
statusSpinner := statusspinner.NewStatusSpinner(statusspinner.WithMessage("Checking for available updates"))
|
||||
reports := plugin.GetUpdateReport(state.InstallationID, runUpdatesFor)
|
||||
display.StopSpinner(spinner)
|
||||
statusSpinner.Done()
|
||||
|
||||
if len(reports) == 0 {
|
||||
// this happens if for some reason the update server could not be contacted,
|
||||
// in which case we get back an empty map
|
||||
utils.ShowError(fmt.Errorf("there was an issue contacting the update server. Please try later"))
|
||||
utils.ShowError(ctx, fmt.Errorf("there was an issue contacting the update server. Please try later"))
|
||||
exitCode = 3
|
||||
return
|
||||
}
|
||||
@@ -363,9 +366,9 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
continue
|
||||
}
|
||||
|
||||
spinner := display.ShowSpinner(fmt.Sprintf("Updating plugin %s...", report.CheckResponse.Name))
|
||||
statusSpinner.SetStatus(fmt.Sprintf("Updating plugin %s...", report.CheckResponse.Name))
|
||||
image, err := plugin.Install(cmd.Context(), report.Plugin.Name)
|
||||
display.StopSpinner(spinner)
|
||||
statusSpinner.Done()
|
||||
if err != nil {
|
||||
msg := ""
|
||||
if strings.HasSuffix(err.Error(), "not found") {
|
||||
@@ -398,7 +401,7 @@ func runPluginUpdateCmd(cmd *cobra.Command, args []string) {
|
||||
})
|
||||
}
|
||||
|
||||
refreshConnectionsIfNecessary(cmd.Context(), updateReports, true)
|
||||
refreshConnectionsIfNecessary(cmd.Context(), updateReports, false)
|
||||
display.PrintInstallReports(updateReports, true)
|
||||
|
||||
// a concluding blank line - since we always output multiple lines
|
||||
@@ -421,7 +424,7 @@ func resolveUpdatePluginsFromArgs(args []string) ([]string, error) {
|
||||
}
|
||||
|
||||
// start service if necessary and refresh connections
|
||||
func refreshConnectionsIfNecessary(ctx context.Context, reports []display.InstallReport, isUpdate bool) error {
|
||||
func refreshConnectionsIfNecessary(ctx context.Context, reports []display.InstallReport, shouldReload bool) error {
|
||||
// get count of skipped reports
|
||||
skipped := 0
|
||||
for _, report := range reports {
|
||||
@@ -436,7 +439,7 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal
|
||||
}
|
||||
|
||||
// reload the config, since an installation MUST have created a new config file
|
||||
if !isUpdate {
|
||||
if shouldReload {
|
||||
var cmd = viper.Get(constants.ConfigKeyActiveCommand).(*cobra.Command)
|
||||
config, err := steampipeconfig.LoadSteampipeConfig(viper.GetString(constants.ArgWorkspaceChDir), cmd.Name())
|
||||
if err != nil {
|
||||
@@ -449,7 +452,7 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
defer client.Close(ctx)
|
||||
res := client.RefreshConnectionAndSearchPaths(ctx)
|
||||
if res.Error != nil {
|
||||
return res.Error
|
||||
@@ -460,24 +463,26 @@ func refreshConnectionsIfNecessary(ctx context.Context, reports []display.Instal
|
||||
}
|
||||
|
||||
func runPluginListCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runPluginListCmd list")
|
||||
defer func() {
|
||||
utils.LogTime("runPluginListCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
|
||||
pluginConnectionMap, err := getPluginConnectionMap(cmd.Context())
|
||||
if err != nil {
|
||||
utils.ShowErrorWithMessage(err, "Plugin Listing failed")
|
||||
utils.ShowErrorWithMessage(ctx, err, "Plugin Listing failed")
|
||||
exitCode = 4
|
||||
return
|
||||
}
|
||||
|
||||
list, err := plugin.List(pluginConnectionMap)
|
||||
if err != nil {
|
||||
utils.ShowErrorWithMessage(err, "Plugin Listing failed")
|
||||
utils.ShowErrorWithMessage(ctx, err, "Plugin Listing failed")
|
||||
exitCode = 4
|
||||
}
|
||||
headers := []string{"Name", "Version", "Connections"}
|
||||
@@ -489,35 +494,37 @@ func runPluginListCmd(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
func runPluginUninstallCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runPluginUninstallCmd uninstall")
|
||||
|
||||
defer func() {
|
||||
utils.LogTime("runPluginUninstallCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
exitCode = 1
|
||||
}
|
||||
}()
|
||||
|
||||
if len(args) == 0 {
|
||||
fmt.Println()
|
||||
utils.ShowError(fmt.Errorf("you need to provide at least one plugin to uninstall"))
|
||||
utils.ShowError(ctx, fmt.Errorf("you need to provide at least one plugin to uninstall"))
|
||||
fmt.Println()
|
||||
cmd.Help()
|
||||
fmt.Println()
|
||||
exitCode = 2
|
||||
return
|
||||
}
|
||||
connectionMap, err := getPluginConnectionMap(cmd.Context())
|
||||
|
||||
connectionMap, err := getPluginConnectionMap(ctx)
|
||||
if err != nil {
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
exitCode = 4
|
||||
return
|
||||
}
|
||||
|
||||
for _, p := range args {
|
||||
if err := plugin.Remove(p, connectionMap); err != nil {
|
||||
utils.ShowErrorWithMessage(err, fmt.Sprintf("Failed to uninstall plugin '%s'", p))
|
||||
if err := plugin.Remove(ctx, p, connectionMap); err != nil {
|
||||
utils.ShowErrorWithMessage(ctx, err, fmt.Sprintf("Failed to uninstall plugin '%s'", p))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -528,7 +535,7 @@ func getPluginConnectionMap(ctx context.Context) (map[string][]modconfig.Connect
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer client.Close()
|
||||
defer client.Close(ctx)
|
||||
res := client.RefreshConnectionAndSearchPaths(ctx)
|
||||
if res.Error != nil {
|
||||
return nil, res.Error
|
||||
|
||||
@@ -33,6 +33,7 @@ func pluginManagerCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runPluginManagerCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
logger := createPluginManagerLog()
|
||||
|
||||
log.Printf("[INFO] starting plugin manager")
|
||||
@@ -51,7 +52,7 @@ func runPluginManagerCmd(cmd *cobra.Command, args []string) {
|
||||
connectionWatcher, err := connectionwatcher.NewConnectionWatcher(pluginManager.SetConnectionConfigMap)
|
||||
if err != nil {
|
||||
log.Printf("[WARN] failed to create connection watcher: %s", err.Error())
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
|
||||
44
cmd/query.go
44
cmd/query.go
@@ -6,10 +6,8 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
@@ -18,6 +16,7 @@ import (
|
||||
"github.com/turbot/steampipe/interactive"
|
||||
"github.com/turbot/steampipe/query"
|
||||
"github.com/turbot/steampipe/query/queryexecute"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/steampipeconfig/modconfig"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
"github.com/turbot/steampipe/workspace"
|
||||
@@ -80,11 +79,12 @@ Examples:
|
||||
}
|
||||
|
||||
func runQueryCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("cmd.runQueryCmd start")
|
||||
defer func() {
|
||||
utils.LogTime("cmd.runQueryCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -97,19 +97,11 @@ func runQueryCmd(cmd *cobra.Command, args []string) {
|
||||
|
||||
// enable spinner only in interactive mode
|
||||
interactiveMode := len(args) == 0
|
||||
cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, interactiveMode)
|
||||
// set config to indicate whether we are running an interactive query
|
||||
viper.Set(constants.ConfigKeyInteractive, interactiveMode)
|
||||
|
||||
ctx := cmd.Context()
|
||||
if !interactiveMode {
|
||||
c, cancel := context.WithCancel(ctx)
|
||||
startCancelHandler(cancel)
|
||||
ctx = c
|
||||
}
|
||||
|
||||
// load the workspace
|
||||
w, err := loadWorkspacePromptingForVariables(ctx, nil)
|
||||
w, err := loadWorkspacePromptingForVariables(ctx)
|
||||
utils.FailOnErrorWithMessage(err, "failed to load workspace")
|
||||
|
||||
// se we have loaded a workspace - be sure to close it
|
||||
@@ -119,7 +111,7 @@ func runQueryCmd(cmd *cobra.Command, args []string) {
|
||||
initData := query.NewInitData(ctx, w, args)
|
||||
|
||||
if interactiveMode {
|
||||
queryexecute.RunInteractiveSession(initData)
|
||||
queryexecute.RunInteractiveSession(ctx, initData)
|
||||
} else {
|
||||
// set global exit code
|
||||
exitCode = queryexecute.RunBatchSession(ctx, initData)
|
||||
@@ -144,10 +136,10 @@ func getPipedStdinData() string {
|
||||
return stdinData
|
||||
}
|
||||
|
||||
func loadWorkspacePromptingForVariables(ctx context.Context, spinner *spinner.Spinner) (*workspace.Workspace, error) {
|
||||
func loadWorkspacePromptingForVariables(ctx context.Context) (*workspace.Workspace, error) {
|
||||
workspacePath := viper.GetString(constants.ArgWorkspaceChDir)
|
||||
|
||||
w, err := workspace.Load(workspacePath)
|
||||
w, err := workspace.Load(ctx, workspacePath)
|
||||
if err == nil {
|
||||
return w, nil
|
||||
}
|
||||
@@ -156,29 +148,13 @@ func loadWorkspacePromptingForVariables(ctx context.Context, spinner *spinner.Sp
|
||||
if !ok {
|
||||
return nil, err
|
||||
}
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
// so we have missing variables - prompt for them
|
||||
// first hide spinner if it is there
|
||||
statushooks.Done(ctx)
|
||||
if err := interactive.PromptForMissingVariables(ctx, missingVariablesError.MissingVariables); err != nil {
|
||||
log.Printf("[TRACE] Interactive variables prompting returned error %v", err)
|
||||
return nil, err
|
||||
}
|
||||
if spinner != nil {
|
||||
spinner.Start()
|
||||
}
|
||||
// ok we should have all variables now - reload workspace
|
||||
return workspace.Load(workspacePath)
|
||||
}
|
||||
|
||||
func startCancelHandler(cancel context.CancelFunc) {
|
||||
sigIntChannel := make(chan os.Signal, 1)
|
||||
signal.Notify(sigIntChannel, os.Interrupt)
|
||||
go func() {
|
||||
<-sigIntChannel
|
||||
log.Println("[TRACE] got SIGINT")
|
||||
// call context cancellation function
|
||||
cancel()
|
||||
// leave the channel open - any subsequent interrupts hits will be ignored
|
||||
}()
|
||||
return workspace.Load(ctx, workspacePath)
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
"github.com/turbot/steampipe-plugin-sdk/logging"
|
||||
"github.com/turbot/steampipe/cmdconfig"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/contexthelpers"
|
||||
"github.com/turbot/steampipe/report/reportserver"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
@@ -29,18 +29,17 @@ func reportCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runReportCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
logging.LogTime("runReportCmd start")
|
||||
defer func() {
|
||||
logging.LogTime("runReportCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
}
|
||||
}()
|
||||
|
||||
cmdconfig.Viper().Set(constants.ConfigKeyShowInteractiveOutput, false)
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
startCancelHandler(cancel)
|
||||
contexthelpers.StartCancelHandler(cancel)
|
||||
|
||||
// start db if necessary
|
||||
//err := db_local.EnsureDbAndStartService(constants.InvokerReport, true)
|
||||
@@ -53,7 +52,7 @@ func runReportCmd(cmd *cobra.Command, args []string) {
|
||||
utils.FailOnError(err)
|
||||
}
|
||||
|
||||
defer server.Shutdown()
|
||||
defer server.Shutdown(ctx)
|
||||
|
||||
server.Start()
|
||||
}
|
||||
|
||||
19
cmd/root.go
19
cmd/root.go
@@ -1,10 +1,14 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/statusspinner"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/mattn/go-isatty"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -183,6 +187,19 @@ func Execute() int {
|
||||
utils.LogTime("cmd.root.Execute start")
|
||||
defer utils.LogTime("cmd.root.Execute end")
|
||||
|
||||
rootCmd.Execute()
|
||||
ctx := createRootContext()
|
||||
rootCmd.ExecuteContext(ctx)
|
||||
return exitCode
|
||||
}
|
||||
|
||||
// create the root context - create a status renderer and set as value
|
||||
func createRootContext() context.Context {
|
||||
var statusRenderer statushooks.StatusHooks = statushooks.NullHooks
|
||||
// if the client is a TTY, inject a status spinner
|
||||
if isatty.IsTerminal(os.Stdout.Fd()) {
|
||||
statusRenderer = statusspinner.NewStatusSpinner()
|
||||
}
|
||||
|
||||
ctx := statushooks.AddStatusHooksToContext(context.Background(), statusRenderer)
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
@@ -127,11 +129,12 @@ func serviceRestartCmd() *cobra.Command {
|
||||
}
|
||||
|
||||
func runServiceStartCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runServiceStartCmd start")
|
||||
defer func() {
|
||||
utils.LogTime("runServiceStartCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
if exitCode == 0 {
|
||||
// there was an error and the exitcode
|
||||
// was not set to a non-zero value.
|
||||
@@ -163,7 +166,7 @@ func runServiceStartCmd(cmd *cobra.Command, args []string) {
|
||||
utils.FailOnError(startResult.Error)
|
||||
|
||||
if startResult.Status == db_local.ServiceFailedToStart {
|
||||
utils.ShowError(fmt.Errorf("steampipe service failed to start"))
|
||||
utils.ShowError(ctx, fmt.Errorf("steampipe service failed to start"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -191,18 +194,18 @@ func runServiceStartCmd(cmd *cobra.Command, args []string) {
|
||||
|
||||
err = db_local.RefreshConnectionAndSearchPaths(ctx, invoker)
|
||||
if err != nil {
|
||||
db_local.StopServices(false, constants.InvokerService, nil)
|
||||
db_local.StopServices(ctx, false, constants.InvokerService)
|
||||
utils.FailOnError(err)
|
||||
}
|
||||
|
||||
printStatus(startResult.DbState, startResult.PluginManagerState)
|
||||
printStatus(ctx, startResult.DbState, startResult.PluginManagerState)
|
||||
|
||||
if viper.GetBool(constants.ArgForeground) {
|
||||
runServiceInForeground(invoker)
|
||||
runServiceInForeground(ctx, invoker)
|
||||
}
|
||||
}
|
||||
|
||||
func runServiceInForeground(invoker constants.Invoker) {
|
||||
func runServiceInForeground(ctx context.Context, invoker constants.Invoker) {
|
||||
fmt.Println("Hit Ctrl+C to stop the service")
|
||||
|
||||
sigIntChannel := make(chan os.Signal, 1)
|
||||
@@ -232,7 +235,7 @@ func runServiceInForeground(invoker constants.Invoker) {
|
||||
count, err := db_local.GetCountOfThirdPartyClients(context.Background())
|
||||
if err != nil {
|
||||
// report the error in the off chance that there's one
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -246,7 +249,7 @@ func runServiceInForeground(invoker constants.Invoker) {
|
||||
}
|
||||
fmt.Println("Stopping Steampipe service.")
|
||||
|
||||
db_local.StopServices(false, invoker, nil)
|
||||
db_local.StopServices(ctx, false, invoker)
|
||||
fmt.Println("Steampipe service stopped.")
|
||||
return
|
||||
}
|
||||
@@ -254,11 +257,12 @@ func runServiceInForeground(invoker constants.Invoker) {
|
||||
}
|
||||
|
||||
func runServiceRestartCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runServiceRestartCmd start")
|
||||
defer func() {
|
||||
utils.LogTime("runServiceRestartCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
if exitCode == 0 {
|
||||
// there was an error and the exitcode
|
||||
// was not set to a non-zero value.
|
||||
@@ -277,7 +281,7 @@ func runServiceRestartCmd(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
// stop db
|
||||
stopStatus, err := db_local.StopServices(viper.GetBool(constants.ArgForce), constants.InvokerService, nil)
|
||||
stopStatus, err := db_local.StopServices(ctx, viper.GetBool(constants.ArgForce), constants.InvokerService)
|
||||
utils.FailOnErrorWithMessage(err, "could not stop current instance")
|
||||
if stopStatus != db_local.ServiceStopped {
|
||||
fmt.Println(`
|
||||
@@ -307,16 +311,17 @@ to force a restart.
|
||||
utils.FailOnError(err)
|
||||
fmt.Println("Steampipe service restarted.")
|
||||
|
||||
printStatus(startResult.DbState, startResult.PluginManagerState)
|
||||
printStatus(ctx, startResult.DbState, startResult.PluginManagerState)
|
||||
|
||||
}
|
||||
|
||||
func runServiceStatusCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runServiceStatusCmd status")
|
||||
defer func() {
|
||||
utils.LogTime("runServiceStatusCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -331,10 +336,10 @@ func runServiceStatusCmd(cmd *cobra.Command, args []string) {
|
||||
pmState, pmStateErr := pluginmanager.LoadPluginManagerState()
|
||||
|
||||
if dbStateErr != nil || pmStateErr != nil {
|
||||
utils.ShowError(composeStateError(dbStateErr, pmStateErr))
|
||||
utils.ShowError(ctx, composeStateError(dbStateErr, pmStateErr))
|
||||
return
|
||||
}
|
||||
printStatus(dbState, pmState)
|
||||
printStatus(ctx, dbState, pmState)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -354,19 +359,17 @@ func composeStateError(dbStateErr error, pmStateErr error) error {
|
||||
}
|
||||
|
||||
func runServiceStopCmd(cmd *cobra.Command, args []string) {
|
||||
ctx := cmd.Context()
|
||||
utils.LogTime("runServiceStopCmd stop")
|
||||
|
||||
stoppedChan := make(chan bool, 1)
|
||||
var status db_local.StopStatus
|
||||
var err error
|
||||
var dbState *db_local.RunningDBInstanceInfo
|
||||
|
||||
spinner := display.StartSpinnerAfterDelay("", constants.SpinnerShowTimeout, stoppedChan)
|
||||
|
||||
defer func() {
|
||||
utils.LogTime("runServiceStopCmd end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
if exitCode == 0 {
|
||||
// there was an error and the exitcode
|
||||
// was not set to a non-zero value.
|
||||
@@ -378,20 +381,17 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) {
|
||||
|
||||
force := cmdconfig.Viper().GetBool(constants.ArgForce)
|
||||
if force {
|
||||
status, err = db_local.StopServices(force, constants.InvokerService, spinner)
|
||||
status, err = db_local.StopServices(ctx, force, constants.InvokerService)
|
||||
} else {
|
||||
dbState, err = db_local.GetState()
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
utils.FailOnErrorWithMessage(err, "could not stop Steampipe service")
|
||||
}
|
||||
if dbState == nil {
|
||||
display.StopSpinner(spinner)
|
||||
fmt.Println("Steampipe service is not running.")
|
||||
return
|
||||
}
|
||||
if dbState.Invoker != constants.InvokerService {
|
||||
display.StopSpinner(spinner)
|
||||
printRunningImplicit(dbState.Invoker)
|
||||
return
|
||||
}
|
||||
@@ -399,27 +399,22 @@ func runServiceStopCmd(cmd *cobra.Command, args []string) {
|
||||
// check if there are any connected clients to the service
|
||||
connectedClientCount, err := db_local.GetCountOfThirdPartyClients(cmd.Context())
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
utils.FailOnErrorWithMessage(err, "error during service stop")
|
||||
}
|
||||
|
||||
if connectedClientCount > 0 {
|
||||
display.StopSpinner(spinner)
|
||||
printClientsConnected()
|
||||
return
|
||||
}
|
||||
|
||||
status, _ = db_local.StopServices(false, constants.InvokerService, spinner)
|
||||
status, _ = db_local.StopServices(ctx, false, constants.InvokerService)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
display.StopSpinner(spinner)
|
||||
|
||||
switch status {
|
||||
case db_local.ServiceStopped:
|
||||
if dbState != nil {
|
||||
@@ -451,12 +446,9 @@ func showAllStatus(ctx context.Context) {
|
||||
var processes []*psutils.Process
|
||||
var err error
|
||||
|
||||
doneFetchingDetailsChan := make(chan bool)
|
||||
sp := display.StartSpinnerAfterDelay("Getting details", constants.SpinnerShowTimeout, doneFetchingDetailsChan)
|
||||
|
||||
statushooks.SetStatus(ctx, "Getting details")
|
||||
processes, err = db_local.FindAllSteampipePostgresInstances(ctx)
|
||||
close(doneFetchingDetailsChan)
|
||||
display.StopSpinner(sp)
|
||||
statushooks.Done(ctx)
|
||||
|
||||
utils.FailOnError(err)
|
||||
|
||||
@@ -498,7 +490,7 @@ func getServiceProcessDetails(process *psutils.Process) (string, string, string,
|
||||
return fmt.Sprintf("%d", process.Pid), installDir, port, listenType
|
||||
}
|
||||
|
||||
func printStatus(dbState *db_local.RunningDBInstanceInfo, pmState *pluginmanager.PluginManagerState) {
|
||||
func printStatus(ctx context.Context, dbState *db_local.RunningDBInstanceInfo, pmState *pluginmanager.PluginManagerState) {
|
||||
if dbState == nil && !pmState.Running {
|
||||
fmt.Println("Service is not running")
|
||||
return
|
||||
@@ -553,7 +545,7 @@ To keep the service running after the %s session completes, use %s.
|
||||
// the service is running, but the plugin_manager is not running and there's no state file
|
||||
// meaning that it cannot be restarted by the FDW
|
||||
// it's an ERROR
|
||||
utils.ShowError(fmt.Errorf(`
|
||||
utils.ShowError(ctx, fmt.Errorf(`
|
||||
Service is running, but the Plugin Manager cannot be recovered.
|
||||
Please use %s to recover the service
|
||||
`,
|
||||
|
||||
@@ -86,7 +86,7 @@ func (w *ConnectionWatcher) handleFileWatcherEvent(e []fsnotify.Event) {
|
||||
if err != nil {
|
||||
log.Printf("[WARN] Error creating client to handle updated connection config: %s", err.Error())
|
||||
}
|
||||
defer client.Close()
|
||||
defer client.Close(ctx)
|
||||
|
||||
log.Printf("[TRACE] loaded updated config")
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package constants
|
||||
|
||||
// viper config keys
|
||||
const (
|
||||
ConfigKeyShowInteractiveOutput = "show-interactive-output"
|
||||
// ConfigKeyDatabaseSearchPath is used to store the search path set in the database config in viper
|
||||
// the viper value will be set via via a call to getScopedKey in steampipeconfig/steampipeconfig.go
|
||||
ConfigKeyDatabaseSearchPath = "database.search-path"
|
||||
|
||||
21
contexthelpers/cancel.go
Normal file
21
contexthelpers/cancel.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package contexthelpers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
)
|
||||
|
||||
func StartCancelHandler(cancel context.CancelFunc) chan os.Signal {
|
||||
sigIntChannel := make(chan os.Signal, 1)
|
||||
signal.Notify(sigIntChannel, os.Interrupt)
|
||||
go func() {
|
||||
<-sigIntChannel
|
||||
log.Println("[TRACE] got SIGINT")
|
||||
// call context cancellation function
|
||||
cancel()
|
||||
// leave the channel open - any subsequent interrupts hits will be ignored
|
||||
}()
|
||||
return sigIntChannel
|
||||
}
|
||||
8
contexthelpers/context_key.go
Normal file
8
contexthelpers/context_key.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package contexthelpers
|
||||
|
||||
//https://medium.com/@matryer/context-keys-in-go-5312346a868d
|
||||
type ContextKey string
|
||||
|
||||
func (c ContextKey) String() string {
|
||||
return "steampipe context key " + string(c)
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/query/queryresult"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/steampipeconfig/modconfig"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
@@ -92,8 +93,8 @@ func NewControlRun(control *modconfig.Control, group *ResultGroup, executionTree
|
||||
return res
|
||||
}
|
||||
|
||||
func (r *ControlRun) skip() {
|
||||
r.setRunStatus(ControlRunComplete)
|
||||
func (r *ControlRun) skip(ctx context.Context) {
|
||||
r.setRunStatus(ctx, ControlRunComplete)
|
||||
}
|
||||
|
||||
// set search path for this control run
|
||||
@@ -196,14 +197,14 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
|
||||
r.runStatus = ControlRunStarted
|
||||
|
||||
// update the current running control in the Progress renderer
|
||||
r.executionTree.progress.OnControlStart(control)
|
||||
defer r.executionTree.progress.OnControlFinish()
|
||||
r.executionTree.progress.OnControlStart(ctx, control)
|
||||
defer r.executionTree.progress.OnControlFinish(ctx)
|
||||
|
||||
// resolve the control query
|
||||
r.Lifecycle.Add("query_resolution_start")
|
||||
query, err := r.resolveControlQuery(control)
|
||||
if err != nil {
|
||||
r.SetError(err)
|
||||
r.SetError(ctx, err)
|
||||
return
|
||||
}
|
||||
r.Lifecycle.Add("query_resolution_finish")
|
||||
@@ -211,7 +212,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
|
||||
log.Printf("[TRACE] setting search path %s\n", control.Name())
|
||||
r.Lifecycle.Add("set_search_path_start")
|
||||
if err := r.setSearchPath(ctx, dbSession, client); err != nil {
|
||||
r.SetError(err)
|
||||
r.SetError(ctx, err)
|
||||
return
|
||||
}
|
||||
r.Lifecycle.Add("set_search_path_finish")
|
||||
@@ -225,7 +226,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
|
||||
// NOTE no need to pass an OnComplete callback - we are already closing our session after waiting for results
|
||||
log.Printf("[TRACE] execute start for, %s\n", control.Name())
|
||||
r.Lifecycle.Add("query_start")
|
||||
queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, query, nil, false)
|
||||
queryResult, err := client.ExecuteInSession(controlExecutionCtx, dbSession, query, nil)
|
||||
r.Lifecycle.Add("query_finish")
|
||||
log.Printf("[TRACE] execute finish for, %s\n", control.Name())
|
||||
|
||||
@@ -243,7 +244,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
|
||||
log.Printf("[TRACE] control %s query failed again with plugin connectivity error %s - NOT retrying...", r.Control.Name(), err)
|
||||
}
|
||||
}
|
||||
r.SetError(err)
|
||||
r.SetError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -255,7 +256,7 @@ func (r *ControlRun) execute(ctx context.Context, client db_common.Client) {
|
||||
log.Printf("[TRACE] finish result for, %s\n", control.Name())
|
||||
}
|
||||
|
||||
func (r *ControlRun) SetError(err error) {
|
||||
func (r *ControlRun) SetError(ctx context.Context, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
@@ -263,18 +264,23 @@ func (r *ControlRun) SetError(err error) {
|
||||
|
||||
// update error count
|
||||
r.Summary.Error++
|
||||
r.setRunStatus(ControlRunError)
|
||||
r.setRunStatus(ctx, ControlRunError)
|
||||
}
|
||||
|
||||
func (r *ControlRun) GetError() error {
|
||||
return r.runError
|
||||
}
|
||||
|
||||
// create a context with a deadline, and with status updates disabled (we do not want to show 'loading' results)
|
||||
func (r *ControlRun) getControlQueryContext(ctx context.Context) context.Context {
|
||||
// create a context with a deadline
|
||||
shouldBeDoneBy := time.Now().Add(controlQueryTimeout)
|
||||
// we don't use this cancel fn because, pgx prematurely cancels the PG connection when this cancel gets called in 'defer'
|
||||
newCtx, _ := context.WithDeadline(ctx, shouldBeDoneBy)
|
||||
|
||||
// disable the status spinner to hide 'loading' results)
|
||||
newCtx = statushooks.DisableStatusHooks(newCtx)
|
||||
|
||||
return newCtx
|
||||
}
|
||||
|
||||
@@ -304,20 +310,20 @@ func (r *ControlRun) waitForResults(ctx context.Context) {
|
||||
// create a channel to which will be closed when gathering has been done
|
||||
gatherDoneChan := make(chan string)
|
||||
go func() {
|
||||
r.gatherResults()
|
||||
r.gatherResults(ctx)
|
||||
close(gatherDoneChan)
|
||||
}()
|
||||
|
||||
select {
|
||||
// check for cancellation
|
||||
case <-ctx.Done():
|
||||
r.SetError(ctx.Err())
|
||||
r.SetError(ctx, ctx.Err())
|
||||
case <-gatherDoneChan:
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ControlRun) gatherResults() {
|
||||
func (r *ControlRun) gatherResults(ctx context.Context) {
|
||||
r.Lifecycle.Add("gather_start")
|
||||
defer func() { r.Lifecycle.Add("gather_finish") }()
|
||||
for {
|
||||
@@ -326,14 +332,14 @@ func (r *ControlRun) gatherResults() {
|
||||
// nil row means control run is complete
|
||||
if row == nil {
|
||||
// nil row means we are done
|
||||
r.setRunStatus(ControlRunComplete)
|
||||
r.setRunStatus(ctx, ControlRunComplete)
|
||||
r.createdOrderedResultRows()
|
||||
return
|
||||
}
|
||||
// if the row is in error then we terminate the run
|
||||
if row.Error != nil {
|
||||
// set error status and summary
|
||||
r.SetError(row.Error)
|
||||
r.SetError(ctx, row.Error)
|
||||
// update the result group status with our status - this will be passed all the way up the execution tree
|
||||
r.group.updateSummary(r.Summary)
|
||||
return
|
||||
@@ -342,7 +348,7 @@ func (r *ControlRun) gatherResults() {
|
||||
// so all is ok - create another result row
|
||||
result, err := NewResultRow(r.Control, row, r.queryResult.ColTypes)
|
||||
if err != nil {
|
||||
r.SetError(err)
|
||||
r.SetError(ctx, err)
|
||||
return
|
||||
}
|
||||
r.addResultRow(result)
|
||||
@@ -380,7 +386,7 @@ func (r *ControlRun) createdOrderedResultRows() {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ControlRun) setRunStatus(status ControlRunStatus) {
|
||||
func (r *ControlRun) setRunStatus(ctx context.Context, status ControlRunStatus) {
|
||||
r.stateLock.Lock()
|
||||
r.runStatus = status
|
||||
r.stateLock.Unlock()
|
||||
@@ -388,12 +394,11 @@ func (r *ControlRun) setRunStatus(status ControlRunStatus) {
|
||||
if r.Finished() {
|
||||
// update Progress
|
||||
if status == ControlRunError {
|
||||
r.executionTree.progress.OnControlError()
|
||||
r.executionTree.progress.OnControlError(ctx)
|
||||
} else {
|
||||
r.executionTree.progress.OnControlComplete()
|
||||
r.executionTree.progress.OnControlComplete(ctx)
|
||||
}
|
||||
|
||||
// TODO CANCEL QUERY IF NEEDED
|
||||
r.doneChan <- true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/query/queryresult"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/steampipeconfig/modconfig"
|
||||
"github.com/turbot/steampipe/workspace"
|
||||
"golang.org/x/sync/semaphore"
|
||||
@@ -40,9 +41,10 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien
|
||||
workspace: workspace,
|
||||
client: client,
|
||||
}
|
||||
// if a "--where" or "--tag" parameter was passed, build a map of control manes used to filter the controls to run
|
||||
// NOTE: not enabled yet
|
||||
err := executionTree.populateControlFilterMap(ctx)
|
||||
// if a "--where" or "--tag" parameter was passed, build a map of control names used to filter the controls to run
|
||||
// create a context with status hooks disabled
|
||||
noStatusCtx := statushooks.DisableStatusHooks(ctx)
|
||||
err := executionTree.populateControlFilterMap(noStatusCtx)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -55,7 +57,7 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien
|
||||
}
|
||||
|
||||
// build tree of result groups, starting with a synthetic 'root' node
|
||||
executionTree.Root = NewRootResultGroup(executionTree, rootItem)
|
||||
executionTree.Root = NewRootResultGroup(ctx, executionTree, rootItem)
|
||||
|
||||
// after tree has built, ControlCount will be set - create progress rendered
|
||||
executionTree.progress = NewControlProgressRenderer(len(executionTree.controlRuns))
|
||||
@@ -65,7 +67,7 @@ func NewExecutionTree(ctx context.Context, workspace *workspace.Workspace, clien
|
||||
|
||||
// AddControl checks whether control should be included in the tree
|
||||
// if so, creates a ControlRun, which is added to the parent group
|
||||
func (e *ExecutionTree) AddControl(control *modconfig.Control, group *ResultGroup) {
|
||||
func (e *ExecutionTree) AddControl(ctx context.Context, control *modconfig.Control, group *ResultGroup) {
|
||||
// note we use short name to determine whether to include a control
|
||||
if e.ShouldIncludeControl(control.ShortName) {
|
||||
// create new ControlRun with treeItem as the parent
|
||||
@@ -81,11 +83,11 @@ func (e *ExecutionTree) Execute(ctx context.Context, client db_common.Client) in
|
||||
log.Println("[TRACE]", "begin ExecutionTree.Execute")
|
||||
defer log.Println("[TRACE]", "end ExecutionTree.Execute")
|
||||
e.StartTime = time.Now()
|
||||
e.progress.Start()
|
||||
e.progress.Start(ctx)
|
||||
|
||||
defer func() {
|
||||
e.EndTime = time.Now()
|
||||
e.progress.Finish()
|
||||
e.progress.Finish(ctx)
|
||||
}()
|
||||
|
||||
// the number of goroutines parallel to start
|
||||
@@ -247,7 +249,7 @@ func (e *ExecutionTree) getControlMapFromWhereClause(ctx context.Context, whereC
|
||||
query = fmt.Sprintf("select resource_name from %s where %s", constants.IntrospectionTableControl, whereClause)
|
||||
}
|
||||
|
||||
res, err := e.client.ExecuteSync(ctx, query, false)
|
||||
res, err := e.client.ExecuteSync(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
package controlexecute
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
|
||||
"github.com/turbot/steampipe/steampipeconfig/modconfig"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
@@ -34,16 +35,16 @@ func NewControlProgressRenderer(total int) *ControlProgressRenderer {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ControlProgressRenderer) Start() {
|
||||
func (p *ControlProgressRenderer) Start(ctx context.Context) {
|
||||
p.updateLock.Lock()
|
||||
defer p.updateLock.Unlock()
|
||||
|
||||
if p.enabled {
|
||||
p.spinner = display.ShowSpinner("Starting controls...")
|
||||
statushooks.SetStatus(ctx, "Starting controls...")
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ControlProgressRenderer) OnControlStart(control *modconfig.Control) {
|
||||
func (p *ControlProgressRenderer) OnControlStart(ctx context.Context, control *modconfig.Control) {
|
||||
p.updateLock.Lock()
|
||||
defer p.updateLock.Unlock()
|
||||
|
||||
@@ -54,43 +55,43 @@ func (p *ControlProgressRenderer) OnControlStart(control *modconfig.Control) {
|
||||
p.pending--
|
||||
|
||||
if p.enabled {
|
||||
display.UpdateSpinnerMessage(p.spinner, p.message())
|
||||
statushooks.SetStatus(ctx, p.message())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ControlProgressRenderer) OnControlFinish() {
|
||||
func (p *ControlProgressRenderer) OnControlFinish(ctx context.Context) {
|
||||
p.updateLock.Lock()
|
||||
defer p.updateLock.Unlock()
|
||||
// decrement the parallel execution count
|
||||
p.executing--
|
||||
if p.enabled {
|
||||
display.UpdateSpinnerMessage(p.spinner, p.message())
|
||||
statushooks.SetStatus(ctx, p.message())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ControlProgressRenderer) OnControlComplete() {
|
||||
func (p *ControlProgressRenderer) OnControlComplete(ctx context.Context) {
|
||||
p.updateLock.Lock()
|
||||
defer p.updateLock.Unlock()
|
||||
p.complete++
|
||||
|
||||
if p.enabled {
|
||||
display.UpdateSpinnerMessage(p.spinner, p.message())
|
||||
statushooks.SetStatus(ctx, p.message())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ControlProgressRenderer) OnControlError() {
|
||||
func (p *ControlProgressRenderer) OnControlError(ctx context.Context) {
|
||||
p.updateLock.Lock()
|
||||
defer p.updateLock.Unlock()
|
||||
p.error++
|
||||
|
||||
if p.enabled {
|
||||
display.UpdateSpinnerMessage(p.spinner, p.message())
|
||||
statushooks.SetStatus(ctx, p.message())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ControlProgressRenderer) Finish() {
|
||||
func (p *ControlProgressRenderer) Finish(ctx context.Context) {
|
||||
if p.enabled {
|
||||
display.StopSpinner(p.spinner)
|
||||
statushooks.Done(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ func NewGroupSummary() *GroupSummary {
|
||||
}
|
||||
|
||||
// NewRootResultGroup creates a ResultGroup to act as the root node of a control execution tree
|
||||
func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.ModTreeItem) *ResultGroup {
|
||||
func NewRootResultGroup(ctx context.Context, executionTree *ExecutionTree, rootItems ...modconfig.ModTreeItem) *ResultGroup {
|
||||
root := &ResultGroup{
|
||||
GroupId: RootResultGroupName,
|
||||
Groups: []*ResultGroup{},
|
||||
@@ -62,10 +62,10 @@ func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.Mod
|
||||
// if root item is a benchmark, create new result group with root as parent
|
||||
if control, ok := item.(*modconfig.Control); ok {
|
||||
// if root item is a control, add control run
|
||||
executionTree.AddControl(control, root)
|
||||
executionTree.AddControl(ctx, control, root)
|
||||
} else {
|
||||
// create a result group for this item
|
||||
itemGroup := NewResultGroup(executionTree, item, root)
|
||||
itemGroup := NewResultGroup(ctx, executionTree, item, root)
|
||||
root.Groups = append(root.Groups, itemGroup)
|
||||
}
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func NewRootResultGroup(executionTree *ExecutionTree, rootItems ...modconfig.Mod
|
||||
}
|
||||
|
||||
// NewResultGroup creates a result group from a ModTreeItem
|
||||
func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem, parent *ResultGroup) *ResultGroup {
|
||||
func NewResultGroup(ctx context.Context, executionTree *ExecutionTree, treeItem modconfig.ModTreeItem, parent *ResultGroup) *ResultGroup {
|
||||
// only show qualified group names for controls from dependent mods
|
||||
groupId := treeItem.Name()
|
||||
if mod := treeItem.GetMod(); mod != nil && mod.Name() == executionTree.workspace.Mod.Name() {
|
||||
@@ -96,7 +96,7 @@ func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem
|
||||
for _, c := range treeItem.GetChildren() {
|
||||
if benchmark, ok := c.(*modconfig.Benchmark); ok {
|
||||
// create a result group for this item
|
||||
benchmarkGroup := NewResultGroup(executionTree, benchmark, group)
|
||||
benchmarkGroup := NewResultGroup(ctx, executionTree, benchmark, group)
|
||||
// if the group has any control runs, add to tree
|
||||
if benchmarkGroup.ControlRunCount() > 0 {
|
||||
// create a new result group with 'group' as the parent
|
||||
@@ -104,7 +104,7 @@ func NewResultGroup(executionTree *ExecutionTree, treeItem modconfig.ModTreeItem
|
||||
}
|
||||
}
|
||||
if control, ok := c.(*modconfig.Control); ok {
|
||||
executionTree.AddControl(control, group)
|
||||
executionTree.AddControl(ctx, control, group)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -180,18 +180,18 @@ func (r *ResultGroup) execute(ctx context.Context, client db_common.Client, para
|
||||
|
||||
for _, controlRun := range r.ControlRuns {
|
||||
if utils.IsContextCancelled(ctx) {
|
||||
controlRun.SetError(ctx.Err())
|
||||
controlRun.SetError(ctx, ctx.Err())
|
||||
continue
|
||||
}
|
||||
|
||||
if viper.GetBool(constants.ArgDryRun) {
|
||||
controlRun.skip()
|
||||
controlRun.skip(ctx)
|
||||
continue
|
||||
}
|
||||
|
||||
err := parallelismLock.Acquire(ctx, 1)
|
||||
if err != nil {
|
||||
controlRun.SetError(err)
|
||||
controlRun.SetError(ctx, err)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -199,7 +199,7 @@ func (r *ResultGroup) execute(ctx context.Context, client db_common.Client, para
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
// if the Execute panic'ed, set it as an error
|
||||
run.SetError(helpers.ToError(r))
|
||||
run.SetError(ctx, helpers.ToError(r))
|
||||
}
|
||||
// Release in defer, so that we don't retain the lock even if there's a panic inside
|
||||
parallelismLock.Release(1)
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/workspace"
|
||||
)
|
||||
|
||||
type InitData struct {
|
||||
Ctx context.Context
|
||||
Workspace *workspace.Workspace
|
||||
Client db_common.Client
|
||||
Result *db_common.InitResult
|
||||
|
||||
@@ -68,7 +68,7 @@ func NewDbClient(ctx context.Context, connectionString string) (*DbClient, error
|
||||
client.connectionString = connectionString
|
||||
|
||||
if err := client.LoadForeignSchemaNames(ctx); err != nil {
|
||||
client.Close()
|
||||
client.Close(ctx)
|
||||
return nil, err
|
||||
}
|
||||
return client, nil
|
||||
@@ -114,7 +114,7 @@ func (c *DbClient) SetEnsureSessionDataFunc(f db_common.EnsureSessionStateCallba
|
||||
|
||||
// Close implements Client
|
||||
// closes the connection to the database and shuts down the backend
|
||||
func (c *DbClient) Close() error {
|
||||
func (c *DbClient) Close(context.Context) error {
|
||||
log.Printf("[TRACE] DbClient.Close %v", c.dbClient)
|
||||
if c.dbClient != nil {
|
||||
c.sessionInitWaitGroup.Wait()
|
||||
|
||||
@@ -8,12 +8,9 @@ import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/turbot/steampipe/cmdconfig"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/query/queryresult"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
"golang.org/x/text/language"
|
||||
"golang.org/x/text/message"
|
||||
@@ -21,7 +18,7 @@ import (
|
||||
|
||||
// ExecuteSync implements Client
|
||||
// execute a query against this client and wait for the result
|
||||
func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
|
||||
func (c *DbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) {
|
||||
// acquire a session
|
||||
sessionResult := c.AcquireSession(ctx)
|
||||
if sessionResult.Error != nil {
|
||||
@@ -32,17 +29,17 @@ func (c *DbClient) ExecuteSync(ctx context.Context, query string, disableSpinner
|
||||
// and not in call-time
|
||||
sessionResult.Session.Close(utils.IsContextCancelled(ctx))
|
||||
}()
|
||||
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query, disableSpinner)
|
||||
return c.ExecuteSyncInSession(ctx, sessionResult.Session, query)
|
||||
}
|
||||
|
||||
// ExecuteSyncInSession implements Client
|
||||
// execute a query against this client and wait for the result
|
||||
func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
|
||||
func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) {
|
||||
if query == "" {
|
||||
return &queryresult.SyncQueryResult{}, nil
|
||||
}
|
||||
|
||||
result, err := c.ExecuteInSession(ctx, session, query, nil, disableSpinner)
|
||||
result, err := c.ExecuteInSession(ctx, session, query, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -61,7 +58,7 @@ func (c *DbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.
|
||||
// Execute implements Client
|
||||
// execute the query in the given Context
|
||||
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
|
||||
func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner bool) (*queryresult.Result, error) {
|
||||
func (c *DbClient) Execute(ctx context.Context, query string) (*queryresult.Result, error) {
|
||||
// acquire a session
|
||||
sessionResult := c.AcquireSession(ctx)
|
||||
if sessionResult.Error != nil {
|
||||
@@ -70,26 +67,29 @@ func (c *DbClient) Execute(ctx context.Context, query string, disableSpinner boo
|
||||
|
||||
// define callback to close session when the async execution is complete
|
||||
closeSessionCallback := func() { sessionResult.Session.Close(utils.IsContextCancelled(ctx)) }
|
||||
return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback, disableSpinner)
|
||||
return c.ExecuteInSession(ctx, sessionResult.Session, query, closeSessionCallback)
|
||||
}
|
||||
|
||||
// ExecuteInSession implements Client
|
||||
// execute the query in the given Context using the provided DatabaseSession
|
||||
// ExecuteInSession assumes no responsibility over the lifecycle of the DatabaseSession - that is the responsibility of the caller
|
||||
// NOTE: The returned Result MUST be fully read - otherwise the connection will block and will prevent further communication
|
||||
func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error) {
|
||||
func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) {
|
||||
if query == "" {
|
||||
return queryresult.NewQueryResult(nil), nil
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
// channel to flag to spinner that the query has run
|
||||
var spinner *spinner.Spinner
|
||||
var tx *sql.Tx
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// stop spinner in case of error
|
||||
display.StopSpinner(spinner)
|
||||
statushooks.Done(ctx)
|
||||
// error - rollback transaction if we have one
|
||||
if tx != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
// call the completion callback - if one was provided
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
@@ -97,11 +97,7 @@ func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.Data
|
||||
}
|
||||
}()
|
||||
|
||||
if !disableSpinner && cmdconfig.Viper().GetBool(constants.ConfigKeyShowInteractiveOutput) {
|
||||
// if `show-interactive-output` is false, the spinner gets created, but is never shown
|
||||
// so the s.Active() will always come back false . . .
|
||||
spinner = display.ShowSpinner("Loading results...")
|
||||
}
|
||||
statushooks.SetStatus(ctx, "Loading results...")
|
||||
|
||||
// start query
|
||||
var rows *sql.Rows
|
||||
@@ -122,7 +118,7 @@ func (c *DbClient) ExecuteInSession(ctx context.Context, session *db_common.Data
|
||||
// read the rows in a go routine
|
||||
go func() {
|
||||
// read in the rows and stream to the query result object
|
||||
c.readRows(ctx, startTime, rows, result, spinner)
|
||||
c.readRows(ctx, startTime, rows, result)
|
||||
if onComplete != nil {
|
||||
onComplete()
|
||||
}
|
||||
@@ -158,11 +154,11 @@ func (c *DbClient) startQuery(ctx context.Context, query string, conn *sql.Conn)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows, result *queryresult.Result, activeSpinner *spinner.Spinner) {
|
||||
func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows, result *queryresult.Result) {
|
||||
// defer this, so that these get cleaned up even if there is an unforeseen error
|
||||
defer func() {
|
||||
// we are done fetching results. time for display. remove the spinner
|
||||
display.StopSpinner(activeSpinner)
|
||||
// we are done fetching results. time for display. clear the status indication
|
||||
statushooks.Done(ctx)
|
||||
// close the sql rows object
|
||||
rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
@@ -191,7 +187,7 @@ func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows
|
||||
continueToNext := true
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
display.UpdateSpinnerMessage(activeSpinner, "Cancelling query")
|
||||
statushooks.SetStatus(ctx, "Cancelling query")
|
||||
continueToNext = false
|
||||
default:
|
||||
if rowResult, err := readRowContext(ctx, rows, cols, colTypes); err != nil {
|
||||
@@ -200,9 +196,9 @@ func (c *DbClient) readRows(ctx context.Context, start time.Time, rows *sql.Rows
|
||||
} else {
|
||||
result.StreamRow(rowResult)
|
||||
}
|
||||
// update the spinner message with the count of rows that have already been fetched
|
||||
// update the status message with the count of rows that have already been fetched
|
||||
// this will not show if the spinner is not active
|
||||
display.UpdateSpinnerMessage(activeSpinner, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount)))
|
||||
statushooks.SetStatus(ctx, fmt.Sprintf("Loading results: %3s", humanizeRowCount(rowCount)))
|
||||
rowCount++
|
||||
}
|
||||
if !continueToNext {
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
func (c *DbClient) GetCurrentSearchPath(ctx context.Context) ([]string, error) {
|
||||
var currentSearchPath []string
|
||||
var pathAsString string
|
||||
rows, err := c.ExecuteSync(ctx, "show search_path", true)
|
||||
rows, err := c.ExecuteSync(ctx, "show search_path")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
type EnsureSessionStateCallback = func(context.Context, *DatabaseSession) (err error, warnings []string)
|
||||
|
||||
type Client interface {
|
||||
Close() error
|
||||
Close(ctx context.Context) error
|
||||
|
||||
ForeignSchemas() []string
|
||||
ConnectionMap() *steampipeconfig.ConnectionDataMap
|
||||
@@ -22,11 +22,11 @@ type Client interface {
|
||||
|
||||
AcquireSession(context.Context) *AcquireSessionResult
|
||||
|
||||
ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error)
|
||||
Execute(ctx context.Context, query string, disableSpinner bool) (res *queryresult.Result, err error)
|
||||
ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error)
|
||||
Execute(ctx context.Context, query string) (res *queryresult.Result, err error)
|
||||
|
||||
ExecuteSyncInSession(ctx context.Context, session *DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error)
|
||||
ExecuteInSession(ctx context.Context, session *DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error)
|
||||
ExecuteSyncInSession(ctx context.Context, session *DatabaseSession, query string) (*queryresult.SyncQueryResult, error)
|
||||
ExecuteInSession(ctx context.Context, session *DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error)
|
||||
|
||||
CacheOn(context.Context) error
|
||||
CacheOff(context.Context) error
|
||||
|
||||
@@ -13,7 +13,7 @@ func ExecuteQuery(ctx context.Context, queryString string, client Client) (*quer
|
||||
defer utils.LogTime("db.ExecuteQuery end")
|
||||
|
||||
resultsStreamer := queryresult.NewResultStreamer()
|
||||
result, err := client.Execute(ctx, queryString, false)
|
||||
result, err := client.Execute(ctx, queryString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -19,25 +19,6 @@ import (
|
||||
// TagColumn is the tag used to specify the column name and type in the introspection tables
|
||||
const TagColumn = "column"
|
||||
|
||||
func UpdateIntrospectionTables(workspaceResources *modconfig.WorkspaceResourceMaps, client Client) error {
|
||||
utils.LogTime("db.UpdateIntrospectionTables start")
|
||||
defer utils.LogTime("db.UpdateIntrospectionTables end")
|
||||
|
||||
// get the create sql for each table type
|
||||
clearSql := getClearTablesSql()
|
||||
|
||||
// now get sql to populate the tables
|
||||
insertSql := getTableInsertSql(workspaceResources)
|
||||
|
||||
sql := []string{clearSql, insertSql}
|
||||
// execute the query, passing 'true' to disable the spinner
|
||||
_, err := client.ExecuteSync(context.Background(), strings.Join(sql, "\n"), true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update introspection tables: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateIntrospectionTables(ctx context.Context, workspaceResources *modconfig.WorkspaceResourceMaps, session *DatabaseSession) error {
|
||||
utils.LogTime("db.CreateIntrospectionTables start")
|
||||
defer utils.LogTime("db.CreateIntrospectionTables end")
|
||||
|
||||
@@ -10,15 +10,14 @@ import (
|
||||
"os/exec"
|
||||
"sync"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
psutils "github.com/shirou/gopsutil/process"
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/filepaths"
|
||||
"github.com/turbot/steampipe/ociinstaller"
|
||||
"github.com/turbot/steampipe/ociinstaller/versionfile"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
@@ -41,66 +40,57 @@ func EnsureDBInstalled(ctx context.Context) (err error) {
|
||||
close(doneChan)
|
||||
}()
|
||||
|
||||
spinner := display.StartSpinnerAfterDelay("", constants.SpinnerShowTimeout, doneChan)
|
||||
|
||||
if IsInstalled() {
|
||||
// check if the FDW need updating, and init the db id required
|
||||
err := prepareDb(ctx, spinner)
|
||||
display.StopSpinner(spinner)
|
||||
err := prepareDb(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Println("[TRACE] calling removeRunningInstanceInfo")
|
||||
err = removeRunningInstanceInfo()
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] removeRunningInstanceInfo failed: %v", err)
|
||||
return fmt.Errorf("Cleanup any Steampipe processes... FAILED!")
|
||||
}
|
||||
|
||||
log.Println("[TRACE] removing previous installation")
|
||||
display.UpdateSpinnerMessage(spinner, "Prepare database install location...")
|
||||
statushooks.SetStatus(ctx, "Prepare database install location...")
|
||||
defer statushooks.Done(ctx)
|
||||
|
||||
err = os.RemoveAll(getDatabaseLocation())
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] %v", err)
|
||||
return fmt.Errorf("Prepare database install location... FAILED!")
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Download & install embedded PostgreSQL database...")
|
||||
statushooks.SetStatus(ctx, "Download & install embedded PostgreSQL database...")
|
||||
_, err = ociinstaller.InstallDB(ctx, constants.DefaultEmbeddedPostgresImage, getDatabaseLocation())
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] %v", err)
|
||||
return fmt.Errorf("Download & install embedded PostgreSQL database... FAILED!")
|
||||
}
|
||||
|
||||
// installFDW takes care of the spinner, since it may need to run independently
|
||||
_, err = installFDW(ctx, true, spinner)
|
||||
_, err = installFDW(ctx, true)
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] installFDW failed: %v", err)
|
||||
return fmt.Errorf("Download & install steampipe-postgres-fdw... FAILED!")
|
||||
}
|
||||
|
||||
// run the database installation
|
||||
err = runInstall(ctx, true, spinner)
|
||||
err = runInstall(ctx, true)
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
return err
|
||||
}
|
||||
|
||||
// write a signature after everything gets done!
|
||||
// so that we can check for this later on
|
||||
display.UpdateSpinnerMessage(spinner, "Updating install records...")
|
||||
statushooks.SetStatus(ctx, "Updating install records...")
|
||||
err = updateDownloadedBinarySignature()
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] updateDownloadedBinarySignature failed: %v", err)
|
||||
return fmt.Errorf("Updating install records... FAILED!")
|
||||
}
|
||||
|
||||
display.StopSpinner(spinner)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,11 +127,10 @@ func IsInstalled() bool {
|
||||
}
|
||||
|
||||
// prepareDb updates the FDW if needed, and inits the database if required
|
||||
func prepareDb(ctx context.Context, spinner *spinner.Spinner) error {
|
||||
func prepareDb(ctx context.Context) error {
|
||||
// check if FDW needs to be updated
|
||||
if fdwNeedsUpdate() {
|
||||
_, err := installFDW(ctx, false, spinner)
|
||||
spinner.Stop()
|
||||
_, err := installFDW(ctx, false)
|
||||
if err != nil {
|
||||
log.Printf("[TRACE] installFDW failed: %v", err)
|
||||
return fmt.Errorf("Update steampipe-postgres-fdw... FAILED!")
|
||||
@@ -153,10 +142,9 @@ func prepareDb(ctx context.Context, spinner *spinner.Spinner) error {
|
||||
}
|
||||
|
||||
if needsInit() {
|
||||
spinner.Start()
|
||||
display.UpdateSpinnerMessage(spinner, "Cleanup any Steampipe processes...")
|
||||
statushooks.SetStatus(ctx, "Cleanup any Steampipe processes...")
|
||||
killInstanceIfAny(ctx)
|
||||
if err := runInstall(ctx, false, spinner); err != nil {
|
||||
if err := runInstall(ctx, false); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -175,15 +163,15 @@ func fdwNeedsUpdate() bool {
|
||||
return versionInfo.FdwExtension.Version != constants.FdwVersion
|
||||
}
|
||||
|
||||
func installFDW(ctx context.Context, firstSetup bool, spinner *spinner.Spinner) (string, error) {
|
||||
func installFDW(ctx context.Context, firstSetup bool) (string, error) {
|
||||
utils.LogTime("db_local.installFDW start")
|
||||
defer utils.LogTime("db_local.installFDW end")
|
||||
|
||||
status, err := GetState()
|
||||
state, err := GetState()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if status != nil {
|
||||
if state != nil {
|
||||
defer func() {
|
||||
if !firstSetup {
|
||||
// update the signature
|
||||
@@ -191,7 +179,7 @@ func installFDW(ctx context.Context, firstSetup bool, spinner *spinner.Spinner)
|
||||
}
|
||||
}()
|
||||
}
|
||||
display.UpdateSpinnerMessage(spinner, fmt.Sprintf("Download & install %s...", constants.Bold("steampipe-postgres-fdw")))
|
||||
statushooks.SetStatus(ctx, fmt.Sprintf("Download & install %s...", constants.Bold("steampipe-postgres-fdw")))
|
||||
return ociinstaller.InstallFdw(ctx, constants.DefaultFdwImage, getDatabaseLocation())
|
||||
}
|
||||
|
||||
@@ -203,58 +191,54 @@ func needsInit() bool {
|
||||
return !helpers.FileExists(getPgHbaConfLocation())
|
||||
}
|
||||
|
||||
func runInstall(ctx context.Context, firstInstall bool, spinner *spinner.Spinner) error {
|
||||
func runInstall(ctx context.Context, firstInstall bool) error {
|
||||
utils.LogTime("db_local.runInstall start")
|
||||
defer utils.LogTime("db_local.runInstall end")
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Cleaning up...")
|
||||
statushooks.SetStatus(ctx, "Cleaning up...")
|
||||
defer statushooks.Done(ctx)
|
||||
|
||||
err := utils.RemoveDirectoryContents(getDataLocation())
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] %v", err)
|
||||
return fmt.Errorf("Prepare database install location... FAILED!")
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Initializing database...")
|
||||
statushooks.SetStatus(ctx, "Initializing database...")
|
||||
err = initDatabase()
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] initDatabase failed: %v", err)
|
||||
return fmt.Errorf("Initializing database... FAILED!")
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Starting database...")
|
||||
statushooks.SetStatus(ctx, "Starting database...")
|
||||
port, err := getNextFreePort()
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] getNextFreePort failed: %v", err)
|
||||
return fmt.Errorf("Starting database... FAILED!")
|
||||
}
|
||||
|
||||
process, err := startServiceForInstall(port)
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] startServiceForInstall failed: %v", err)
|
||||
return fmt.Errorf("Starting database... FAILED!")
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Connection to database...")
|
||||
statushooks.SetStatus(ctx, "Connection to database...")
|
||||
client, err := createMaintenanceClient(ctx, port)
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
return fmt.Errorf("Connection to database... FAILED!")
|
||||
}
|
||||
defer func() {
|
||||
display.UpdateSpinnerMessage(spinner, "Completing configuration")
|
||||
statushooks.SetStatus(ctx, "Completing configuration")
|
||||
client.Close()
|
||||
doThreeStepPostgresExit(process)
|
||||
doThreeStepPostgresExit(ctx, process)
|
||||
}()
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Generating database passwords...")
|
||||
statushooks.SetStatus(ctx, "Generating database passwords...")
|
||||
// generate a password file for use later
|
||||
_, err = readPasswordFile()
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] readPassword failed: %v", err)
|
||||
return fmt.Errorf("Generating database passwords... FAILED!")
|
||||
}
|
||||
@@ -274,23 +258,21 @@ func runInstall(ctx context.Context, firstInstall bool, spinner *spinner.Spinner
|
||||
return fmt.Errorf("Invalid database name '%s' - must start with either a lowercase character or an underscore", databaseName)
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Configuring database...")
|
||||
statushooks.SetStatus(ctx, "Configuring database...")
|
||||
err = installDatabaseWithPermissions(ctx, databaseName, client)
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] installSteampipeDatabaseAndUser failed: %v", err)
|
||||
return fmt.Errorf("Configuring database... FAILED!")
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Configuring Steampipe...")
|
||||
statushooks.SetStatus(ctx, "Configuring Steampipe...")
|
||||
err = installForeignServer(ctx, client)
|
||||
if err != nil {
|
||||
display.StopSpinner(spinner)
|
||||
log.Printf("[TRACE] installForeignServer failed: %v", err)
|
||||
return fmt.Errorf("Configuring Steampipe... FAILED!")
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveDatabaseName() string {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/query/queryresult"
|
||||
"github.com/turbot/steampipe/schema"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/steampipeconfig"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
@@ -41,13 +42,12 @@ func GetLocalClient(ctx context.Context, invoker constants.Invoker) (db_common.C
|
||||
|
||||
client, err := NewLocalClient(ctx, invoker)
|
||||
if err != nil {
|
||||
ShutdownService(invoker)
|
||||
ShutdownService(ctx, invoker)
|
||||
}
|
||||
return client, err
|
||||
}
|
||||
|
||||
// NewLocalClient ensures that the database instance is running
|
||||
// and returns a `Client` to interact with it
|
||||
// NewLocalClient verifies that the local database instance is running and returns a Client to interact with it
|
||||
func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbClient, error) {
|
||||
utils.LogTime("db.NewLocalClient start")
|
||||
defer utils.LogTime("db.NewLocalClient end")
|
||||
@@ -69,19 +69,17 @@ func NewLocalClient(ctx context.Context, invoker constants.Invoker) (*LocalDbCli
|
||||
|
||||
// Close implements Client
|
||||
// close the connection to the database and shuts down the backend
|
||||
func (c *LocalDbClient) Close() error {
|
||||
func (c *LocalDbClient) Close(ctx context.Context) error {
|
||||
log.Printf("[TRACE] close local client %p", c)
|
||||
if c.client != nil {
|
||||
log.Printf("[TRACE] local client not NIL")
|
||||
if err := c.client.Close(); err != nil {
|
||||
if err := c.client.Close(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[TRACE] local client close complete")
|
||||
}
|
||||
log.Printf("[TRACE] shutdown local service %v", c.invoker)
|
||||
// no context to pass on - use background
|
||||
// we shouldn't do this in a context that can be cancelled anyway
|
||||
ShutdownService(c.invoker)
|
||||
ShutdownService(ctx, c.invoker)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -108,23 +106,23 @@ func (c *LocalDbClient) AcquireSession(ctx context.Context) *db_common.AcquireSe
|
||||
}
|
||||
|
||||
// ExecuteSync implements Client
|
||||
func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
|
||||
return c.client.ExecuteSync(ctx, query, disableSpinner)
|
||||
func (c *LocalDbClient) ExecuteSync(ctx context.Context, query string) (*queryresult.SyncQueryResult, error) {
|
||||
return c.client.ExecuteSync(ctx, query)
|
||||
}
|
||||
|
||||
// ExecuteSyncInSession implements Client
|
||||
func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, disableSpinner bool) (*queryresult.SyncQueryResult, error) {
|
||||
return c.client.ExecuteSyncInSession(ctx, session, query, disableSpinner)
|
||||
func (c *LocalDbClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string) (*queryresult.SyncQueryResult, error) {
|
||||
return c.client.ExecuteSyncInSession(ctx, session, query)
|
||||
}
|
||||
|
||||
// ExecuteInSession implements Client
|
||||
func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func(), disableSpinner bool) (res *queryresult.Result, err error) {
|
||||
return c.client.ExecuteInSession(ctx, session, query, onComplete, disableSpinner)
|
||||
func (c *LocalDbClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, query string, onComplete func()) (res *queryresult.Result, err error) {
|
||||
return c.client.ExecuteInSession(ctx, session, query, onComplete)
|
||||
}
|
||||
|
||||
// Execute implements Client
|
||||
func (c *LocalDbClient) Execute(ctx context.Context, query string, disableSpinner bool) (res *queryresult.Result, err error) {
|
||||
return c.client.Execute(ctx, query, disableSpinner)
|
||||
func (c *LocalDbClient) Execute(ctx context.Context, query string) (res *queryresult.Result, err error) {
|
||||
return c.client.Execute(ctx, query)
|
||||
}
|
||||
|
||||
// CacheOn implements Client
|
||||
@@ -167,6 +165,9 @@ func (c *LocalDbClient) LoadForeignSchemaNames(ctx context.Context) error {
|
||||
// local only functions
|
||||
|
||||
func (c *LocalDbClient) RefreshConnectionAndSearchPaths(ctx context.Context) *steampipeconfig.RefreshConnectionResult {
|
||||
// NOTE: disable any status updates - we do not want 'loading' output from any queries
|
||||
ctx = statushooks.DisableStatusHooks(ctx)
|
||||
|
||||
res := c.refreshConnections(ctx)
|
||||
if res.Error != nil {
|
||||
return res
|
||||
@@ -221,7 +222,7 @@ func (c *LocalDbClient) setUserSearchPath(ctx context.Context) ([]string, error)
|
||||
|
||||
// get all roles which are a member of steampipe_users
|
||||
query := fmt.Sprintf(`select usename from pg_user where pg_has_role(usename, '%s', 'member')`, constants.DatabaseUsersRole)
|
||||
res, err := c.ExecuteSync(context.Background(), query, true)
|
||||
res, err := c.ExecuteSync(context.Background(), query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ func RefreshConnectionAndSearchPaths(ctx context.Context, invoker constants.Invo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
defer client.Close(ctx)
|
||||
refreshResult := client.RefreshConnectionAndSearchPaths(ctx)
|
||||
// display any initialisation warnings
|
||||
refreshResult.ShowWarnings()
|
||||
|
||||
@@ -136,7 +136,7 @@ func startDB(ctx context.Context, port int, listen StartListenType, invoker cons
|
||||
// if there was an error and we started the service, stop it again
|
||||
if res.Error != nil {
|
||||
if res.Status == ServiceStarted {
|
||||
StopServices(false, invoker, nil)
|
||||
StopServices(ctx, false, invoker)
|
||||
}
|
||||
// remove the state file if we are going back with an error
|
||||
removeRunningInstanceInfo()
|
||||
@@ -554,7 +554,7 @@ func killInstanceIfAny(ctx context.Context) bool {
|
||||
for _, process := range processes {
|
||||
wg.Add(1)
|
||||
go func(p *psutils.Process) {
|
||||
doThreeStepPostgresExit(p)
|
||||
doThreeStepPostgresExit(ctx, p)
|
||||
wg.Done()
|
||||
}(process)
|
||||
}
|
||||
|
||||
@@ -9,14 +9,12 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
psutils "github.com/shirou/gopsutil/process"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/constants/runtime"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/filepaths"
|
||||
"github.com/turbot/steampipe/pluginmanager"
|
||||
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
@@ -32,7 +30,7 @@ const (
|
||||
)
|
||||
|
||||
// ShutdownService stops the database instance if the given 'invoker' matches
|
||||
func ShutdownService(invoker constants.Invoker) {
|
||||
func ShutdownService(ctx context.Context, invoker constants.Invoker) {
|
||||
utils.LogTime("db_local.ShutdownService start")
|
||||
defer utils.LogTime("db_local.ShutdownService end")
|
||||
|
||||
@@ -54,18 +52,18 @@ func ShutdownService(invoker constants.Invoker) {
|
||||
}
|
||||
|
||||
// we can shut down the database
|
||||
stopStatus, err := StopServices(false, invoker, nil)
|
||||
stopStatus, err := StopServices(ctx, false, invoker)
|
||||
if err != nil {
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
}
|
||||
if stopStatus == ServiceStopped {
|
||||
return
|
||||
}
|
||||
|
||||
// shutdown failed - try to force stop
|
||||
_, err = StopServices(true, invoker, nil)
|
||||
_, err = StopServices(ctx, true, invoker)
|
||||
if err != nil {
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
}
|
||||
|
||||
}
|
||||
@@ -92,7 +90,8 @@ func GetCountOfThirdPartyClients(ctx context.Context) (i int, e error) {
|
||||
}
|
||||
|
||||
// StopServices searches for and stops the running instance. Does nothing if an instance was not found
|
||||
func StopServices(force bool, invoker constants.Invoker, spinner *spinner.Spinner) (status StopStatus, e error) {
|
||||
func StopServices(ctx context.Context, force bool, invoker constants.Invoker) (status StopStatus, e error) {
|
||||
|
||||
log.Printf("[TRACE] StopDB invoker %s, force %v", invoker, force)
|
||||
utils.LogTime("db_local.StopDB start")
|
||||
|
||||
@@ -108,15 +107,16 @@ func StopServices(force bool, invoker constants.Invoker, spinner *spinner.Spinne
|
||||
pluginManagerStopError := pluginmanager.Stop()
|
||||
|
||||
// stop the DB Service
|
||||
stopResult, dbStopError := stopDBService(spinner, force)
|
||||
stopResult, dbStopError := stopDBService(ctx, force)
|
||||
|
||||
return stopResult, utils.CombineErrors(dbStopError, pluginManagerStopError)
|
||||
}
|
||||
|
||||
func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) {
|
||||
func stopDBService(ctx context.Context, force bool) (StopStatus, error) {
|
||||
if force {
|
||||
// check if we have a process from another install-dir
|
||||
display.UpdateSpinnerMessage(spinner, "Checking for running instances...")
|
||||
statushooks.SetStatus(ctx, "Checking for running instances...")
|
||||
defer statushooks.Done(ctx)
|
||||
// do not use a context that can be cancelled
|
||||
killInstanceIfAny(context.Background())
|
||||
return ServiceStopped, nil
|
||||
@@ -139,9 +139,7 @@ func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) {
|
||||
return ServiceStopFailed, err
|
||||
}
|
||||
|
||||
display.UpdateSpinnerMessage(spinner, "Shutting down...")
|
||||
|
||||
err = doThreeStepPostgresExit(process)
|
||||
err = doThreeStepPostgresExit(ctx, process)
|
||||
if err != nil {
|
||||
// we couldn't stop it still.
|
||||
// timeout
|
||||
@@ -176,7 +174,7 @@ func stopDBService(spinner *spinner.Spinner, force bool) (StopStatus, error) {
|
||||
checked that the service can indeed shutdown gracefully,
|
||||
the sequence is there only as a backup.
|
||||
**/
|
||||
func doThreeStepPostgresExit(process *psutils.Process) error {
|
||||
func doThreeStepPostgresExit(ctx context.Context, process *psutils.Process) error {
|
||||
utils.LogTime("db_local.doThreeStepPostgresExit start")
|
||||
defer utils.LogTime("db_local.doThreeStepPostgresExit end")
|
||||
|
||||
@@ -191,6 +189,11 @@ func doThreeStepPostgresExit(process *psutils.Process) error {
|
||||
exitSuccessful = waitForProcessExit(process, 2*time.Second)
|
||||
if !exitSuccessful {
|
||||
// process didn't quit
|
||||
|
||||
// set status, as this is taking time
|
||||
statushooks.SetStatus(ctx, "Shutting down...")
|
||||
defer statushooks.Done(ctx)
|
||||
|
||||
// try a SIGINT
|
||||
err = process.SendSignal(syscall.SIGINT)
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package display
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -20,17 +21,17 @@ import (
|
||||
)
|
||||
|
||||
// ShowOutput :: displays the output using the proper formatter as applicable
|
||||
func ShowOutput(result *queryresult.Result) {
|
||||
func ShowOutput(ctx context.Context, result *queryresult.Result) {
|
||||
output := cmdconfig.Viper().GetString(constants.ArgOutput)
|
||||
if output == constants.OutputFormatJSON {
|
||||
displayJSON(result)
|
||||
displayJSON(ctx, result)
|
||||
} else if output == constants.OutputFormatCSV {
|
||||
displayCSV(result)
|
||||
displayCSV(ctx, result)
|
||||
} else if output == constants.OutputFormatLine {
|
||||
displayLine(result)
|
||||
displayLine(ctx, result)
|
||||
} else {
|
||||
// default
|
||||
displayTable(result)
|
||||
displayTable(ctx, result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -102,7 +103,7 @@ func getColumnSettings(headers []string, rows [][]string) ([]table.ColumnConfig,
|
||||
return colConfigs, headerRow
|
||||
}
|
||||
|
||||
func displayLine(result *queryresult.Result) {
|
||||
func displayLine(ctx context.Context, result *queryresult.Result) {
|
||||
colNames := ColumnNames(result.ColTypes)
|
||||
maxColNameLength := 0
|
||||
for _, colName := range colNames {
|
||||
@@ -158,7 +159,7 @@ func displayLine(result *queryresult.Result) {
|
||||
|
||||
// call this function for each row
|
||||
if err := iterateResults(result, rowFunc); err != nil {
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -173,7 +174,7 @@ func getTerminalColumnsRequiredForString(str string) int {
|
||||
return colsRequired
|
||||
}
|
||||
|
||||
func displayJSON(result *queryresult.Result) {
|
||||
func displayJSON(ctx context.Context, result *queryresult.Result) {
|
||||
var jsonOutput []map[string]interface{}
|
||||
|
||||
// define function to add each row to the JSON output
|
||||
@@ -188,7 +189,7 @@ func displayJSON(result *queryresult.Result) {
|
||||
|
||||
// call this function for each row
|
||||
if err := iterateResults(result, rowFunc); err != nil {
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
return
|
||||
}
|
||||
// display the JSON
|
||||
@@ -202,7 +203,7 @@ func displayJSON(result *queryresult.Result) {
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func displayCSV(result *queryresult.Result) {
|
||||
func displayCSV(ctx context.Context, result *queryresult.Result) {
|
||||
csvWriter := csv.NewWriter(os.Stdout)
|
||||
csvWriter.Comma = []rune(cmdconfig.Viper().GetString(constants.ArgSeparator))[0]
|
||||
|
||||
@@ -219,17 +220,17 @@ func displayCSV(result *queryresult.Result) {
|
||||
|
||||
// call this function for each row
|
||||
if err := iterateResults(result, rowFunc); err != nil {
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
return
|
||||
}
|
||||
|
||||
csvWriter.Flush()
|
||||
if csvWriter.Error() != nil {
|
||||
utils.ShowErrorWithMessage(csvWriter.Error(), "unable to print csv")
|
||||
utils.ShowErrorWithMessage(ctx, csvWriter.Error(), "unable to print csv")
|
||||
}
|
||||
}
|
||||
|
||||
func displayTable(result *queryresult.Result) {
|
||||
func displayTable(ctx context.Context, result *queryresult.Result) {
|
||||
// the buffer to put the output data in
|
||||
outbuf := bytes.NewBufferString("")
|
||||
|
||||
@@ -271,14 +272,14 @@ func displayTable(result *queryresult.Result) {
|
||||
if err != nil {
|
||||
// display the error
|
||||
fmt.Println()
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
fmt.Println()
|
||||
}
|
||||
// write out the table to the buffer
|
||||
t.Render()
|
||||
|
||||
// page out the table
|
||||
ShowPaged(outbuf.String())
|
||||
ShowPaged(ctx, outbuf.String())
|
||||
|
||||
// if timer is turned on
|
||||
if cmdconfig.Viper().GetBool(constants.ArgTimer) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package display
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
@@ -15,9 +16,9 @@ import (
|
||||
)
|
||||
|
||||
// ShowPaged displays the `content` in a system dependent pager
|
||||
func ShowPaged(content string) {
|
||||
func ShowPaged(ctx context.Context, content string) {
|
||||
if isPagerNeeded(content) && (runtime.GOOS == "darwin" || runtime.GOOS == "linux") {
|
||||
nixPager(content)
|
||||
nixPager(ctx, content)
|
||||
} else {
|
||||
nullPager(content)
|
||||
}
|
||||
@@ -59,11 +60,11 @@ func nullPager(content string) {
|
||||
fmt.Print(content)
|
||||
}
|
||||
|
||||
func nixPager(content string) {
|
||||
func nixPager(ctx context.Context, content string) {
|
||||
if isLessAvailable() {
|
||||
execPager(exec.Command("less", "-SRXF"), content)
|
||||
execPager(ctx, exec.Command("less", "-SRXF"), content)
|
||||
} else if isMoreAvailable() {
|
||||
execPager(exec.Command("more"), content)
|
||||
execPager(ctx, exec.Command("more"), content)
|
||||
} else {
|
||||
nullPager(content)
|
||||
}
|
||||
@@ -79,13 +80,13 @@ func isMoreAvailable() bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func execPager(cmd *exec.Cmd, content string) {
|
||||
func execPager(ctx context.Context, cmd *exec.Cmd, content string) {
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = strings.NewReader(content)
|
||||
// run the command - it will block until the pager is exited
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
utils.ShowErrorWithMessage(err, "could not display results")
|
||||
utils.ShowErrorWithMessage(ctx, err, "could not display results")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
package display
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/karrick/gows"
|
||||
"github.com/spf13/viper"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
)
|
||||
|
||||
//
|
||||
// spinner format:
|
||||
// <spinner><space><message><space><dot><dot><dot><cursor>
|
||||
// 1 1 [.......] 1 1 1 1 1
|
||||
// We need at least seven characters to show the spinner properly
|
||||
//
|
||||
// Not using the (…) character, since it is too small
|
||||
//
|
||||
const minSpinnerWidth = 7
|
||||
|
||||
func truncateSpinnerMessageToScreen(msg string) string {
|
||||
if len(strings.TrimSpace(msg)) == 0 {
|
||||
// if this is a blank message, return it as is
|
||||
return msg
|
||||
}
|
||||
|
||||
maxCols, _, _ := gows.GetWinSize()
|
||||
// if the screen is smaller than the minimum spinner width, we cannot truncate
|
||||
if maxCols < minSpinnerWidth {
|
||||
return msg
|
||||
}
|
||||
availableColumns := maxCols - minSpinnerWidth
|
||||
if len(msg) > availableColumns {
|
||||
msg = msg[:availableColumns]
|
||||
msg = fmt.Sprintf("%s ...", msg)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// StartSpinnerAfterDelay shows the spinner with the given `msg` if and only if `cancelStartIf` resolves
|
||||
// after `delay`.
|
||||
//
|
||||
// Example: if delay is 2 seconds and `cancelStartIf` resolves after 2.5 seconds, the spinner
|
||||
// will show for 0.5 seconds. If `cancelStartIf` resolves after 1.5 seconds, the spinner will
|
||||
// NOT be shown at all
|
||||
//
|
||||
func StartSpinnerAfterDelay(msg string, delay time.Duration, cancelStartIf chan bool) *spinner.Spinner {
|
||||
if !viper.GetBool(constants.ConfigKeyIsTerminalTTY) {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg = truncateSpinnerMessageToScreen(msg)
|
||||
spinner := spinner.New(
|
||||
spinner.CharSets[14],
|
||||
100*time.Millisecond,
|
||||
spinner.WithHiddenCursor(true),
|
||||
spinner.WithWriter(os.Stdout),
|
||||
spinner.WithSuffix(fmt.Sprintf(" %s", msg)),
|
||||
)
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-cancelStartIf:
|
||||
case <-time.After(delay):
|
||||
if spinner != nil && !spinner.Active() {
|
||||
spinner.Start()
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}()
|
||||
|
||||
return spinner
|
||||
}
|
||||
|
||||
// ShowSpinner shows a spinner with the given message
|
||||
func ShowSpinner(msg string) *spinner.Spinner {
|
||||
if !viper.GetBool(constants.ConfigKeyIsTerminalTTY) {
|
||||
return nil
|
||||
}
|
||||
|
||||
msg = truncateSpinnerMessageToScreen(msg)
|
||||
s := spinner.New(
|
||||
spinner.CharSets[14],
|
||||
100*time.Millisecond,
|
||||
spinner.WithHiddenCursor(true),
|
||||
spinner.WithWriter(os.Stdout),
|
||||
spinner.WithSuffix(fmt.Sprintf(" %s", msg)),
|
||||
)
|
||||
s.Start()
|
||||
return s
|
||||
}
|
||||
|
||||
// StopSpinnerWithMessage stops a spinner instance and clears it, after writing `finalMsg`
|
||||
func StopSpinnerWithMessage(spinner *spinner.Spinner, finalMsg string) {
|
||||
if spinner != nil {
|
||||
spinner.FinalMSG = finalMsg
|
||||
spinner.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// StopSpinner stops a spinner instance and clears it
|
||||
func StopSpinner(spinner *spinner.Spinner) {
|
||||
if spinner != nil {
|
||||
spinner.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func ResumeSpinner(spinner *spinner.Spinner) {
|
||||
if spinner != nil && !spinner.Active() {
|
||||
spinner.Restart()
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSpinnerMessage updates the message of the given spinner
|
||||
func UpdateSpinnerMessage(spinner *spinner.Spinner, newMessage string) {
|
||||
if spinner != nil {
|
||||
newMessage = truncateSpinnerMessageToScreen(newMessage)
|
||||
spinner.Suffix = fmt.Sprintf(" %s", newMessage)
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package executionlayer
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/report/reportevents"
|
||||
"github.com/turbot/steampipe/report/reportexecute"
|
||||
@@ -11,6 +13,9 @@ import (
|
||||
)
|
||||
|
||||
func ExecuteReportNode(ctx context.Context, reportName string, workspace *workspace.Workspace, client db_common.Client) error {
|
||||
// create context for the report execution
|
||||
// (for now just disable all status messages - replace with event based? )
|
||||
reportCtx := statushooks.DisableStatusHooks(ctx)
|
||||
executionTree, err := reportexecute.NewReportExecutionTree(reportName, client, workspace)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -19,7 +24,7 @@ func ExecuteReportNode(ctx context.Context, reportName string, workspace *worksp
|
||||
go func() {
|
||||
workspace.PublishReportEvent(&reportevents.ExecutionStarted{ReportNode: executionTree.Root})
|
||||
|
||||
if err := executionTree.Execute(ctx); err != nil {
|
||||
if err := executionTree.Execute(reportCtx); err != nil {
|
||||
if executionTree.Root.GetRunStatus() == reportinterfaces.ReportRunError {
|
||||
// set error state on the root node
|
||||
executionTree.Root.SetError(err)
|
||||
|
||||
@@ -17,13 +17,14 @@ import (
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
"github.com/turbot/steampipe/cmdconfig"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/contexthelpers"
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/query"
|
||||
"github.com/turbot/steampipe/query/metaquery"
|
||||
"github.com/turbot/steampipe/query/queryhistory"
|
||||
"github.com/turbot/steampipe/query/queryresult"
|
||||
"github.com/turbot/steampipe/schema"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/steampipeconfig"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
"github.com/turbot/steampipe/version"
|
||||
@@ -54,7 +55,11 @@ type InteractiveClient struct {
|
||||
// lock while execution is occurring to avoid errors/warnings being shown
|
||||
executionLock sync.Mutex
|
||||
schemaMetadata *schema.Metadata
|
||||
highlighter *Highlighter
|
||||
|
||||
highlighter *Highlighter
|
||||
|
||||
// status update hooks
|
||||
statusHook statushooks.StatusHooks
|
||||
}
|
||||
|
||||
func getHighlighter(theme string) *Highlighter {
|
||||
@@ -65,7 +70,7 @@ func getHighlighter(theme string) *Highlighter {
|
||||
)
|
||||
}
|
||||
|
||||
func newInteractiveClient(initData *query.InitData, resultsStreamer *queryresult.ResultStreamer) (*InteractiveClient, error) {
|
||||
func newInteractiveClient(ctx context.Context, initData *query.InitData, resultsStreamer *queryresult.ResultStreamer) (*InteractiveClient, error) {
|
||||
c := &InteractiveClient{
|
||||
initData: initData,
|
||||
resultsStreamer: resultsStreamer,
|
||||
@@ -75,30 +80,33 @@ func newInteractiveClient(initData *query.InitData, resultsStreamer *queryresult
|
||||
initResultChan: make(chan *db_common.InitResult, 1),
|
||||
highlighter: getHighlighter(viper.GetString(constants.ArgTheme)),
|
||||
}
|
||||
|
||||
// asynchronously wait for init to complete
|
||||
// we start this immediately rather than lazy loading as we want to handle errors asap
|
||||
go c.readInitDataStream()
|
||||
go c.readInitDataStream(ctx)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// InteractivePrompt starts an interactive prompt and return
|
||||
func (c *InteractiveClient) InteractivePrompt() {
|
||||
func (c *InteractiveClient) InteractivePrompt(ctx context.Context) {
|
||||
// start a cancel handler for the interactive client - this will call activeQueryCancelFunc if it is set
|
||||
// (registered when we call createQueryContext)
|
||||
interruptSignalChannel := c.startCancelHandler()
|
||||
interruptSignalChannel := contexthelpers.StartCancelHandler(c.cancelActiveQueryIfAny)
|
||||
|
||||
// create a cancel context for the prompt - this will set c.cancelPrompt
|
||||
promptCtx := c.createPromptContext()
|
||||
parentContext := ctx
|
||||
ctx = c.createPromptContext(parentContext)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
}
|
||||
// close up the SIGINT channel so that the receiver goroutine can quit
|
||||
signal.Stop(interruptSignalChannel)
|
||||
close(interruptSignalChannel)
|
||||
|
||||
// cleanup the init data to ensure any services we started are stopped
|
||||
c.initData.Cleanup()
|
||||
c.initData.Cleanup(ctx)
|
||||
|
||||
// close the result stream
|
||||
// this needs to be the last thing we do,
|
||||
@@ -106,18 +114,19 @@ func (c *InteractiveClient) InteractivePrompt() {
|
||||
c.resultsStreamer.Close()
|
||||
}()
|
||||
|
||||
fmt.Printf("Welcome to Steampipe v%s\n", version.SteampipeVersion.String())
|
||||
fmt.Printf("For more information, type %s\n", constants.Bold(".help"))
|
||||
statushooks.Message(ctx,
|
||||
fmt.Sprintf("Welcome to Steampipe v%s", version.SteampipeVersion.String()),
|
||||
fmt.Sprintf("For more information, type %s", constants.Bold(".help")))
|
||||
|
||||
// run the prompt in a goroutine, so we can also detect async initialisation errors
|
||||
promptResultChan := make(chan utils.InteractiveExitStatus, 1)
|
||||
c.runInteractivePromptAsync(promptCtx, &promptResultChan)
|
||||
c.runInteractivePromptAsync(ctx, &promptResultChan)
|
||||
|
||||
// select results
|
||||
for {
|
||||
select {
|
||||
case initResult := <-c.initResultChan:
|
||||
c.handleInitResult(promptCtx, initResult)
|
||||
c.handleInitResult(ctx, initResult)
|
||||
// if there was an error, handleInitResult will shut down the prompt
|
||||
// - we must wait for it to shut down and not return immediately
|
||||
|
||||
@@ -129,9 +138,9 @@ func (c *InteractiveClient) InteractivePrompt() {
|
||||
return
|
||||
}
|
||||
// create new context
|
||||
promptCtx = c.createPromptContext()
|
||||
ctx = c.createPromptContext(parentContext)
|
||||
// now run it again
|
||||
c.runInteractivePromptAsync(promptCtx, &promptResultChan)
|
||||
c.runInteractivePromptAsync(ctx, &promptResultChan)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -184,7 +193,7 @@ func (c *InteractiveClient) handleInitResult(ctx context.Context, initResult *db
|
||||
c.ClosePrompt(AfterPromptCloseExit)
|
||||
// add newline to ensure error is not printed at end of current prompt line
|
||||
fmt.Println()
|
||||
utils.ShowError(initResult.Error)
|
||||
utils.ShowError(ctx, initResult.Error)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -231,7 +240,7 @@ func (c *InteractiveClient) runInteractivePrompt(ctx context.Context) (ret utils
|
||||
}()
|
||||
|
||||
callExecutor := func(line string) {
|
||||
c.executor(line)
|
||||
c.executor(ctx, line)
|
||||
}
|
||||
completer := func(d prompt.Document) []prompt.Suggest {
|
||||
return c.queryCompleter(d)
|
||||
@@ -333,7 +342,7 @@ func (c *InteractiveClient) breakMultilinePrompt(buffer *prompt.Buffer) {
|
||||
c.interactiveBuffer = []string{}
|
||||
}
|
||||
|
||||
func (c *InteractiveClient) executor(line string) {
|
||||
func (c *InteractiveClient) executor(ctx context.Context, line string) {
|
||||
// take an execution lock, so that errors and warnings don't show up while
|
||||
// we are underway
|
||||
c.executionLock.Lock()
|
||||
@@ -347,10 +356,10 @@ func (c *InteractiveClient) executor(line string) {
|
||||
// we want to store even if we fail to resolve a query
|
||||
c.interactiveQueryHistory.Push(line)
|
||||
|
||||
query, err := c.getQuery(line)
|
||||
query, err := c.getQuery(ctx, line)
|
||||
if query == "" {
|
||||
if err != nil {
|
||||
utils.ShowError(utils.HandleCancelError(err))
|
||||
utils.ShowError(ctx, utils.HandleCancelError(err))
|
||||
}
|
||||
// restart the prompt
|
||||
c.restartInteractiveSession()
|
||||
@@ -358,20 +367,20 @@ func (c *InteractiveClient) executor(line string) {
|
||||
}
|
||||
|
||||
// create a context for the execution of the query
|
||||
queryContext := c.createQueryContext()
|
||||
queryContext := c.createQueryContext(ctx)
|
||||
|
||||
if metaquery.IsMetaQuery(query) {
|
||||
if err := c.executeMetaquery(queryContext, query); err != nil {
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
}
|
||||
// cancel the context
|
||||
c.cancelActiveQueryIfAny()
|
||||
|
||||
} else {
|
||||
// otherwise execute query
|
||||
result, err := c.client().Execute(queryContext, query, false)
|
||||
result, err := c.client().Execute(queryContext, query)
|
||||
if err != nil {
|
||||
utils.ShowError(utils.HandleCancelError(err))
|
||||
utils.ShowError(ctx, utils.HandleCancelError(err))
|
||||
} else {
|
||||
c.resultsStreamer.StreamResult(result)
|
||||
}
|
||||
@@ -381,7 +390,7 @@ func (c *InteractiveClient) executor(line string) {
|
||||
c.restartInteractiveSession()
|
||||
}
|
||||
|
||||
func (c *InteractiveClient) getQuery(line string) (string, error) {
|
||||
func (c *InteractiveClient) getQuery(ctx context.Context, line string) (string, error) {
|
||||
// if it's an empty line, then we don't need to do anything
|
||||
if line == "" {
|
||||
return "", nil
|
||||
@@ -391,23 +400,20 @@ func (c *InteractiveClient) getQuery(line string) (string, error) {
|
||||
if !c.isInitialised() {
|
||||
// create a context used purely to detect cancellation during initialisation
|
||||
// this will also set c.cancelActiveQuery
|
||||
queryContext := c.createQueryContext()
|
||||
queryContext := c.createQueryContext(ctx)
|
||||
defer func() {
|
||||
// cancel this context
|
||||
c.cancelActiveQueryIfAny()
|
||||
}()
|
||||
|
||||
initDoneChan := make(chan bool)
|
||||
sp := display.StartSpinnerAfterDelay("Initializing...", constants.SpinnerShowTimeout, initDoneChan)
|
||||
statushooks.SetStatus(ctx, "Initializing...")
|
||||
// wait for client initialisation to complete
|
||||
if err := c.waitForInitData(queryContext); err != nil {
|
||||
err := c.waitForInitData(queryContext)
|
||||
statushooks.Done(ctx)
|
||||
if err != nil {
|
||||
// if it failed, report error and quit
|
||||
close(initDoneChan)
|
||||
display.StopSpinner(sp)
|
||||
return "", err
|
||||
}
|
||||
close(initDoneChan)
|
||||
display.StopSpinner(sp)
|
||||
}
|
||||
|
||||
// push the current line into the buffer
|
||||
@@ -420,7 +426,7 @@ func (c *InteractiveClient) getQuery(line string) (string, error) {
|
||||
query, _, err := c.workspace().ResolveQueryAndArgs(queryString)
|
||||
if err != nil {
|
||||
// if we fail to resolve, show error but do not return it - we want to stay in the prompt
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
return "", nil
|
||||
}
|
||||
isNamedQuery := query != queryString
|
||||
|
||||
@@ -2,31 +2,21 @@ package interactive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func (c *InteractiveClient) startCancelHandler() chan os.Signal {
|
||||
interruptSignalChannel := make(chan os.Signal, 10)
|
||||
signal.Notify(interruptSignalChannel, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
for range interruptSignalChannel {
|
||||
c.cancelActiveQueryIfAny()
|
||||
}
|
||||
}()
|
||||
return interruptSignalChannel
|
||||
}
|
||||
|
||||
// create a cancel context for the interactive prompt, and set c.cancelFunc
|
||||
func (c *InteractiveClient) createPromptContext() context.Context {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func (c *InteractiveClient) createPromptContext(parentContext context.Context) context.Context {
|
||||
// ensure previous prompt is cleaned up
|
||||
if c.cancelPrompt != nil {
|
||||
c.cancelPrompt()
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parentContext)
|
||||
c.cancelPrompt = cancel
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (c *InteractiveClient) createQueryContext() context.Context {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func (c *InteractiveClient) createQueryContext(ctx context.Context) context.Context {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
c.cancelActiveQuery = cancel
|
||||
return ctx
|
||||
}
|
||||
|
||||
@@ -16,11 +16,11 @@ import (
|
||||
|
||||
var initTimeout = 40 * time.Second
|
||||
|
||||
func (c *InteractiveClient) readInitDataStream() {
|
||||
func (c *InteractiveClient) readInitDataStream(ctx context.Context) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
c.interactivePrompt.ClearScreen()
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
|
||||
}
|
||||
}()
|
||||
@@ -40,15 +40,15 @@ func (c *InteractiveClient) readInitDataStream() {
|
||||
// start the workspace file watcher
|
||||
if viper.GetBool(constants.ArgWatch) {
|
||||
// provide an explicit error handler which re-renders the prompt after displaying the error
|
||||
c.initData.Result.Error = c.initData.Workspace.SetupWatcher(c.initData.Client, c.workspaceWatcherErrorHandler)
|
||||
c.initData.Result.Error = c.initData.Workspace.SetupWatcher(ctx, c.initData.Client, c.workspaceWatcherErrorHandler)
|
||||
|
||||
}
|
||||
c.initResultChan <- c.initData.Result
|
||||
}
|
||||
|
||||
func (c *InteractiveClient) workspaceWatcherErrorHandler(err error) {
|
||||
func (c *InteractiveClient) workspaceWatcherErrorHandler(ctx context.Context, err error) {
|
||||
fmt.Println()
|
||||
utils.ShowError(err)
|
||||
utils.ShowError(ctx, err)
|
||||
c.interactivePrompt.Render()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package interactive
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/db/db_local"
|
||||
"github.com/turbot/steampipe/query"
|
||||
@@ -9,19 +11,19 @@ import (
|
||||
)
|
||||
|
||||
// RunInteractivePrompt starts the interactive query prompt
|
||||
func RunInteractivePrompt(initData *query.InitData) (*queryresult.ResultStreamer, error) {
|
||||
func RunInteractivePrompt(ctx context.Context, initData *query.InitData) (*queryresult.ResultStreamer, error) {
|
||||
resultsStreamer := queryresult.NewResultStreamer()
|
||||
|
||||
interactiveClient, err := newInteractiveClient(initData, resultsStreamer)
|
||||
interactiveClient, err := newInteractiveClient(ctx, initData, resultsStreamer)
|
||||
if err != nil {
|
||||
utils.ShowErrorWithMessage(err, "interactive client failed to initialize")
|
||||
utils.ShowErrorWithMessage(ctx, err, "interactive client failed to initialize")
|
||||
// do not bind shutdown to any cancellable context
|
||||
db_local.ShutdownService(constants.InvokerQuery)
|
||||
db_local.ShutdownService(ctx, constants.InvokerQuery)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// start the interactive prompt in a go routine
|
||||
go interactiveClient.InteractivePrompt()
|
||||
go interactiveClient.InteractivePrompt(ctx)
|
||||
|
||||
return resultsStreamer, nil
|
||||
}
|
||||
|
||||
12
main.go
12
main.go
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -17,11 +18,12 @@ var Logger hclog.Logger
|
||||
var exitCode int
|
||||
|
||||
func main() {
|
||||
ctx := context.Background()
|
||||
utils.LogTime("main start")
|
||||
exitCode := 0
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
}
|
||||
utils.LogTime("main end")
|
||||
utils.DisplayProfileData()
|
||||
@@ -29,7 +31,7 @@ func main() {
|
||||
}()
|
||||
|
||||
// ensure steampipe is not being run as root
|
||||
checkRoot()
|
||||
checkRoot(ctx)
|
||||
|
||||
// increase the soft ULIMIT to match the hard limit
|
||||
err := setULimit()
|
||||
@@ -60,10 +62,10 @@ func setULimit() error {
|
||||
|
||||
// this is to replicate the user security mechanism of out underlying
|
||||
// postgresql engine.
|
||||
func checkRoot() {
|
||||
func checkRoot(ctx context.Context) {
|
||||
if os.Geteuid() == 0 {
|
||||
exitCode = 1
|
||||
utils.ShowError(fmt.Errorf(`Steampipe cannot be run as the "root" user.
|
||||
utils.ShowError(ctx, fmt.Errorf(`Steampipe cannot be run as the "root" user.
|
||||
To reduce security risk, use an unprivileged user account instead.`))
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
@@ -79,7 +81,7 @@ To reduce security risk, use an unprivileged user account instead.`))
|
||||
|
||||
if os.Geteuid() != os.Getuid() {
|
||||
exitCode = 1
|
||||
utils.ShowError(fmt.Errorf("real and effective user IDs must match."))
|
||||
utils.ShowError(ctx, fmt.Errorf("real and effective user IDs must match."))
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
package modinstaller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/turbot/go-kit/helpers"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
func UninstallWorkspaceDependencies(opts *InstallOpts) (*InstallData, error) {
|
||||
func UninstallWorkspaceDependencies(ctx context.Context, opts *InstallOpts) (*InstallData, error) {
|
||||
utils.LogTime("cmd.UninstallWorkspaceDependencies")
|
||||
defer func() {
|
||||
utils.LogTime("cmd.UninstallWorkspaceDependencies end")
|
||||
if r := recover(); r != nil {
|
||||
utils.ShowError(helpers.ToError(r))
|
||||
utils.ShowError(ctx, helpers.ToError(r))
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/filepaths"
|
||||
"github.com/turbot/steampipe/ociinstaller"
|
||||
"github.com/turbot/steampipe/ociinstaller/versionfile"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
"github.com/turbot/steampipe/steampipeconfig/modconfig"
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
@@ -22,9 +22,9 @@ const (
|
||||
)
|
||||
|
||||
// Remove removes an installed plugin
|
||||
func Remove(image string, pluginConnections map[string][]modconfig.Connection) error {
|
||||
spinner := display.ShowSpinner(fmt.Sprintf("Removing plugin %s", image))
|
||||
defer display.StopSpinner(spinner)
|
||||
func Remove(ctx context.Context, image string, pluginConnections map[string][]modconfig.Connection) error {
|
||||
statushooks.SetStatus(ctx, fmt.Sprintf("Removing plugin %s", image))
|
||||
defer statushooks.Done(ctx)
|
||||
|
||||
fullPluginName := ociinstaller.NewSteampipeImageRef(image).DisplayImageRef()
|
||||
|
||||
@@ -60,7 +60,7 @@ func Remove(image string, pluginConnections map[string][]modconfig.Connection) e
|
||||
connFiles := Unique(files)
|
||||
|
||||
if len(connFiles) > 0 {
|
||||
display.StopSpinner(spinner)
|
||||
|
||||
str := []string{fmt.Sprintf("\nUninstalled plugin %s\n\nNote: the following %s %s %s steampipe %s using the '%s' plugin:", image, utils.Pluralize("file", len(connFiles)), utils.Pluralize("has", len(connFiles)), utils.Pluralize("a", len(conns)), utils.Pluralize("connection", len(conns)), image)}
|
||||
for _, file := range connFiles {
|
||||
str = append(str, fmt.Sprintf("\n \t* file: %s", file))
|
||||
@@ -78,7 +78,7 @@ func Remove(image string, pluginConnections map[string][]modconfig.Connection) e
|
||||
}
|
||||
}
|
||||
str = append(str, fmt.Sprintf("\nPlease remove %s to continue using steampipe", utils.Pluralize("it", len(connFiles))))
|
||||
fmt.Println(strings.Join(str, "\n"))
|
||||
statushooks.Message(ctx, str...)
|
||||
fmt.Println()
|
||||
}
|
||||
return err
|
||||
|
||||
@@ -37,7 +37,7 @@ func NewInitData(ctx context.Context, w *workspace.Workspace, args []string) *In
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *InitData) Cleanup() {
|
||||
func (i *InitData) Cleanup(ctx context.Context) {
|
||||
// cancel any ongoing operation
|
||||
if i.cancel != nil {
|
||||
i.cancel()
|
||||
@@ -50,7 +50,7 @@ func (i *InitData) Cleanup() {
|
||||
|
||||
// if a client was initialised, close it
|
||||
if i.Client != nil {
|
||||
i.Client.Close()
|
||||
i.Client.Close(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/spf13/viper"
|
||||
"github.com/turbot/steampipe/constants"
|
||||
"github.com/turbot/steampipe/contexthelpers"
|
||||
"github.com/turbot/steampipe/db/db_common"
|
||||
"github.com/turbot/steampipe/display"
|
||||
"github.com/turbot/steampipe/interactive"
|
||||
@@ -13,34 +14,36 @@ import (
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
func RunInteractiveSession(initData *query.InitData) {
|
||||
func RunInteractiveSession(ctx context.Context, initData *query.InitData) {
|
||||
utils.LogTime("execute.RunInteractiveSession start")
|
||||
defer utils.LogTime("execute.RunInteractiveSession end")
|
||||
|
||||
// the db executor sends result data over resultsStreamer
|
||||
resultsStreamer, err := interactive.RunInteractivePrompt(initData)
|
||||
resultsStreamer, err := interactive.RunInteractivePrompt(ctx, initData)
|
||||
utils.FailOnError(err)
|
||||
|
||||
// print the data as it comes
|
||||
for r := range resultsStreamer.Results {
|
||||
display.ShowOutput(r)
|
||||
display.ShowOutput(ctx, r)
|
||||
// signal to the resultStreamer that we are done with this chunk of the stream
|
||||
resultsStreamer.AllResultsRead()
|
||||
}
|
||||
}
|
||||
|
||||
func RunBatchSession(ctx context.Context, initData *query.InitData) int {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(ctx)
|
||||
|
||||
// start cancel handler to intercept interurpts and cancel the context
|
||||
contexthelpers.StartCancelHandler(cancel)
|
||||
|
||||
// wait for init
|
||||
<-initData.Loaded
|
||||
if err := initData.Result.Error; err != nil {
|
||||
utils.FailOnError(err)
|
||||
}
|
||||
// ensure we close client
|
||||
defer func() {
|
||||
if initData.Client != nil {
|
||||
initData.Client.Close()
|
||||
}
|
||||
}()
|
||||
defer initData.Cleanup(ctx)
|
||||
|
||||
// display any initialisation messages/warnings
|
||||
initData.Result.DisplayMessages()
|
||||
@@ -86,7 +89,7 @@ func executeQuery(ctx context.Context, queryString string, client db_common.Clie
|
||||
|
||||
// print the data as it comes
|
||||
for r := range resultsStreamer.Results {
|
||||
display.ShowOutput(r)
|
||||
display.ShowOutput(ctx, r)
|
||||
// signal to the resultStreamer that we are done with this result
|
||||
resultsStreamer.AllResultsRead()
|
||||
}
|
||||
|
||||
@@ -166,7 +166,7 @@ func (e *ReportExecutionTree) ExecuteNode(ctx context.Context, name string) erro
|
||||
}
|
||||
|
||||
func (e *ReportExecutionTree) executePanelSQL(ctx context.Context, query string) ([][]interface{}, error) {
|
||||
queryResult, err := e.client.ExecuteSync(ctx, query, true)
|
||||
queryResult, err := e.client.ExecuteSync(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ type ReportClientInfo struct {
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context) (*Server, error) {
|
||||
dbClient, err := db_local.GetLocalClient(ctx, constants.InvokerReport)
|
||||
var dbClient, err = db_local.GetLocalClient(ctx, constants.InvokerReport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func NewServer(ctx context.Context) (*Server, error) {
|
||||
}
|
||||
refreshResult.ShowWarnings()
|
||||
|
||||
loadedWorkspace, err := workspace.Load(viper.GetString(constants.ArgWorkspaceChDir))
|
||||
loadedWorkspace, err := workspace.Load(ctx, viper.GetString(constants.ArgWorkspaceChDir))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -78,7 +78,7 @@ func NewServer(ctx context.Context) (*Server, error) {
|
||||
}
|
||||
|
||||
loadedWorkspace.RegisterReportEventHandler(server.HandleWorkspaceUpdate)
|
||||
err = loadedWorkspace.SetupWatcher(dbClient, nil)
|
||||
err = loadedWorkspace.SetupWatcher(ctx, dbClient, nil)
|
||||
|
||||
return server, err
|
||||
}
|
||||
@@ -129,10 +129,10 @@ func (s *Server) Start() {
|
||||
StartAPI(s.context, s.webSocket)
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown() {
|
||||
func (s *Server) Shutdown(ctx context.Context) {
|
||||
// Close the DB client
|
||||
if s.dbClient != nil {
|
||||
s.dbClient.Close()
|
||||
s.dbClient.Close(ctx)
|
||||
}
|
||||
|
||||
if s.webSocket != nil {
|
||||
|
||||
42
statushooks/context.go
Normal file
42
statushooks/context.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package statushooks
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/turbot/steampipe/contexthelpers"
|
||||
)
|
||||
|
||||
var (
|
||||
contextKeyStatusHook = contexthelpers.ContextKey("status_hook")
|
||||
)
|
||||
|
||||
func DisableStatusHooks(ctx context.Context) context.Context {
|
||||
return AddStatusHooksToContext(ctx, NullHooks)
|
||||
}
|
||||
|
||||
func AddStatusHooksToContext(ctx context.Context, statusHooks StatusHooks) context.Context {
|
||||
return context.WithValue(ctx, contextKeyStatusHook, statusHooks)
|
||||
}
|
||||
|
||||
func StatusHooksFromContext(ctx context.Context) StatusHooks {
|
||||
if ctx == nil {
|
||||
return NullHooks
|
||||
}
|
||||
if val, ok := ctx.Value(contextKeyStatusHook).(StatusHooks); ok {
|
||||
return val
|
||||
}
|
||||
// no status hook in context - return null status hook
|
||||
return NullHooks
|
||||
}
|
||||
|
||||
func SetStatus(ctx context.Context, msg string) {
|
||||
StatusHooksFromContext(ctx).SetStatus(msg)
|
||||
}
|
||||
|
||||
func Done(ctx context.Context) {
|
||||
StatusHooksFromContext(ctx).Done()
|
||||
}
|
||||
|
||||
func Message(ctx context.Context, msgs ...string) {
|
||||
StatusHooksFromContext(ctx).Message(msgs...)
|
||||
}
|
||||
15
statushooks/status_hooks.go
Normal file
15
statushooks/status_hooks.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package statushooks
|
||||
|
||||
type StatusHooks interface {
|
||||
SetStatus(string)
|
||||
Done()
|
||||
Message(...string)
|
||||
}
|
||||
|
||||
var NullHooks = &NullStatusHook{}
|
||||
|
||||
type NullStatusHook struct{}
|
||||
|
||||
func (*NullStatusHook) SetStatus(string) {}
|
||||
func (*NullStatusHook) Done() {}
|
||||
func (*NullStatusHook) Message(...string) {}
|
||||
137
statusspinner/spinner.go
Normal file
137
statusspinner/spinner.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package statusspinner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/briandowns/spinner"
|
||||
"github.com/karrick/gows"
|
||||
)
|
||||
|
||||
//
|
||||
// spinner format:
|
||||
// <spinner><space><message><space><dot><dot><dot><cursor>
|
||||
// 1 1 [.......] 1 1 1 1 1
|
||||
// We need at least seven characters to show the spinner properly
|
||||
//
|
||||
// Not using the (…) character, since it is too small
|
||||
//
|
||||
const minSpinnerWidth = 7
|
||||
|
||||
// StatusSpinner is a struct which implements StatusHooks, and uses a spinner to display status messages
|
||||
type StatusSpinner struct {
|
||||
spinner *spinner.Spinner
|
||||
delay time.Duration
|
||||
cancel chan struct{}
|
||||
}
|
||||
|
||||
type StatusSpinnerOpt func(*StatusSpinner)
|
||||
|
||||
func WithMessage(msg string) StatusSpinnerOpt {
|
||||
return func(s *StatusSpinner) {
|
||||
s.UpdateSpinnerMessage(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func WithDelay(delay time.Duration) StatusSpinnerOpt {
|
||||
return func(s *StatusSpinner) {
|
||||
s.delay = delay
|
||||
}
|
||||
}
|
||||
|
||||
func NewStatusSpinner(opts ...StatusSpinnerOpt) *StatusSpinner {
|
||||
res := &StatusSpinner{}
|
||||
|
||||
res.spinner = spinner.New(
|
||||
spinner.CharSets[14],
|
||||
100*time.Millisecond,
|
||||
spinner.WithHiddenCursor(true),
|
||||
spinner.WithWriter(os.Stdout),
|
||||
)
|
||||
for _, opt := range opts {
|
||||
opt(res)
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
// SetStatus implements StatusHooks
|
||||
func (s *StatusSpinner) SetStatus(msg string) {
|
||||
s.UpdateSpinnerMessage(msg)
|
||||
if !s.spinner.Active() {
|
||||
s.startSpinner()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatusSpinner) startSpinner() {
|
||||
if s.cancel != nil {
|
||||
// if there is a cancel channel, we are already waiting for the service to start after a delay
|
||||
return
|
||||
}
|
||||
if s.delay == 0 {
|
||||
s.spinner.Start()
|
||||
return
|
||||
}
|
||||
|
||||
s.cancel = make(chan struct{}, 1)
|
||||
go func() {
|
||||
select {
|
||||
case <-s.cancel:
|
||||
case <-time.After(s.delay):
|
||||
s.spinner.Start()
|
||||
s.cancel = nil
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *StatusSpinner) Message(msgs ...string) {
|
||||
if s.spinner.Active() {
|
||||
s.spinner.Stop()
|
||||
defer s.spinner.Start()
|
||||
}
|
||||
for _, msg := range msgs {
|
||||
fmt.Println(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Done implements StatusHooks
|
||||
func (s *StatusSpinner) Done() {
|
||||
if s.cancel != nil {
|
||||
close(s.cancel)
|
||||
}
|
||||
s.closeSpinner()
|
||||
}
|
||||
|
||||
// UpdateSpinnerMessage updates the message of the given spinner
|
||||
func (s *StatusSpinner) UpdateSpinnerMessage(newMessage string) {
|
||||
newMessage = s.truncateSpinnerMessageToScreen(newMessage)
|
||||
s.spinner.Suffix = fmt.Sprintf(" %s", newMessage)
|
||||
}
|
||||
|
||||
func (s *StatusSpinner) closeSpinner() {
|
||||
if s.spinner != nil {
|
||||
s.spinner.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *StatusSpinner) truncateSpinnerMessageToScreen(msg string) string {
|
||||
if len(strings.TrimSpace(msg)) == 0 {
|
||||
// if this is a blank message, return it as is
|
||||
return msg
|
||||
}
|
||||
|
||||
maxCols, _, _ := gows.GetWinSize()
|
||||
// if the screen is smaller than the minimum spinner width, we cannot truncate
|
||||
if maxCols < minSpinnerWidth {
|
||||
return msg
|
||||
}
|
||||
availableColumns := maxCols - minSpinnerWidth
|
||||
if len(msg) > availableColumns {
|
||||
msg = msg[:availableColumns]
|
||||
msg = fmt.Sprintf("%s ...", msg)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/shiena/ansicolor"
|
||||
"github.com/turbot/steampipe/statushooks"
|
||||
)
|
||||
|
||||
var (
|
||||
colorErr = color.RedString("Error")
|
||||
colorWarn = color.YellowString("Warning")
|
||||
colorErr = color.RedString("Error")
|
||||
colorWarn = color.YellowString("Warning")
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -34,20 +35,22 @@ func FailOnErrorWithMessage(err error, message string) {
|
||||
}
|
||||
}
|
||||
|
||||
func ShowError(err error) {
|
||||
func ShowError(ctx context.Context, err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = HandleCancelError(err)
|
||||
statushooks.Done(ctx)
|
||||
fmt.Fprintf(color.Output, "%s: %v\n", colorErr, TransformErrorToSteampipe(err))
|
||||
}
|
||||
|
||||
// ShowErrorWithMessage displays the given error nicely with the given message
|
||||
func ShowErrorWithMessage(err error, message string) {
|
||||
func ShowErrorWithMessage(ctx context.Context, err error, message string) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
err = HandleCancelError(err)
|
||||
statushooks.Done(ctx)
|
||||
fmt.Fprintf(color.Output, "%s: %s - %v\n", colorErr, message, TransformErrorToSteampipe(err))
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ package workspace
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@@ -42,7 +43,7 @@ type Workspace struct {
|
||||
exclusions []string
|
||||
// should we load/watch files recursively
|
||||
listFlag filehelpers.ListFlag
|
||||
fileWatcherErrorHandler func(error)
|
||||
fileWatcherErrorHandler func(context.Context, error)
|
||||
watcherError error
|
||||
// event handlers
|
||||
reportEventHandlers []reportevents.ReportEventHandler
|
||||
@@ -53,7 +54,7 @@ type Workspace struct {
|
||||
}
|
||||
|
||||
// Load creates a Workspace and loads the workspace mod
|
||||
func Load(workspacePath string) (*Workspace, error) {
|
||||
func Load(ctx context.Context, workspacePath string) (*Workspace, error) {
|
||||
utils.LogTime("workspace.Load start")
|
||||
defer utils.LogTime("workspace.Load end")
|
||||
|
||||
@@ -72,7 +73,7 @@ func Load(workspacePath string) (*Workspace, error) {
|
||||
}
|
||||
|
||||
// load the workspace mod
|
||||
if err := workspace.loadWorkspaceMod(); err != nil {
|
||||
if err := workspace.loadWorkspaceMod(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -101,7 +102,7 @@ func LoadResourceNames(workspacePath string) (*modconfig.WorkspaceResources, err
|
||||
return workspace.loadWorkspaceResourceName()
|
||||
}
|
||||
|
||||
func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(error)) error {
|
||||
func (w *Workspace) SetupWatcher(ctx context.Context, client db_common.Client, errorHandler func(context.Context, error)) error {
|
||||
watcherOptions := &utils.WatcherOptions{
|
||||
Directories: []string{w.Path},
|
||||
Include: filehelpers.InclusionsFromExtensions(steampipeconfig.GetModFileExtensions()),
|
||||
@@ -112,7 +113,7 @@ func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(erro
|
||||
// decide how to handle them
|
||||
// OnError: errCallback,
|
||||
OnChange: func(events []fsnotify.Event) {
|
||||
w.handleFileWatcherEvent(client, events)
|
||||
w.handleFileWatcherEvent(ctx, client, events)
|
||||
},
|
||||
}
|
||||
watcher, err := utils.NewWatcher(watcherOptions)
|
||||
@@ -127,9 +128,9 @@ func (w *Workspace) SetupWatcher(client db_common.Client, errorHandler func(erro
|
||||
// after a file watcher event
|
||||
w.fileWatcherErrorHandler = errorHandler
|
||||
if w.fileWatcherErrorHandler == nil {
|
||||
w.fileWatcherErrorHandler = func(err error) {
|
||||
w.fileWatcherErrorHandler = func(ctx context.Context, err error) {
|
||||
fmt.Println()
|
||||
utils.ShowErrorWithMessage(err, "Failed to reload mod from file watcher")
|
||||
utils.ShowErrorWithMessage(ctx, err, "Failed to reload mod from file watcher")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -249,11 +250,11 @@ func (w *Workspace) setModfileExists() {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Workspace) loadWorkspaceMod() error {
|
||||
func (w *Workspace) loadWorkspaceMod(ctx context.Context) error {
|
||||
// clear all resource maps
|
||||
w.reset()
|
||||
// load and evaluate all variables
|
||||
inputVariables, err := w.getAllVariables()
|
||||
inputVariables, err := w.getAllVariables(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -22,7 +22,7 @@ func (w *Workspace) RegisterReportEventHandler(handler reportevents.ReportEventH
|
||||
w.reportEventHandlers = append(w.reportEventHandlers, handler)
|
||||
}
|
||||
|
||||
func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsnotify.Event) {
|
||||
func (w *Workspace) handleFileWatcherEvent(ctx context.Context, client db_common.Client, events []fsnotify.Event) {
|
||||
w.loadLock.Lock()
|
||||
defer w.loadLock.Unlock()
|
||||
|
||||
@@ -32,11 +32,11 @@ func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsn
|
||||
prevResourceMaps := w.GetResourceMaps()
|
||||
|
||||
// now reload the workspace
|
||||
err := w.loadWorkspaceMod()
|
||||
err := w.loadWorkspaceMod(ctx)
|
||||
if err != nil {
|
||||
// check the existing watcher error - if we are already in an error state, do not show error
|
||||
if w.watcherError == nil {
|
||||
w.fileWatcherErrorHandler(utils.PrefixError(err, "Failed to reload workspace"))
|
||||
w.fileWatcherErrorHandler(ctx, utils.PrefixError(err, "Failed to reload workspace"))
|
||||
}
|
||||
// now set watcher error to new error
|
||||
w.watcherError = err
|
||||
@@ -53,7 +53,7 @@ func (w *Workspace) handleFileWatcherEvent(client db_common.Client, events []fsn
|
||||
res := client.RefreshSessions(context.Background())
|
||||
if res.Error != nil || len(res.Warnings) > 0 {
|
||||
fmt.Println()
|
||||
utils.ShowErrorWithMessage(res.Error, "error when refreshing session data")
|
||||
utils.ShowErrorWithMessage(ctx, res.Error, "error when refreshing session data")
|
||||
utils.ShowWarning(strings.Join(res.Warnings, "\n"))
|
||||
if w.onFileWatcherEventMessages != nil {
|
||||
w.onFileWatcherEventMessages()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -135,7 +136,7 @@ var testCasesLoadWorkspace = map[string]loadWorkspaceTest{
|
||||
func TestLoadWorkspace(t *testing.T) {
|
||||
for name, test := range testCasesLoadWorkspace {
|
||||
workspacePath, err := filepath.Abs(test.source)
|
||||
workspace, err := Load(workspacePath)
|
||||
workspace, err := Load(context.Background(), workspacePath)
|
||||
|
||||
if err != nil {
|
||||
if test.expected != "ERROR" {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package workspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"github.com/turbot/steampipe/utils"
|
||||
)
|
||||
|
||||
func (w *Workspace) getAllVariables() (map[string]*modconfig.Variable, error) {
|
||||
func (w *Workspace) getAllVariables(ctx context.Context) (map[string]*modconfig.Variable, error) {
|
||||
// build options used to load workspace
|
||||
runCtx, err := w.getRunContext()
|
||||
if err != nil {
|
||||
@@ -40,7 +41,7 @@ func (w *Workspace) getAllVariables() (map[string]*modconfig.Variable, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := validateVariables(variableMap, inputVariables); err != nil {
|
||||
if err := validateVariables(ctx, variableMap, inputVariables); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -73,21 +74,21 @@ func (w *Workspace) getInputVariables(variableMap map[string]*modconfig.Variable
|
||||
return parsedValues, diags.Err()
|
||||
}
|
||||
|
||||
func validateVariables(variableMap map[string]*modconfig.Variable, variables inputvars.InputValues) error {
|
||||
func validateVariables(ctx context.Context, variableMap map[string]*modconfig.Variable, variables inputvars.InputValues) error {
|
||||
diags := inputvars.CheckInputVariables(variableMap, variables)
|
||||
if diags.HasErrors() {
|
||||
displayValidationErrors(diags)
|
||||
displayValidationErrors(ctx, diags)
|
||||
// return empty error
|
||||
return modconfig.VariableValidationFailedError{}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func displayValidationErrors(diags tfdiags.Diagnostics) {
|
||||
func displayValidationErrors(ctx context.Context, diags tfdiags.Diagnostics) {
|
||||
fmt.Println()
|
||||
for i, diag := range diags {
|
||||
|
||||
utils.ShowError(fmt.Errorf("%s", constants.Bold(diag.Description().Summary)))
|
||||
utils.ShowError(ctx, fmt.Errorf("%s", constants.Bold(diag.Description().Summary)))
|
||||
fmt.Println(diag.Description().Detail)
|
||||
if i < len(diags)-1 {
|
||||
fmt.Println()
|
||||
|
||||
Reference in New Issue
Block a user