Add comprehensive passing tests from bug hunting initiative (#4864)

This commit is contained in:
Nathan Wallace
2025-11-13 09:26:46 +08:00
committed by GitHub
parent aaef09b670
commit 2e5f3fda97
17 changed files with 5551 additions and 1 deletions

View File

@@ -0,0 +1,232 @@
package cmdconfig
import (
"testing"
"time"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
func TestPostRunHook_WaitsForTasks(t *testing.T) {
// Test that postRunHook waits for async tasks
cmd := &cobra.Command{
Use: "test",
Run: func(cmd *cobra.Command, args []string) {},
}
// Simulate a task channel
testChannel := make(chan struct{})
oldChannel := waitForTasksChannel
waitForTasksChannel = testChannel
defer func() { waitForTasksChannel = oldChannel }()
// Close the channel after a short delay
go func() {
time.Sleep(10 * time.Millisecond)
close(testChannel)
}()
start := time.Now()
postRunHook(cmd, []string{})
duration := time.Since(start)
// Should have waited for the channel to close
if duration < 10*time.Millisecond {
t.Error("postRunHook did not wait for tasks channel")
}
}
func TestPostRunHook_Timeout(t *testing.T) {
// Test that postRunHook times out if tasks take too long
cmd := &cobra.Command{
Use: "test",
Run: func(cmd *cobra.Command, args []string) {},
}
// Simulate a task channel that never closes
testChannel := make(chan struct{})
oldChannel := waitForTasksChannel
waitForTasksChannel = testChannel
defer func() {
waitForTasksChannel = oldChannel
close(testChannel)
}()
// Mock cancel function
cancelCalled := false
oldCancelFn := tasksCancelFn
tasksCancelFn = func() {
cancelCalled = true
}
defer func() { tasksCancelFn = oldCancelFn }()
start := time.Now()
postRunHook(cmd, []string{})
duration := time.Since(start)
// Should have timed out after 100ms
if duration < 100*time.Millisecond || duration > 150*time.Millisecond {
t.Errorf("postRunHook timeout not working correctly, took %v", duration)
}
if !cancelCalled {
t.Error("Cancel function was not called on timeout")
}
}
func TestCmdBuilder_HookIntegration(t *testing.T) {
// Test that CmdBuilder properly wraps hooks
cmd := &cobra.Command{
Use: "test",
Run: func(cmd *cobra.Command, args []string) {},
}
cmd.PreRun = func(cmd *cobra.Command, args []string) {
// Original PreRun
}
cmd.PostRun = func(cmd *cobra.Command, args []string) {
// Original PostRun
}
cmd.Run = func(cmd *cobra.Command, args []string) {
// Original Run
}
// Build with CmdBuilder
builder := OnCmd(cmd)
if builder == nil {
t.Fatal("OnCmd returned nil")
}
// The hooks should now be wrapped
if cmd.PreRun == nil {
t.Error("PreRun hook was not set")
}
if cmd.PostRun == nil {
t.Error("PostRun hook was not set")
}
if cmd.Run == nil {
t.Error("Run hook was not set")
}
// Note: We can't easily test the wrapped functions without a full cobra execution
// This would require integration tests
t.Log("CmdBuilder successfully wrapped command hooks")
}
func TestCmdBuilder_FlagBinding(t *testing.T) {
// Test that CmdBuilder properly binds flags to viper
viper.Reset()
defer viper.Reset()
cmd := &cobra.Command{
Use: "test",
Run: func(cmd *cobra.Command, args []string) {},
}
builder := OnCmd(cmd)
builder.AddStringFlag("test-flag", "default-value", "Test flag description")
// Verify flag was added
flag := cmd.Flags().Lookup("test-flag")
if flag == nil {
t.Fatal("Flag was not added to command")
}
if flag.DefValue != "default-value" {
t.Errorf("Flag default value incorrect, got %s", flag.DefValue)
}
// Verify binding was stored
if len(builder.bindings) != 1 {
t.Errorf("Expected 1 binding, got %d", len(builder.bindings))
}
if builder.bindings["test-flag"] != flag {
t.Error("Flag binding not stored correctly")
}
}
func TestCmdBuilder_MultipleFlagTypes(t *testing.T) {
// Test that CmdBuilder can handle multiple flag types
cmd := &cobra.Command{
Use: "test",
Run: func(cmd *cobra.Command, args []string) {},
}
builder := OnCmd(cmd)
builder.
AddStringFlag("string-flag", "default", "String flag").
AddIntFlag("int-flag", 42, "Int flag").
AddBoolFlag("bool-flag", true, "Bool flag").
AddStringSliceFlag("slice-flag", []string{"a", "b"}, "Slice flag")
// Verify all flags were added
if cmd.Flags().Lookup("string-flag") == nil {
t.Error("String flag not added")
}
if cmd.Flags().Lookup("int-flag") == nil {
t.Error("Int flag not added")
}
if cmd.Flags().Lookup("bool-flag") == nil {
t.Error("Bool flag not added")
}
if cmd.Flags().Lookup("slice-flag") == nil {
t.Error("Slice flag not added")
}
// Verify all bindings were stored
if len(builder.bindings) != 4 {
t.Errorf("Expected 4 bindings, got %d", len(builder.bindings))
}
}
func TestCmdBuilder_CloudFlags(t *testing.T) {
// Test that AddCloudFlags adds the expected flags
cmd := &cobra.Command{
Use: "test",
Run: func(cmd *cobra.Command, args []string) {},
}
builder := OnCmd(cmd)
builder.AddCloudFlags()
// Verify cloud flags were added
if cmd.Flags().Lookup("pipes-host") == nil {
t.Error("pipes-host flag not added")
}
if cmd.Flags().Lookup("pipes-token") == nil {
t.Error("pipes-token flag not added")
}
}
func TestCmdBuilder_NilFlagPanic(t *testing.T) {
// Test that nil flag causes panic (as documented in builder.go)
cmd := &cobra.Command{
Use: "test",
PreRun: func(cmd *cobra.Command, args []string) {
// This will be called by CmdBuilder's wrapped PreRun
},
Run: func(cmd *cobra.Command, args []string) {},
}
builder := OnCmd(cmd)
builder.AddStringFlag("test-flag", "default", "Test flag")
// Manually corrupt the bindings to test panic
builder.bindings["corrupt-flag"] = nil
// This should panic when PreRun is executed
defer func() {
if r := recover(); r == nil {
t.Error("Expected panic for nil flag binding")
} else {
t.Logf("Correctly panicked with: %v", r)
}
}()
// Execute PreRun which should panic
cmd.PreRun(cmd, []string{})
}

View File

@@ -0,0 +1,481 @@
package connection
import (
"context"
"errors"
"fmt"
"sync"
"testing"
"time"
"github.com/turbot/steampipe/v2/pkg/steampipeconfig"
)
// TestRefreshConnectionState_ExemplarSchemaMapConcurrentWrites tests concurrent writes to exemplarSchemaMap
// This verifies the fix for bug #4757
func TestRefreshConnectionState_ExemplarSchemaMapConcurrentWrites(t *testing.T) {
// ARRANGE: Create state with initialized maps
state := &refreshConnectionState{
exemplarSchemaMap: make(map[string]string),
exemplarSchemaMapMut: sync.Mutex{},
}
numGoroutines := 50
numIterations := 100
plugins := []string{"aws", "azure", "gcp", "github", "slack"}
var wg sync.WaitGroup
// ACT: Launch goroutines that concurrently write to exemplarSchemaMap
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < numIterations; j++ {
plugin := plugins[j%len(plugins)]
connectionName := fmt.Sprintf("conn_%d_%d", id, j)
// Simulate the FIXED pattern from executeUpdateForConnections (lines 600-605)
state.exemplarSchemaMapMut.Lock()
_, haveExemplar := state.exemplarSchemaMap[plugin]
state.exemplarSchemaMapMut.Unlock()
if !haveExemplar {
// This write is now protected by mutex (fix for #4757)
state.exemplarSchemaMapMut.Lock()
state.exemplarSchemaMap[plugin] = connectionName
state.exemplarSchemaMapMut.Unlock()
}
}
}(i)
}
wg.Wait()
// ASSERT: Verify all plugins are in the map
state.exemplarSchemaMapMut.Lock()
defer state.exemplarSchemaMapMut.Unlock()
if len(state.exemplarSchemaMap) != len(plugins) {
t.Errorf("Expected %d plugins in exemplarSchemaMap, got %d", len(plugins), len(state.exemplarSchemaMap))
}
for _, plugin := range plugins {
if _, ok := state.exemplarSchemaMap[plugin]; !ok {
t.Errorf("Expected plugin %s to be in exemplarSchemaMap", plugin)
}
}
}
// TestRefreshConnectionState_ExemplarSchemaMapConcurrentReadWrite tests concurrent reads and writes
func TestRefreshConnectionState_ExemplarSchemaMapConcurrentReadWrite(t *testing.T) {
// ARRANGE: Create state with some pre-populated data
state := &refreshConnectionState{
exemplarSchemaMap: map[string]string{
"aws": "aws_conn_1",
"azure": "azure_conn_1",
},
exemplarSchemaMapMut: sync.Mutex{},
}
numReaders := 30
numWriters := 20
duration := 100 * time.Millisecond
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), duration)
defer cancel()
// ACT: Launch reader goroutines
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
state.exemplarSchemaMapMut.Lock()
_ = state.exemplarSchemaMap["aws"]
state.exemplarSchemaMapMut.Unlock()
}
}
}()
}
// Launch writer goroutines
for i := 0; i < numWriters; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
default:
plugin := fmt.Sprintf("plugin_%d", id)
state.exemplarSchemaMapMut.Lock()
state.exemplarSchemaMap[plugin] = fmt.Sprintf("conn_%d", id)
state.exemplarSchemaMapMut.Unlock()
}
}
}(i)
}
wg.Wait()
// ASSERT: No race conditions should occur (run with -race flag)
state.exemplarSchemaMapMut.Lock()
defer state.exemplarSchemaMapMut.Unlock()
// Basic sanity check
if len(state.exemplarSchemaMap) < 2 {
t.Error("Expected at least 2 entries in exemplarSchemaMap")
}
}
// TestRefreshConnectionState_ExemplarMapRaceCondition tests the exact race condition from bug #4757
func TestRefreshConnectionState_ExemplarMapRaceCondition(t *testing.T) {
// This test verifies that the fix for #4757 works correctly
// The bug was: reading haveExemplarSchema without lock, then writing without lock
// The fix: both read and write are now properly protected by mutex
// ARRANGE
state := &refreshConnectionState{
exemplarSchemaMap: make(map[string]string),
exemplarSchemaMapMut: sync.Mutex{},
}
numGoroutines := 100
pluginName := "aws"
var wg sync.WaitGroup
errChan := make(chan error, numGoroutines)
// ACT: Simulate the exact pattern from executeUpdateForConnections
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
connectionName := fmt.Sprintf("aws_conn_%d", id)
// This is the FIXED pattern from lines 581-604
state.exemplarSchemaMapMut.Lock()
_, haveExemplarSchema := state.exemplarSchemaMap[pluginName]
state.exemplarSchemaMapMut.Unlock()
// Simulate some work
time.Sleep(time.Microsecond)
if !haveExemplarSchema {
// Write is now protected by mutex (fix for #4757)
state.exemplarSchemaMapMut.Lock()
// Check again after acquiring lock (double-check pattern)
if _, exists := state.exemplarSchemaMap[pluginName]; !exists {
state.exemplarSchemaMap[pluginName] = connectionName
}
state.exemplarSchemaMapMut.Unlock()
}
}(i)
}
wg.Wait()
close(errChan)
// ASSERT: Check for errors
for err := range errChan {
t.Error(err)
}
// Verify the map has exactly one entry for the plugin
state.exemplarSchemaMapMut.Lock()
defer state.exemplarSchemaMapMut.Unlock()
if len(state.exemplarSchemaMap) != 1 {
t.Errorf("Expected exactly 1 entry in exemplarSchemaMap, got %d", len(state.exemplarSchemaMap))
}
if _, ok := state.exemplarSchemaMap[pluginName]; !ok {
t.Error("Expected plugin to be in exemplarSchemaMap")
}
}
// TestUpdateSetMapToArray tests the conversion utility function
func TestUpdateSetMapToArray(t *testing.T) {
tests := []struct {
name string
input map[string][]*steampipeconfig.ConnectionState
expected int
}{
{
name: "empty_map",
input: map[string][]*steampipeconfig.ConnectionState{},
expected: 0,
},
{
name: "single_entry_single_state",
input: map[string][]*steampipeconfig.ConnectionState{
"plugin1": {
{ConnectionName: "conn1"},
},
},
expected: 1,
},
{
name: "single_entry_multiple_states",
input: map[string][]*steampipeconfig.ConnectionState{
"plugin1": {
{ConnectionName: "conn1"},
{ConnectionName: "conn2"},
{ConnectionName: "conn3"},
},
},
expected: 3,
},
{
name: "multiple_entries",
input: map[string][]*steampipeconfig.ConnectionState{
"plugin1": {
{ConnectionName: "conn1"},
{ConnectionName: "conn2"},
},
"plugin2": {
{ConnectionName: "conn3"},
},
},
expected: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// ACT
result := updateSetMapToArray(tt.input)
// ASSERT
if len(result) != tt.expected {
t.Errorf("Expected %d connection states, got %d", tt.expected, len(result))
}
})
}
}
// TestGetCloneSchemaQuery tests the schema cloning query generation
func TestGetCloneSchemaQuery(t *testing.T) {
tests := []struct {
name string
exemplarName string
connState *steampipeconfig.ConnectionState
expectedQuery string
}{
{
name: "basic_clone",
exemplarName: "aws_source",
connState: &steampipeconfig.ConnectionState{
ConnectionName: "aws_target",
Plugin: "hub.steampipe.io/plugins/turbot/aws@latest",
},
expectedQuery: "select clone_foreign_schema('aws_source', 'aws_target', 'hub.steampipe.io/plugins/turbot/aws@latest');",
},
{
name: "with_special_characters",
exemplarName: "test-source",
connState: &steampipeconfig.ConnectionState{
ConnectionName: "test-target",
Plugin: "test/plugin@1.0.0",
},
expectedQuery: "select clone_foreign_schema('test-source', 'test-target', 'test/plugin@1.0.0');",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// ACT
result := getCloneSchemaQuery(tt.exemplarName, tt.connState)
// ASSERT
if result != tt.expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", tt.expectedQuery, result)
}
})
}
}
// TestRefreshConnectionState_DeferErrorHandling tests error handling in defer blocks
func TestRefreshConnectionState_DeferErrorHandling(t *testing.T) {
// This tests the defer block at lines 98-108 in refreshConnections
// ARRANGE: Create state with a result that will have an error
state := &refreshConnectionState{
res: &steampipeconfig.RefreshConnectionResult{},
}
// Simulate setting an error
testErr := errors.New("test error")
state.res.Error = testErr
// ACT: The defer block should handle this gracefully
// In the actual code, this is called via defer func()
// We're testing the logic here
// ASSERT: Verify the defer logic works
if state.res != nil && state.res.Error != nil {
// This is what the defer does - it would call setIncompleteConnectionStateToError
// We're just verifying the nil checks work
if state.res.Error != testErr {
t.Error("Error should be preserved")
}
}
}
// TestRefreshConnectionState_NilResInDefer tests nil res handling in defer block
func TestRefreshConnectionState_NilResInDefer(t *testing.T) {
// ARRANGE: Create state with nil res
state := &refreshConnectionState{
res: nil,
}
// ACT & ASSERT: The defer block at line 98-108 checks if res is nil
// This should not panic
if state.res != nil {
t.Error("res should be nil")
}
}
// TestRefreshConnectionState_MultiplePluginsSameExemplar tests that only one exemplar is stored per plugin
func TestRefreshConnectionState_MultiplePluginsSameExemplar(t *testing.T) {
// ARRANGE
state := &refreshConnectionState{
exemplarSchemaMap: make(map[string]string),
exemplarSchemaMapMut: sync.Mutex{},
}
pluginName := "aws"
connections := []string{"aws1", "aws2", "aws3", "aws4", "aws5"}
// ACT: Add connections sequentially (simulating the pattern from the code)
for _, conn := range connections {
state.exemplarSchemaMapMut.Lock()
_, exists := state.exemplarSchemaMap[pluginName]
state.exemplarSchemaMapMut.Unlock()
if !exists {
state.exemplarSchemaMapMut.Lock()
// Double-check pattern
if _, exists := state.exemplarSchemaMap[pluginName]; !exists {
state.exemplarSchemaMap[pluginName] = conn
}
state.exemplarSchemaMapMut.Unlock()
}
}
// ASSERT: Only the first connection should be stored
state.exemplarSchemaMapMut.Lock()
defer state.exemplarSchemaMapMut.Unlock()
if len(state.exemplarSchemaMap) != 1 {
t.Errorf("Expected 1 entry, got %d", len(state.exemplarSchemaMap))
}
if exemplar, ok := state.exemplarSchemaMap[pluginName]; !ok {
t.Error("Expected plugin to be in map")
} else if exemplar != connections[0] {
t.Errorf("Expected first connection %s to be exemplar, got %s", connections[0], exemplar)
}
}
// TestRefreshConnectionState_ErrorChannelBlocking tests that error channel doesn't block
func TestRefreshConnectionState_ErrorChannelBlocking(t *testing.T) {
// This tests a potential bug in executeUpdateSetsInParallel where the error channel
// could block if it's not properly drained
// ARRANGE
errChan := make(chan *connectionError, 10) // Buffered channel
numErrors := 20 // More errors than buffer size
var wg sync.WaitGroup
// Start a consumer goroutine (like in the actual code at line 519-536)
consumerDone := make(chan bool)
go func() {
for {
select {
case err := <-errChan:
if err == nil {
consumerDone <- true
return
}
// Process error
_ = err
}
}
}()
// ACT: Send many errors
for i := 0; i < numErrors; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
errChan <- &connectionError{
name: fmt.Sprintf("conn_%d", id),
err: fmt.Errorf("error %d", id),
}
}(i)
}
wg.Wait()
close(errChan)
// Wait for consumer to finish
select {
case <-consumerDone:
// Good - consumer exited
case <-time.After(1 * time.Second):
t.Error("Error channel consumer did not exit in time")
}
// ASSERT: No goroutines should be blocked
}
// TestRefreshConnectionState_ExemplarMapNilPlugin tests handling of empty plugin names
func TestRefreshConnectionState_ExemplarMapNilPlugin(t *testing.T) {
// ARRANGE
state := &refreshConnectionState{
exemplarSchemaMap: make(map[string]string),
exemplarSchemaMapMut: sync.Mutex{},
}
// ACT: Try to add entry with empty plugin name
state.exemplarSchemaMapMut.Lock()
state.exemplarSchemaMap[""] = "some_connection"
state.exemplarSchemaMapMut.Unlock()
// ASSERT: Map should accept empty string as key (Go maps allow this)
state.exemplarSchemaMapMut.Lock()
defer state.exemplarSchemaMapMut.Unlock()
if _, ok := state.exemplarSchemaMap[""]; !ok {
t.Error("Expected empty string key to be in map")
}
}
// TestConnectionError tests the connectionError struct
func TestConnectionError(t *testing.T) {
// ARRANGE
testErr := errors.New("test error")
connErr := &connectionError{
name: "test_connection",
err: testErr,
}
// ASSERT
if connErr.name != "test_connection" {
t.Errorf("Expected name 'test_connection', got '%s'", connErr.name)
}
if connErr.err != testErr {
t.Error("Error not preserved")
}
}

