mirror of
https://github.com/turbot/steampipe.git
synced 2025-12-19 18:12:43 -05:00
* Add test demonstrating validateQueryArgs race condition Add concurrent test that demonstrates the thread-safety issue with validateQueryArgs() using global viper state. The test fails with data races when run with -race flag. Issue #4706 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * Fix validateQueryArgs thread-safety by passing config struct Replace global viper state access with a queryConfig struct parameter in validateQueryArgs(). This eliminates race conditions by reading configuration once in the caller and passing immutable values. Changes: - Add queryConfig struct to hold validation parameters - Update validateQueryArgs to accept config parameter - Modify runQueryCmd to read viper once and create config - Update all tests to pass config struct instead of using viper This makes validateQueryArgs thread-safe and easier to test. Fixes #4706 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
30
cmd/query.go
30
cmd/query.go
@@ -30,6 +30,15 @@ var queryTimingMode = constants.QueryTimingModeOff
|
||||
// variable used to assign the output mode flag
|
||||
var queryOutputMode = constants.QueryOutputModeTable
|
||||
|
||||
// queryConfig holds the configuration needed for query validation
|
||||
// This avoids concurrent access to global viper state
|
||||
type queryConfig struct {
|
||||
snapshot bool
|
||||
share bool
|
||||
export []string
|
||||
output string
|
||||
}
|
||||
|
||||
func queryCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "query",
|
||||
@@ -93,8 +102,16 @@ func runQueryCmd(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
}()
|
||||
|
||||
// Read configuration from viper once to avoid concurrent access issues
|
||||
cfg := &queryConfig{
|
||||
snapshot: viper.IsSet(pconstants.ArgSnapshot),
|
||||
share: viper.IsSet(pconstants.ArgShare),
|
||||
export: viper.GetStringSlice(pconstants.ArgExport),
|
||||
output: viper.GetString(pconstants.ArgOutput),
|
||||
}
|
||||
|
||||
// validate args
|
||||
err := validateQueryArgs(ctx, args)
|
||||
err := validateQueryArgs(ctx, args, cfg)
|
||||
error_helpers.FailOnError(err)
|
||||
|
||||
// if diagnostic mode is set, print out config and return
|
||||
@@ -150,13 +167,13 @@ func runQueryCmd(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
}
|
||||
|
||||
func validateQueryArgs(ctx context.Context, args []string) error {
|
||||
func validateQueryArgs(ctx context.Context, args []string, cfg *queryConfig) error {
|
||||
interactiveMode := len(args) == 0
|
||||
if interactiveMode && (viper.IsSet(pconstants.ArgSnapshot) || viper.IsSet(pconstants.ArgShare)) {
|
||||
if interactiveMode && (cfg.snapshot || cfg.share) {
|
||||
exitCode = constants.ExitCodeInsufficientOrWrongInputs
|
||||
return sperr.New("cannot share snapshots in interactive mode")
|
||||
}
|
||||
if interactiveMode && len(viper.GetStringSlice(pconstants.ArgExport)) > 0 {
|
||||
if interactiveMode && len(cfg.export) > 0 {
|
||||
exitCode = constants.ExitCodeInsufficientOrWrongInputs
|
||||
return sperr.New("cannot export query results in interactive mode")
|
||||
}
|
||||
@@ -168,10 +185,9 @@ func validateQueryArgs(ctx context.Context, args []string) error {
|
||||
}
|
||||
|
||||
validOutputFormats := []string{constants.OutputFormatLine, constants.OutputFormatCSV, constants.OutputFormatTable, constants.OutputFormatJSON, constants.OutputFormatSnapshot, constants.OutputFormatSnapshotShort, constants.OutputFormatNone}
|
||||
output := viper.GetString(pconstants.ArgOutput)
|
||||
if !slices.Contains(validOutputFormats, output) {
|
||||
if !slices.Contains(validOutputFormats, cfg.output) {
|
||||
exitCode = constants.ExitCodeInsufficientOrWrongInputs
|
||||
return sperr.New("invalid output format: '%s', must be one of [%s]", output, strings.Join(validOutputFormats, ", "))
|
||||
return sperr.New("invalid output format: '%s', must be one of [%s]", cfg.output, strings.Join(validOutputFormats, ", "))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/turbot/steampipe/v2/pkg/constants"
|
||||
)
|
||||
|
||||
func TestGetPipedStdinData_PreservesNewlines(t *testing.T) {
|
||||
@@ -52,3 +57,110 @@ func TestGetPipedStdinData_PreservesNewlines(t *testing.T) {
|
||||
t.Logf("Got lines: %v", resultLines)
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateQueryArgs_ConcurrentCalls tests that validateQueryArgs is thread-safe
|
||||
// Bug #4706: validateQueryArgs uses global viper state which is not thread-safe
|
||||
func TestValidateQueryArgs_ConcurrentCalls(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
var wg sync.WaitGroup
|
||||
errors := make(chan error, 100)
|
||||
|
||||
// Run 100 concurrent calls to validateQueryArgs
|
||||
for i := 0; i < 100; i++ {
|
||||
wg.Add(1)
|
||||
go func(iteration int) {
|
||||
defer wg.Done()
|
||||
|
||||
// Create config struct - this is now thread-safe
|
||||
// Each goroutine has its own config instance
|
||||
cfg := &queryConfig{
|
||||
snapshot: false,
|
||||
share: false,
|
||||
export: []string{},
|
||||
output: constants.OutputFormatTable,
|
||||
}
|
||||
|
||||
// Call validateQueryArgs with a query argument (non-interactive mode)
|
||||
err := validateQueryArgs(ctx, []string{"SELECT 1"}, cfg)
|
||||
if err != nil {
|
||||
errors <- err
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check if any errors occurred
|
||||
var errs []error
|
||||
for err := range errors {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
// The test should not panic or produce errors
|
||||
assert.Empty(t, errs, "validateQueryArgs should handle concurrent calls without errors")
|
||||
}
|
||||
|
||||
// TestValidateQueryArgs_InteractiveModeWithSnapshot tests validation in interactive mode with snapshot
|
||||
func TestValidateQueryArgs_InteractiveModeWithSnapshot(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup config with snapshot enabled
|
||||
cfg := &queryConfig{
|
||||
snapshot: true,
|
||||
share: false,
|
||||
export: []string{},
|
||||
output: constants.OutputFormatTable,
|
||||
}
|
||||
|
||||
// Call with no args (interactive mode)
|
||||
err := validateQueryArgs(ctx, []string{}, cfg)
|
||||
|
||||
// Should return error for snapshot in interactive mode
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot share snapshots in interactive mode")
|
||||
}
|
||||
|
||||
// TestValidateQueryArgs_BatchModeWithSnapshot tests validation in batch mode with snapshot
|
||||
func TestValidateQueryArgs_BatchModeWithSnapshot(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup config with snapshot enabled
|
||||
cfg := &queryConfig{
|
||||
snapshot: true,
|
||||
share: false,
|
||||
export: []string{},
|
||||
output: constants.OutputFormatTable,
|
||||
}
|
||||
|
||||
// Call with args (batch mode)
|
||||
err := validateQueryArgs(ctx, []string{"SELECT 1"}, cfg)
|
||||
|
||||
// Should not return error for snapshot in batch mode
|
||||
// (unless there are other validation errors from cmdconfig.ValidateSnapshotArgs)
|
||||
// For this test, we expect it to pass basic validation
|
||||
if err != nil {
|
||||
// If there's an error, it should not be about interactive mode
|
||||
assert.NotContains(t, err.Error(), "cannot share snapshots in interactive mode")
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateQueryArgs_InvalidOutputFormat tests validation with invalid output format
|
||||
func TestValidateQueryArgs_InvalidOutputFormat(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Setup config with invalid output format
|
||||
cfg := &queryConfig{
|
||||
snapshot: false,
|
||||
share: false,
|
||||
export: []string{},
|
||||
output: "invalid-format",
|
||||
}
|
||||
|
||||
// Call with args (batch mode)
|
||||
err := validateQueryArgs(ctx, []string{"SELECT 1"}, cfg)
|
||||
|
||||
// Should return error for invalid output format
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid output format")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user