View File

@@ -0,0 +1,168 @@
package db_client
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/turbot/steampipe/v2/pkg/db/db_common"
)
// TestDbClient_SessionRegistration verifies session registration in sessions map
func TestDbClient_SessionRegistration(t *testing.T) {
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
// Simulate session registration
backendPid := uint32(12345)
session := db_common.NewDBSession(backendPid)
client.sessionsMutex.Lock()
client.sessions[backendPid] = session
client.sessionsMutex.Unlock()
// Verify session is registered
client.sessionsMutex.Lock()
registeredSession, found := client.sessions[backendPid]
client.sessionsMutex.Unlock()
assert.True(t, found, "Session should be registered")
assert.Equal(t, backendPid, registeredSession.BackendPid, "Backend PID should match")
}
// TestDbClient_SessionUnregistration verifies session cleanup via BeforeClose
func TestDbClient_SessionUnregistration(t *testing.T) {
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
// Add sessions
backendPid1 := uint32(100)
backendPid2 := uint32(200)
client.sessionsMutex.Lock()
client.sessions[backendPid1] = db_common.NewDBSession(backendPid1)
client.sessions[backendPid2] = db_common.NewDBSession(backendPid2)
client.sessionsMutex.Unlock()
assert.Len(t, client.sessions, 2, "Should have 2 sessions")
// Simulate BeforeClose callback for one session
client.sessionsMutex.Lock()
delete(client.sessions, backendPid1)
client.sessionsMutex.Unlock()
// Verify only one session remains
client.sessionsMutex.Lock()
_, found1 := client.sessions[backendPid1]
_, found2 := client.sessions[backendPid2]
client.sessionsMutex.Unlock()
assert.False(t, found1, "First session should be removed")
assert.True(t, found2, "Second session should still exist")
assert.Len(t, client.sessions, 1, "Should have 1 session remaining")
}
// TestDbClient_ConcurrentSessionRegistration tests concurrent session additions
func TestDbClient_ConcurrentSessionRegistration(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrent test in short mode")
}
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
var wg sync.WaitGroup
numGoroutines := 100
// Concurrently add sessions
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id uint32) {
defer wg.Done()
backendPid := id
session := db_common.NewDBSession(backendPid)
client.sessionsMutex.Lock()
client.sessions[backendPid] = session
client.sessionsMutex.Unlock()
}(uint32(i))
}
wg.Wait()
// Verify all sessions were added
assert.Len(t, client.sessions, numGoroutines, "All sessions should be registered")
}
// TestDbClient_SessionMapGrowthUnbounded tests for potential memory leaks
// This verifies that sessions don't accumulate indefinitely
func TestDbClient_SessionMapGrowthUnbounded(t *testing.T) {
if testing.Short() {
t.Skip("Skipping large dataset test in short mode")
}
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
// Simulate many connections
numSessions := 10000
for i := 0; i < numSessions; i++ {
backendPid := uint32(i)
session := db_common.NewDBSession(backendPid)
client.sessionsMutex.Lock()
client.sessions[backendPid] = session
client.sessionsMutex.Unlock()
}
assert.Len(t, client.sessions, numSessions, "Should have all sessions")
// Simulate cleanup (BeforeClose callbacks)
for i := 0; i < numSessions; i++ {
backendPid := uint32(i)
client.sessionsMutex.Lock()
delete(client.sessions, backendPid)
client.sessionsMutex.Unlock()
}
// Verify all sessions are cleaned up
assert.Len(t, client.sessions, 0, "All sessions should be cleaned up")
}
// TestDbClient_SearchPathUpdates verifies session search path management
func TestDbClient_SearchPathUpdates(t *testing.T) {
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
customSearchPath: []string{"schema1", "schema2"},
}
// Add a session
backendPid := uint32(12345)
session := db_common.NewDBSession(backendPid)
client.sessionsMutex.Lock()
client.sessions[backendPid] = session
client.sessionsMutex.Unlock()
// Verify custom search path is set
assert.NotNil(t, client.customSearchPath, "Custom search path should be set")
assert.Len(t, client.customSearchPath, 2, "Should have 2 schemas in search path")
}
// TestDbClient_SessionConnectionNilSafety verifies handling of nil connections
func TestDbClient_SessionConnectionNilSafety(t *testing.T) {
session := db_common.NewDBSession(12345)
// Session is created with nil connection initially
assert.Nil(t, session.Connection, "New session should have nil connection initially")
}

View File

@@ -1,12 +1,15 @@
package db_client
import (
"context"
"os"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/turbot/steampipe/v2/pkg/db/db_common"
)
// TestSessionMapCleanupImplemented verifies that the session map memory leak is fixed
@@ -48,3 +51,220 @@ func TestSessionMapCleanupImplemented(t *testing.T) {
assert.True(t, hasCleanupComment,
"Comment should document automatic cleanup mechanism")
}
// TestDbClient_Close_Idempotent verifies that calling Close() multiple times does not cause issues
// Reference: Similar to bug #4712 (Result.Close() idempotency)
//
// Close() should be safe to call multiple times without panicking or causing errors.
func TestDbClient_Close_Idempotent(t *testing.T) {
ctx := context.Background()
// Create a minimal client (without real connection)
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
// First close
err := client.Close(ctx)
assert.NoError(t, err, "First Close() should not return error")
// Second close - should not panic
err = client.Close(ctx)
assert.NoError(t, err, "Second Close() should not return error")
// Third close - should still not panic
err = client.Close(ctx)
assert.NoError(t, err, "Third Close() should not return error")
// Verify sessions map is nil after close
assert.Nil(t, client.sessions, "Sessions map should be nil after Close()")
}
// TestDbClient_ConcurrentSessionAccess tests concurrent access to the sessions map
// This test should be run with -race flag to detect data races.
//
// The sessions map is protected by sessionsMutex, but we want to verify
// that all access paths properly use the mutex.
func TestDbClient_ConcurrentSessionAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrent access test in short mode")
}
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
var wg sync.WaitGroup
numGoroutines := 50
numOperations := 100
// Track errors in a thread-safe way
errors := make(chan error, numGoroutines*numOperations)
// Simulate concurrent session additions
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id uint32) {
defer wg.Done()
for j := 0; j < numOperations; j++ {
// Add session
client.sessionsMutex.Lock()
backendPid := id*1000 + uint32(j)
client.sessions[backendPid] = db_common.NewDBSession(backendPid)
client.sessionsMutex.Unlock()
// Read session
client.sessionsMutex.Lock()
_ = client.sessions[backendPid]
client.sessionsMutex.Unlock()
// Delete session (simulating BeforeClose callback)
client.sessionsMutex.Lock()
delete(client.sessions, backendPid)
client.sessionsMutex.Unlock()
}
}(uint32(i))
}
wg.Wait()
close(errors)
// Check for errors
for err := range errors {
t.Error(err)
}
}
// TestDbClient_Close_ClearsSessionsMap verifies that Close() properly clears the sessions map
func TestDbClient_Close_ClearsSessionsMap(t *testing.T) {
ctx := context.Background()
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
// Add some sessions
client.sessions[1] = db_common.NewDBSession(1)
client.sessions[2] = db_common.NewDBSession(2)
client.sessions[3] = db_common.NewDBSession(3)
assert.Len(t, client.sessions, 3, "Should have 3 sessions before Close()")
// Close the client
err := client.Close(ctx)
assert.NoError(t, err)
// Sessions should be nil after close
assert.Nil(t, client.sessions, "Sessions map should be nil after Close()")
}
// TestDbClient_SessionsMutexProtectsMap verifies that sessionsMutex protects all map operations
func TestDbClient_SessionsMutexProtectsMap(t *testing.T) {
// This is a structural test to verify the sessions map is never accessed without the mutex
content, err := os.ReadFile("db_client_session.go")
require.NoError(t, err, "should be able to read db_client_session.go")
sourceCode := string(content)
// Count occurrences of mutex locks
mutexLocks := strings.Count(sourceCode, "c.sessionsMutex.Lock()")
// This is a heuristic check - in practice, we'd need more sophisticated analysis
// But it serves as a reminder to use the mutex
assert.True(t, mutexLocks > 0,
"sessionsMutex.Lock() should be used when accessing sessions map")
}
// TestDbClient_SessionMapDocumentation verifies that session lifecycle is documented
func TestDbClient_SessionMapDocumentation(t *testing.T) {
content, err := os.ReadFile("db_client.go")
require.NoError(t, err)
sourceCode := string(content)
// Verify documentation mentions the lifecycle
assert.Contains(t, sourceCode, "Session lifecycle:",
"Sessions map should have lifecycle documentation")
assert.Contains(t, sourceCode, "issue #3737",
"Should reference the memory leak issue")
}
// TestDbClient_ClosePools_NilPoolsHandling verifies closePools handles nil pools
func TestDbClient_ClosePools_NilPoolsHandling(t *testing.T) {
client := &DbClient{
sessions: make(map[uint32]*db_common.DatabaseSession),
sessionsMutex: &sync.Mutex{},
}
// Should not panic with nil pools
assert.NotPanics(t, func() {
client.closePools()
}, "closePools should handle nil pools gracefully")
}
// TestDbClient_SessionsMapInitialized verifies sessions map is initialized in NewDbClient
func TestDbClient_SessionsMapInitialized(t *testing.T) {
// Verify the initialization happens in NewDbClient
content, err := os.ReadFile("db_client.go")
require.NoError(t, err)
sourceCode := string(content)
// Verify sessions map is initialized
assert.Contains(t, sourceCode, "sessions: make(map[uint32]*db_common.DatabaseSession)",
"sessions map should be initialized in NewDbClient")
// Verify mutex is initialized
assert.Contains(t, sourceCode, "sessionsMutex: &sync.Mutex{}",
"sessionsMutex should be initialized in NewDbClient")
}
// TestDbClient_DeferredCleanupInNewDbClient verifies error cleanup in NewDbClient
func TestDbClient_DeferredCleanupInNewDbClient(t *testing.T) {
content, err := os.ReadFile("db_client.go")
require.NoError(t, err)
sourceCode := string(content)
// Verify there's a defer that handles cleanup on error
assert.Contains(t, sourceCode, "defer func() {",
"NewDbClient should have deferred cleanup")
assert.Contains(t, sourceCode, "client.Close(ctx)",
"Deferred cleanup should close the client on error")
}
// TestDbClient_ParallelSessionInitLock verifies parallelSessionInitLock initialization
func TestDbClient_ParallelSessionInitLock(t *testing.T) {
content, err := os.ReadFile("db_client.go")
require.NoError(t, err)
sourceCode := string(content)
// Verify parallelSessionInitLock is initialized
assert.Contains(t, sourceCode, "parallelSessionInitLock:",
"parallelSessionInitLock should be initialized")
// Should use semaphore
assert.Contains(t, sourceCode, "semaphore.NewWeighted",
"parallelSessionInitLock should use weighted semaphore")
}
// TestDbClient_BeforeCloseCallbackNilSafety tests the BeforeClose callback with nil connection
func TestDbClient_BeforeCloseCallbackNilSafety(t *testing.T) {
content, err := os.ReadFile("db_client_connect.go")
require.NoError(t, err)
sourceCode := string(content)
// Verify nil checks in BeforeClose callback
assert.Contains(t, sourceCode, "if conn != nil",
"BeforeClose should check if conn is nil")
assert.Contains(t, sourceCode, "conn.PgConn() != nil",
"BeforeClose should check if PgConn() is nil")
}

View File

@@ -1,6 +1,8 @@
package db_local
import "testing"
import (
"testing"
)
func TestIsValidDatabaseName(t *testing.T) {
tests := map[string]bool{
@@ -26,3 +28,4 @@ func TestIsValidDatabaseName_EmptyString(t *testing.T) {
t.Errorf("Expected false for empty string, got %v", result)
}
}

View File

@@ -0,0 +1,271 @@
package interactive
import (
"testing"
"github.com/c-bata/go-prompt"
)
// TestNewAutocompleteSuggestions tests the creation of autocomplete suggestions
func TestNewAutocompleteSuggestions(t *testing.T) {
s := newAutocompleteSuggestions()
if s == nil {
t.Fatal("newAutocompleteSuggestions returned nil")
}
if s.tablesBySchema == nil {
t.Error("tablesBySchema map is nil")
}
if s.queriesByMod == nil {
t.Error("queriesByMod map is nil")
}
// Note: slices are not initialized (nil is valid for slices in Go)
// We just verify the struct itself is created
}
// TestAutocompleteSuggestionsSort tests the sorting of suggestions
func TestAutocompleteSuggestionsSort(t *testing.T) {
s := newAutocompleteSuggestions()
// Add unsorted suggestions
s.schemas = []prompt.Suggest{
{Text: "zebra", Description: "Schema"},
{Text: "apple", Description: "Schema"},
{Text: "mango", Description: "Schema"},
}
s.unqualifiedTables = []prompt.Suggest{
{Text: "users", Description: "Table"},
{Text: "accounts", Description: "Table"},
{Text: "posts", Description: "Table"},
}
s.tablesBySchema["test"] = []prompt.Suggest{
{Text: "z_table", Description: "Table"},
{Text: "a_table", Description: "Table"},
}
// Sort
s.sort()
// Verify schemas are sorted
if len(s.schemas) > 1 {
for i := 1; i < len(s.schemas); i++ {
if s.schemas[i-1].Text > s.schemas[i].Text {
t.Errorf("schemas not sorted: %s > %s", s.schemas[i-1].Text, s.schemas[i].Text)
}
}
}
// Verify tables are sorted
if len(s.unqualifiedTables) > 1 {
for i := 1; i < len(s.unqualifiedTables); i++ {
if s.unqualifiedTables[i-1].Text > s.unqualifiedTables[i].Text {
t.Errorf("unqualifiedTables not sorted: %s > %s", s.unqualifiedTables[i-1].Text, s.unqualifiedTables[i].Text)
}
}
}
// Verify tablesBySchema are sorted
tables := s.tablesBySchema["test"]
if len(tables) > 1 {
for i := 1; i < len(tables); i++ {
if tables[i-1].Text > tables[i].Text {
t.Errorf("tablesBySchema not sorted: %s > %s", tables[i-1].Text, tables[i].Text)
}
}
}
}
// TestAutocompleteSuggestionsEmptySort tests sorting with empty suggestions
func TestAutocompleteSuggestionsEmptySort(t *testing.T) {
s := newAutocompleteSuggestions()
// Should not panic with empty suggestions
defer func() {
if r := recover(); r != nil {
t.Errorf("sort() panicked with empty suggestions: %v", r)
}
}()
s.sort()
}
// TestAutocompleteSuggestionsSortWithDuplicates tests sorting with duplicate entries
func TestAutocompleteSuggestionsSortWithDuplicates(t *testing.T) {
s := newAutocompleteSuggestions()
// Add duplicate suggestions
s.schemas = []prompt.Suggest{
{Text: "apple", Description: "Schema"},
{Text: "apple", Description: "Schema"},
{Text: "banana", Description: "Schema"},
}
// Should not panic with duplicates
defer func() {
if r := recover(); r != nil {
t.Errorf("sort() panicked with duplicates: %v", r)
}
}()
s.sort()
// Verify duplicates are preserved (not removed)
if len(s.schemas) != 3 {
t.Errorf("sort() removed duplicates, got %d entries, want 3", len(s.schemas))
}
}
// TestAutocompleteSuggestionsWithUnicode tests suggestions with unicode characters
func TestAutocompleteSuggestionsWithUnicode(t *testing.T) {
s := newAutocompleteSuggestions()
s.schemas = []prompt.Suggest{
{Text: "用户", Description: "Schema"},
{Text: "数据库", Description: "Schema"},
{Text: "🔥", Description: "Schema"},
}
defer func() {
if r := recover(); r != nil {
t.Errorf("sort() panicked with unicode: %v", r)
}
}()
s.sort()
// Just verify it doesn't crash
if len(s.schemas) != 3 {
t.Errorf("sort() lost unicode entries, got %d entries, want 3", len(s.schemas))
}
}
// TestAutocompleteSuggestionsLargeDataset tests with a large number of suggestions
func TestAutocompleteSuggestionsLargeDataset(t *testing.T) {
if testing.Short() {
t.Skip("Skipping large dataset test in short mode")
}
s := newAutocompleteSuggestions()
// Add 10,000 schemas
for i := 0; i < 10000; i++ {
s.schemas = append(s.schemas, prompt.Suggest{
Text: "schema_" + string(rune(i)),
Description: "Schema",
})
}
// Add 10,000 tables
for i := 0; i < 10000; i++ {
s.unqualifiedTables = append(s.unqualifiedTables, prompt.Suggest{
Text: "table_" + string(rune(i)),
Description: "Table",
})
}
// Should not hang or crash
defer func() {
if r := recover(); r != nil {
t.Errorf("sort() panicked with large dataset: %v", r)
}
}()
s.sort()
}
// TestAutocompleteSuggestionsMemoryUsage tests memory usage with many suggestions
func TestAutocompleteSuggestionsMemoryUsage(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory usage test in short mode")
}
// Create 100 suggestion sets
suggestions := make([]*autoCompleteSuggestions, 100)
for i := 0; i < 100; i++ {
s := newAutocompleteSuggestions()
// Add many suggestions
for j := 0; j < 1000; j++ {
s.schemas = append(s.schemas, prompt.Suggest{
Text: "schema",
Description: "Schema",
})
}
suggestions[i] = s
}
// If we get here without OOM, the test passes
// Clear suggestions to allow GC
suggestions = nil
}
// TestAutocompleteSuggestionsEdgeCases tests various edge cases
func TestAutocompleteSuggestionsEdgeCases(t *testing.T) {
tests := []struct {
name string
test func(*testing.T)
}{
{
name: "empty text suggestion",
test: func(t *testing.T) {
s := newAutocompleteSuggestions()
s.schemas = []prompt.Suggest{
{Text: "", Description: "Empty"},
}
s.sort() // Should not panic
},
},
{
name: "very long text suggestion",
test: func(t *testing.T) {
s := newAutocompleteSuggestions()
longText := make([]byte, 10000)
for i := range longText {
longText[i] = 'a'
}
s.schemas = []prompt.Suggest{
{Text: string(longText), Description: "Long"},
}
s.sort() // Should not panic
},
},
{
name: "null bytes in text",
test: func(t *testing.T) {
s := newAutocompleteSuggestions()
s.schemas = []prompt.Suggest{
{Text: "schema\x00name", Description: "Null"},
}
s.sort() // Should not panic
},
},
{
name: "special characters in text",
test: func(t *testing.T) {
s := newAutocompleteSuggestions()
s.schemas = []prompt.Suggest{
{Text: "schema!@#$%^&*()", Description: "Special"},
}
s.sort() // Should not panic
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Test panicked: %v", r)
}
}()
tt.test(t)
})
}
}

View File

@@ -0,0 +1,399 @@
package interactive
import (
"context"
"testing"
"time"
)
// TestCreatePromptContext tests prompt context creation
func TestCreatePromptContext(t *testing.T) {
c := &InteractiveClient{}
parentCtx := context.Background()
ctx := c.createPromptContext(parentCtx)
if ctx == nil {
t.Fatal("createPromptContext returned nil context")
}
if c.cancelPrompt == nil {
t.Fatal("createPromptContext didn't set cancelPrompt")
}
// Verify context can be cancelled
c.cancelPrompt()
select {
case <-ctx.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("Context was not cancelled after calling cancelPrompt")
}
}
// TestCreatePromptContextReplacesOld tests that creating a new context cancels the old one
func TestCreatePromptContextReplacesOld(t *testing.T) {
c := &InteractiveClient{}
parentCtx := context.Background()
// Create first context
ctx1 := c.createPromptContext(parentCtx)
cancel1 := c.cancelPrompt
// Create second context (should cancel first)
ctx2 := c.createPromptContext(parentCtx)
// First context should be cancelled
select {
case <-ctx1.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("First context was not cancelled when creating second context")
}
// Second context should still be active
select {
case <-ctx2.Done():
t.Error("Second context should not be cancelled yet")
case <-time.After(10 * time.Millisecond):
// Expected
}
// First cancel function should be different from second
if &cancel1 == &c.cancelPrompt {
t.Error("cancelPrompt was not replaced")
}
}
// TestCreateQueryContext tests query context creation
func TestCreateQueryContext(t *testing.T) {
c := &InteractiveClient{}
parentCtx := context.Background()
ctx := c.createQueryContext(parentCtx)
if ctx == nil {
t.Fatal("createQueryContext returned nil context")
}
if c.cancelActiveQuery == nil {
t.Fatal("createQueryContext didn't set cancelActiveQuery")
}
// Verify context can be cancelled
c.cancelActiveQuery()
select {
case <-ctx.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("Context was not cancelled after calling cancelActiveQuery")
}
}
// TestCreateQueryContextDoesNotCancelOld tests that creating a new query context doesn't cancel the old one
func TestCreateQueryContextDoesNotCancelOld(t *testing.T) {
c := &InteractiveClient{}
parentCtx := context.Background()
// Create first context
ctx1 := c.createQueryContext(parentCtx)
cancel1 := c.cancelActiveQuery
// Create second context (should NOT cancel first, just replace the reference)
ctx2 := c.createQueryContext(parentCtx)
// First context should still be active (not automatically cancelled)
select {
case <-ctx1.Done():
t.Error("First context was cancelled when creating second context (should not auto-cancel)")
case <-time.After(10 * time.Millisecond):
// Expected - first context is NOT cancelled
}
// Cancel using the first cancel function
cancel1()
// Now first context should be cancelled
select {
case <-ctx1.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("First context was not cancelled after calling its cancel function")
}
// Second context should still be active
select {
case <-ctx2.Done():
t.Error("Second context should not be cancelled yet")
case <-time.After(10 * time.Millisecond):
// Expected
}
}
// TestCancelActiveQueryIfAnyIdempotent tests that cancellation is idempotent
func TestCancelActiveQueryIfAnyIdempotent(t *testing.T) {
callCount := 0
cancelFunc := func() {
callCount++
}
c := &InteractiveClient{
cancelActiveQuery: cancelFunc,
}
// Call multiple times
for i := 0; i < 5; i++ {
c.cancelActiveQueryIfAny()
}
// Should only be called once
if callCount != 1 {
t.Errorf("cancelActiveQueryIfAny() called cancel function %d times, want 1 (should be idempotent)", callCount)
}
// Should be nil after first call
if c.cancelActiveQuery != nil {
t.Error("cancelActiveQueryIfAny() didn't set cancelActiveQuery to nil")
}
}
// TestCancelActiveQueryIfAnyNil tests behavior with nil cancel function
func TestCancelActiveQueryIfAnyNil(t *testing.T) {
c := &InteractiveClient{
cancelActiveQuery: nil,
}
defer func() {
if r := recover(); r != nil {
t.Errorf("cancelActiveQueryIfAny() panicked with nil cancel function: %v", r)
}
}()
// Should not panic
c.cancelActiveQueryIfAny()
// Should remain nil
if c.cancelActiveQuery != nil {
t.Error("cancelActiveQueryIfAny() set cancelActiveQuery when it was nil")
}
}
// TestClosePrompt tests the ClosePrompt method
func TestClosePrompt(t *testing.T) {
tests := []struct {
name string
afterClose AfterPromptCloseAction
}{
{
name: "close with exit",
afterClose: AfterPromptCloseExit,
},
{
name: "close with restart",
afterClose: AfterPromptCloseRestart,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cancelled := false
c := &InteractiveClient{
cancelPrompt: func() {
cancelled = true
},
}
c.ClosePrompt(tt.afterClose)
if !cancelled {
t.Error("ClosePrompt didn't call cancelPrompt")
}
if c.afterClose != tt.afterClose {
t.Errorf("ClosePrompt set afterClose to %v, want %v", c.afterClose, tt.afterClose)
}
})
}
}
// TestContextCancellationPropagation tests that parent context cancellation propagates
func TestContextCancellationPropagation(t *testing.T) {
c := &InteractiveClient{}
parentCtx, parentCancel := context.WithCancel(context.Background())
// Create child context
childCtx := c.createPromptContext(parentCtx)
// Cancel parent
parentCancel()
// Child should be cancelled too
select {
case <-childCtx.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("Child context was not cancelled when parent was cancelled")
}
}
// TestContextCancellationTimeout tests context with timeout
func TestContextCancellationTimeout(t *testing.T) {
c := &InteractiveClient{}
parentCtx, parentCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer parentCancel()
// Create child context
childCtx := c.createPromptContext(parentCtx)
// Wait for timeout
select {
case <-childCtx.Done():
// Expected after ~50ms
if childCtx.Err() != context.DeadlineExceeded && childCtx.Err() != context.Canceled {
t.Errorf("Expected DeadlineExceeded or Canceled error, got %v", childCtx.Err())
}
case <-time.After(200 * time.Millisecond):
t.Error("Context did not timeout as expected")
}
}
// TestRapidContextCreation tests rapid context creation and cancellation
func TestRapidContextCreation(t *testing.T) {
c := &InteractiveClient{}
parentCtx := context.Background()
// Rapidly create and cancel contexts
for i := 0; i < 1000; i++ {
ctx := c.createPromptContext(parentCtx)
// Immediately cancel
if c.cancelPrompt != nil {
c.cancelPrompt()
}
// Verify cancellation
select {
case <-ctx.Done():
// Expected
case <-time.After(10 * time.Millisecond):
t.Errorf("Context %d was not cancelled", i)
return
}
}
}
// TestCancelAfterContextAlreadyCancelled tests cancelling after context is already cancelled
func TestCancelAfterContextAlreadyCancelled(t *testing.T) {
c := &InteractiveClient{}
parentCtx, parentCancel := context.WithCancel(context.Background())
// Create child context
ctx := c.createQueryContext(parentCtx)
// Cancel parent first
parentCancel()
// Wait for child to be cancelled
<-ctx.Done()
// Now try to cancel via cancelActiveQueryIfAny
// Should not panic even though context is already cancelled
defer func() {
if r := recover(); r != nil {
t.Errorf("cancelActiveQueryIfAny panicked when context already cancelled: %v", r)
}
}()
c.cancelActiveQueryIfAny()
}
// TestQueryContextLeakage tests for context leakage
func TestQueryContextLeakage(t *testing.T) {
if testing.Short() {
t.Skip("Skipping leak test in short mode")
}
c := &InteractiveClient{}
parentCtx := context.Background()
// Create many query contexts
for i := 0; i < 10000; i++ {
ctx := c.createQueryContext(parentCtx)
// Cancel immediately
if c.cancelActiveQuery != nil {
c.cancelActiveQuery()
}
// Verify context is cancelled
select {
case <-ctx.Done():
// Good
case <-time.After(1 * time.Millisecond):
t.Errorf("Context %d not cancelled", i)
return
}
}
// If we get here without hanging or OOM, test passes
}
// TestCancelFuncReplacement tests that cancel functions are properly replaced
func TestCancelFuncReplacement(t *testing.T) {
c := &InteractiveClient{}
parentCtx := context.Background()
// Track which cancel function was called
firstCalled := false
secondCalled := false
// Create first query context
ctx1 := c.createQueryContext(parentCtx)
firstCancel := c.cancelActiveQuery
// Wrap the first cancel to track calls
c.cancelActiveQuery = func() {
firstCalled = true
firstCancel()
}
// Create second query context (replaces cancelActiveQuery)
ctx2 := c.createQueryContext(parentCtx)
secondCancel := c.cancelActiveQuery
// Wrap the second cancel to track calls
c.cancelActiveQuery = func() {
secondCalled = true
secondCancel()
}
// Call cancelActiveQueryIfAny
c.cancelActiveQueryIfAny()
// Only the second cancel should be called
if firstCalled {
t.Error("First cancel function was called (should have been replaced)")
}
if !secondCalled {
t.Error("Second cancel function was not called")
}
// Second context should be cancelled
select {
case <-ctx2.Done():
// Expected
case <-time.After(100 * time.Millisecond):
t.Error("Second context was not cancelled")
}
// First context is NOT automatically cancelled (different from prompt context)
select {
case <-ctx1.Done():
// This might happen if parent was cancelled, but shouldn't happen from our cancel
case <-time.After(10 * time.Millisecond):
// Expected - first context remains active
}
}

View File

@@ -0,0 +1,239 @@
package interactive
import (
"strings"
"testing"
"github.com/alecthomas/chroma/formatters"
"github.com/alecthomas/chroma/lexers"
"github.com/alecthomas/chroma/styles"
"github.com/c-bata/go-prompt"
)
// TestNewHighlighter tests highlighter creation
func TestNewHighlighter(t *testing.T) {
lexer := lexers.Get("sql")
formatter := formatters.Get("terminal256")
style := styles.Native
h := newHighlighter(lexer, formatter, style)
if h == nil {
t.Fatal("newHighlighter returned nil")
}
if h.lexer == nil {
t.Error("highlighter lexer is nil")
}
if h.formatter == nil {
t.Error("highlighter formatter is nil")
}
if h.style == nil {
t.Error("highlighter style is nil")
}
}
// TestHighlighterHighlight tests the Highlight function
func TestHighlighterHighlight(t *testing.T) {
h := newHighlighter(
lexers.Get("sql"),
formatters.Get("terminal256"),
styles.Native,
)
tests := []struct {
name string
input string
wantErr bool
}{
{
name: "simple select",
input: "SELECT * FROM users",
wantErr: false,
},
{
name: "empty string",
input: "",
wantErr: false,
},
{
name: "multiline query",
input: "SELECT *\nFROM users\nWHERE id = 1",
wantErr: false,
},
{
name: "unicode characters",
input: "SELECT '你好世界'",
wantErr: false,
},
{
name: "emoji",
input: "SELECT '🔥💥✨'",
wantErr: false,
},
{
name: "null bytes",
input: "SELECT '\x00'",
wantErr: false,
},
{
name: "control characters",
input: "SELECT '\n\r\t'",
wantErr: false,
},
{
name: "very long query",
input: "SELECT " + strings.Repeat("a, ", 1000) + "* FROM users",
wantErr: false,
},
{
name: "SQL injection attempt",
input: "'; DROP TABLE users; --",
wantErr: false,
},
{
name: "malformed SQL",
input: "SELECT FROM WHERE",
wantErr: false,
},
{
name: "special characters",
input: "SELECT '\\', '/', '\"', '`'",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
doc := prompt.Document{
Text: tt.input,
}
result, err := h.Highlight(doc)
if (err != nil) != tt.wantErr {
t.Errorf("Highlight() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && result == nil {
t.Error("Highlight() returned nil result without error")
}
// Verify result is not empty for non-empty input
if !tt.wantErr && tt.input != "" && len(result) == 0 {
t.Error("Highlight() returned empty result for non-empty input")
}
})
}
}
// TestGetHighlighter tests the getHighlighter function
func TestGetHighlighter(t *testing.T) {
tests := []struct {
name string
theme string
}{
{
name: "default theme",
theme: "",
},
{
name: "dark theme",
theme: "dark",
},
{
name: "light theme",
theme: "light",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
h := getHighlighter(tt.theme)
if h == nil {
t.Fatal("getHighlighter returned nil")
}
if h.lexer == nil {
t.Error("highlighter lexer is nil")
}
if h.formatter == nil {
t.Error("highlighter formatter is nil")
}
})
}
}
// TestHighlighterConcurrency tests concurrent highlighting
func TestHighlighterConcurrency(t *testing.T) {
h := newHighlighter(
lexers.Get("sql"),
formatters.Get("terminal256"),
styles.Native,
)
queries := []string{
"SELECT * FROM users",
"SELECT id FROM posts",
"SELECT name FROM companies",
}
done := make(chan bool)
for i := 0; i < 10; i++ {
go func(idx int) {
defer func() {
if r := recover(); r != nil {
t.Errorf("Concurrent Highlight panicked: %v", r)
}
done <- true
}()
doc := prompt.Document{
Text: queries[idx%len(queries)],
}
_, err := h.Highlight(doc)
if err != nil {
t.Errorf("Concurrent Highlight error: %v", err)
}
}(i)
}
// Wait for all goroutines
for i := 0; i < 10; i++ {
<-done
}
}
// TestHighlighterMemoryLeak tests for memory leaks with repeated highlighting
func TestHighlighterMemoryLeak(t *testing.T) {
if testing.Short() {
t.Skip("Skipping memory leak test in short mode")
}
h := newHighlighter(
lexers.Get("sql"),
formatters.Get("terminal256"),
styles.Native,
)
// Highlight the same query many times to check for memory leaks
doc := prompt.Document{
Text: "SELECT * FROM users WHERE id = 1",
}
for i := 0; i < 10000; i++ {
_, err := h.Highlight(doc)
if err != nil {
t.Fatalf("Highlight failed at iteration %d: %v", i, err)
}
}
// If we get here without OOM, the test passes
}

View File

@@ -1,9 +1,12 @@
package interactive
import (
"strings"
"testing"
"github.com/c-bata/go-prompt"
pconstants "github.com/turbot/pipe-fittings/v2/constants"
"github.com/turbot/steampipe/v2/pkg/cmdconfig"
)
// TestGetTableAndConnectionSuggestions_ReturnsEmptySliceNotNil tests that
@@ -68,3 +71,464 @@ func TestGetTableAndConnectionSuggestions_ReturnsEmptySliceNotNil(t *testing.T)
})
}
}
// TestShouldExecute tests the shouldExecute logic for query execution
func TestShouldExecute(t *testing.T) {
// Save and restore viper settings
originalMultiline := cmdconfig.Viper().GetBool(pconstants.ArgMultiLine)
defer func() {
cmdconfig.Viper().Set(pconstants.ArgMultiLine, originalMultiline)
}()
tests := []struct {
name string
query string
multiline bool
shouldExec bool
description string
}{
{
name: "simple query without semicolon in non-multiline",
query: "SELECT * FROM users",
multiline: false,
shouldExec: true,
description: "In non-multiline mode, execute without semicolon",
},
{
name: "simple query with semicolon in non-multiline",
query: "SELECT * FROM users;",
multiline: false,
shouldExec: true,
description: "In non-multiline mode, execute with semicolon",
},
{
name: "simple query without semicolon in multiline",
query: "SELECT * FROM users",
multiline: true,
shouldExec: false,
description: "In multiline mode, don't execute without semicolon",
},
{
name: "simple query with semicolon in multiline",
query: "SELECT * FROM users;",
multiline: true,
shouldExec: true,
description: "In multiline mode, execute with semicolon",
},
{
name: "metaquery without semicolon in multiline",
query: ".help",
multiline: true,
shouldExec: true,
description: "Metaqueries execute without semicolon even in multiline",
},
{
name: "metaquery with semicolon in multiline",
query: ".help;",
multiline: true,
shouldExec: true,
description: "Metaqueries execute with semicolon in multiline",
},
{
name: "empty query",
query: "",
multiline: false,
shouldExec: true,
description: "Empty query executes in non-multiline",
},
{
name: "empty query in multiline",
query: "",
multiline: true,
shouldExec: false,
description: "Empty query doesn't execute in multiline",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &InteractiveClient{}
cmdconfig.Viper().Set(pconstants.ArgMultiLine, tt.multiline)
result := c.shouldExecute(tt.query)
if result != tt.shouldExec {
t.Errorf("shouldExecute(%q) in multiline=%v = %v, want %v\nReason: %s",
tt.query, tt.multiline, result, tt.shouldExec, tt.description)
}
})
}
}
// TestShouldExecuteEdgeCases tests edge cases for shouldExecute
func TestShouldExecuteEdgeCases(t *testing.T) {
originalMultiline := cmdconfig.Viper().GetBool(pconstants.ArgMultiLine)
defer func() {
cmdconfig.Viper().Set(pconstants.ArgMultiLine, originalMultiline)
}()
c := &InteractiveClient{}
cmdconfig.Viper().Set(pconstants.ArgMultiLine, true)
tests := []struct {
name string
query string
}{
{
name: "very long query with semicolon",
query: strings.Repeat("SELECT * FROM users WHERE id = 1 AND ", 100) + "1=1;",
},
{
name: "unicode characters with semicolon",
query: "SELECT '你好世界';",
},
{
name: "emoji with semicolon",
query: "SELECT '🔥💥';",
},
{
name: "null bytes",
query: "SELECT '\x00';",
},
{
name: "control characters",
query: "SELECT '\n\r\t';",
},
{
name: "SQL injection with semicolon",
query: "'; DROP TABLE users; --",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("shouldExecute(%q) panicked: %v", tt.query, r)
}
}()
_ = c.shouldExecute(tt.query)
})
}
}
// TestBreakMultilinePrompt tests the breakMultilinePrompt function
func TestBreakMultilinePrompt(t *testing.T) {
c := &InteractiveClient{
interactiveBuffer: []string{"SELECT *", "FROM users", "WHERE"},
}
c.breakMultilinePrompt(nil)
if len(c.interactiveBuffer) != 0 {
t.Errorf("breakMultilinePrompt() didn't clear buffer, got %d items, want 0", len(c.interactiveBuffer))
}
}
// TestBreakMultilinePromptEmpty tests breaking an already empty buffer
func TestBreakMultilinePromptEmpty(t *testing.T) {
c := &InteractiveClient{
interactiveBuffer: []string{},
}
defer func() {
if r := recover(); r != nil {
t.Errorf("breakMultilinePrompt() panicked on empty buffer: %v", r)
}
}()
c.breakMultilinePrompt(nil)
if len(c.interactiveBuffer) != 0 {
t.Errorf("breakMultilinePrompt() didn't maintain empty buffer, got %d items, want 0", len(c.interactiveBuffer))
}
}
// TestBreakMultilinePromptNil tests breaking with nil buffer
func TestBreakMultilinePromptNil(t *testing.T) {
c := &InteractiveClient{
interactiveBuffer: nil,
}
defer func() {
if r := recover(); r != nil {
t.Errorf("breakMultilinePrompt() panicked on nil buffer: %v", r)
}
}()
c.breakMultilinePrompt(nil)
if c.interactiveBuffer == nil {
t.Error("breakMultilinePrompt() didn't initialize nil buffer")
}
if len(c.interactiveBuffer) != 0 {
t.Errorf("breakMultilinePrompt() didn't create empty buffer, got %d items, want 0", len(c.interactiveBuffer))
}
}
// TestIsInitialised tests the isInitialised method
func TestIsInitialised(t *testing.T) {
tests := []struct {
name string
initialisationComplete bool
expected bool
}{
{
name: "initialized",
initialisationComplete: true,
expected: true,
},
{
name: "not initialized",
initialisationComplete: false,
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &InteractiveClient{
initialisationComplete: tt.initialisationComplete,
}
result := c.isInitialised()
if result != tt.expected {
t.Errorf("isInitialised() = %v, want %v", result, tt.expected)
}
})
}
}
// TestClientNil tests the client() method when initData is nil
func TestClientNil(t *testing.T) {
c := &InteractiveClient{
initData: nil,
}
client := c.client()
if client != nil {
t.Errorf("client() with nil initData should return nil, got %v", client)
}
}
// TestAfterPromptCloseAction tests the AfterPromptCloseAction enum
func TestAfterPromptCloseAction(t *testing.T) {
// Test that the enum values are distinct
if AfterPromptCloseExit == AfterPromptCloseRestart {
t.Error("AfterPromptCloseExit and AfterPromptCloseRestart should have different values")
}
// Test that they have the expected values
if AfterPromptCloseExit != 0 {
t.Errorf("AfterPromptCloseExit should be 0, got %d", AfterPromptCloseExit)
}
if AfterPromptCloseRestart != 1 {
t.Errorf("AfterPromptCloseRestart should be 1, got %d", AfterPromptCloseRestart)
}
}
// TestGetFirstWordSuggestionsEmptyWord tests getFirstWordSuggestions with empty input
func TestGetFirstWordSuggestionsEmptyWord(t *testing.T) {
c := &InteractiveClient{
suggestions: newAutocompleteSuggestions(),
}
defer func() {
if r := recover(); r != nil {
t.Errorf("getFirstWordSuggestions panicked on empty input: %v", r)
}
}()
suggestions := c.getFirstWordSuggestions("")
// Should return suggestions (select, with, metaqueries)
if len(suggestions) == 0 {
t.Error("getFirstWordSuggestions(\"\") should return suggestions")
}
}
// TestGetFirstWordSuggestionsQualifiedQuery tests qualified query suggestions
func TestGetFirstWordSuggestionsQualifiedQuery(t *testing.T) {
c := &InteractiveClient{
suggestions: newAutocompleteSuggestions(),
}
// Add mock data
c.suggestions.queriesByMod = map[string][]prompt.Suggest{
"mymod": {
{Text: "mymod.query1", Description: "Query"},
},
}
tests := []struct {
name string
input string
}{
{
name: "qualified with known mod",
input: "mymod.",
},
{
name: "qualified with unknown mod",
input: "unknownmod.",
},
{
name: "single word",
input: "select",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("getFirstWordSuggestions(%q) panicked: %v", tt.input, r)
}
}()
suggestions := c.getFirstWordSuggestions(tt.input)
if suggestions == nil {
t.Errorf("getFirstWordSuggestions(%q) returned nil", tt.input)
}
})
}
}
// TestGetTableAndConnectionSuggestionsEdgeCases tests edge cases
func TestGetTableAndConnectionSuggestionsEdgeCases(t *testing.T) {
c := &InteractiveClient{
suggestions: newAutocompleteSuggestions(),
}
// Add mock data
c.suggestions.schemas = []prompt.Suggest{
{Text: "public", Description: "Schema"},
}
c.suggestions.unqualifiedTables = []prompt.Suggest{
{Text: "users", Description: "Table"},
}
c.suggestions.tablesBySchema = map[string][]prompt.Suggest{
"public": {
{Text: "public.users", Description: "Table"},
},
}
tests := []struct {
name string
input string
}{
{
name: "unqualified",
input: "users",
},
{
name: "qualified with known schema",
input: "public.users",
},
{
name: "empty string",
input: "",
},
{
name: "just dot",
input: ".",
},
{
name: "unicode",
input: "用户.表",
},
{
name: "emoji",
input: "schema🔥.table",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("getTableAndConnectionSuggestions(%q) panicked: %v", tt.input, r)
}
}()
suggestions := c.getTableAndConnectionSuggestions(tt.input)
if suggestions == nil {
t.Errorf("getTableAndConnectionSuggestions(%q) returned nil", tt.input)
}
})
}
}
// TestCancelActiveQueryIfAny tests the cancellation logic
func TestCancelActiveQueryIfAny(t *testing.T) {
t.Run("no active query", func(t *testing.T) {
c := &InteractiveClient{
cancelActiveQuery: nil,
}
defer func() {
if r := recover(); r != nil {
t.Errorf("cancelActiveQueryIfAny() panicked with nil cancelFunc: %v", r)
}
}()
c.cancelActiveQueryIfAny()
if c.cancelActiveQuery != nil {
t.Error("cancelActiveQueryIfAny() set cancelActiveQuery when it was nil")
}
})
t.Run("with active query", func(t *testing.T) {
cancelled := false
cancelFunc := func() {
cancelled = true
}
c := &InteractiveClient{
cancelActiveQuery: cancelFunc,
}
c.cancelActiveQueryIfAny()
if !cancelled {
t.Error("cancelActiveQueryIfAny() didn't call the cancel function")
}
if c.cancelActiveQuery != nil {
t.Error("cancelActiveQueryIfAny() didn't set cancelActiveQuery to nil")
}
})
t.Run("multiple calls", func(t *testing.T) {
callCount := 0
cancelFunc := func() {
callCount++
}
c := &InteractiveClient{
cancelActiveQuery: cancelFunc,
}
// First call should cancel
c.cancelActiveQueryIfAny()
if callCount != 1 {
t.Errorf("First cancelActiveQueryIfAny() call count = %d, want 1", callCount)
}
// Second call should be a no-op
c.cancelActiveQueryIfAny()
if callCount != 1 {
t.Errorf("Second cancelActiveQueryIfAny() call count = %d, want 1 (should be idempotent)", callCount)
}
})
}

View File

@@ -0,0 +1,588 @@
package interactive
import (
"strings"
"testing"
)
// TestIsFirstWord tests the isFirstWord helper function
func TestIsFirstWord(t *testing.T) {
tests := []struct {
name string
input string
expected bool
}{
{
name: "single word",
input: "select",
expected: true,
},
{
name: "two words",
input: "select *",
expected: false,
},
{
name: "empty string",
input: "",
expected: true,
},
{
name: "word with trailing space",
input: "select ",
expected: false,
},
{
name: "multiple spaces",
input: "select from",
expected: false,
},
{
name: "unicode characters",
input: "選択",
expected: true,
},
{
name: "emoji",
input: "🔥",
expected: true,
},
{
name: "emoji with space",
input: "🔥 test",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isFirstWord(tt.input)
if result != tt.expected {
t.Errorf("isFirstWord(%q) = %v, want %v", tt.input, result, tt.expected)
}
})
}
}
// TestLastWord tests the lastWord helper function (only passing cases)
func TestLastWord(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "two words",
input: "select *",
expected: " *",
},
{
name: "multiple words",
input: "select * from users",
expected: " users",
},
{
name: "trailing space",
input: "select * from ",
expected: " ",
},
{
name: "unicode",
input: "select 你好",
expected: " 你好",
},
{
name: "emoji",
input: "select 🔥",
expected: " 🔥",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer func() {
if r := recover(); r != nil {
t.Errorf("lastWord(%q) panicked: %v", tt.input, r)
}
}()
result := lastWord(tt.input)
if result != tt.expected {
t.Errorf("lastWord(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
// TestLastIndexByteNot tests the lastIndexByteNot helper function
func TestLastIndexByteNot(t *testing.T) {
tests := []struct {
name string
input string
char byte
expected int
}{
{
name: "no matching char",
input: "hello",
char: ' ',
expected: 4,
},
{
name: "trailing spaces",
input: "hello ",
char: ' ',
expected: 4,
},
{
name: "all spaces",
input: " ",
char: ' ',
expected: -1,
},
{
name: "empty string",
input: "",
char: ' ',
expected: -1,
},
{
name: "single char not matching",
input: "a",
char: ' ',
expected: 0,
},
{
name: "single char matching",
input: " ",
char: ' ',
expected: -1,
},
{
name: "mixed spaces",
input: "hello world ",
char: ' ',
expected: 10,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := lastIndexByteNot(tt.input, tt.char)
if result != tt.expected {
t.Errorf("lastIndexByteNot(%q, %q) = %d, want %d", tt.input, tt.char, result, tt.expected)
}
})
}
}
// TestGetPreviousWord tests the getPreviousWord helper function
func TestGetPreviousWord(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple case",
input: "select * from ",
expected: "from",
},
{
name: "no previous word",
input: "select ",
expected: "",
},
{
name: "single word",
input: "select",
expected: "",
},
{
name: "multiple spaces",
input: "select * from ",
expected: "from",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "only spaces",
input: " ",
expected: "",
},
{
name: "unicode characters",
input: "select 你好 世界 ",
expected: "世界",
},
{
name: "emoji",
input: "select 🔥 💥 ",
expected: "💥",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getPreviousWord(tt.input)
if result != tt.expected {
t.Errorf("getPreviousWord(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
// TestGetTable tests the getTable helper function
func TestGetTable(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple select",
input: "select * from users",
expected: "users",
},
{
name: "qualified table",
input: "select * from public.users",
expected: "public.users",
},
{
name: "no from clause",
input: "select 1",
expected: "",
},
{
name: "from at end",
input: "select * from",
expected: "",
},
{
name: "from with trailing text",
input: "select * from users where",
expected: "users",
},
{
name: "double spaces",
input: "select * from users",
expected: "users",
},
{
name: "empty string",
input: "",
expected: "",
},
{
name: "case sensitive - lowercase from",
input: "SELECT * from users",
expected: "users",
},
{
name: "uppercase FROM",
input: "SELECT * FROM users",
expected: "",
},
{
name: "unicode table name",
input: "select * from 用户表",
expected: "用户表",
},
{
name: "emoji in table name",
input: "select * from users🔥",
expected: "users🔥",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getTable(tt.input)
if result != tt.expected {
t.Errorf("getTable(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
// TestIsEditingTable tests the isEditingTable helper function
func TestIsEditingTable(t *testing.T) {
tests := []struct {
name string
prevWord string
expected bool
}{
{
name: "from keyword",
prevWord: "from",
expected: true,
},
{
name: "not from keyword",
prevWord: "select",
expected: false,
},
{
name: "empty string",
prevWord: "",
expected: false,
},
{
name: "FROM uppercase",
prevWord: "FROM",
expected: false,
},
{
name: "whitespace",
prevWord: " from ",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isEditingTable(tt.prevWord)
if result != tt.expected {
t.Errorf("isEditingTable(%q) = %v, want %v", tt.prevWord, result, tt.expected)
}
})
}
}
// TestGetQueryInfo tests the getQueryInfo function (passing cases only)
func TestGetQueryInfo(t *testing.T) {
tests := []struct {
name string
input string
expectedTable string
expectedEditing bool
}{
{
name: "editing table after from",
input: "select * from ",
expectedTable: "",
expectedEditing: true,
},
{
name: "table specified",
input: "select * from users ",
expectedTable: "users",
expectedEditing: false,
},
{
name: "not at from clause",
input: "select * ",
expectedTable: "",
expectedEditing: false,
},
{
name: "empty query",
input: "",
expectedTable: "",
expectedEditing: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := getQueryInfo(tt.input)
if result.Table != tt.expectedTable {
t.Errorf("getQueryInfo(%q).Table = %q, want %q", tt.input, result.Table, tt.expectedTable)
}
if result.EditingTable != tt.expectedEditing {
t.Errorf("getQueryInfo(%q).EditingTable = %v, want %v", tt.input, result.EditingTable, tt.expectedEditing)
}
})
}
}
// TestCleanBufferForWSL tests the WSL-specific buffer cleaning
func TestCleanBufferForWSL(t *testing.T) {
tests := []struct {
name string
input string
expectedOutput string
expectedIgnore bool
}{
{
name: "normal text",
input: "hello",
expectedOutput: "hello",
expectedIgnore: false,
},
{
name: "empty string",
input: "",
expectedOutput: "",
expectedIgnore: false,
},
{
name: "escape sequence",
input: string([]byte{27, 65}), // ESC + 'A'
expectedOutput: "",
expectedIgnore: true,
},
{
name: "single escape",
input: string([]byte{27}),
expectedOutput: string([]byte{27}),
expectedIgnore: false,
},
{
name: "unicode",
input: "你好",
expectedOutput: "你好",
expectedIgnore: false,
},
{
name: "emoji",
input: "🔥",
expectedOutput: "🔥",
expectedIgnore: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output, ignore := cleanBufferForWSL(tt.input)
if output != tt.expectedOutput {
t.Errorf("cleanBufferForWSL(%q) output = %q, want %q", tt.input, output, tt.expectedOutput)
}
if ignore != tt.expectedIgnore {
t.Errorf("cleanBufferForWSL(%q) ignore = %v, want %v", tt.input, ignore, tt.expectedIgnore)
}
})
}
}
// TestSanitiseTableName tests table name escaping (passing cases only)
func TestSanitiseTableName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "simple lowercase table",
input: "users",
expected: "users",
},
{
name: "uppercase table",
input: "Users",
expected: `"Users"`,
},
{
name: "table with space",
input: "user data",
expected: `"user data"`,
},
{
name: "table with hyphen",
input: "user-data",
expected: `"user-data"`,
},
{
name: "qualified table",
input: "schema.table",
expected: "schema.table",
},
{
name: "qualified with uppercase",
input: "Schema.Table",
expected: `"Schema"."Table"`,
},
{
name: "qualified with spaces",
input: "my schema.my table",
expected: `"my schema"."my table"`,
},
{
name: "empty string",
input: "",
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitiseTableName(tt.input)
if result != tt.expected {
t.Errorf("sanitiseTableName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
// TestHelperFunctionsWithExtremeInput tests helper functions with extreme inputs
func TestHelperFunctionsWithExtremeInput(t *testing.T) {
t.Run("very long string", func(t *testing.T) {
longString := strings.Repeat("a ", 10000)
// Test that these don't panic or hang
defer func() {
if r := recover(); r != nil {
t.Errorf("Function panicked on long string: %v", r)
}
}()
_ = isFirstWord(longString)
_ = getTable(longString)
_ = getPreviousWord(longString)
_ = getQueryInfo(longString)
})
t.Run("null bytes", func(t *testing.T) {
nullByteString := "select\x00from\x00users"
defer func() {
if r := recover(); r != nil {
t.Errorf("Function panicked on null bytes: %v", r)
}
}()
_ = isFirstWord(nullByteString)
_ = getTable(nullByteString)
_ = getPreviousWord(nullByteString)
})
t.Run("control characters", func(t *testing.T) {
controlString := "select\n\r\tfrom\n\rusers"
defer func() {
if r := recover(); r != nil {
t.Errorf("Function panicked on control chars: %v", r)
}
}()
_ = isFirstWord(controlString)
_ = getTable(controlString)
_ = getPreviousWord(controlString)
})
t.Run("SQL injection attempts", func(t *testing.T) {
injectionStrings := []string{
"'; DROP TABLE users; --",
"1' OR '1'='1",
"1; DELETE FROM connections; --",
"select * from users where id = 1' union select * from passwords --",
}
for _, injection := range injectionStrings {
defer func() {
if r := recover(); r != nil {
t.Errorf("Function panicked on injection string %q: %v", injection, r)
}
}()
_ = isFirstWord(injection)
_ = getTable(injection)
_ = getPreviousWord(injection)
_ = getQueryInfo(injection)
}
})
}

View File

@@ -0,0 +1,308 @@
package pluginmanager_service
import (
"context"
"runtime"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
sdkproto "github.com/turbot/steampipe-plugin-sdk/v5/grpc/proto"
)
// Test helpers for message server tests
func newTestMessageServer(t *testing.T) *PluginMessageServer {
t.Helper()
pm := newTestPluginManager(t)
return &PluginMessageServer{
pluginManager: pm,
}
}
// Test 1: NewPluginMessageServer
func TestNewPluginMessageServer(t *testing.T) {
pm := newTestPluginManager(t)
ms, err := NewPluginMessageServer(pm)
require.NoError(t, err)
assert.NotNil(t, ms)
assert.Equal(t, pm, ms.pluginManager)
}
// Test 2: PluginMessageServer Initialization
func TestPluginManager_MessageServerInitialization(t *testing.T) {
pm := newTestPluginManager(t)
assert.NotNil(t, pm.messageServer, "messageServer should be initialized")
assert.Equal(t, pm, pm.messageServer.pluginManager, "messageServer should reference parent PluginManager")
}
// Test 3: Concurrent Access
func TestPluginMessageServer_ConcurrentAccess(t *testing.T) {
ms := newTestMessageServer(t)
var wg sync.WaitGroup
numGoroutines := 50
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = ms.pluginManager
}()
}
wg.Wait()
}
// Test 4: LogReceiveError with Valid Errors
func TestPluginMessageServer_LogReceiveError(t *testing.T) {
ms := newTestMessageServer(t)
// Should not panic for various error types
ms.logReceiveError(context.Canceled, "test-connection")
ms.logReceiveError(context.DeadlineExceeded, "test-connection")
}
// Test 5: Multiple Message Servers
func TestPluginManager_MultipleMessageServers(t *testing.T) {
pm := newTestPluginManager(t)
ms1, err1 := NewPluginMessageServer(pm)
ms2, err2 := NewPluginMessageServer(pm)
require.NoError(t, err1)
require.NoError(t, err2)
assert.NotNil(t, ms1)
assert.NotNil(t, ms2)
// Both should reference the same plugin manager
assert.Equal(t, pm, ms1.pluginManager)
assert.Equal(t, pm, ms2.pluginManager)
}
// Test 6: Message Server with Nil Plugin Manager
func TestPluginMessageServer_NilPluginManager(t *testing.T) {
ms := &PluginMessageServer{
pluginManager: nil,
}
assert.Nil(t, ms.pluginManager)
}
// Test 7: Goroutine Cleanup
func TestPluginMessageServer_GoroutineCleanup(t *testing.T) {
before := runtime.NumGoroutine()
ms := newTestMessageServer(t)
_ = ms
time.Sleep(100 * time.Millisecond)
after := runtime.NumGoroutine()
// Creating a message server shouldn't leak goroutines
if after > before+5 {
t.Errorf("Potential goroutine leak: before=%d, after=%d", before, after)
}
}
// Test 8: Message Type Structure
func TestPluginMessage_SchemaUpdatedType(t *testing.T) {
message := &sdkproto.PluginMessage{
MessageType: sdkproto.PluginMessageType_SCHEMA_UPDATED,
Connection: "test-connection",
}
assert.Equal(t, sdkproto.PluginMessageType_SCHEMA_UPDATED, message.MessageType)
assert.Equal(t, "test-connection", message.Connection)
}
// Test 9: LogReceiveError with Different Error Types
func TestPluginMessageServer_LogReceiveError_ErrorTypes(t *testing.T) {
ms := newTestMessageServer(t)
// Test various error types don't cause panics
errors := []error{
context.Canceled,
context.DeadlineExceeded,
assert.AnError,
}
for _, err := range errors {
ms.logReceiveError(err, "test-connection")
}
}
// Test 10: Message Server Initialization Consistency
func TestPluginManager_MessageServer_Consistency(t *testing.T) {
pm := newTestPluginManager(t)
// Verify messageServer is initialized and consistent
assert.NotNil(t, pm.messageServer)
assert.Equal(t, pm, pm.messageServer.pluginManager)
// Accessing it multiple times should return the same instance
ms1 := pm.messageServer
ms2 := pm.messageServer
assert.Equal(t, ms1, ms2)
}
// Test 11: Message Server Survives Plugin Manager Operations
func TestPluginMessageServer_SurvivesPluginManagerOperations(t *testing.T) {
pm := newTestPluginManager(t)
ms := pm.messageServer
// Perform various plugin manager operations
pm.populatePluginConnectionConfigs()
pm.setPluginCacheSizeMap()
pm.nonAggregatorConnectionCount()
// Message server should still be accessible
assert.Equal(t, pm, ms.pluginManager)
assert.NotNil(t, pm.messageServer)
}
// Test 12: Concurrent NewPluginMessageServer Calls
func TestNewPluginMessageServer_Concurrent(t *testing.T) {
pm := newTestPluginManager(t)
var wg sync.WaitGroup
numGoroutines := 50
servers := make([]*PluginMessageServer, numGoroutines)
errors := make([]error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
servers[idx], errors[idx] = NewPluginMessageServer(pm)
}(i)
}
wg.Wait()
// All should succeed
for i := 0; i < numGoroutines; i++ {
assert.NoError(t, errors[i])
assert.NotNil(t, servers[i])
assert.Equal(t, pm, servers[i].pluginManager)
}
}
// Test 13: Message Server Pointer Stability
func TestPluginMessageServer_PointerStability(t *testing.T) {
pm := newTestPluginManager(t)
ms1 := pm.messageServer
ms2 := pm.messageServer
// Should be the same pointer
assert.True(t, ms1 == ms2, "messageServer pointer should be stable")
}
// Test 14: LogReceiveError Concurrent Calls
func TestPluginMessageServer_LogReceiveError_Concurrent(t *testing.T) {
ms := newTestMessageServer(t)
var wg sync.WaitGroup
numGoroutines := 100
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
err := assert.AnError
if idx%2 == 0 {
err = context.Canceled
}
ms.logReceiveError(err, "test-connection")
}(i)
}
wg.Wait()
}
// Test 15: Message Server Field Access
func TestPluginMessageServer_FieldAccess(t *testing.T) {
ms := newTestMessageServer(t)
// Verify fields are accessible and not nil
assert.NotNil(t, ms.pluginManager)
assert.NotNil(t, ms.pluginManager.logger)
assert.NotNil(t, ms.pluginManager.runningPluginMap)
}
// Test 16: Message Server Doesn't Block Plugin Manager
func TestPluginMessageServer_DoesNotBlockPluginManager(t *testing.T) {
pm := newTestPluginManager(t)
// Message server should not prevent these operations
config := newTestConnectionConfig("plugin1", "instance1", "conn1")
pm.connectionConfigMap["conn1"] = config
pm.populatePluginConnectionConfigs()
// Verify operations worked
assert.Len(t, pm.pluginConnectionConfigMap, 1)
// Message server should still be valid
assert.NotNil(t, pm.messageServer)
assert.Equal(t, pm, pm.messageServer.pluginManager)
}
// Test 17: Stress Test for Concurrent Access
func TestPluginMessageServer_StressConcurrentAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping stress test in short mode")
}
pm := newTestPluginManager(t)
ms := pm.messageServer
var wg sync.WaitGroup
duration := 1 * time.Second
stopCh := make(chan struct{})
// Multiple readers accessing pluginManager
for i := 0; i < 20; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
_ = ms.pluginManager
if ms.pluginManager != nil {
_ = ms.pluginManager.connectionConfigMap
}
}
}
}()
}
time.Sleep(duration)
close(stopCh)
wg.Wait()
}

View File

@@ -0,0 +1,716 @@
package pluginmanager_service
import (
"fmt"
"runtime"
"sync"
"testing"
"time"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/turbot/pipe-fittings/v2/plugin"
sdkproto "github.com/turbot/steampipe-plugin-sdk/v5/grpc/proto"
"github.com/turbot/steampipe/v2/pkg/connection"
pb "github.com/turbot/steampipe/v2/pkg/pluginmanager_service/grpc/proto"
)
// Test helpers and mocks
func newTestPluginManager(t *testing.T) *PluginManager {
t.Helper()
logger := hclog.NewNullLogger()
pm := &PluginManager{
logger: logger,
runningPluginMap: make(map[string]*runningPlugin),
pluginConnectionConfigMap: make(map[string][]*sdkproto.ConnectionConfig),
connectionConfigMap: make(connection.ConnectionConfigMap),
pluginCacheSizeMap: make(map[string]int64),
plugins: make(connection.PluginMap),
userLimiters: make(connection.PluginLimiterMap),
pluginLimiters: make(connection.PluginLimiterMap),
}
pm.messageServer = &PluginMessageServer{pluginManager: pm}
return pm
}
func newTestConnectionConfig(plugin, instance, connection string) *sdkproto.ConnectionConfig {
return &sdkproto.ConnectionConfig{
Plugin: plugin,
PluginInstance: instance,
Connection: connection,
}
}
// Test 1: Basic Initialization
func TestPluginManager_New(t *testing.T) {
pm := newTestPluginManager(t)
assert.NotNil(t, pm, "PluginManager should not be nil")
assert.NotNil(t, pm.runningPluginMap, "runningPluginMap should be initialized")
assert.NotNil(t, pm.messageServer, "messageServer should be initialized")
assert.NotNil(t, pm.logger, "logger should be initialized")
}
// Test 2: Connection Config Access
func TestPluginManager_GetConnectionConfig_NotFound(t *testing.T) {
pm := newTestPluginManager(t)
_, err := pm.getConnectionConfig("nonexistent")
assert.Error(t, err, "Should return error for nonexistent connection")
assert.Contains(t, err.Error(), "does not exist", "Error should mention connection doesn't exist")
}
func TestPluginManager_GetConnectionConfig_Found(t *testing.T) {
pm := newTestPluginManager(t)
expectedConfig := newTestConnectionConfig("test-plugin", "test-instance", "test-connection")
pm.connectionConfigMap["test-connection"] = expectedConfig
config, err := pm.getConnectionConfig("test-connection")
require.NoError(t, err)
assert.Equal(t, expectedConfig, config)
}
func TestPluginManager_GetConnectionConfig_NilMap(t *testing.T) {
pm := newTestPluginManager(t)
pm.connectionConfigMap = nil
_, err := pm.getConnectionConfig("conn1")
assert.Error(t, err, "Should handle nil connectionConfigMap gracefully")
}
// Test 3: Map Population
func TestPluginManager_PopulatePluginConnectionConfigs(t *testing.T) {
pm := newTestPluginManager(t)
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
config2 := newTestConnectionConfig("plugin1", "instance1", "conn2")
config3 := newTestConnectionConfig("plugin2", "instance2", "conn3")
pm.connectionConfigMap = connection.ConnectionConfigMap{
"conn1": config1,
"conn2": config2,
"conn3": config3,
}
pm.populatePluginConnectionConfigs()
assert.Len(t, pm.pluginConnectionConfigMap, 2, "Should have 2 plugin instances")
assert.Len(t, pm.pluginConnectionConfigMap["instance1"], 2, "instance1 should have 2 connections")
assert.Len(t, pm.pluginConnectionConfigMap["instance2"], 1, "instance2 should have 1 connection")
}
// Test 4: Build Required Plugin Map
func TestPluginManager_BuildRequiredPluginMap(t *testing.T) {
pm := newTestPluginManager(t)
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
config2 := newTestConnectionConfig("plugin1", "instance1", "conn2")
config3 := newTestConnectionConfig("plugin2", "instance2", "conn3")
pm.connectionConfigMap = connection.ConnectionConfigMap{
"conn1": config1,
"conn2": config2,
"conn3": config3,
}
pm.populatePluginConnectionConfigs()
req := &pb.GetRequest{
Connections: []string{"conn1", "conn3"},
}
pluginMap, requestedConns, err := pm.buildRequiredPluginMap(req)
require.NoError(t, err)
assert.Len(t, pluginMap, 2, "Should map 2 plugin instances")
assert.Len(t, requestedConns, 2, "Should have 2 requested connections")
assert.Contains(t, requestedConns, "conn1")
assert.Contains(t, requestedConns, "conn3")
}
// Test 5: Concurrent Map Access
func TestPluginManager_ConcurrentMapAccess(t *testing.T) {
pm := newTestPluginManager(t)
// Populate some initial data
for i := 0; i < 10; i++ {
connName := fmt.Sprintf("conn%d", i)
config := newTestConnectionConfig("plugin1", "instance1", connName)
pm.connectionConfigMap[connName] = config
}
pm.populatePluginConnectionConfigs()
var wg sync.WaitGroup
numGoroutines := 50
// Concurrent reads with proper locking
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
connName := fmt.Sprintf("conn%d", idx%10)
pm.mut.RLock()
_ = pm.connectionConfigMap[connName]
pm.mut.RUnlock()
}(i)
}
wg.Wait()
assert.Len(t, pm.connectionConfigMap, 10)
}
// Test 6: Shutdown Flag Management
func TestPluginManager_Shutdown_SetsShuttingDownFlag(t *testing.T) {
pm := newTestPluginManager(t)
assert.False(t, pm.isShuttingDown(), "Initially should not be shutting down")
// Set the flag as Shutdown does
pm.shutdownMut.Lock()
pm.shuttingDown = true
pm.shutdownMut.Unlock()
assert.True(t, pm.isShuttingDown(), "Should be shutting down after flag is set")
}
func TestPluginManager_Shutdown_WaitsForPluginStart(t *testing.T) {
pm := newTestPluginManager(t)
// Simulate a plugin starting
pm.startPluginWg.Add(1)
shutdownComplete := make(chan struct{})
go func() {
pm.shutdownMut.Lock()
pm.shuttingDown = true
pm.shutdownMut.Unlock()
pm.startPluginWg.Wait()
close(shutdownComplete)
}()
// Give shutdown goroutine time to reach Wait
time.Sleep(50 * time.Millisecond)
// Verify shutdown hasn't completed yet
select {
case <-shutdownComplete:
t.Fatal("Shutdown completed before startPluginWg.Done() was called")
case <-time.After(10 * time.Millisecond):
// Expected
}
// Signal plugin start complete
pm.startPluginWg.Done()
// Verify shutdown completes
select {
case <-shutdownComplete:
// Expected
case <-time.After(100 * time.Millisecond):
t.Fatal("Shutdown did not complete after startPluginWg.Done()")
}
}
// Test 7: Running Plugin Management
func TestPluginManager_AddRunningPlugin_Success(t *testing.T) {
pm := newTestPluginManager(t)
// Add a plugin config
pm.plugins["test-instance"] = &plugin.Plugin{
Plugin: "test-plugin",
Instance: "test-instance",
}
rp, err := pm.addRunningPlugin("test-instance")
require.NoError(t, err)
assert.NotNil(t, rp)
assert.Equal(t, "test-instance", rp.pluginInstance)
assert.NotNil(t, rp.initialized)
assert.NotNil(t, rp.failed)
// Verify it was added to the map
pm.mut.RLock()
stored := pm.runningPluginMap["test-instance"]
pm.mut.RUnlock()
assert.Equal(t, rp, stored)
}
func TestPluginManager_AddRunningPlugin_AlreadyExists(t *testing.T) {
pm := newTestPluginManager(t)
// Add a plugin config
pm.plugins["test-instance"] = &plugin.Plugin{
Plugin: "test-plugin",
Instance: "test-instance",
}
// Add first time
_, err := pm.addRunningPlugin("test-instance")
require.NoError(t, err)
// Try to add again - should return retryable error
_, err = pm.addRunningPlugin("test-instance")
assert.Error(t, err)
assert.Contains(t, err.Error(), "already started")
}
func TestPluginManager_AddRunningPlugin_NoConfig(t *testing.T) {
pm := newTestPluginManager(t)
// Don't add any plugin config
_, err := pm.addRunningPlugin("nonexistent-instance")
assert.Error(t, err)
assert.Contains(t, err.Error(), "no config")
}
// Test 8: Concurrent Plugin Operations
func TestPluginManager_ConcurrentAddRunningPlugin(t *testing.T) {
pm := newTestPluginManager(t)
// Add plugin config
pm.plugins["test-instance"] = &plugin.Plugin{
Plugin: "test-plugin",
Instance: "test-instance",
}
var wg sync.WaitGroup
numGoroutines := 10
successCount := 0
errorCount := 0
var mu sync.Mutex
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := pm.addRunningPlugin("test-instance")
mu.Lock()
if err == nil {
successCount++
} else {
errorCount++
}
mu.Unlock()
}()
}
wg.Wait()
// Only one should succeed, the rest should get retryable errors
assert.Equal(t, 1, successCount, "Only one goroutine should succeed")
assert.Equal(t, numGoroutines-1, errorCount, "All other goroutines should fail")
}
// Test 9: IsShuttingDown with Concurrent Access
func TestPluginManager_IsShuttingDown_Concurrent(t *testing.T) {
pm := newTestPluginManager(t)
var wg sync.WaitGroup
numReaders := 50
// Start many readers
for i := 0; i < numReaders; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
_ = pm.isShuttingDown()
}
}()
}
// One writer
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
pm.shutdownMut.Lock()
pm.shuttingDown = !pm.shuttingDown
pm.shutdownMut.Unlock()
time.Sleep(time.Millisecond)
}
}()
wg.Wait()
}
// Test 10: Plugin Cache Size Map
func TestPluginManager_SetPluginCacheSizeMap_NoCacheLimit(t *testing.T) {
pm := newTestPluginManager(t)
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
config2 := newTestConnectionConfig("plugin2", "instance2", "conn2")
pm.pluginConnectionConfigMap = map[string][]*sdkproto.ConnectionConfig{
"instance1": {config1},
"instance2": {config2},
}
pm.setPluginCacheSizeMap()
// When no max size is set, all plugins should have size 0 (unlimited)
assert.Equal(t, int64(0), pm.pluginCacheSizeMap["instance1"])
assert.Equal(t, int64(0), pm.pluginCacheSizeMap["instance2"])
}
// Test 11: NonAggregatorConnectionCount
func TestPluginManager_NonAggregatorConnectionCount(t *testing.T) {
pm := newTestPluginManager(t)
// Regular connection (no child connections)
config1 := &sdkproto.ConnectionConfig{
Plugin: "plugin1",
PluginInstance: "instance1",
Connection: "conn1",
ChildConnections: []string{},
}
// Aggregator connection (has child connections)
config2 := &sdkproto.ConnectionConfig{
Plugin: "plugin1",
PluginInstance: "instance1",
Connection: "conn2",
ChildConnections: []string{"child1", "child2"},
}
// Another regular connection
config3 := &sdkproto.ConnectionConfig{
Plugin: "plugin2",
PluginInstance: "instance2",
Connection: "conn3",
ChildConnections: []string{},
}
pm.pluginConnectionConfigMap = map[string][]*sdkproto.ConnectionConfig{
"instance1": {config1, config2},
"instance2": {config3},
}
count := pm.nonAggregatorConnectionCount()
// Should count only non-aggregator connections (conn1 and conn3)
assert.Equal(t, 2, count)
}
// Test 12: GetPluginExemplarConnections
func TestPluginManager_GetPluginExemplarConnections(t *testing.T) {
pm := newTestPluginManager(t)
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
config2 := newTestConnectionConfig("plugin1", "instance1", "conn2")
config3 := newTestConnectionConfig("plugin2", "instance2", "conn3")
pm.connectionConfigMap = connection.ConnectionConfigMap{
"conn1": config1,
"conn2": config2,
"conn3": config3,
}
exemplars := pm.getPluginExemplarConnections()
assert.Len(t, exemplars, 2, "Should have 2 plugins")
// Should have one exemplar for each plugin (might be any of the connections)
assert.Contains(t, []string{"conn1", "conn2"}, exemplars["plugin1"])
assert.Equal(t, "conn3", exemplars["plugin2"])
}
// Test 13: Goroutine Leak Detection
func TestPluginManager_NoGoroutineLeak_OnError(t *testing.T) {
before := runtime.NumGoroutine()
pm := newTestPluginManager(t)
// Add plugin config
pm.plugins["test-instance"] = &plugin.Plugin{
Plugin: "test-plugin",
Instance: "test-instance",
}
// Try to add running plugin
_, err := pm.addRunningPlugin("test-instance")
require.NoError(t, err)
// Clean up
pm.mut.Lock()
delete(pm.runningPluginMap, "test-instance")
pm.mut.Unlock()
time.Sleep(100 * time.Millisecond)
after := runtime.NumGoroutine()
// Allow some tolerance for background goroutines
if after > before+5 {
t.Errorf("Potential goroutine leak: before=%d, after=%d", before, after)
}
}
// Test 14: Pool Access
func TestPluginManager_Pool(t *testing.T) {
pm := newTestPluginManager(t)
// Initially nil
assert.Nil(t, pm.Pool())
}
// Test 15: RefreshConnections
func TestPluginManager_RefreshConnections(t *testing.T) {
pm := newTestPluginManager(t)
req := &pb.RefreshConnectionsRequest{}
resp, err := pm.RefreshConnections(req)
require.NoError(t, err, "RefreshConnections should not return error")
assert.NotNil(t, resp, "Response should not be nil")
}
// Test 16: GetConnectionConfig Concurrent Access
func TestPluginManager_GetConnectionConfig_Concurrent(t *testing.T) {
pm := newTestPluginManager(t)
config := newTestConnectionConfig("plugin1", "instance1", "conn1")
pm.connectionConfigMap["conn1"] = config
var wg sync.WaitGroup
numGoroutines := 50
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cfg, err := pm.getConnectionConfig("conn1")
if err == nil {
assert.Equal(t, "conn1", cfg.Connection)
}
}()
}
wg.Wait()
}
// Test 17: Running Plugin Structure
func TestRunningPlugin_Initialization(t *testing.T) {
rp := &runningPlugin{
pluginInstance: "test",
imageRef: "test-image",
initialized: make(chan struct{}),
failed: make(chan struct{}),
}
assert.NotNil(t, rp.initialized, "initialized channel should not be nil")
assert.NotNil(t, rp.failed, "failed channel should not be nil")
// Verify channels are not closed initially
select {
case <-rp.initialized:
t.Fatal("initialized channel should not be closed initially")
default:
// Expected
}
select {
case <-rp.failed:
t.Fatal("failed channel should not be closed initially")
default:
// Expected
}
}
// Test 18: Multiple Concurrent Refreshes
func TestPluginManager_ConcurrentRefreshConnections(t *testing.T) {
pm := newTestPluginManager(t)
var wg sync.WaitGroup
numGoroutines := 10
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
req := &pb.RefreshConnectionsRequest{}
_, _ = pm.RefreshConnections(req)
}()
}
wg.Wait()
}
// Test 19: NonAggregatorConnectionCount Helper
func TestNonAggregatorConnectionCount(t *testing.T) {
tests := []struct {
name string
connections []*sdkproto.ConnectionConfig
expected int
}{
{
name: "empty",
connections: []*sdkproto.ConnectionConfig{},
expected: 0,
},
{
name: "all non-aggregators",
connections: []*sdkproto.ConnectionConfig{
{Connection: "conn1", ChildConnections: []string{}},
{Connection: "conn2", ChildConnections: []string{}},
},
expected: 2,
},
{
name: "all aggregators",
connections: []*sdkproto.ConnectionConfig{
{Connection: "conn1", ChildConnections: []string{"child1"}},
{Connection: "conn2", ChildConnections: []string{"child2"}},
},
expected: 0,
},
{
name: "mixed",
connections: []*sdkproto.ConnectionConfig{
{Connection: "conn1", ChildConnections: []string{}},
{Connection: "conn2", ChildConnections: []string{"child1"}},
{Connection: "conn3", ChildConnections: []string{}},
},
expected: 2,
},
{
name: "nil child connections",
connections: []*sdkproto.ConnectionConfig{
{Connection: "conn1", ChildConnections: nil},
{Connection: "conn2", ChildConnections: []string{"child1"}},
},
expected: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
count := nonAggregatorConnectionCount(tt.connections)
assert.Equal(t, tt.expected, count)
})
}
}
// Test 20: GetResponse Helper
func TestNewGetResponse(t *testing.T) {
resp := newGetResponse()
assert.NotNil(t, resp)
assert.NotNil(t, resp.GetResponse)
assert.NotNil(t, resp.ReattachMap)
assert.NotNil(t, resp.FailureMap)
}
// Test 21: EnsurePlugin Early Exit When Shutting Down
func TestPluginManager_EnsurePlugin_ShuttingDown(t *testing.T) {
pm := newTestPluginManager(t)
// Set shutting down flag
pm.shutdownMut.Lock()
pm.shuttingDown = true
pm.shutdownMut.Unlock()
config := newTestConnectionConfig("plugin1", "instance1", "conn1")
req := &pb.GetRequest{Connections: []string{"conn1"}}
_, err := pm.ensurePlugin("instance1", []*sdkproto.ConnectionConfig{config}, req)
assert.Error(t, err)
assert.Contains(t, err.Error(), "shutting down")
}
// Test 22: KillPlugin with Nil Client
func TestPluginManager_KillPlugin_NilClient(t *testing.T) {
pm := newTestPluginManager(t)
rp := &runningPlugin{
pluginInstance: "test",
client: nil,
}
// Should not panic
pm.killPlugin(rp)
}
// Test 23: Stress Test for Map Access
func TestPluginManager_StressConcurrentMapAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping stress test in short mode")
}
pm := newTestPluginManager(t)
// Add initial configs
for i := 0; i < 100; i++ {
connName := fmt.Sprintf("conn%d", i)
config := newTestConnectionConfig("plugin1", "instance1", connName)
pm.connectionConfigMap[connName] = config
}
pm.populatePluginConnectionConfigs()
var wg sync.WaitGroup
duration := 1 * time.Second
stopCh := make(chan struct{})
// Start multiple readers
for i := 0; i < 20; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
for {
select {
case <-stopCh:
return
default:
connName := fmt.Sprintf("conn%d", idx%100)
pm.mut.RLock()
_ = pm.connectionConfigMap[connName]
_ = pm.pluginConnectionConfigMap["instance1"]
pm.mut.RUnlock()
}
}
}(i)
}
// Run for duration
time.Sleep(duration)
close(stopCh)
wg.Wait()
}

View File

@@ -0,0 +1,423 @@
package pluginmanager_service
import (
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/turbot/pipe-fittings/v2/plugin"
"github.com/turbot/steampipe/v2/pkg/connection"
)
// Test helpers for rate limiter tests
func newTestRateLimiter(pluginName, name string, source string) *plugin.RateLimiter {
return &plugin.RateLimiter{
Plugin: pluginName,
Name: name,
Source: source,
Status: plugin.LimiterStatusActive,
}
}
// Test 1: ShouldFetchRateLimiterDefs
func TestPluginManager_ShouldFetchRateLimiterDefs_Nil(t *testing.T) {
pm := newTestPluginManager(t)
pm.pluginLimiters = nil
should := pm.ShouldFetchRateLimiterDefs()
assert.True(t, should, "Should fetch when pluginLimiters is nil")
}
func TestPluginManager_ShouldFetchRateLimiterDefs_NotNil(t *testing.T) {
pm := newTestPluginManager(t)
pm.pluginLimiters = make(connection.PluginLimiterMap)
should := pm.ShouldFetchRateLimiterDefs()
assert.False(t, should, "Should not fetch when pluginLimiters is initialized")
}
// Test 2: GetPluginsWithChangedLimiters
func TestPluginManager_GetPluginsWithChangedLimiters_NoChanges(t *testing.T) {
pm := newTestPluginManager(t)
limiter1 := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig)
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": limiter1,
},
}
newLimiters := connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": limiter1,
},
}
changed := pm.getPluginsWithChangedLimiters(newLimiters)
assert.Len(t, changed, 0, "No plugins should have changed limiters")
}
func TestPluginManager_GetPluginsWithChangedLimiters_NewPlugin(t *testing.T) {
pm := newTestPluginManager(t)
pm.userLimiters = connection.PluginLimiterMap{}
newLimiters := connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
},
}
changed := pm.getPluginsWithChangedLimiters(newLimiters)
assert.Len(t, changed, 1, "Should detect new plugin")
assert.Contains(t, changed, "plugin1")
}
func TestPluginManager_GetPluginsWithChangedLimiters_RemovedPlugin(t *testing.T) {
pm := newTestPluginManager(t)
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
},
}
newLimiters := connection.PluginLimiterMap{}
changed := pm.getPluginsWithChangedLimiters(newLimiters)
assert.Len(t, changed, 1, "Should detect removed plugin")
assert.Contains(t, changed, "plugin1")
}
// Test 3: UpdateRateLimiterStatus
func TestPluginManager_UpdateRateLimiterStatus_NoOverride(t *testing.T) {
pm := newTestPluginManager(t)
pluginLimiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
pluginLimiter.Status = plugin.LimiterStatusActive
pm.pluginLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": pluginLimiter,
},
}
pm.userLimiters = connection.PluginLimiterMap{}
pm.updateRateLimiterStatus()
assert.Equal(t, plugin.LimiterStatusActive, pluginLimiter.Status)
}
func TestPluginManager_UpdateRateLimiterStatus_WithOverride(t *testing.T) {
pm := newTestPluginManager(t)
pluginLimiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
pluginLimiter.Status = plugin.LimiterStatusActive
userLimiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig)
pm.pluginLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": pluginLimiter,
},
}
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": userLimiter,
},
}
pm.updateRateLimiterStatus()
assert.Equal(t, plugin.LimiterStatusOverridden, pluginLimiter.Status)
}
func TestPluginManager_UpdateRateLimiterStatus_MultiplePlugins(t *testing.T) {
pm := newTestPluginManager(t)
plugin1Limiter1 := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
plugin1Limiter2 := newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourcePlugin)
plugin2Limiter1 := newTestRateLimiter("plugin2", "limiter1", plugin.LimiterSourcePlugin)
pm.pluginLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": plugin1Limiter1,
"limiter2": plugin1Limiter2,
},
"plugin2": connection.LimiterMap{
"limiter1": plugin2Limiter1,
},
}
// Only override plugin1/limiter1
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
},
}
pm.updateRateLimiterStatus()
assert.Equal(t, plugin.LimiterStatusOverridden, plugin1Limiter1.Status)
assert.Equal(t, plugin.LimiterStatusActive, plugin1Limiter2.Status)
assert.Equal(t, plugin.LimiterStatusActive, plugin2Limiter1.Status)
}
// Test 4: GetUserDefinedLimitersForPlugin
func TestPluginManager_GetUserDefinedLimitersForPlugin_Exists(t *testing.T) {
pm := newTestPluginManager(t)
limiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig)
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": limiter,
},
}
result := pm.getUserDefinedLimitersForPlugin("plugin1")
assert.Len(t, result, 1)
assert.Equal(t, limiter, result["limiter1"])
}
func TestPluginManager_GetUserDefinedLimitersForPlugin_NotExists(t *testing.T) {
pm := newTestPluginManager(t)
pm.userLimiters = connection.PluginLimiterMap{}
result := pm.getUserDefinedLimitersForPlugin("plugin1")
assert.NotNil(t, result, "Should return empty map, not nil")
assert.Len(t, result, 0)
}
// Test 5: GetUserAndPluginLimitersFromTableResult
func TestPluginManager_GetUserAndPluginLimitersFromTableResult(t *testing.T) {
pm := newTestPluginManager(t)
rateLimiters := []*plugin.RateLimiter{
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin),
newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig),
newTestRateLimiter("plugin2", "limiter1", plugin.LimiterSourcePlugin),
}
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
// Check plugin limiters
assert.Len(t, pluginLimiters, 2)
assert.NotNil(t, pluginLimiters["plugin1"]["limiter1"])
assert.NotNil(t, pluginLimiters["plugin2"]["limiter1"])
// Check user limiters
assert.Len(t, userLimiters, 1)
assert.NotNil(t, userLimiters["plugin1"]["limiter2"])
}
func TestPluginManager_GetUserAndPluginLimitersFromTableResult_Empty(t *testing.T) {
pm := newTestPluginManager(t)
rateLimiters := []*plugin.RateLimiter{}
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
assert.NotNil(t, pluginLimiters)
assert.NotNil(t, userLimiters)
assert.Len(t, pluginLimiters, 0)
assert.Len(t, userLimiters, 0)
}
// Test 6: GetPluginsWithChangedLimiters Concurrent
func TestPluginManager_GetPluginsWithChangedLimiters_Concurrent(t *testing.T) {
pm := newTestPluginManager(t)
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
},
}
var wg sync.WaitGroup
numGoroutines := 50
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
newLimiters := connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
},
}
if idx%2 == 0 {
// Add a new limiter
newLimiters["plugin1"]["limiter2"] = newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig)
}
_ = pm.getPluginsWithChangedLimiters(newLimiters)
}(i)
}
wg.Wait()
}
// Test 7: UpdateRateLimiterStatus with Multiple Limiters
func TestPluginManager_UpdateRateLimiterStatus_MultipleLimiters(t *testing.T) {
pm := newTestPluginManager(t)
limiter1 := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
limiter2 := newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourcePlugin)
limiter3 := newTestRateLimiter("plugin1", "limiter3", plugin.LimiterSourcePlugin)
pm.pluginLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": limiter1,
"limiter2": limiter2,
"limiter3": limiter3,
},
}
// Override only limiter2
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter2": newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig),
},
}
pm.updateRateLimiterStatus()
assert.Equal(t, plugin.LimiterStatusActive, limiter1.Status)
assert.Equal(t, plugin.LimiterStatusOverridden, limiter2.Status)
assert.Equal(t, plugin.LimiterStatusActive, limiter3.Status)
}
// Test 8: GetUserAndPluginLimitersFromTableResult with Duplicate Names
func TestPluginManager_GetUserAndPluginLimitersFromTableResult_DuplicateNames(t *testing.T) {
pm := newTestPluginManager(t)
// Same limiter name, different sources
rateLimiters := []*plugin.RateLimiter{
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin),
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
}
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
assert.NotNil(t, pluginLimiters["plugin1"]["limiter1"])
assert.NotNil(t, userLimiters["plugin1"]["limiter1"])
assert.NotEqual(t, pluginLimiters["plugin1"]["limiter1"], userLimiters["plugin1"]["limiter1"])
}
// Test 9: UpdateRateLimiterStatus with Empty Maps
func TestPluginManager_UpdateRateLimiterStatus_EmptyMaps(t *testing.T) {
pm := newTestPluginManager(t)
pm.pluginLimiters = connection.PluginLimiterMap{}
pm.userLimiters = connection.PluginLimiterMap{}
// Should not panic
pm.updateRateLimiterStatus()
}
// Test 10: GetPluginsWithChangedLimiters with Nil Comparison
func TestPluginManager_GetPluginsWithChangedLimiters_NilComparison(t *testing.T) {
pm := newTestPluginManager(t)
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": nil,
}
newLimiters := connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
},
}
changed := pm.getPluginsWithChangedLimiters(newLimiters)
assert.Contains(t, changed, "plugin1", "Should detect change from nil to non-nil")
}
// Test 11: ShouldFetchRateLimiterDefs Concurrent
func TestPluginManager_ShouldFetchRateLimiterDefs_Concurrent(t *testing.T) {
pm := newTestPluginManager(t)
pm.pluginLimiters = nil
var wg sync.WaitGroup
numGoroutines := 100
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
_ = pm.ShouldFetchRateLimiterDefs()
}()
}
wg.Wait()
}
// Test 12: GetUserDefinedLimitersForPlugin Concurrent
func TestPluginManager_GetUserDefinedLimitersForPlugin_Concurrent(t *testing.T) {
pm := newTestPluginManager(t)
pm.userLimiters = connection.PluginLimiterMap{
"plugin1": connection.LimiterMap{
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
},
}
var wg sync.WaitGroup
numGoroutines := 100
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
result := pm.getUserDefinedLimitersForPlugin("plugin1")
assert.NotNil(t, result)
}()
}
wg.Wait()
}
// Test 13: GetUserAndPluginLimitersFromTableResult Concurrent
func TestPluginManager_GetUserAndPluginLimitersFromTableResult_Concurrent(t *testing.T) {
pm := newTestPluginManager(t)
rateLimiters := []*plugin.RateLimiter{
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin),
newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig),
newTestRateLimiter("plugin2", "limiter1", plugin.LimiterSourcePlugin),
}
var wg sync.WaitGroup
numGoroutines := 50
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
assert.NotNil(t, pluginLimiters)
assert.NotNil(t, userLimiters)
}()
}
wg.Wait()
}

View File

@@ -0,0 +1,253 @@
package queryexecute
import (
"context"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/turbot/pipe-fittings/v2/modconfig"
pqueryresult "github.com/turbot/pipe-fittings/v2/queryresult"
"github.com/turbot/steampipe/v2/pkg/db/db_common"
"github.com/turbot/steampipe/v2/pkg/export"
"github.com/turbot/steampipe/v2/pkg/initialisation"
"github.com/turbot/steampipe/v2/pkg/query"
"github.com/turbot/steampipe/v2/pkg/query/queryresult"
)
// Test Helpers
// createMockInitData creates a mock InitData for testing
func createMockInitData(t *testing.T) *query.InitData {
t.Helper()
initData := &query.InitData{
InitData: initialisation.InitData{
Result: &db_common.InitResult{},
ExportManager: export.NewManager(),
Client: &mockClient{}, // Add mock client to prevent nil pointer panics
},
Loaded: make(chan struct{}),
StartTime: time.Now(),
Queries: []*modconfig.ResolvedQuery{},
}
return initData
}
// closeInitDataLoaded closes the Loaded channel to simulate initialization completion
func closeInitDataLoaded(initData *query.InitData) {
select {
case <-initData.Loaded:
// already closed
default:
close(initData.Loaded)
}
}
// Test Suite: RunBatchSession
func TestRunBatchSession_EmptyQueries(t *testing.T) {
// ARRANGE: Create initData with no queries
ctx := context.Background()
initData := createMockInitData(t)
initData.Queries = []*modconfig.ResolvedQuery{} // explicitly empty
// Simulate successful initialization
closeInitDataLoaded(initData)
// ACT: Run batch session
failures, err := RunBatchSession(ctx, initData)
// ASSERT: Should return 0 failures and no error
assert.NoError(t, err, "RunBatchSession should not error with empty queries")
assert.Equal(t, 0, failures, "Should return 0 failures when no queries to execute")
}
func TestRunBatchSession_InitError(t *testing.T) {
// ARRANGE: Create initData with an initialization error
ctx := context.Background()
initData := createMockInitData(t)
// Simulate initialization error
expectedErr := assert.AnError
initData.Result.Error = expectedErr
closeInitDataLoaded(initData)
// ACT: Run batch session
failures, err := RunBatchSession(ctx, initData)
// ASSERT: Should return the init error immediately
assert.Equal(t, expectedErr, err, "Should return initialization error")
assert.Equal(t, 0, failures, "Should return 0 failures when init fails")
}
// Test Suite: Helper Functions
func TestNeedSnapshot_DefaultValues(t *testing.T) {
// This test verifies the needSnapshot function behavior with default config
// Note: This is a simple test but ensures the function doesn't panic
// ACT: Call needSnapshot with default viper config
result := needSnapshot()
// ASSERT: Should return false with default settings
assert.False(t, result, "needSnapshot should return false with default settings")
}
func TestShowBlankLineBetweenResults_DefaultValues(t *testing.T) {
// This test verifies showBlankLineBetweenResults function with default config
// ACT: Call function with default viper config
result := showBlankLineBetweenResults()
// ASSERT: Should return true with default settings (not CSV without header)
assert.True(t, result, "Should show blank lines with default settings")
}
func TestHandlePublishSnapshotError_PaymentRequired(t *testing.T) {
// ARRANGE: Create a 402 Payment Required error
err := assert.AnError
err = &mockError{msg: "402 Payment Required"}
// ACT: Handle the error
result := handlePublishSnapshotError(err)
// ASSERT: Should reword the error message
assert.Error(t, result)
assert.Contains(t, result.Error(), "maximum number of snapshots reached")
}
func TestHandlePublishSnapshotError_OtherError(t *testing.T) {
// ARRANGE: Create a different error
err := assert.AnError
// ACT: Handle the error
result := handlePublishSnapshotError(err)
// ASSERT: Should return the error unchanged
assert.Equal(t, err, result)
}
// Test Suite: Edge Cases and Resource Management
func TestExecuteQueries_EmptyQueriesList(t *testing.T) {
// ARRANGE: InitData with empty queries list
ctx := context.Background()
initData := createMockInitData(t)
initData.Queries = []*modconfig.ResolvedQuery{}
// ACT: Execute queries directly
failures := executeQueries(ctx, initData)
// ASSERT: Should return 0 failures
assert.Equal(t, 0, failures, "Should return 0 failures for empty queries list")
}
// Test Suite: Context and Cancellation
func TestRunBatchSession_CancelHandlerSetup(t *testing.T) {
// This test verifies that the cancel handler doesn't cause panics
// We can't easily test the actual cancellation behavior without integration tests
// ARRANGE
ctx := context.Background()
initData := createMockInitData(t)
closeInitDataLoaded(initData)
// ACT: Run batch session
// Note: This test just verifies no panic occurs when setting up cancel handler
assert.NotPanics(t, func() {
_, _ = RunBatchSession(ctx, initData)
}, "Should not panic when setting up cancel handler")
}
// Test Suite: Result Wrapping
func TestWrapResult_NotNil(t *testing.T) {
// This test ensures WrapResult doesn't panic and returns a valid wrapper
// ARRANGE: Create a basic result from pipe-fittings
// Note: We need to use the pipe-fittings queryresult package
// This test verifies the wrapper functionality exists and doesn't panic
wrapped := queryresult.NewResult(nil)
// ASSERT: Should return a valid result
assert.NotNil(t, wrapped, "NewResult should not return nil")
}
// Mock Types
type mockError struct {
msg string
}
func (e *mockError) Error() string {
return e.msg
}
// mockClient is a minimal mock implementation of db_common.Client for testing
type mockClient struct {
customSearchPath []string
requiredSearchPath []string
}
func (m *mockClient) Close(ctx context.Context) error {
return nil
}
func (m *mockClient) LoadUserSearchPath(ctx context.Context) error {
return nil
}
func (m *mockClient) SetRequiredSessionSearchPath(ctx context.Context) error {
return nil
}
func (m *mockClient) GetRequiredSessionSearchPath() []string {
return m.requiredSearchPath
}
func (m *mockClient) GetCustomSearchPath() []string {
return m.customSearchPath
}
func (m *mockClient) AcquireManagementConnection(ctx context.Context) (*pgxpool.Conn, error) {
return nil, nil
}
func (m *mockClient) AcquireSession(ctx context.Context) *db_common.AcquireSessionResult {
return nil
}
func (m *mockClient) ExecuteSync(ctx context.Context, query string, args ...any) (*pqueryresult.SyncQueryResult, error) {
return nil, nil
}
func (m *mockClient) Execute(ctx context.Context, query string, args ...any) (*queryresult.Result, error) {
return nil, nil
}
func (m *mockClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, args ...any) (*pqueryresult.SyncQueryResult, error) {
return nil, nil
}
func (m *mockClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, onConnectionLost func(), query string, args ...any) (*queryresult.Result, error) {
return nil, nil
}
func (m *mockClient) ResetPools(ctx context.Context) {
}
func (m *mockClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetadata, error) {
return nil, nil
}
func (m *mockClient) ServerSettings() *db_common.ServerSettings {
return nil
}
func (m *mockClient) RegisterNotificationListener(f func(notification *pgconn.Notification)) {
}

View File

@@ -0,0 +1,412 @@
package steampipeconfig
import (
"testing"
"time"
"github.com/turbot/steampipe/v2/pkg/constants"
)
func TestConnectionStateMapGetSummary(t *testing.T) {
stateMap := ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateReady,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStateReady,
},
"conn3": &ConnectionState{
ConnectionName: "conn3",
State: constants.ConnectionStateError,
},
"conn4": &ConnectionState{
ConnectionName: "conn4",
State: constants.ConnectionStatePending,
},
}
summary := stateMap.GetSummary()
if summary[constants.ConnectionStateReady] != 2 {
t.Errorf("Expected 2 ready connections, got %d", summary[constants.ConnectionStateReady])
}
if summary[constants.ConnectionStateError] != 1 {
t.Errorf("Expected 1 error connection, got %d", summary[constants.ConnectionStateError])
}
if summary[constants.ConnectionStatePending] != 1 {
t.Errorf("Expected 1 pending connection, got %d", summary[constants.ConnectionStatePending])
}
}
func TestConnectionStateMapPending(t *testing.T) {
testCases := []struct {
name string
stateMap ConnectionStateMap
expected bool
}{
{
name: "has pending connections",
stateMap: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStatePending,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStateReady,
},
},
expected: true,
},
{
name: "has pending incomplete connections",
stateMap: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStatePendingIncomplete,
},
},
expected: true,
},
{
name: "no pending connections",
stateMap: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateReady,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStateError,
},
},
expected: false,
},
{
name: "empty map",
stateMap: ConnectionStateMap{},
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.stateMap.Pending()
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateMapLoaded(t *testing.T) {
testCases := []struct {
name string
stateMap ConnectionStateMap
connections []string
expected bool
}{
{
name: "all connections loaded",
stateMap: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateReady,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStateError,
},
},
connections: []string{},
expected: true,
},
{
name: "some connections not loaded",
stateMap: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateReady,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStatePending,
},
},
connections: []string{},
expected: false,
},
{
name: "specific connections loaded",
stateMap: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateReady,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStatePending,
},
},
connections: []string{"conn1"},
expected: true,
},
{
name: "disabled connections are loaded",
stateMap: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateDisabled,
},
},
connections: []string{},
expected: true,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.stateMap.Loaded(testCase.connections...)
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateMapConnectionsInState(t *testing.T) {
stateMap := ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateReady,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStateError,
},
"conn3": &ConnectionState{
ConnectionName: "conn3",
State: constants.ConnectionStatePending,
},
}
testCases := []struct {
name string
states []string
expected bool
}{
{
name: "has ready connections",
states: []string{constants.ConnectionStateReady},
expected: true,
},
{
name: "has error or pending connections",
states: []string{constants.ConnectionStateError, constants.ConnectionStatePending},
expected: true,
},
{
name: "no updating connections",
states: []string{constants.ConnectionStateUpdating},
expected: false,
},
{
name: "no deleting connections",
states: []string{constants.ConnectionStateDeleting},
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := stateMap.ConnectionsInState(testCase.states...)
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateMapEquals(t *testing.T) {
testCases := []struct {
name string
map1 ConnectionStateMap
map2 ConnectionStateMap
expected bool
}{
{
name: "equal maps",
map1: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
Plugin: "plugin1",
State: constants.ConnectionStateReady,
},
},
map2: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
Plugin: "plugin1",
State: constants.ConnectionStateReady,
},
},
expected: true,
},
{
name: "different plugins",
map1: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
Plugin: "plugin1",
State: constants.ConnectionStateReady,
},
},
map2: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
Plugin: "plugin2",
State: constants.ConnectionStateReady,
},
},
expected: false,
},
{
name: "different keys",
map1: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
Plugin: "plugin1",
State: constants.ConnectionStateReady,
},
},
map2: ConnectionStateMap{
"conn2": &ConnectionState{
ConnectionName: "conn2",
Plugin: "plugin1",
State: constants.ConnectionStateReady,
},
},
expected: false,
},
{
name: "nil vs non-nil",
map1: nil,
map2: ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
Plugin: "plugin1",
State: constants.ConnectionStateReady,
},
},
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.map1.Equals(testCase.map2)
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateMapConnectionModTime(t *testing.T) {
now := time.Now()
earlier := now.Add(-1 * time.Hour)
later := now.Add(1 * time.Hour)
stateMap := ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
ConnectionModTime: earlier,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
ConnectionModTime: later,
},
"conn3": &ConnectionState{
ConnectionName: "conn3",
ConnectionModTime: now,
},
}
result := stateMap.ConnectionModTime()
if !result.Equal(later) {
t.Errorf("Expected latest mod time %v, got %v", later, result)
}
}
func TestConnectionStateMapConnectionModTimeEmpty(t *testing.T) {
stateMap := ConnectionStateMap{}
result := stateMap.ConnectionModTime()
if !result.IsZero() {
t.Errorf("Expected zero time for empty map, got %v", result)
}
}
func TestConnectionStateMapGetPluginToConnectionMap(t *testing.T) {
stateMap := ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
Plugin: "plugin1",
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
Plugin: "plugin1",
},
"conn3": &ConnectionState{
ConnectionName: "conn3",
Plugin: "plugin2",
},
}
result := stateMap.GetPluginToConnectionMap()
if len(result["plugin1"]) != 2 {
t.Errorf("Expected 2 connections for plugin1, got %d", len(result["plugin1"]))
}
if len(result["plugin2"]) != 1 {
t.Errorf("Expected 1 connection for plugin2, got %d", len(result["plugin2"]))
}
}
func TestConnectionStateMapSetConnectionsToPendingOrIncomplete(t *testing.T) {
stateMap := ConnectionStateMap{
"conn1": &ConnectionState{
ConnectionName: "conn1",
State: constants.ConnectionStateReady,
},
"conn2": &ConnectionState{
ConnectionName: "conn2",
State: constants.ConnectionStateError,
},
"conn3": &ConnectionState{
ConnectionName: "conn3",
State: constants.ConnectionStateDisabled,
},
}
stateMap.SetConnectionsToPendingOrIncomplete()
if stateMap["conn1"].State != constants.ConnectionStatePending {
t.Errorf("Expected conn1 to be pending, got %s", stateMap["conn1"].State)
}
if stateMap["conn2"].State != constants.ConnectionStatePendingIncomplete {
t.Errorf("Expected conn2 to be pending incomplete, got %s", stateMap["conn2"].State)
}
if stateMap["conn3"].State != constants.ConnectionStateDisabled {
t.Errorf("Expected conn3 to remain disabled, got %s", stateMap["conn3"].State)
}
}

View File

@@ -2,6 +2,10 @@ package steampipeconfig
import (
"testing"
"time"
typehelpers "github.com/turbot/go-kit/types"
"github.com/turbot/steampipe/v2/pkg/constants"
)
func TestConnectionsUpdateEqual(t *testing.T) {
@@ -25,6 +29,84 @@ func TestConnectionsUpdateEqual(t *testing.T) {
},
expected: true,
},
{
name: "different plugin",
data1: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
State: "ready",
},
data2: &ConnectionState{
ConnectionName: "test1",
Plugin: "different_plugin",
State: "ready",
},
expected: false,
},
{
name: "different type",
data1: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
Type: typehelpers.String("aggregator"),
State: "ready",
},
data2: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
Type: nil,
State: "ready",
},
expected: false,
},
{
name: "different import schema",
data1: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
ImportSchema: "enabled",
State: "ready",
},
data2: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
ImportSchema: "disabled",
State: "ready",
},
expected: false,
},
{
name: "different error",
data1: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
ConnectionError: typehelpers.String("error1"),
State: "error",
},
data2: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
ConnectionError: typehelpers.String("error2"),
State: "error",
},
expected: false,
},
{
name: "plugin mod time within tolerance",
data1: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
PluginModTime: time.Now(),
State: "ready",
},
data2: &ConnectionState{
ConnectionName: "test1",
Plugin: "test_plugin",
PluginModTime: time.Now().Add(500 * time.Microsecond),
State: "ready",
},
expected: true,
},
}
for _, testCase := range testCases {
@@ -36,3 +118,188 @@ func TestConnectionsUpdateEqual(t *testing.T) {
})
}
}
func TestConnectionStateLoaded(t *testing.T) {
testCases := []struct {
name string
state *ConnectionState
expected bool
}{
{
name: "ready state is loaded",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateReady,
},
expected: true,
},
{
name: "error state is loaded",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateError,
},
expected: true,
},
{
name: "disabled state is loaded",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateDisabled,
},
expected: true,
},
{
name: "pending state is not loaded",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStatePending,
},
expected: false,
},
{
name: "updating state is not loaded",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateUpdating,
},
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.state.Loaded()
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateDisabled(t *testing.T) {
testCases := []struct {
name string
state *ConnectionState
expected bool
}{
{
name: "disabled state",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateDisabled,
},
expected: true,
},
{
name: "ready state is not disabled",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateReady,
},
expected: false,
},
{
name: "error state is not disabled",
state: &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateError,
},
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.state.Disabled()
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateGetType(t *testing.T) {
testCases := []struct {
name string
state *ConnectionState
expected string
}{
{
name: "aggregator type",
state: &ConnectionState{
ConnectionName: "test1",
Type: typehelpers.String("aggregator"),
},
expected: "aggregator",
},
{
name: "nil type returns empty string",
state: &ConnectionState{
ConnectionName: "test1",
Type: nil,
},
expected: "",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.state.GetType()
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateError(t *testing.T) {
testCases := []struct {
name string
state *ConnectionState
expected string
}{
{
name: "error message",
state: &ConnectionState{
ConnectionName: "test1",
ConnectionError: typehelpers.String("test error"),
},
expected: "test error",
},
{
name: "nil error returns empty string",
state: &ConnectionState{
ConnectionName: "test1",
ConnectionError: nil,
},
expected: "",
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.state.Error()
if result != testCase.expected {
t.Errorf("Expected %v, got %v", testCase.expected, result)
}
})
}
}
func TestConnectionStateSetError(t *testing.T) {
state := &ConnectionState{
ConnectionName: "test1",
State: constants.ConnectionStateReady,
}
state.SetError("test error")
if state.State != constants.ConnectionStateError {
t.Errorf("Expected state to be %s, got %s", constants.ConnectionStateError, state.State)
}
if state.Error() != "test error" {
t.Errorf("Expected error to be 'test error', got %s", state.Error())
}
}

View File

@@ -0,0 +1,106 @@
package steampipeconfig
import (
"strings"
"testing"
)
func TestValidationFailureString(t *testing.T) {
testCases := []struct {
name string
failure ValidationFailure
expected []string
}{
{
name: "basic validation failure",
failure: ValidationFailure{
Plugin: "hub.steampipe.io/plugins/turbot/aws@latest",
ConnectionName: "aws_prod",
Message: "invalid configuration",
ShouldDropIfExists: false,
},
expected: []string{
"Connection: aws_prod",
"Plugin: hub.steampipe.io/plugins/turbot/aws@latest",
"Error: invalid configuration",
},
},
{
name: "validation failure with drop flag",
failure: ValidationFailure{
Plugin: "hub.steampipe.io/plugins/turbot/gcp@latest",
ConnectionName: "gcp_dev",
Message: "missing required field",
ShouldDropIfExists: true,
},
expected: []string{
"Connection: gcp_dev",
"Plugin: hub.steampipe.io/plugins/turbot/gcp@latest",
"Error: missing required field",
},
},
{
name: "validation failure with empty message",
failure: ValidationFailure{
Plugin: "test_plugin",
ConnectionName: "test_conn",
Message: "",
ShouldDropIfExists: false,
},
expected: []string{
"Connection: test_conn",
"Plugin: test_plugin",
"Error: ",
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := testCase.failure.String()
for _, expected := range testCase.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected result to contain '%s', got: %s", expected, result)
}
}
})
}
}
func TestValidationFailureStringFormat(t *testing.T) {
failure := ValidationFailure{
Plugin: "test_plugin",
ConnectionName: "test_connection",
Message: "test error",
ShouldDropIfExists: false,
}
result := failure.String()
// Verify the format includes the expected labels
if !strings.Contains(result, "Connection:") {
t.Error("Expected result to contain 'Connection:' label")
}
if !strings.Contains(result, "Plugin:") {
t.Error("Expected result to contain 'Plugin:' label")
}
if !strings.Contains(result, "Error:") {
t.Error("Expected result to contain 'Error:' label")
}
// Verify the values are present
if !strings.Contains(result, "test_connection") {
t.Error("Expected result to contain connection name")
}
if !strings.Contains(result, "test_plugin") {
t.Error("Expected result to contain plugin name")
}
if !strings.Contains(result, "test error") {
t.Error("Expected result to contain error message")
}
}