mirror of
https://github.com/turbot/steampipe.git
synced 2025-12-19 18:12:43 -05:00
Add comprehensive passing tests from bug hunting initiative (#4864)
This commit is contained in:
232
pkg/cmdconfig/cmd_hooks_test.go
Normal file
232
pkg/cmdconfig/cmd_hooks_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package cmdconfig
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func TestPostRunHook_WaitsForTasks(t *testing.T) {
|
||||
// Test that postRunHook waits for async tasks
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
// Simulate a task channel
|
||||
testChannel := make(chan struct{})
|
||||
oldChannel := waitForTasksChannel
|
||||
waitForTasksChannel = testChannel
|
||||
defer func() { waitForTasksChannel = oldChannel }()
|
||||
|
||||
// Close the channel after a short delay
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
close(testChannel)
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
postRunHook(cmd, []string{})
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should have waited for the channel to close
|
||||
if duration < 10*time.Millisecond {
|
||||
t.Error("postRunHook did not wait for tasks channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostRunHook_Timeout(t *testing.T) {
|
||||
// Test that postRunHook times out if tasks take too long
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
// Simulate a task channel that never closes
|
||||
testChannel := make(chan struct{})
|
||||
oldChannel := waitForTasksChannel
|
||||
waitForTasksChannel = testChannel
|
||||
defer func() {
|
||||
waitForTasksChannel = oldChannel
|
||||
close(testChannel)
|
||||
}()
|
||||
|
||||
// Mock cancel function
|
||||
cancelCalled := false
|
||||
oldCancelFn := tasksCancelFn
|
||||
tasksCancelFn = func() {
|
||||
cancelCalled = true
|
||||
}
|
||||
defer func() { tasksCancelFn = oldCancelFn }()
|
||||
|
||||
start := time.Now()
|
||||
postRunHook(cmd, []string{})
|
||||
duration := time.Since(start)
|
||||
|
||||
// Should have timed out after 100ms
|
||||
if duration < 100*time.Millisecond || duration > 150*time.Millisecond {
|
||||
t.Errorf("postRunHook timeout not working correctly, took %v", duration)
|
||||
}
|
||||
|
||||
if !cancelCalled {
|
||||
t.Error("Cancel function was not called on timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdBuilder_HookIntegration(t *testing.T) {
|
||||
// Test that CmdBuilder properly wraps hooks
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
cmd.PreRun = func(cmd *cobra.Command, args []string) {
|
||||
// Original PreRun
|
||||
}
|
||||
|
||||
cmd.PostRun = func(cmd *cobra.Command, args []string) {
|
||||
// Original PostRun
|
||||
}
|
||||
|
||||
cmd.Run = func(cmd *cobra.Command, args []string) {
|
||||
// Original Run
|
||||
}
|
||||
|
||||
// Build with CmdBuilder
|
||||
builder := OnCmd(cmd)
|
||||
if builder == nil {
|
||||
t.Fatal("OnCmd returned nil")
|
||||
}
|
||||
|
||||
// The hooks should now be wrapped
|
||||
if cmd.PreRun == nil {
|
||||
t.Error("PreRun hook was not set")
|
||||
}
|
||||
if cmd.PostRun == nil {
|
||||
t.Error("PostRun hook was not set")
|
||||
}
|
||||
if cmd.Run == nil {
|
||||
t.Error("Run hook was not set")
|
||||
}
|
||||
|
||||
// Note: We can't easily test the wrapped functions without a full cobra execution
|
||||
// This would require integration tests
|
||||
t.Log("CmdBuilder successfully wrapped command hooks")
|
||||
}
|
||||
|
||||
func TestCmdBuilder_FlagBinding(t *testing.T) {
|
||||
// Test that CmdBuilder properly binds flags to viper
|
||||
viper.Reset()
|
||||
defer viper.Reset()
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
builder := OnCmd(cmd)
|
||||
builder.AddStringFlag("test-flag", "default-value", "Test flag description")
|
||||
|
||||
// Verify flag was added
|
||||
flag := cmd.Flags().Lookup("test-flag")
|
||||
if flag == nil {
|
||||
t.Fatal("Flag was not added to command")
|
||||
}
|
||||
|
||||
if flag.DefValue != "default-value" {
|
||||
t.Errorf("Flag default value incorrect, got %s", flag.DefValue)
|
||||
}
|
||||
|
||||
// Verify binding was stored
|
||||
if len(builder.bindings) != 1 {
|
||||
t.Errorf("Expected 1 binding, got %d", len(builder.bindings))
|
||||
}
|
||||
|
||||
if builder.bindings["test-flag"] != flag {
|
||||
t.Error("Flag binding not stored correctly")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdBuilder_MultipleFlagTypes(t *testing.T) {
|
||||
// Test that CmdBuilder can handle multiple flag types
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
builder := OnCmd(cmd)
|
||||
builder.
|
||||
AddStringFlag("string-flag", "default", "String flag").
|
||||
AddIntFlag("int-flag", 42, "Int flag").
|
||||
AddBoolFlag("bool-flag", true, "Bool flag").
|
||||
AddStringSliceFlag("slice-flag", []string{"a", "b"}, "Slice flag")
|
||||
|
||||
// Verify all flags were added
|
||||
if cmd.Flags().Lookup("string-flag") == nil {
|
||||
t.Error("String flag not added")
|
||||
}
|
||||
if cmd.Flags().Lookup("int-flag") == nil {
|
||||
t.Error("Int flag not added")
|
||||
}
|
||||
if cmd.Flags().Lookup("bool-flag") == nil {
|
||||
t.Error("Bool flag not added")
|
||||
}
|
||||
if cmd.Flags().Lookup("slice-flag") == nil {
|
||||
t.Error("Slice flag not added")
|
||||
}
|
||||
|
||||
// Verify all bindings were stored
|
||||
if len(builder.bindings) != 4 {
|
||||
t.Errorf("Expected 4 bindings, got %d", len(builder.bindings))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdBuilder_CloudFlags(t *testing.T) {
|
||||
// Test that AddCloudFlags adds the expected flags
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
builder := OnCmd(cmd)
|
||||
builder.AddCloudFlags()
|
||||
|
||||
// Verify cloud flags were added
|
||||
if cmd.Flags().Lookup("pipes-host") == nil {
|
||||
t.Error("pipes-host flag not added")
|
||||
}
|
||||
if cmd.Flags().Lookup("pipes-token") == nil {
|
||||
t.Error("pipes-token flag not added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCmdBuilder_NilFlagPanic(t *testing.T) {
|
||||
// Test that nil flag causes panic (as documented in builder.go)
|
||||
cmd := &cobra.Command{
|
||||
Use: "test",
|
||||
PreRun: func(cmd *cobra.Command, args []string) {
|
||||
// This will be called by CmdBuilder's wrapped PreRun
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {},
|
||||
}
|
||||
|
||||
builder := OnCmd(cmd)
|
||||
builder.AddStringFlag("test-flag", "default", "Test flag")
|
||||
|
||||
// Manually corrupt the bindings to test panic
|
||||
builder.bindings["corrupt-flag"] = nil
|
||||
|
||||
// This should panic when PreRun is executed
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Error("Expected panic for nil flag binding")
|
||||
} else {
|
||||
t.Logf("Correctly panicked with: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Execute PreRun which should panic
|
||||
cmd.PreRun(cmd, []string{})
|
||||
}
|
||||
481
pkg/connection/refresh_connections_state_test.go
Normal file
481
pkg/connection/refresh_connections_state_test.go
Normal file
@@ -0,0 +1,481 @@
|
||||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/turbot/steampipe/v2/pkg/steampipeconfig"
|
||||
)
|
||||
|
||||
// TestRefreshConnectionState_ExemplarSchemaMapConcurrentWrites tests concurrent writes to exemplarSchemaMap
|
||||
// This verifies the fix for bug #4757
|
||||
func TestRefreshConnectionState_ExemplarSchemaMapConcurrentWrites(t *testing.T) {
|
||||
// ARRANGE: Create state with initialized maps
|
||||
state := &refreshConnectionState{
|
||||
exemplarSchemaMap: make(map[string]string),
|
||||
exemplarSchemaMapMut: sync.Mutex{},
|
||||
}
|
||||
|
||||
numGoroutines := 50
|
||||
numIterations := 100
|
||||
plugins := []string{"aws", "azure", "gcp", "github", "slack"}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// ACT: Launch goroutines that concurrently write to exemplarSchemaMap
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numIterations; j++ {
|
||||
plugin := plugins[j%len(plugins)]
|
||||
connectionName := fmt.Sprintf("conn_%d_%d", id, j)
|
||||
|
||||
// Simulate the FIXED pattern from executeUpdateForConnections (lines 600-605)
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
_, haveExemplar := state.exemplarSchemaMap[plugin]
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
if !haveExemplar {
|
||||
// This write is now protected by mutex (fix for #4757)
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
state.exemplarSchemaMap[plugin] = connectionName
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// ASSERT: Verify all plugins are in the map
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
defer state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
if len(state.exemplarSchemaMap) != len(plugins) {
|
||||
t.Errorf("Expected %d plugins in exemplarSchemaMap, got %d", len(plugins), len(state.exemplarSchemaMap))
|
||||
}
|
||||
|
||||
for _, plugin := range plugins {
|
||||
if _, ok := state.exemplarSchemaMap[plugin]; !ok {
|
||||
t.Errorf("Expected plugin %s to be in exemplarSchemaMap", plugin)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshConnectionState_ExemplarSchemaMapConcurrentReadWrite tests concurrent reads and writes
|
||||
func TestRefreshConnectionState_ExemplarSchemaMapConcurrentReadWrite(t *testing.T) {
|
||||
// ARRANGE: Create state with some pre-populated data
|
||||
state := &refreshConnectionState{
|
||||
exemplarSchemaMap: map[string]string{
|
||||
"aws": "aws_conn_1",
|
||||
"azure": "azure_conn_1",
|
||||
},
|
||||
exemplarSchemaMapMut: sync.Mutex{},
|
||||
}
|
||||
|
||||
numReaders := 30
|
||||
numWriters := 20
|
||||
duration := 100 * time.Millisecond
|
||||
|
||||
var wg sync.WaitGroup
|
||||
ctx, cancel := context.WithTimeout(context.Background(), duration)
|
||||
defer cancel()
|
||||
|
||||
// ACT: Launch reader goroutines
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
_ = state.exemplarSchemaMap["aws"]
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Launch writer goroutines
|
||||
for i := 0; i < numWriters; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
plugin := fmt.Sprintf("plugin_%d", id)
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
state.exemplarSchemaMap[plugin] = fmt.Sprintf("conn_%d", id)
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// ASSERT: No race conditions should occur (run with -race flag)
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
defer state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
// Basic sanity check
|
||||
if len(state.exemplarSchemaMap) < 2 {
|
||||
t.Error("Expected at least 2 entries in exemplarSchemaMap")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshConnectionState_ExemplarMapRaceCondition tests the exact race condition from bug #4757
|
||||
func TestRefreshConnectionState_ExemplarMapRaceCondition(t *testing.T) {
|
||||
// This test verifies that the fix for #4757 works correctly
|
||||
// The bug was: reading haveExemplarSchema without lock, then writing without lock
|
||||
// The fix: both read and write are now properly protected by mutex
|
||||
|
||||
// ARRANGE
|
||||
state := &refreshConnectionState{
|
||||
exemplarSchemaMap: make(map[string]string),
|
||||
exemplarSchemaMapMut: sync.Mutex{},
|
||||
}
|
||||
|
||||
numGoroutines := 100
|
||||
pluginName := "aws"
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, numGoroutines)
|
||||
|
||||
// ACT: Simulate the exact pattern from executeUpdateForConnections
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
|
||||
connectionName := fmt.Sprintf("aws_conn_%d", id)
|
||||
|
||||
// This is the FIXED pattern from lines 581-604
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
_, haveExemplarSchema := state.exemplarSchemaMap[pluginName]
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
// Simulate some work
|
||||
time.Sleep(time.Microsecond)
|
||||
|
||||
if !haveExemplarSchema {
|
||||
// Write is now protected by mutex (fix for #4757)
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
// Check again after acquiring lock (double-check pattern)
|
||||
if _, exists := state.exemplarSchemaMap[pluginName]; !exists {
|
||||
state.exemplarSchemaMap[pluginName] = connectionName
|
||||
}
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// ASSERT: Check for errors
|
||||
for err := range errChan {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// Verify the map has exactly one entry for the plugin
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
defer state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
if len(state.exemplarSchemaMap) != 1 {
|
||||
t.Errorf("Expected exactly 1 entry in exemplarSchemaMap, got %d", len(state.exemplarSchemaMap))
|
||||
}
|
||||
|
||||
if _, ok := state.exemplarSchemaMap[pluginName]; !ok {
|
||||
t.Error("Expected plugin to be in exemplarSchemaMap")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateSetMapToArray tests the conversion utility function
|
||||
func TestUpdateSetMapToArray(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input map[string][]*steampipeconfig.ConnectionState
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "empty_map",
|
||||
input: map[string][]*steampipeconfig.ConnectionState{},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "single_entry_single_state",
|
||||
input: map[string][]*steampipeconfig.ConnectionState{
|
||||
"plugin1": {
|
||||
{ConnectionName: "conn1"},
|
||||
},
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "single_entry_multiple_states",
|
||||
input: map[string][]*steampipeconfig.ConnectionState{
|
||||
"plugin1": {
|
||||
{ConnectionName: "conn1"},
|
||||
{ConnectionName: "conn2"},
|
||||
{ConnectionName: "conn3"},
|
||||
},
|
||||
},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "multiple_entries",
|
||||
input: map[string][]*steampipeconfig.ConnectionState{
|
||||
"plugin1": {
|
||||
{ConnectionName: "conn1"},
|
||||
{ConnectionName: "conn2"},
|
||||
},
|
||||
"plugin2": {
|
||||
{ConnectionName: "conn3"},
|
||||
},
|
||||
},
|
||||
expected: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// ACT
|
||||
result := updateSetMapToArray(tt.input)
|
||||
|
||||
// ASSERT
|
||||
if len(result) != tt.expected {
|
||||
t.Errorf("Expected %d connection states, got %d", tt.expected, len(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCloneSchemaQuery tests the schema cloning query generation
|
||||
func TestGetCloneSchemaQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
exemplarName string
|
||||
connState *steampipeconfig.ConnectionState
|
||||
expectedQuery string
|
||||
}{
|
||||
{
|
||||
name: "basic_clone",
|
||||
exemplarName: "aws_source",
|
||||
connState: &steampipeconfig.ConnectionState{
|
||||
ConnectionName: "aws_target",
|
||||
Plugin: "hub.steampipe.io/plugins/turbot/aws@latest",
|
||||
},
|
||||
expectedQuery: "select clone_foreign_schema('aws_source', 'aws_target', 'hub.steampipe.io/plugins/turbot/aws@latest');",
|
||||
},
|
||||
{
|
||||
name: "with_special_characters",
|
||||
exemplarName: "test-source",
|
||||
connState: &steampipeconfig.ConnectionState{
|
||||
ConnectionName: "test-target",
|
||||
Plugin: "test/plugin@1.0.0",
|
||||
},
|
||||
expectedQuery: "select clone_foreign_schema('test-source', 'test-target', 'test/plugin@1.0.0');",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// ACT
|
||||
result := getCloneSchemaQuery(tt.exemplarName, tt.connState)
|
||||
|
||||
// ASSERT
|
||||
if result != tt.expectedQuery {
|
||||
t.Errorf("Expected query:\n%s\nGot:\n%s", tt.expectedQuery, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshConnectionState_DeferErrorHandling tests error handling in defer blocks
|
||||
func TestRefreshConnectionState_DeferErrorHandling(t *testing.T) {
|
||||
// This tests the defer block at lines 98-108 in refreshConnections
|
||||
|
||||
// ARRANGE: Create state with a result that will have an error
|
||||
state := &refreshConnectionState{
|
||||
res: &steampipeconfig.RefreshConnectionResult{},
|
||||
}
|
||||
|
||||
// Simulate setting an error
|
||||
testErr := errors.New("test error")
|
||||
state.res.Error = testErr
|
||||
|
||||
// ACT: The defer block should handle this gracefully
|
||||
// In the actual code, this is called via defer func()
|
||||
// We're testing the logic here
|
||||
|
||||
// ASSERT: Verify the defer logic works
|
||||
if state.res != nil && state.res.Error != nil {
|
||||
// This is what the defer does - it would call setIncompleteConnectionStateToError
|
||||
// We're just verifying the nil checks work
|
||||
if state.res.Error != testErr {
|
||||
t.Error("Error should be preserved")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshConnectionState_NilResInDefer tests nil res handling in defer block
|
||||
func TestRefreshConnectionState_NilResInDefer(t *testing.T) {
|
||||
// ARRANGE: Create state with nil res
|
||||
state := &refreshConnectionState{
|
||||
res: nil,
|
||||
}
|
||||
|
||||
// ACT & ASSERT: The defer block at line 98-108 checks if res is nil
|
||||
// This should not panic
|
||||
if state.res != nil {
|
||||
t.Error("res should be nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshConnectionState_MultiplePluginsSameExemplar tests that only one exemplar is stored per plugin
|
||||
func TestRefreshConnectionState_MultiplePluginsSameExemplar(t *testing.T) {
|
||||
// ARRANGE
|
||||
state := &refreshConnectionState{
|
||||
exemplarSchemaMap: make(map[string]string),
|
||||
exemplarSchemaMapMut: sync.Mutex{},
|
||||
}
|
||||
|
||||
pluginName := "aws"
|
||||
connections := []string{"aws1", "aws2", "aws3", "aws4", "aws5"}
|
||||
|
||||
// ACT: Add connections sequentially (simulating the pattern from the code)
|
||||
for _, conn := range connections {
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
_, exists := state.exemplarSchemaMap[pluginName]
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
if !exists {
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
// Double-check pattern
|
||||
if _, exists := state.exemplarSchemaMap[pluginName]; !exists {
|
||||
state.exemplarSchemaMap[pluginName] = conn
|
||||
}
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// ASSERT: Only the first connection should be stored
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
defer state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
if len(state.exemplarSchemaMap) != 1 {
|
||||
t.Errorf("Expected 1 entry, got %d", len(state.exemplarSchemaMap))
|
||||
}
|
||||
|
||||
if exemplar, ok := state.exemplarSchemaMap[pluginName]; !ok {
|
||||
t.Error("Expected plugin to be in map")
|
||||
} else if exemplar != connections[0] {
|
||||
t.Errorf("Expected first connection %s to be exemplar, got %s", connections[0], exemplar)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshConnectionState_ErrorChannelBlocking tests that error channel doesn't block
|
||||
func TestRefreshConnectionState_ErrorChannelBlocking(t *testing.T) {
|
||||
// This tests a potential bug in executeUpdateSetsInParallel where the error channel
|
||||
// could block if it's not properly drained
|
||||
|
||||
// ARRANGE
|
||||
errChan := make(chan *connectionError, 10) // Buffered channel
|
||||
numErrors := 20 // More errors than buffer size
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Start a consumer goroutine (like in the actual code at line 519-536)
|
||||
consumerDone := make(chan bool)
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case err := <-errChan:
|
||||
if err == nil {
|
||||
consumerDone <- true
|
||||
return
|
||||
}
|
||||
// Process error
|
||||
_ = err
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// ACT: Send many errors
|
||||
for i := 0; i < numErrors; i++ {
|
||||
wg.Add(1)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
errChan <- &connectionError{
|
||||
name: fmt.Sprintf("conn_%d", id),
|
||||
err: fmt.Errorf("error %d", id),
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
// Wait for consumer to finish
|
||||
select {
|
||||
case <-consumerDone:
|
||||
// Good - consumer exited
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("Error channel consumer did not exit in time")
|
||||
}
|
||||
|
||||
// ASSERT: No goroutines should be blocked
|
||||
}
|
||||
|
||||
// TestRefreshConnectionState_ExemplarMapNilPlugin tests handling of empty plugin names
|
||||
func TestRefreshConnectionState_ExemplarMapNilPlugin(t *testing.T) {
|
||||
// ARRANGE
|
||||
state := &refreshConnectionState{
|
||||
exemplarSchemaMap: make(map[string]string),
|
||||
exemplarSchemaMapMut: sync.Mutex{},
|
||||
}
|
||||
|
||||
// ACT: Try to add entry with empty plugin name
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
state.exemplarSchemaMap[""] = "some_connection"
|
||||
state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
// ASSERT: Map should accept empty string as key (Go maps allow this)
|
||||
state.exemplarSchemaMapMut.Lock()
|
||||
defer state.exemplarSchemaMapMut.Unlock()
|
||||
|
||||
if _, ok := state.exemplarSchemaMap[""]; !ok {
|
||||
t.Error("Expected empty string key to be in map")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConnectionError tests the connectionError struct
|
||||
func TestConnectionError(t *testing.T) {
|
||||
// ARRANGE
|
||||
testErr := errors.New("test error")
|
||||
connErr := &connectionError{
|
||||
name: "test_connection",
|
||||
err: testErr,
|
||||
}
|
||||
|
||||
// ASSERT
|
||||
if connErr.name != "test_connection" {
|
||||
t.Errorf("Expected name 'test_connection', got '%s'", connErr.name)
|
||||
}
|
||||
|
||||
if connErr.err != testErr {
|
||||
t.Error("Error not preserved")
|
||||
}
|
||||
}
|
||||
168
pkg/db/db_client/db_client_session_test.go
Normal file
168
pkg/db/db_client/db_client_session_test.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package db_client
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/turbot/steampipe/v2/pkg/db/db_common"
|
||||
)
|
||||
|
||||
// TestDbClient_SessionRegistration verifies session registration in sessions map
|
||||
func TestDbClient_SessionRegistration(t *testing.T) {
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Simulate session registration
|
||||
backendPid := uint32(12345)
|
||||
session := db_common.NewDBSession(backendPid)
|
||||
|
||||
client.sessionsMutex.Lock()
|
||||
client.sessions[backendPid] = session
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
// Verify session is registered
|
||||
client.sessionsMutex.Lock()
|
||||
registeredSession, found := client.sessions[backendPid]
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
assert.True(t, found, "Session should be registered")
|
||||
assert.Equal(t, backendPid, registeredSession.BackendPid, "Backend PID should match")
|
||||
}
|
||||
|
||||
// TestDbClient_SessionUnregistration verifies session cleanup via BeforeClose
|
||||
func TestDbClient_SessionUnregistration(t *testing.T) {
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Add sessions
|
||||
backendPid1 := uint32(100)
|
||||
backendPid2 := uint32(200)
|
||||
|
||||
client.sessionsMutex.Lock()
|
||||
client.sessions[backendPid1] = db_common.NewDBSession(backendPid1)
|
||||
client.sessions[backendPid2] = db_common.NewDBSession(backendPid2)
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
assert.Len(t, client.sessions, 2, "Should have 2 sessions")
|
||||
|
||||
// Simulate BeforeClose callback for one session
|
||||
client.sessionsMutex.Lock()
|
||||
delete(client.sessions, backendPid1)
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
// Verify only one session remains
|
||||
client.sessionsMutex.Lock()
|
||||
_, found1 := client.sessions[backendPid1]
|
||||
_, found2 := client.sessions[backendPid2]
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
assert.False(t, found1, "First session should be removed")
|
||||
assert.True(t, found2, "Second session should still exist")
|
||||
assert.Len(t, client.sessions, 1, "Should have 1 session remaining")
|
||||
}
|
||||
|
||||
// TestDbClient_ConcurrentSessionRegistration tests concurrent session additions
|
||||
func TestDbClient_ConcurrentSessionRegistration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrent test in short mode")
|
||||
}
|
||||
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
|
||||
// Concurrently add sessions
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id uint32) {
|
||||
defer wg.Done()
|
||||
backendPid := id
|
||||
session := db_common.NewDBSession(backendPid)
|
||||
|
||||
client.sessionsMutex.Lock()
|
||||
client.sessions[backendPid] = session
|
||||
client.sessionsMutex.Unlock()
|
||||
}(uint32(i))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Verify all sessions were added
|
||||
assert.Len(t, client.sessions, numGoroutines, "All sessions should be registered")
|
||||
}
|
||||
|
||||
// TestDbClient_SessionMapGrowthUnbounded tests for potential memory leaks
|
||||
// This verifies that sessions don't accumulate indefinitely
|
||||
func TestDbClient_SessionMapGrowthUnbounded(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping large dataset test in short mode")
|
||||
}
|
||||
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Simulate many connections
|
||||
numSessions := 10000
|
||||
for i := 0; i < numSessions; i++ {
|
||||
backendPid := uint32(i)
|
||||
session := db_common.NewDBSession(backendPid)
|
||||
|
||||
client.sessionsMutex.Lock()
|
||||
client.sessions[backendPid] = session
|
||||
client.sessionsMutex.Unlock()
|
||||
}
|
||||
|
||||
assert.Len(t, client.sessions, numSessions, "Should have all sessions")
|
||||
|
||||
// Simulate cleanup (BeforeClose callbacks)
|
||||
for i := 0; i < numSessions; i++ {
|
||||
backendPid := uint32(i)
|
||||
|
||||
client.sessionsMutex.Lock()
|
||||
delete(client.sessions, backendPid)
|
||||
client.sessionsMutex.Unlock()
|
||||
}
|
||||
|
||||
// Verify all sessions are cleaned up
|
||||
assert.Len(t, client.sessions, 0, "All sessions should be cleaned up")
|
||||
}
|
||||
|
||||
// TestDbClient_SearchPathUpdates verifies session search path management
|
||||
func TestDbClient_SearchPathUpdates(t *testing.T) {
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
customSearchPath: []string{"schema1", "schema2"},
|
||||
}
|
||||
|
||||
// Add a session
|
||||
backendPid := uint32(12345)
|
||||
session := db_common.NewDBSession(backendPid)
|
||||
|
||||
client.sessionsMutex.Lock()
|
||||
client.sessions[backendPid] = session
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
// Verify custom search path is set
|
||||
assert.NotNil(t, client.customSearchPath, "Custom search path should be set")
|
||||
assert.Len(t, client.customSearchPath, 2, "Should have 2 schemas in search path")
|
||||
}
|
||||
|
||||
// TestDbClient_SessionConnectionNilSafety verifies handling of nil connections
|
||||
func TestDbClient_SessionConnectionNilSafety(t *testing.T) {
|
||||
session := db_common.NewDBSession(12345)
|
||||
|
||||
// Session is created with nil connection initially
|
||||
assert.Nil(t, session.Connection, "New session should have nil connection initially")
|
||||
}
|
||||
@@ -1,12 +1,15 @@
|
||||
package db_client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/turbot/steampipe/v2/pkg/db/db_common"
|
||||
)
|
||||
|
||||
// TestSessionMapCleanupImplemented verifies that the session map memory leak is fixed
|
||||
@@ -48,3 +51,220 @@ func TestSessionMapCleanupImplemented(t *testing.T) {
|
||||
assert.True(t, hasCleanupComment,
|
||||
"Comment should document automatic cleanup mechanism")
|
||||
}
|
||||
|
||||
// TestDbClient_Close_Idempotent verifies that calling Close() multiple times does not cause issues
|
||||
// Reference: Similar to bug #4712 (Result.Close() idempotency)
|
||||
//
|
||||
// Close() should be safe to call multiple times without panicking or causing errors.
|
||||
func TestDbClient_Close_Idempotent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a minimal client (without real connection)
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// First close
|
||||
err := client.Close(ctx)
|
||||
assert.NoError(t, err, "First Close() should not return error")
|
||||
|
||||
// Second close - should not panic
|
||||
err = client.Close(ctx)
|
||||
assert.NoError(t, err, "Second Close() should not return error")
|
||||
|
||||
// Third close - should still not panic
|
||||
err = client.Close(ctx)
|
||||
assert.NoError(t, err, "Third Close() should not return error")
|
||||
|
||||
// Verify sessions map is nil after close
|
||||
assert.Nil(t, client.sessions, "Sessions map should be nil after Close()")
|
||||
}
|
||||
|
||||
// TestDbClient_ConcurrentSessionAccess tests concurrent access to the sessions map
|
||||
// This test should be run with -race flag to detect data races.
|
||||
//
|
||||
// The sessions map is protected by sessionsMutex, but we want to verify
|
||||
// that all access paths properly use the mutex.
|
||||
func TestDbClient_ConcurrentSessionAccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrent access test in short mode")
|
||||
}
|
||||
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
numOperations := 100
|
||||
|
||||
// Track errors in a thread-safe way
|
||||
errors := make(chan error, numGoroutines*numOperations)
|
||||
|
||||
// Simulate concurrent session additions
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(id uint32) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < numOperations; j++ {
|
||||
// Add session
|
||||
client.sessionsMutex.Lock()
|
||||
backendPid := id*1000 + uint32(j)
|
||||
client.sessions[backendPid] = db_common.NewDBSession(backendPid)
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
// Read session
|
||||
client.sessionsMutex.Lock()
|
||||
_ = client.sessions[backendPid]
|
||||
client.sessionsMutex.Unlock()
|
||||
|
||||
// Delete session (simulating BeforeClose callback)
|
||||
client.sessionsMutex.Lock()
|
||||
delete(client.sessions, backendPid)
|
||||
client.sessionsMutex.Unlock()
|
||||
}
|
||||
}(uint32(i))
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
// Check for errors
|
||||
for err := range errors {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDbClient_Close_ClearsSessionsMap verifies that Close() properly clears the sessions map
|
||||
func TestDbClient_Close_ClearsSessionsMap(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Add some sessions
|
||||
client.sessions[1] = db_common.NewDBSession(1)
|
||||
client.sessions[2] = db_common.NewDBSession(2)
|
||||
client.sessions[3] = db_common.NewDBSession(3)
|
||||
|
||||
assert.Len(t, client.sessions, 3, "Should have 3 sessions before Close()")
|
||||
|
||||
// Close the client
|
||||
err := client.Close(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Sessions should be nil after close
|
||||
assert.Nil(t, client.sessions, "Sessions map should be nil after Close()")
|
||||
}
|
||||
|
||||
// TestDbClient_SessionsMutexProtectsMap verifies that sessionsMutex protects all map operations
|
||||
func TestDbClient_SessionsMutexProtectsMap(t *testing.T) {
|
||||
// This is a structural test to verify the sessions map is never accessed without the mutex
|
||||
content, err := os.ReadFile("db_client_session.go")
|
||||
require.NoError(t, err, "should be able to read db_client_session.go")
|
||||
|
||||
sourceCode := string(content)
|
||||
|
||||
// Count occurrences of mutex locks
|
||||
mutexLocks := strings.Count(sourceCode, "c.sessionsMutex.Lock()")
|
||||
|
||||
// This is a heuristic check - in practice, we'd need more sophisticated analysis
|
||||
// But it serves as a reminder to use the mutex
|
||||
assert.True(t, mutexLocks > 0,
|
||||
"sessionsMutex.Lock() should be used when accessing sessions map")
|
||||
}
|
||||
|
||||
// TestDbClient_SessionMapDocumentation verifies that session lifecycle is documented
|
||||
func TestDbClient_SessionMapDocumentation(t *testing.T) {
|
||||
content, err := os.ReadFile("db_client.go")
|
||||
require.NoError(t, err)
|
||||
|
||||
sourceCode := string(content)
|
||||
|
||||
// Verify documentation mentions the lifecycle
|
||||
assert.Contains(t, sourceCode, "Session lifecycle:",
|
||||
"Sessions map should have lifecycle documentation")
|
||||
|
||||
assert.Contains(t, sourceCode, "issue #3737",
|
||||
"Should reference the memory leak issue")
|
||||
}
|
||||
|
||||
// TestDbClient_ClosePools_NilPoolsHandling verifies closePools handles nil pools
|
||||
func TestDbClient_ClosePools_NilPoolsHandling(t *testing.T) {
|
||||
client := &DbClient{
|
||||
sessions: make(map[uint32]*db_common.DatabaseSession),
|
||||
sessionsMutex: &sync.Mutex{},
|
||||
}
|
||||
|
||||
// Should not panic with nil pools
|
||||
assert.NotPanics(t, func() {
|
||||
client.closePools()
|
||||
}, "closePools should handle nil pools gracefully")
|
||||
}
|
||||
|
||||
// TestDbClient_SessionsMapInitialized verifies sessions map is initialized in NewDbClient
|
||||
func TestDbClient_SessionsMapInitialized(t *testing.T) {
|
||||
// Verify the initialization happens in NewDbClient
|
||||
content, err := os.ReadFile("db_client.go")
|
||||
require.NoError(t, err)
|
||||
|
||||
sourceCode := string(content)
|
||||
|
||||
// Verify sessions map is initialized
|
||||
assert.Contains(t, sourceCode, "sessions: make(map[uint32]*db_common.DatabaseSession)",
|
||||
"sessions map should be initialized in NewDbClient")
|
||||
|
||||
// Verify mutex is initialized
|
||||
assert.Contains(t, sourceCode, "sessionsMutex: &sync.Mutex{}",
|
||||
"sessionsMutex should be initialized in NewDbClient")
|
||||
}
|
||||
|
||||
// TestDbClient_DeferredCleanupInNewDbClient verifies error cleanup in NewDbClient
|
||||
func TestDbClient_DeferredCleanupInNewDbClient(t *testing.T) {
|
||||
content, err := os.ReadFile("db_client.go")
|
||||
require.NoError(t, err)
|
||||
|
||||
sourceCode := string(content)
|
||||
|
||||
// Verify there's a defer that handles cleanup on error
|
||||
assert.Contains(t, sourceCode, "defer func() {",
|
||||
"NewDbClient should have deferred cleanup")
|
||||
|
||||
assert.Contains(t, sourceCode, "client.Close(ctx)",
|
||||
"Deferred cleanup should close the client on error")
|
||||
}
|
||||
|
||||
// TestDbClient_ParallelSessionInitLock verifies parallelSessionInitLock initialization
|
||||
func TestDbClient_ParallelSessionInitLock(t *testing.T) {
|
||||
content, err := os.ReadFile("db_client.go")
|
||||
require.NoError(t, err)
|
||||
|
||||
sourceCode := string(content)
|
||||
|
||||
// Verify parallelSessionInitLock is initialized
|
||||
assert.Contains(t, sourceCode, "parallelSessionInitLock:",
|
||||
"parallelSessionInitLock should be initialized")
|
||||
|
||||
// Should use semaphore
|
||||
assert.Contains(t, sourceCode, "semaphore.NewWeighted",
|
||||
"parallelSessionInitLock should use weighted semaphore")
|
||||
}
|
||||
|
||||
// TestDbClient_BeforeCloseCallbackNilSafety tests the BeforeClose callback with nil connection
|
||||
func TestDbClient_BeforeCloseCallbackNilSafety(t *testing.T) {
|
||||
content, err := os.ReadFile("db_client_connect.go")
|
||||
require.NoError(t, err)
|
||||
|
||||
sourceCode := string(content)
|
||||
|
||||
// Verify nil checks in BeforeClose callback
|
||||
assert.Contains(t, sourceCode, "if conn != nil",
|
||||
"BeforeClose should check if conn is nil")
|
||||
|
||||
assert.Contains(t, sourceCode, "conn.PgConn() != nil",
|
||||
"BeforeClose should check if PgConn() is nil")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package db_local
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsValidDatabaseName(t *testing.T) {
|
||||
tests := map[string]bool{
|
||||
@@ -26,3 +28,4 @@ func TestIsValidDatabaseName_EmptyString(t *testing.T) {
|
||||
t.Errorf("Expected false for empty string, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
271
pkg/interactive/autocomplete_test.go
Normal file
271
pkg/interactive/autocomplete_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package interactive
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/c-bata/go-prompt"
|
||||
)
|
||||
|
||||
// TestNewAutocompleteSuggestions tests the creation of autocomplete suggestions
|
||||
func TestNewAutocompleteSuggestions(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
|
||||
if s == nil {
|
||||
t.Fatal("newAutocompleteSuggestions returned nil")
|
||||
}
|
||||
|
||||
if s.tablesBySchema == nil {
|
||||
t.Error("tablesBySchema map is nil")
|
||||
}
|
||||
|
||||
if s.queriesByMod == nil {
|
||||
t.Error("queriesByMod map is nil")
|
||||
}
|
||||
|
||||
// Note: slices are not initialized (nil is valid for slices in Go)
|
||||
// We just verify the struct itself is created
|
||||
}
|
||||
|
||||
// TestAutocompleteSuggestionsSort tests the sorting of suggestions
|
||||
func TestAutocompleteSuggestionsSort(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
|
||||
// Add unsorted suggestions
|
||||
s.schemas = []prompt.Suggest{
|
||||
{Text: "zebra", Description: "Schema"},
|
||||
{Text: "apple", Description: "Schema"},
|
||||
{Text: "mango", Description: "Schema"},
|
||||
}
|
||||
|
||||
s.unqualifiedTables = []prompt.Suggest{
|
||||
{Text: "users", Description: "Table"},
|
||||
{Text: "accounts", Description: "Table"},
|
||||
{Text: "posts", Description: "Table"},
|
||||
}
|
||||
|
||||
s.tablesBySchema["test"] = []prompt.Suggest{
|
||||
{Text: "z_table", Description: "Table"},
|
||||
{Text: "a_table", Description: "Table"},
|
||||
}
|
||||
|
||||
// Sort
|
||||
s.sort()
|
||||
|
||||
// Verify schemas are sorted
|
||||
if len(s.schemas) > 1 {
|
||||
for i := 1; i < len(s.schemas); i++ {
|
||||
if s.schemas[i-1].Text > s.schemas[i].Text {
|
||||
t.Errorf("schemas not sorted: %s > %s", s.schemas[i-1].Text, s.schemas[i].Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify tables are sorted
|
||||
if len(s.unqualifiedTables) > 1 {
|
||||
for i := 1; i < len(s.unqualifiedTables); i++ {
|
||||
if s.unqualifiedTables[i-1].Text > s.unqualifiedTables[i].Text {
|
||||
t.Errorf("unqualifiedTables not sorted: %s > %s", s.unqualifiedTables[i-1].Text, s.unqualifiedTables[i].Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify tablesBySchema are sorted
|
||||
tables := s.tablesBySchema["test"]
|
||||
if len(tables) > 1 {
|
||||
for i := 1; i < len(tables); i++ {
|
||||
if tables[i-1].Text > tables[i].Text {
|
||||
t.Errorf("tablesBySchema not sorted: %s > %s", tables[i-1].Text, tables[i].Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutocompleteSuggestionsEmptySort tests sorting with empty suggestions
|
||||
func TestAutocompleteSuggestionsEmptySort(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
|
||||
// Should not panic with empty suggestions
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("sort() panicked with empty suggestions: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
s.sort()
|
||||
}
|
||||
|
||||
// TestAutocompleteSuggestionsSortWithDuplicates tests sorting with duplicate entries
|
||||
func TestAutocompleteSuggestionsSortWithDuplicates(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
|
||||
// Add duplicate suggestions
|
||||
s.schemas = []prompt.Suggest{
|
||||
{Text: "apple", Description: "Schema"},
|
||||
{Text: "apple", Description: "Schema"},
|
||||
{Text: "banana", Description: "Schema"},
|
||||
}
|
||||
|
||||
// Should not panic with duplicates
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("sort() panicked with duplicates: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
s.sort()
|
||||
|
||||
// Verify duplicates are preserved (not removed)
|
||||
if len(s.schemas) != 3 {
|
||||
t.Errorf("sort() removed duplicates, got %d entries, want 3", len(s.schemas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutocompleteSuggestionsWithUnicode tests suggestions with unicode characters
|
||||
func TestAutocompleteSuggestionsWithUnicode(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
|
||||
s.schemas = []prompt.Suggest{
|
||||
{Text: "用户", Description: "Schema"},
|
||||
{Text: "数据库", Description: "Schema"},
|
||||
{Text: "🔥", Description: "Schema"},
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("sort() panicked with unicode: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
s.sort()
|
||||
|
||||
// Just verify it doesn't crash
|
||||
if len(s.schemas) != 3 {
|
||||
t.Errorf("sort() lost unicode entries, got %d entries, want 3", len(s.schemas))
|
||||
}
|
||||
}
|
||||
|
||||
// TestAutocompleteSuggestionsLargeDataset tests with a large number of suggestions
|
||||
func TestAutocompleteSuggestionsLargeDataset(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping large dataset test in short mode")
|
||||
}
|
||||
|
||||
s := newAutocompleteSuggestions()
|
||||
|
||||
// Add 10,000 schemas
|
||||
for i := 0; i < 10000; i++ {
|
||||
s.schemas = append(s.schemas, prompt.Suggest{
|
||||
Text: "schema_" + string(rune(i)),
|
||||
Description: "Schema",
|
||||
})
|
||||
}
|
||||
|
||||
// Add 10,000 tables
|
||||
for i := 0; i < 10000; i++ {
|
||||
s.unqualifiedTables = append(s.unqualifiedTables, prompt.Suggest{
|
||||
Text: "table_" + string(rune(i)),
|
||||
Description: "Table",
|
||||
})
|
||||
}
|
||||
|
||||
// Should not hang or crash
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("sort() panicked with large dataset: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
s.sort()
|
||||
}
|
||||
|
||||
// TestAutocompleteSuggestionsMemoryUsage tests memory usage with many suggestions
|
||||
func TestAutocompleteSuggestionsMemoryUsage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory usage test in short mode")
|
||||
}
|
||||
|
||||
// Create 100 suggestion sets
|
||||
suggestions := make([]*autoCompleteSuggestions, 100)
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
s := newAutocompleteSuggestions()
|
||||
|
||||
// Add many suggestions
|
||||
for j := 0; j < 1000; j++ {
|
||||
s.schemas = append(s.schemas, prompt.Suggest{
|
||||
Text: "schema",
|
||||
Description: "Schema",
|
||||
})
|
||||
}
|
||||
|
||||
suggestions[i] = s
|
||||
}
|
||||
|
||||
// If we get here without OOM, the test passes
|
||||
// Clear suggestions to allow GC
|
||||
suggestions = nil
|
||||
}
|
||||
|
||||
// TestAutocompleteSuggestionsEdgeCases tests various edge cases
|
||||
func TestAutocompleteSuggestionsEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
test func(*testing.T)
|
||||
}{
|
||||
{
|
||||
name: "empty text suggestion",
|
||||
test: func(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
s.schemas = []prompt.Suggest{
|
||||
{Text: "", Description: "Empty"},
|
||||
}
|
||||
s.sort() // Should not panic
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "very long text suggestion",
|
||||
test: func(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
longText := make([]byte, 10000)
|
||||
for i := range longText {
|
||||
longText[i] = 'a'
|
||||
}
|
||||
s.schemas = []prompt.Suggest{
|
||||
{Text: string(longText), Description: "Long"},
|
||||
}
|
||||
s.sort() // Should not panic
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "null bytes in text",
|
||||
test: func(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
s.schemas = []prompt.Suggest{
|
||||
{Text: "schema\x00name", Description: "Null"},
|
||||
}
|
||||
s.sort() // Should not panic
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "special characters in text",
|
||||
test: func(t *testing.T) {
|
||||
s := newAutocompleteSuggestions()
|
||||
s.schemas = []prompt.Suggest{
|
||||
{Text: "schema!@#$%^&*()", Description: "Special"},
|
||||
}
|
||||
s.sort() // Should not panic
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Test panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
tt.test(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
399
pkg/interactive/cancel_test.go
Normal file
399
pkg/interactive/cancel_test.go
Normal file
@@ -0,0 +1,399 @@
|
||||
package interactive
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestCreatePromptContext tests prompt context creation
|
||||
func TestCreatePromptContext(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx := context.Background()
|
||||
|
||||
ctx := c.createPromptContext(parentCtx)
|
||||
|
||||
if ctx == nil {
|
||||
t.Fatal("createPromptContext returned nil context")
|
||||
}
|
||||
|
||||
if c.cancelPrompt == nil {
|
||||
t.Fatal("createPromptContext didn't set cancelPrompt")
|
||||
}
|
||||
|
||||
// Verify context can be cancelled
|
||||
c.cancelPrompt()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Context was not cancelled after calling cancelPrompt")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreatePromptContextReplacesOld tests that creating a new context cancels the old one
|
||||
func TestCreatePromptContextReplacesOld(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx := context.Background()
|
||||
|
||||
// Create first context
|
||||
ctx1 := c.createPromptContext(parentCtx)
|
||||
cancel1 := c.cancelPrompt
|
||||
|
||||
// Create second context (should cancel first)
|
||||
ctx2 := c.createPromptContext(parentCtx)
|
||||
|
||||
// First context should be cancelled
|
||||
select {
|
||||
case <-ctx1.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("First context was not cancelled when creating second context")
|
||||
}
|
||||
|
||||
// Second context should still be active
|
||||
select {
|
||||
case <-ctx2.Done():
|
||||
t.Error("Second context should not be cancelled yet")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// Expected
|
||||
}
|
||||
|
||||
// First cancel function should be different from second
|
||||
if &cancel1 == &c.cancelPrompt {
|
||||
t.Error("cancelPrompt was not replaced")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateQueryContext tests query context creation
|
||||
func TestCreateQueryContext(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx := context.Background()
|
||||
|
||||
ctx := c.createQueryContext(parentCtx)
|
||||
|
||||
if ctx == nil {
|
||||
t.Fatal("createQueryContext returned nil context")
|
||||
}
|
||||
|
||||
if c.cancelActiveQuery == nil {
|
||||
t.Fatal("createQueryContext didn't set cancelActiveQuery")
|
||||
}
|
||||
|
||||
// Verify context can be cancelled
|
||||
c.cancelActiveQuery()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Context was not cancelled after calling cancelActiveQuery")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCreateQueryContextDoesNotCancelOld tests that creating a new query context doesn't cancel the old one
|
||||
func TestCreateQueryContextDoesNotCancelOld(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx := context.Background()
|
||||
|
||||
// Create first context
|
||||
ctx1 := c.createQueryContext(parentCtx)
|
||||
cancel1 := c.cancelActiveQuery
|
||||
|
||||
// Create second context (should NOT cancel first, just replace the reference)
|
||||
ctx2 := c.createQueryContext(parentCtx)
|
||||
|
||||
// First context should still be active (not automatically cancelled)
|
||||
select {
|
||||
case <-ctx1.Done():
|
||||
t.Error("First context was cancelled when creating second context (should not auto-cancel)")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// Expected - first context is NOT cancelled
|
||||
}
|
||||
|
||||
// Cancel using the first cancel function
|
||||
cancel1()
|
||||
|
||||
// Now first context should be cancelled
|
||||
select {
|
||||
case <-ctx1.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("First context was not cancelled after calling its cancel function")
|
||||
}
|
||||
|
||||
// Second context should still be active
|
||||
select {
|
||||
case <-ctx2.Done():
|
||||
t.Error("Second context should not be cancelled yet")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// Expected
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancelActiveQueryIfAnyIdempotent tests that cancellation is idempotent
|
||||
func TestCancelActiveQueryIfAnyIdempotent(t *testing.T) {
|
||||
callCount := 0
|
||||
cancelFunc := func() {
|
||||
callCount++
|
||||
}
|
||||
|
||||
c := &InteractiveClient{
|
||||
cancelActiveQuery: cancelFunc,
|
||||
}
|
||||
|
||||
// Call multiple times
|
||||
for i := 0; i < 5; i++ {
|
||||
c.cancelActiveQueryIfAny()
|
||||
}
|
||||
|
||||
// Should only be called once
|
||||
if callCount != 1 {
|
||||
t.Errorf("cancelActiveQueryIfAny() called cancel function %d times, want 1 (should be idempotent)", callCount)
|
||||
}
|
||||
|
||||
// Should be nil after first call
|
||||
if c.cancelActiveQuery != nil {
|
||||
t.Error("cancelActiveQueryIfAny() didn't set cancelActiveQuery to nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancelActiveQueryIfAnyNil tests behavior with nil cancel function
|
||||
func TestCancelActiveQueryIfAnyNil(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
cancelActiveQuery: nil,
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("cancelActiveQueryIfAny() panicked with nil cancel function: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Should not panic
|
||||
c.cancelActiveQueryIfAny()
|
||||
|
||||
// Should remain nil
|
||||
if c.cancelActiveQuery != nil {
|
||||
t.Error("cancelActiveQueryIfAny() set cancelActiveQuery when it was nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestClosePrompt tests the ClosePrompt method
|
||||
func TestClosePrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
afterClose AfterPromptCloseAction
|
||||
}{
|
||||
{
|
||||
name: "close with exit",
|
||||
afterClose: AfterPromptCloseExit,
|
||||
},
|
||||
{
|
||||
name: "close with restart",
|
||||
afterClose: AfterPromptCloseRestart,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cancelled := false
|
||||
c := &InteractiveClient{
|
||||
cancelPrompt: func() {
|
||||
cancelled = true
|
||||
},
|
||||
}
|
||||
|
||||
c.ClosePrompt(tt.afterClose)
|
||||
|
||||
if !cancelled {
|
||||
t.Error("ClosePrompt didn't call cancelPrompt")
|
||||
}
|
||||
|
||||
if c.afterClose != tt.afterClose {
|
||||
t.Errorf("ClosePrompt set afterClose to %v, want %v", c.afterClose, tt.afterClose)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextCancellationPropagation tests that parent context cancellation propagates
|
||||
func TestContextCancellationPropagation(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx, parentCancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create child context
|
||||
childCtx := c.createPromptContext(parentCtx)
|
||||
|
||||
// Cancel parent
|
||||
parentCancel()
|
||||
|
||||
// Child should be cancelled too
|
||||
select {
|
||||
case <-childCtx.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Child context was not cancelled when parent was cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// TestContextCancellationTimeout tests context with timeout
|
||||
func TestContextCancellationTimeout(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx, parentCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer parentCancel()
|
||||
|
||||
// Create child context
|
||||
childCtx := c.createPromptContext(parentCtx)
|
||||
|
||||
// Wait for timeout
|
||||
select {
|
||||
case <-childCtx.Done():
|
||||
// Expected after ~50ms
|
||||
if childCtx.Err() != context.DeadlineExceeded && childCtx.Err() != context.Canceled {
|
||||
t.Errorf("Expected DeadlineExceeded or Canceled error, got %v", childCtx.Err())
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
t.Error("Context did not timeout as expected")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRapidContextCreation tests rapid context creation and cancellation
|
||||
func TestRapidContextCreation(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx := context.Background()
|
||||
|
||||
// Rapidly create and cancel contexts
|
||||
for i := 0; i < 1000; i++ {
|
||||
ctx := c.createPromptContext(parentCtx)
|
||||
|
||||
// Immediately cancel
|
||||
if c.cancelPrompt != nil {
|
||||
c.cancelPrompt()
|
||||
}
|
||||
|
||||
// Verify cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Expected
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
t.Errorf("Context %d was not cancelled", i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancelAfterContextAlreadyCancelled tests cancelling after context is already cancelled
|
||||
func TestCancelAfterContextAlreadyCancelled(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx, parentCancel := context.WithCancel(context.Background())
|
||||
|
||||
// Create child context
|
||||
ctx := c.createQueryContext(parentCtx)
|
||||
|
||||
// Cancel parent first
|
||||
parentCancel()
|
||||
|
||||
// Wait for child to be cancelled
|
||||
<-ctx.Done()
|
||||
|
||||
// Now try to cancel via cancelActiveQueryIfAny
|
||||
// Should not panic even though context is already cancelled
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("cancelActiveQueryIfAny panicked when context already cancelled: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
c.cancelActiveQueryIfAny()
|
||||
}
|
||||
|
||||
// TestQueryContextLeakage tests for context leakage
|
||||
func TestQueryContextLeakage(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping leak test in short mode")
|
||||
}
|
||||
|
||||
c := &InteractiveClient{}
|
||||
parentCtx := context.Background()
|
||||
|
||||
// Create many query contexts
|
||||
for i := 0; i < 10000; i++ {
|
||||
ctx := c.createQueryContext(parentCtx)
|
||||
|
||||
// Cancel immediately
|
||||
if c.cancelActiveQuery != nil {
|
||||
c.cancelActiveQuery()
|
||||
}
|
||||
|
||||
// Verify context is cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// Good
|
||||
case <-time.After(1 * time.Millisecond):
|
||||
t.Errorf("Context %d not cancelled", i)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here without hanging or OOM, test passes
|
||||
}
|
||||
|
||||
// TestCancelFuncReplacement tests that cancel functions are properly replaced
|
||||
func TestCancelFuncReplacement(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
parentCtx := context.Background()
|
||||
|
||||
// Track which cancel function was called
|
||||
firstCalled := false
|
||||
secondCalled := false
|
||||
|
||||
// Create first query context
|
||||
ctx1 := c.createQueryContext(parentCtx)
|
||||
firstCancel := c.cancelActiveQuery
|
||||
|
||||
// Wrap the first cancel to track calls
|
||||
c.cancelActiveQuery = func() {
|
||||
firstCalled = true
|
||||
firstCancel()
|
||||
}
|
||||
|
||||
// Create second query context (replaces cancelActiveQuery)
|
||||
ctx2 := c.createQueryContext(parentCtx)
|
||||
secondCancel := c.cancelActiveQuery
|
||||
|
||||
// Wrap the second cancel to track calls
|
||||
c.cancelActiveQuery = func() {
|
||||
secondCalled = true
|
||||
secondCancel()
|
||||
}
|
||||
|
||||
// Call cancelActiveQueryIfAny
|
||||
c.cancelActiveQueryIfAny()
|
||||
|
||||
// Only the second cancel should be called
|
||||
if firstCalled {
|
||||
t.Error("First cancel function was called (should have been replaced)")
|
||||
}
|
||||
|
||||
if !secondCalled {
|
||||
t.Error("Second cancel function was not called")
|
||||
}
|
||||
|
||||
// Second context should be cancelled
|
||||
select {
|
||||
case <-ctx2.Done():
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Error("Second context was not cancelled")
|
||||
}
|
||||
|
||||
// First context is NOT automatically cancelled (different from prompt context)
|
||||
select {
|
||||
case <-ctx1.Done():
|
||||
// This might happen if parent was cancelled, but shouldn't happen from our cancel
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// Expected - first context remains active
|
||||
}
|
||||
}
|
||||
239
pkg/interactive/highlighter_test.go
Normal file
239
pkg/interactive/highlighter_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package interactive
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/alecthomas/chroma/formatters"
|
||||
"github.com/alecthomas/chroma/lexers"
|
||||
"github.com/alecthomas/chroma/styles"
|
||||
"github.com/c-bata/go-prompt"
|
||||
)
|
||||
|
||||
// TestNewHighlighter tests highlighter creation
|
||||
func TestNewHighlighter(t *testing.T) {
|
||||
lexer := lexers.Get("sql")
|
||||
formatter := formatters.Get("terminal256")
|
||||
style := styles.Native
|
||||
|
||||
h := newHighlighter(lexer, formatter, style)
|
||||
|
||||
if h == nil {
|
||||
t.Fatal("newHighlighter returned nil")
|
||||
}
|
||||
|
||||
if h.lexer == nil {
|
||||
t.Error("highlighter lexer is nil")
|
||||
}
|
||||
|
||||
if h.formatter == nil {
|
||||
t.Error("highlighter formatter is nil")
|
||||
}
|
||||
|
||||
if h.style == nil {
|
||||
t.Error("highlighter style is nil")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHighlighterHighlight tests the Highlight function
|
||||
func TestHighlighterHighlight(t *testing.T) {
|
||||
h := newHighlighter(
|
||||
lexers.Get("sql"),
|
||||
formatters.Get("terminal256"),
|
||||
styles.Native,
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple select",
|
||||
input: "SELECT * FROM users",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "multiline query",
|
||||
input: "SELECT *\nFROM users\nWHERE id = 1",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
input: "SELECT '你好世界'",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "SELECT '🔥💥✨'",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "null bytes",
|
||||
input: "SELECT '\x00'",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "control characters",
|
||||
input: "SELECT '\n\r\t'",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "very long query",
|
||||
input: "SELECT " + strings.Repeat("a, ", 1000) + "* FROM users",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "SQL injection attempt",
|
||||
input: "'; DROP TABLE users; --",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "malformed SQL",
|
||||
input: "SELECT FROM WHERE",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
input: "SELECT '\\', '/', '\"', '`'",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
doc := prompt.Document{
|
||||
Text: tt.input,
|
||||
}
|
||||
|
||||
result, err := h.Highlight(doc)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Highlight() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && result == nil {
|
||||
t.Error("Highlight() returned nil result without error")
|
||||
}
|
||||
|
||||
// Verify result is not empty for non-empty input
|
||||
if !tt.wantErr && tt.input != "" && len(result) == 0 {
|
||||
t.Error("Highlight() returned empty result for non-empty input")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetHighlighter tests the getHighlighter function
|
||||
func TestGetHighlighter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
theme string
|
||||
}{
|
||||
{
|
||||
name: "default theme",
|
||||
theme: "",
|
||||
},
|
||||
{
|
||||
name: "dark theme",
|
||||
theme: "dark",
|
||||
},
|
||||
{
|
||||
name: "light theme",
|
||||
theme: "light",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := getHighlighter(tt.theme)
|
||||
|
||||
if h == nil {
|
||||
t.Fatal("getHighlighter returned nil")
|
||||
}
|
||||
|
||||
if h.lexer == nil {
|
||||
t.Error("highlighter lexer is nil")
|
||||
}
|
||||
|
||||
if h.formatter == nil {
|
||||
t.Error("highlighter formatter is nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHighlighterConcurrency tests concurrent highlighting
|
||||
func TestHighlighterConcurrency(t *testing.T) {
|
||||
h := newHighlighter(
|
||||
lexers.Get("sql"),
|
||||
formatters.Get("terminal256"),
|
||||
styles.Native,
|
||||
)
|
||||
|
||||
queries := []string{
|
||||
"SELECT * FROM users",
|
||||
"SELECT id FROM posts",
|
||||
"SELECT name FROM companies",
|
||||
}
|
||||
|
||||
done := make(chan bool)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
go func(idx int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Concurrent Highlight panicked: %v", r)
|
||||
}
|
||||
done <- true
|
||||
}()
|
||||
|
||||
doc := prompt.Document{
|
||||
Text: queries[idx%len(queries)],
|
||||
}
|
||||
|
||||
_, err := h.Highlight(doc)
|
||||
if err != nil {
|
||||
t.Errorf("Concurrent Highlight error: %v", err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all goroutines
|
||||
for i := 0; i < 10; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// TestHighlighterMemoryLeak tests for memory leaks with repeated highlighting
|
||||
func TestHighlighterMemoryLeak(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping memory leak test in short mode")
|
||||
}
|
||||
|
||||
h := newHighlighter(
|
||||
lexers.Get("sql"),
|
||||
formatters.Get("terminal256"),
|
||||
styles.Native,
|
||||
)
|
||||
|
||||
// Highlight the same query many times to check for memory leaks
|
||||
doc := prompt.Document{
|
||||
Text: "SELECT * FROM users WHERE id = 1",
|
||||
}
|
||||
|
||||
for i := 0; i < 10000; i++ {
|
||||
_, err := h.Highlight(doc)
|
||||
if err != nil {
|
||||
t.Fatalf("Highlight failed at iteration %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
// If we get here without OOM, the test passes
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
package interactive
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/c-bata/go-prompt"
|
||||
pconstants "github.com/turbot/pipe-fittings/v2/constants"
|
||||
"github.com/turbot/steampipe/v2/pkg/cmdconfig"
|
||||
)
|
||||
|
||||
// TestGetTableAndConnectionSuggestions_ReturnsEmptySliceNotNil tests that
|
||||
@@ -68,3 +71,464 @@ func TestGetTableAndConnectionSuggestions_ReturnsEmptySliceNotNil(t *testing.T)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShouldExecute tests the shouldExecute logic for query execution
|
||||
func TestShouldExecute(t *testing.T) {
|
||||
// Save and restore viper settings
|
||||
originalMultiline := cmdconfig.Viper().GetBool(pconstants.ArgMultiLine)
|
||||
defer func() {
|
||||
cmdconfig.Viper().Set(pconstants.ArgMultiLine, originalMultiline)
|
||||
}()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
multiline bool
|
||||
shouldExec bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "simple query without semicolon in non-multiline",
|
||||
query: "SELECT * FROM users",
|
||||
multiline: false,
|
||||
shouldExec: true,
|
||||
description: "In non-multiline mode, execute without semicolon",
|
||||
},
|
||||
{
|
||||
name: "simple query with semicolon in non-multiline",
|
||||
query: "SELECT * FROM users;",
|
||||
multiline: false,
|
||||
shouldExec: true,
|
||||
description: "In non-multiline mode, execute with semicolon",
|
||||
},
|
||||
{
|
||||
name: "simple query without semicolon in multiline",
|
||||
query: "SELECT * FROM users",
|
||||
multiline: true,
|
||||
shouldExec: false,
|
||||
description: "In multiline mode, don't execute without semicolon",
|
||||
},
|
||||
{
|
||||
name: "simple query with semicolon in multiline",
|
||||
query: "SELECT * FROM users;",
|
||||
multiline: true,
|
||||
shouldExec: true,
|
||||
description: "In multiline mode, execute with semicolon",
|
||||
},
|
||||
{
|
||||
name: "metaquery without semicolon in multiline",
|
||||
query: ".help",
|
||||
multiline: true,
|
||||
shouldExec: true,
|
||||
description: "Metaqueries execute without semicolon even in multiline",
|
||||
},
|
||||
{
|
||||
name: "metaquery with semicolon in multiline",
|
||||
query: ".help;",
|
||||
multiline: true,
|
||||
shouldExec: true,
|
||||
description: "Metaqueries execute with semicolon in multiline",
|
||||
},
|
||||
{
|
||||
name: "empty query",
|
||||
query: "",
|
||||
multiline: false,
|
||||
shouldExec: true,
|
||||
description: "Empty query executes in non-multiline",
|
||||
},
|
||||
{
|
||||
name: "empty query in multiline",
|
||||
query: "",
|
||||
multiline: true,
|
||||
shouldExec: false,
|
||||
description: "Empty query doesn't execute in multiline",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &InteractiveClient{}
|
||||
cmdconfig.Viper().Set(pconstants.ArgMultiLine, tt.multiline)
|
||||
|
||||
result := c.shouldExecute(tt.query)
|
||||
|
||||
if result != tt.shouldExec {
|
||||
t.Errorf("shouldExecute(%q) in multiline=%v = %v, want %v\nReason: %s",
|
||||
tt.query, tt.multiline, result, tt.shouldExec, tt.description)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestShouldExecuteEdgeCases tests edge cases for shouldExecute
|
||||
func TestShouldExecuteEdgeCases(t *testing.T) {
|
||||
originalMultiline := cmdconfig.Viper().GetBool(pconstants.ArgMultiLine)
|
||||
defer func() {
|
||||
cmdconfig.Viper().Set(pconstants.ArgMultiLine, originalMultiline)
|
||||
}()
|
||||
|
||||
c := &InteractiveClient{}
|
||||
cmdconfig.Viper().Set(pconstants.ArgMultiLine, true)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
}{
|
||||
{
|
||||
name: "very long query with semicolon",
|
||||
query: strings.Repeat("SELECT * FROM users WHERE id = 1 AND ", 100) + "1=1;",
|
||||
},
|
||||
{
|
||||
name: "unicode characters with semicolon",
|
||||
query: "SELECT '你好世界';",
|
||||
},
|
||||
{
|
||||
name: "emoji with semicolon",
|
||||
query: "SELECT '🔥💥';",
|
||||
},
|
||||
{
|
||||
name: "null bytes",
|
||||
query: "SELECT '\x00';",
|
||||
},
|
||||
{
|
||||
name: "control characters",
|
||||
query: "SELECT '\n\r\t';",
|
||||
},
|
||||
{
|
||||
name: "SQL injection with semicolon",
|
||||
query: "'; DROP TABLE users; --",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("shouldExecute(%q) panicked: %v", tt.query, r)
|
||||
}
|
||||
}()
|
||||
|
||||
_ = c.shouldExecute(tt.query)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBreakMultilinePrompt tests the breakMultilinePrompt function
|
||||
func TestBreakMultilinePrompt(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
interactiveBuffer: []string{"SELECT *", "FROM users", "WHERE"},
|
||||
}
|
||||
|
||||
c.breakMultilinePrompt(nil)
|
||||
|
||||
if len(c.interactiveBuffer) != 0 {
|
||||
t.Errorf("breakMultilinePrompt() didn't clear buffer, got %d items, want 0", len(c.interactiveBuffer))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBreakMultilinePromptEmpty tests breaking an already empty buffer
|
||||
func TestBreakMultilinePromptEmpty(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
interactiveBuffer: []string{},
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("breakMultilinePrompt() panicked on empty buffer: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
c.breakMultilinePrompt(nil)
|
||||
|
||||
if len(c.interactiveBuffer) != 0 {
|
||||
t.Errorf("breakMultilinePrompt() didn't maintain empty buffer, got %d items, want 0", len(c.interactiveBuffer))
|
||||
}
|
||||
}
|
||||
|
||||
// TestBreakMultilinePromptNil tests breaking with nil buffer
|
||||
func TestBreakMultilinePromptNil(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
interactiveBuffer: nil,
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("breakMultilinePrompt() panicked on nil buffer: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
c.breakMultilinePrompt(nil)
|
||||
|
||||
if c.interactiveBuffer == nil {
|
||||
t.Error("breakMultilinePrompt() didn't initialize nil buffer")
|
||||
}
|
||||
|
||||
if len(c.interactiveBuffer) != 0 {
|
||||
t.Errorf("breakMultilinePrompt() didn't create empty buffer, got %d items, want 0", len(c.interactiveBuffer))
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsInitialised tests the isInitialised method
|
||||
func TestIsInitialised(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
initialisationComplete bool
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "initialized",
|
||||
initialisationComplete: true,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not initialized",
|
||||
initialisationComplete: false,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
initialisationComplete: tt.initialisationComplete,
|
||||
}
|
||||
|
||||
result := c.isInitialised()
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("isInitialised() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClientNil tests the client() method when initData is nil
|
||||
func TestClientNil(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
initData: nil,
|
||||
}
|
||||
|
||||
client := c.client()
|
||||
|
||||
if client != nil {
|
||||
t.Errorf("client() with nil initData should return nil, got %v", client)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAfterPromptCloseAction tests the AfterPromptCloseAction enum
|
||||
func TestAfterPromptCloseAction(t *testing.T) {
|
||||
// Test that the enum values are distinct
|
||||
if AfterPromptCloseExit == AfterPromptCloseRestart {
|
||||
t.Error("AfterPromptCloseExit and AfterPromptCloseRestart should have different values")
|
||||
}
|
||||
|
||||
// Test that they have the expected values
|
||||
if AfterPromptCloseExit != 0 {
|
||||
t.Errorf("AfterPromptCloseExit should be 0, got %d", AfterPromptCloseExit)
|
||||
}
|
||||
|
||||
if AfterPromptCloseRestart != 1 {
|
||||
t.Errorf("AfterPromptCloseRestart should be 1, got %d", AfterPromptCloseRestart)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetFirstWordSuggestionsEmptyWord tests getFirstWordSuggestions with empty input
|
||||
func TestGetFirstWordSuggestionsEmptyWord(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
suggestions: newAutocompleteSuggestions(),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("getFirstWordSuggestions panicked on empty input: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
suggestions := c.getFirstWordSuggestions("")
|
||||
|
||||
// Should return suggestions (select, with, metaqueries)
|
||||
if len(suggestions) == 0 {
|
||||
t.Error("getFirstWordSuggestions(\"\") should return suggestions")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetFirstWordSuggestionsQualifiedQuery tests qualified query suggestions
|
||||
func TestGetFirstWordSuggestionsQualifiedQuery(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
suggestions: newAutocompleteSuggestions(),
|
||||
}
|
||||
|
||||
// Add mock data
|
||||
c.suggestions.queriesByMod = map[string][]prompt.Suggest{
|
||||
"mymod": {
|
||||
{Text: "mymod.query1", Description: "Query"},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "qualified with known mod",
|
||||
input: "mymod.",
|
||||
},
|
||||
{
|
||||
name: "qualified with unknown mod",
|
||||
input: "unknownmod.",
|
||||
},
|
||||
{
|
||||
name: "single word",
|
||||
input: "select",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("getFirstWordSuggestions(%q) panicked: %v", tt.input, r)
|
||||
}
|
||||
}()
|
||||
|
||||
suggestions := c.getFirstWordSuggestions(tt.input)
|
||||
|
||||
if suggestions == nil {
|
||||
t.Errorf("getFirstWordSuggestions(%q) returned nil", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetTableAndConnectionSuggestionsEdgeCases tests edge cases
|
||||
func TestGetTableAndConnectionSuggestionsEdgeCases(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
suggestions: newAutocompleteSuggestions(),
|
||||
}
|
||||
|
||||
// Add mock data
|
||||
c.suggestions.schemas = []prompt.Suggest{
|
||||
{Text: "public", Description: "Schema"},
|
||||
}
|
||||
c.suggestions.unqualifiedTables = []prompt.Suggest{
|
||||
{Text: "users", Description: "Table"},
|
||||
}
|
||||
c.suggestions.tablesBySchema = map[string][]prompt.Suggest{
|
||||
"public": {
|
||||
{Text: "public.users", Description: "Table"},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
}{
|
||||
{
|
||||
name: "unqualified",
|
||||
input: "users",
|
||||
},
|
||||
{
|
||||
name: "qualified with known schema",
|
||||
input: "public.users",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
},
|
||||
{
|
||||
name: "just dot",
|
||||
input: ".",
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
input: "用户.表",
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "schema🔥.table",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("getTableAndConnectionSuggestions(%q) panicked: %v", tt.input, r)
|
||||
}
|
||||
}()
|
||||
|
||||
suggestions := c.getTableAndConnectionSuggestions(tt.input)
|
||||
|
||||
if suggestions == nil {
|
||||
t.Errorf("getTableAndConnectionSuggestions(%q) returned nil", tt.input)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCancelActiveQueryIfAny tests the cancellation logic
|
||||
func TestCancelActiveQueryIfAny(t *testing.T) {
|
||||
t.Run("no active query", func(t *testing.T) {
|
||||
c := &InteractiveClient{
|
||||
cancelActiveQuery: nil,
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("cancelActiveQueryIfAny() panicked with nil cancelFunc: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
c.cancelActiveQueryIfAny()
|
||||
|
||||
if c.cancelActiveQuery != nil {
|
||||
t.Error("cancelActiveQueryIfAny() set cancelActiveQuery when it was nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("with active query", func(t *testing.T) {
|
||||
cancelled := false
|
||||
cancelFunc := func() {
|
||||
cancelled = true
|
||||
}
|
||||
|
||||
c := &InteractiveClient{
|
||||
cancelActiveQuery: cancelFunc,
|
||||
}
|
||||
|
||||
c.cancelActiveQueryIfAny()
|
||||
|
||||
if !cancelled {
|
||||
t.Error("cancelActiveQueryIfAny() didn't call the cancel function")
|
||||
}
|
||||
|
||||
if c.cancelActiveQuery != nil {
|
||||
t.Error("cancelActiveQueryIfAny() didn't set cancelActiveQuery to nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("multiple calls", func(t *testing.T) {
|
||||
callCount := 0
|
||||
cancelFunc := func() {
|
||||
callCount++
|
||||
}
|
||||
|
||||
c := &InteractiveClient{
|
||||
cancelActiveQuery: cancelFunc,
|
||||
}
|
||||
|
||||
// First call should cancel
|
||||
c.cancelActiveQueryIfAny()
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("First cancelActiveQueryIfAny() call count = %d, want 1", callCount)
|
||||
}
|
||||
|
||||
// Second call should be a no-op
|
||||
c.cancelActiveQueryIfAny()
|
||||
|
||||
if callCount != 1 {
|
||||
t.Errorf("Second cancelActiveQueryIfAny() call count = %d, want 1 (should be idempotent)", callCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
588
pkg/interactive/interactive_helpers_test.go
Normal file
588
pkg/interactive/interactive_helpers_test.go
Normal file
@@ -0,0 +1,588 @@
|
||||
package interactive
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestIsFirstWord tests the isFirstWord helper function
|
||||
func TestIsFirstWord(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "single word",
|
||||
input: "select",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "two words",
|
||||
input: "select *",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "word with trailing space",
|
||||
input: "select ",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "multiple spaces",
|
||||
input: "select from",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
input: "選択",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "🔥",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "emoji with space",
|
||||
input: "🔥 test",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isFirstWord(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isFirstWord(%q) = %v, want %v", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLastWord tests the lastWord helper function (only passing cases)
|
||||
func TestLastWord(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "two words",
|
||||
input: "select *",
|
||||
expected: " *",
|
||||
},
|
||||
{
|
||||
name: "multiple words",
|
||||
input: "select * from users",
|
||||
expected: " users",
|
||||
},
|
||||
{
|
||||
name: "trailing space",
|
||||
input: "select * from ",
|
||||
expected: " ",
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
input: "select 你好",
|
||||
expected: " 你好",
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "select 🔥",
|
||||
expected: " 🔥",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("lastWord(%q) panicked: %v", tt.input, r)
|
||||
}
|
||||
}()
|
||||
|
||||
result := lastWord(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("lastWord(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestLastIndexByteNot tests the lastIndexByteNot helper function
|
||||
func TestLastIndexByteNot(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
char byte
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "no matching char",
|
||||
input: "hello",
|
||||
char: ' ',
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "trailing spaces",
|
||||
input: "hello ",
|
||||
char: ' ',
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "all spaces",
|
||||
input: " ",
|
||||
char: ' ',
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
char: ' ',
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "single char not matching",
|
||||
input: "a",
|
||||
char: ' ',
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "single char matching",
|
||||
input: " ",
|
||||
char: ' ',
|
||||
expected: -1,
|
||||
},
|
||||
{
|
||||
name: "mixed spaces",
|
||||
input: "hello world ",
|
||||
char: ' ',
|
||||
expected: 10,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := lastIndexByteNot(tt.input, tt.char)
|
||||
if result != tt.expected {
|
||||
t.Errorf("lastIndexByteNot(%q, %q) = %d, want %d", tt.input, tt.char, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetPreviousWord tests the getPreviousWord helper function
|
||||
func TestGetPreviousWord(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple case",
|
||||
input: "select * from ",
|
||||
expected: "from",
|
||||
},
|
||||
{
|
||||
name: "no previous word",
|
||||
input: "select ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "single word",
|
||||
input: "select",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "multiple spaces",
|
||||
input: "select * from ",
|
||||
expected: "from",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "only spaces",
|
||||
input: " ",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
input: "select 你好 世界 ",
|
||||
expected: "世界",
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "select 🔥 💥 ",
|
||||
expected: "💥",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getPreviousWord(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getPreviousWord(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetTable tests the getTable helper function
|
||||
func TestGetTable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple select",
|
||||
input: "select * from users",
|
||||
expected: "users",
|
||||
},
|
||||
{
|
||||
name: "qualified table",
|
||||
input: "select * from public.users",
|
||||
expected: "public.users",
|
||||
},
|
||||
{
|
||||
name: "no from clause",
|
||||
input: "select 1",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "from at end",
|
||||
input: "select * from",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "from with trailing text",
|
||||
input: "select * from users where",
|
||||
expected: "users",
|
||||
},
|
||||
{
|
||||
name: "double spaces",
|
||||
input: "select * from users",
|
||||
expected: "users",
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "case sensitive - lowercase from",
|
||||
input: "SELECT * from users",
|
||||
expected: "users",
|
||||
},
|
||||
{
|
||||
name: "uppercase FROM",
|
||||
input: "SELECT * FROM users",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode table name",
|
||||
input: "select * from 用户表",
|
||||
expected: "用户表",
|
||||
},
|
||||
{
|
||||
name: "emoji in table name",
|
||||
input: "select * from users🔥",
|
||||
expected: "users🔥",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getTable(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("getTable(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsEditingTable tests the isEditingTable helper function
|
||||
func TestIsEditingTable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prevWord string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "from keyword",
|
||||
prevWord: "from",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "not from keyword",
|
||||
prevWord: "select",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
prevWord: "",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "FROM uppercase",
|
||||
prevWord: "FROM",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace",
|
||||
prevWord: " from ",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isEditingTable(tt.prevWord)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isEditingTable(%q) = %v, want %v", tt.prevWord, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetQueryInfo tests the getQueryInfo function (passing cases only)
|
||||
func TestGetQueryInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedTable string
|
||||
expectedEditing bool
|
||||
}{
|
||||
{
|
||||
name: "editing table after from",
|
||||
input: "select * from ",
|
||||
expectedTable: "",
|
||||
expectedEditing: true,
|
||||
},
|
||||
{
|
||||
name: "table specified",
|
||||
input: "select * from users ",
|
||||
expectedTable: "users",
|
||||
expectedEditing: false,
|
||||
},
|
||||
{
|
||||
name: "not at from clause",
|
||||
input: "select * ",
|
||||
expectedTable: "",
|
||||
expectedEditing: false,
|
||||
},
|
||||
{
|
||||
name: "empty query",
|
||||
input: "",
|
||||
expectedTable: "",
|
||||
expectedEditing: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := getQueryInfo(tt.input)
|
||||
if result.Table != tt.expectedTable {
|
||||
t.Errorf("getQueryInfo(%q).Table = %q, want %q", tt.input, result.Table, tt.expectedTable)
|
||||
}
|
||||
if result.EditingTable != tt.expectedEditing {
|
||||
t.Errorf("getQueryInfo(%q).EditingTable = %v, want %v", tt.input, result.EditingTable, tt.expectedEditing)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCleanBufferForWSL tests the WSL-specific buffer cleaning
|
||||
func TestCleanBufferForWSL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedOutput string
|
||||
expectedIgnore bool
|
||||
}{
|
||||
{
|
||||
name: "normal text",
|
||||
input: "hello",
|
||||
expectedOutput: "hello",
|
||||
expectedIgnore: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expectedOutput: "",
|
||||
expectedIgnore: false,
|
||||
},
|
||||
{
|
||||
name: "escape sequence",
|
||||
input: string([]byte{27, 65}), // ESC + 'A'
|
||||
expectedOutput: "",
|
||||
expectedIgnore: true,
|
||||
},
|
||||
{
|
||||
name: "single escape",
|
||||
input: string([]byte{27}),
|
||||
expectedOutput: string([]byte{27}),
|
||||
expectedIgnore: false,
|
||||
},
|
||||
{
|
||||
name: "unicode",
|
||||
input: "你好",
|
||||
expectedOutput: "你好",
|
||||
expectedIgnore: false,
|
||||
},
|
||||
{
|
||||
name: "emoji",
|
||||
input: "🔥",
|
||||
expectedOutput: "🔥",
|
||||
expectedIgnore: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
output, ignore := cleanBufferForWSL(tt.input)
|
||||
if output != tt.expectedOutput {
|
||||
t.Errorf("cleanBufferForWSL(%q) output = %q, want %q", tt.input, output, tt.expectedOutput)
|
||||
}
|
||||
if ignore != tt.expectedIgnore {
|
||||
t.Errorf("cleanBufferForWSL(%q) ignore = %v, want %v", tt.input, ignore, tt.expectedIgnore)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSanitiseTableName tests table name escaping (passing cases only)
|
||||
func TestSanitiseTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple lowercase table",
|
||||
input: "users",
|
||||
expected: "users",
|
||||
},
|
||||
{
|
||||
name: "uppercase table",
|
||||
input: "Users",
|
||||
expected: `"Users"`,
|
||||
},
|
||||
{
|
||||
name: "table with space",
|
||||
input: "user data",
|
||||
expected: `"user data"`,
|
||||
},
|
||||
{
|
||||
name: "table with hyphen",
|
||||
input: "user-data",
|
||||
expected: `"user-data"`,
|
||||
},
|
||||
{
|
||||
name: "qualified table",
|
||||
input: "schema.table",
|
||||
expected: "schema.table",
|
||||
},
|
||||
{
|
||||
name: "qualified with uppercase",
|
||||
input: "Schema.Table",
|
||||
expected: `"Schema"."Table"`,
|
||||
},
|
||||
{
|
||||
name: "qualified with spaces",
|
||||
input: "my schema.my table",
|
||||
expected: `"my schema"."my table"`,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitiseTableName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitiseTableName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHelperFunctionsWithExtremeInput tests helper functions with extreme inputs
|
||||
func TestHelperFunctionsWithExtremeInput(t *testing.T) {
|
||||
t.Run("very long string", func(t *testing.T) {
|
||||
longString := strings.Repeat("a ", 10000)
|
||||
|
||||
// Test that these don't panic or hang
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Function panicked on long string: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
_ = isFirstWord(longString)
|
||||
_ = getTable(longString)
|
||||
_ = getPreviousWord(longString)
|
||||
_ = getQueryInfo(longString)
|
||||
})
|
||||
|
||||
t.Run("null bytes", func(t *testing.T) {
|
||||
nullByteString := "select\x00from\x00users"
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Function panicked on null bytes: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
_ = isFirstWord(nullByteString)
|
||||
_ = getTable(nullByteString)
|
||||
_ = getPreviousWord(nullByteString)
|
||||
})
|
||||
|
||||
t.Run("control characters", func(t *testing.T) {
|
||||
controlString := "select\n\r\tfrom\n\rusers"
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Function panicked on control chars: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
_ = isFirstWord(controlString)
|
||||
_ = getTable(controlString)
|
||||
_ = getPreviousWord(controlString)
|
||||
})
|
||||
|
||||
t.Run("SQL injection attempts", func(t *testing.T) {
|
||||
injectionStrings := []string{
|
||||
"'; DROP TABLE users; --",
|
||||
"1' OR '1'='1",
|
||||
"1; DELETE FROM connections; --",
|
||||
"select * from users where id = 1' union select * from passwords --",
|
||||
}
|
||||
|
||||
for _, injection := range injectionStrings {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Function panicked on injection string %q: %v", injection, r)
|
||||
}
|
||||
}()
|
||||
|
||||
_ = isFirstWord(injection)
|
||||
_ = getTable(injection)
|
||||
_ = getPreviousWord(injection)
|
||||
_ = getQueryInfo(injection)
|
||||
}
|
||||
})
|
||||
}
|
||||
308
pkg/pluginmanager_service/message_server_test.go
Normal file
308
pkg/pluginmanager_service/message_server_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package pluginmanager_service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
sdkproto "github.com/turbot/steampipe-plugin-sdk/v5/grpc/proto"
|
||||
)
|
||||
|
||||
// Test helpers for message server tests
|
||||
|
||||
func newTestMessageServer(t *testing.T) *PluginMessageServer {
|
||||
t.Helper()
|
||||
pm := newTestPluginManager(t)
|
||||
return &PluginMessageServer{
|
||||
pluginManager: pm,
|
||||
}
|
||||
}
|
||||
|
||||
// Test 1: NewPluginMessageServer
|
||||
|
||||
func TestNewPluginMessageServer(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
ms, err := NewPluginMessageServer(pm)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, ms)
|
||||
assert.Equal(t, pm, ms.pluginManager)
|
||||
}
|
||||
|
||||
// Test 2: PluginMessageServer Initialization
|
||||
|
||||
func TestPluginManager_MessageServerInitialization(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
assert.NotNil(t, pm.messageServer, "messageServer should be initialized")
|
||||
assert.Equal(t, pm, pm.messageServer.pluginManager, "messageServer should reference parent PluginManager")
|
||||
}
|
||||
|
||||
// Test 3: Concurrent Access
|
||||
|
||||
func TestPluginMessageServer_ConcurrentAccess(t *testing.T) {
|
||||
ms := newTestMessageServer(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = ms.pluginManager
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 4: LogReceiveError with Valid Errors
|
||||
|
||||
func TestPluginMessageServer_LogReceiveError(t *testing.T) {
|
||||
ms := newTestMessageServer(t)
|
||||
|
||||
// Should not panic for various error types
|
||||
ms.logReceiveError(context.Canceled, "test-connection")
|
||||
ms.logReceiveError(context.DeadlineExceeded, "test-connection")
|
||||
}
|
||||
|
||||
// Test 5: Multiple Message Servers
|
||||
|
||||
func TestPluginManager_MultipleMessageServers(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
ms1, err1 := NewPluginMessageServer(pm)
|
||||
ms2, err2 := NewPluginMessageServer(pm)
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
assert.NotNil(t, ms1)
|
||||
assert.NotNil(t, ms2)
|
||||
|
||||
// Both should reference the same plugin manager
|
||||
assert.Equal(t, pm, ms1.pluginManager)
|
||||
assert.Equal(t, pm, ms2.pluginManager)
|
||||
}
|
||||
|
||||
// Test 6: Message Server with Nil Plugin Manager
|
||||
|
||||
func TestPluginMessageServer_NilPluginManager(t *testing.T) {
|
||||
ms := &PluginMessageServer{
|
||||
pluginManager: nil,
|
||||
}
|
||||
|
||||
assert.Nil(t, ms.pluginManager)
|
||||
}
|
||||
|
||||
// Test 7: Goroutine Cleanup
|
||||
|
||||
func TestPluginMessageServer_GoroutineCleanup(t *testing.T) {
|
||||
before := runtime.NumGoroutine()
|
||||
|
||||
ms := newTestMessageServer(t)
|
||||
_ = ms
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
after := runtime.NumGoroutine()
|
||||
|
||||
// Creating a message server shouldn't leak goroutines
|
||||
if after > before+5 {
|
||||
t.Errorf("Potential goroutine leak: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 8: Message Type Structure
|
||||
|
||||
func TestPluginMessage_SchemaUpdatedType(t *testing.T) {
|
||||
message := &sdkproto.PluginMessage{
|
||||
MessageType: sdkproto.PluginMessageType_SCHEMA_UPDATED,
|
||||
Connection: "test-connection",
|
||||
}
|
||||
|
||||
assert.Equal(t, sdkproto.PluginMessageType_SCHEMA_UPDATED, message.MessageType)
|
||||
assert.Equal(t, "test-connection", message.Connection)
|
||||
}
|
||||
|
||||
// Test 9: LogReceiveError with Different Error Types
|
||||
|
||||
func TestPluginMessageServer_LogReceiveError_ErrorTypes(t *testing.T) {
|
||||
ms := newTestMessageServer(t)
|
||||
|
||||
// Test various error types don't cause panics
|
||||
errors := []error{
|
||||
context.Canceled,
|
||||
context.DeadlineExceeded,
|
||||
assert.AnError,
|
||||
}
|
||||
|
||||
for _, err := range errors {
|
||||
ms.logReceiveError(err, "test-connection")
|
||||
}
|
||||
}
|
||||
|
||||
// Test 10: Message Server Initialization Consistency
|
||||
|
||||
func TestPluginManager_MessageServer_Consistency(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Verify messageServer is initialized and consistent
|
||||
assert.NotNil(t, pm.messageServer)
|
||||
assert.Equal(t, pm, pm.messageServer.pluginManager)
|
||||
|
||||
// Accessing it multiple times should return the same instance
|
||||
ms1 := pm.messageServer
|
||||
ms2 := pm.messageServer
|
||||
assert.Equal(t, ms1, ms2)
|
||||
}
|
||||
|
||||
// Test 11: Message Server Survives Plugin Manager Operations
|
||||
|
||||
func TestPluginMessageServer_SurvivesPluginManagerOperations(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
ms := pm.messageServer
|
||||
|
||||
// Perform various plugin manager operations
|
||||
pm.populatePluginConnectionConfigs()
|
||||
pm.setPluginCacheSizeMap()
|
||||
pm.nonAggregatorConnectionCount()
|
||||
|
||||
// Message server should still be accessible
|
||||
assert.Equal(t, pm, ms.pluginManager)
|
||||
assert.NotNil(t, pm.messageServer)
|
||||
}
|
||||
|
||||
// Test 12: Concurrent NewPluginMessageServer Calls
|
||||
|
||||
func TestNewPluginMessageServer_Concurrent(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
servers := make([]*PluginMessageServer, numGoroutines)
|
||||
errors := make([]error, numGoroutines)
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
servers[idx], errors[idx] = NewPluginMessageServer(pm)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// All should succeed
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
assert.NoError(t, errors[i])
|
||||
assert.NotNil(t, servers[i])
|
||||
assert.Equal(t, pm, servers[i].pluginManager)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 13: Message Server Pointer Stability
|
||||
|
||||
func TestPluginMessageServer_PointerStability(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
ms1 := pm.messageServer
|
||||
ms2 := pm.messageServer
|
||||
|
||||
// Should be the same pointer
|
||||
assert.True(t, ms1 == ms2, "messageServer pointer should be stable")
|
||||
}
|
||||
|
||||
// Test 14: LogReceiveError Concurrent Calls
|
||||
|
||||
func TestPluginMessageServer_LogReceiveError_Concurrent(t *testing.T) {
|
||||
ms := newTestMessageServer(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
err := assert.AnError
|
||||
if idx%2 == 0 {
|
||||
err = context.Canceled
|
||||
}
|
||||
ms.logReceiveError(err, "test-connection")
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 15: Message Server Field Access
|
||||
|
||||
func TestPluginMessageServer_FieldAccess(t *testing.T) {
|
||||
ms := newTestMessageServer(t)
|
||||
|
||||
// Verify fields are accessible and not nil
|
||||
assert.NotNil(t, ms.pluginManager)
|
||||
assert.NotNil(t, ms.pluginManager.logger)
|
||||
assert.NotNil(t, ms.pluginManager.runningPluginMap)
|
||||
}
|
||||
|
||||
// Test 16: Message Server Doesn't Block Plugin Manager
|
||||
|
||||
func TestPluginMessageServer_DoesNotBlockPluginManager(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Message server should not prevent these operations
|
||||
config := newTestConnectionConfig("plugin1", "instance1", "conn1")
|
||||
pm.connectionConfigMap["conn1"] = config
|
||||
pm.populatePluginConnectionConfigs()
|
||||
|
||||
// Verify operations worked
|
||||
assert.Len(t, pm.pluginConnectionConfigMap, 1)
|
||||
|
||||
// Message server should still be valid
|
||||
assert.NotNil(t, pm.messageServer)
|
||||
assert.Equal(t, pm, pm.messageServer.pluginManager)
|
||||
}
|
||||
|
||||
// Test 17: Stress Test for Concurrent Access
|
||||
|
||||
func TestPluginMessageServer_StressConcurrentAccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping stress test in short mode")
|
||||
}
|
||||
|
||||
pm := newTestPluginManager(t)
|
||||
ms := pm.messageServer
|
||||
|
||||
var wg sync.WaitGroup
|
||||
duration := 1 * time.Second
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
// Multiple readers accessing pluginManager
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
_ = ms.pluginManager
|
||||
if ms.pluginManager != nil {
|
||||
_ = ms.pluginManager.connectionConfigMap
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
time.Sleep(duration)
|
||||
close(stopCh)
|
||||
wg.Wait()
|
||||
}
|
||||
716
pkg/pluginmanager_service/plugin_manager_test.go
Normal file
716
pkg/pluginmanager_service/plugin_manager_test.go
Normal file
@@ -0,0 +1,716 @@
|
||||
package pluginmanager_service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-hclog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/turbot/pipe-fittings/v2/plugin"
|
||||
sdkproto "github.com/turbot/steampipe-plugin-sdk/v5/grpc/proto"
|
||||
"github.com/turbot/steampipe/v2/pkg/connection"
|
||||
pb "github.com/turbot/steampipe/v2/pkg/pluginmanager_service/grpc/proto"
|
||||
)
|
||||
|
||||
// Test helpers and mocks
|
||||
|
||||
func newTestPluginManager(t *testing.T) *PluginManager {
|
||||
t.Helper()
|
||||
|
||||
logger := hclog.NewNullLogger()
|
||||
|
||||
pm := &PluginManager{
|
||||
logger: logger,
|
||||
runningPluginMap: make(map[string]*runningPlugin),
|
||||
pluginConnectionConfigMap: make(map[string][]*sdkproto.ConnectionConfig),
|
||||
connectionConfigMap: make(connection.ConnectionConfigMap),
|
||||
pluginCacheSizeMap: make(map[string]int64),
|
||||
plugins: make(connection.PluginMap),
|
||||
userLimiters: make(connection.PluginLimiterMap),
|
||||
pluginLimiters: make(connection.PluginLimiterMap),
|
||||
}
|
||||
|
||||
pm.messageServer = &PluginMessageServer{pluginManager: pm}
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
func newTestConnectionConfig(plugin, instance, connection string) *sdkproto.ConnectionConfig {
|
||||
return &sdkproto.ConnectionConfig{
|
||||
Plugin: plugin,
|
||||
PluginInstance: instance,
|
||||
Connection: connection,
|
||||
}
|
||||
}
|
||||
|
||||
// Test 1: Basic Initialization
|
||||
|
||||
func TestPluginManager_New(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
assert.NotNil(t, pm, "PluginManager should not be nil")
|
||||
assert.NotNil(t, pm.runningPluginMap, "runningPluginMap should be initialized")
|
||||
assert.NotNil(t, pm.messageServer, "messageServer should be initialized")
|
||||
assert.NotNil(t, pm.logger, "logger should be initialized")
|
||||
}
|
||||
|
||||
// Test 2: Connection Config Access
|
||||
|
||||
func TestPluginManager_GetConnectionConfig_NotFound(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
_, err := pm.getConnectionConfig("nonexistent")
|
||||
|
||||
assert.Error(t, err, "Should return error for nonexistent connection")
|
||||
assert.Contains(t, err.Error(), "does not exist", "Error should mention connection doesn't exist")
|
||||
}
|
||||
|
||||
func TestPluginManager_GetConnectionConfig_Found(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
expectedConfig := newTestConnectionConfig("test-plugin", "test-instance", "test-connection")
|
||||
pm.connectionConfigMap["test-connection"] = expectedConfig
|
||||
|
||||
config, err := pm.getConnectionConfig("test-connection")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedConfig, config)
|
||||
}
|
||||
|
||||
func TestPluginManager_GetConnectionConfig_NilMap(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.connectionConfigMap = nil
|
||||
|
||||
_, err := pm.getConnectionConfig("conn1")
|
||||
|
||||
assert.Error(t, err, "Should handle nil connectionConfigMap gracefully")
|
||||
}
|
||||
|
||||
// Test 3: Map Population
|
||||
|
||||
func TestPluginManager_PopulatePluginConnectionConfigs(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
|
||||
config2 := newTestConnectionConfig("plugin1", "instance1", "conn2")
|
||||
config3 := newTestConnectionConfig("plugin2", "instance2", "conn3")
|
||||
|
||||
pm.connectionConfigMap = connection.ConnectionConfigMap{
|
||||
"conn1": config1,
|
||||
"conn2": config2,
|
||||
"conn3": config3,
|
||||
}
|
||||
|
||||
pm.populatePluginConnectionConfigs()
|
||||
|
||||
assert.Len(t, pm.pluginConnectionConfigMap, 2, "Should have 2 plugin instances")
|
||||
assert.Len(t, pm.pluginConnectionConfigMap["instance1"], 2, "instance1 should have 2 connections")
|
||||
assert.Len(t, pm.pluginConnectionConfigMap["instance2"], 1, "instance2 should have 1 connection")
|
||||
}
|
||||
|
||||
// Test 4: Build Required Plugin Map
|
||||
|
||||
func TestPluginManager_BuildRequiredPluginMap(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
|
||||
config2 := newTestConnectionConfig("plugin1", "instance1", "conn2")
|
||||
config3 := newTestConnectionConfig("plugin2", "instance2", "conn3")
|
||||
|
||||
pm.connectionConfigMap = connection.ConnectionConfigMap{
|
||||
"conn1": config1,
|
||||
"conn2": config2,
|
||||
"conn3": config3,
|
||||
}
|
||||
pm.populatePluginConnectionConfigs()
|
||||
|
||||
req := &pb.GetRequest{
|
||||
Connections: []string{"conn1", "conn3"},
|
||||
}
|
||||
|
||||
pluginMap, requestedConns, err := pm.buildRequiredPluginMap(req)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pluginMap, 2, "Should map 2 plugin instances")
|
||||
assert.Len(t, requestedConns, 2, "Should have 2 requested connections")
|
||||
assert.Contains(t, requestedConns, "conn1")
|
||||
assert.Contains(t, requestedConns, "conn3")
|
||||
}
|
||||
|
||||
// Test 5: Concurrent Map Access
|
||||
|
||||
func TestPluginManager_ConcurrentMapAccess(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Populate some initial data
|
||||
for i := 0; i < 10; i++ {
|
||||
connName := fmt.Sprintf("conn%d", i)
|
||||
config := newTestConnectionConfig("plugin1", "instance1", connName)
|
||||
pm.connectionConfigMap[connName] = config
|
||||
}
|
||||
pm.populatePluginConnectionConfigs()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
|
||||
// Concurrent reads with proper locking
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
connName := fmt.Sprintf("conn%d", idx%10)
|
||||
|
||||
pm.mut.RLock()
|
||||
_ = pm.connectionConfigMap[connName]
|
||||
pm.mut.RUnlock()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, pm.connectionConfigMap, 10)
|
||||
}
|
||||
|
||||
// Test 6: Shutdown Flag Management
|
||||
|
||||
func TestPluginManager_Shutdown_SetsShuttingDownFlag(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
assert.False(t, pm.isShuttingDown(), "Initially should not be shutting down")
|
||||
|
||||
// Set the flag as Shutdown does
|
||||
pm.shutdownMut.Lock()
|
||||
pm.shuttingDown = true
|
||||
pm.shutdownMut.Unlock()
|
||||
|
||||
assert.True(t, pm.isShuttingDown(), "Should be shutting down after flag is set")
|
||||
}
|
||||
|
||||
func TestPluginManager_Shutdown_WaitsForPluginStart(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Simulate a plugin starting
|
||||
pm.startPluginWg.Add(1)
|
||||
|
||||
shutdownComplete := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
pm.shutdownMut.Lock()
|
||||
pm.shuttingDown = true
|
||||
pm.shutdownMut.Unlock()
|
||||
pm.startPluginWg.Wait()
|
||||
close(shutdownComplete)
|
||||
}()
|
||||
|
||||
// Give shutdown goroutine time to reach Wait
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Verify shutdown hasn't completed yet
|
||||
select {
|
||||
case <-shutdownComplete:
|
||||
t.Fatal("Shutdown completed before startPluginWg.Done() was called")
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
// Expected
|
||||
}
|
||||
|
||||
// Signal plugin start complete
|
||||
pm.startPluginWg.Done()
|
||||
|
||||
// Verify shutdown completes
|
||||
select {
|
||||
case <-shutdownComplete:
|
||||
// Expected
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
t.Fatal("Shutdown did not complete after startPluginWg.Done()")
|
||||
}
|
||||
}
|
||||
|
||||
// Test 7: Running Plugin Management
|
||||
|
||||
func TestPluginManager_AddRunningPlugin_Success(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Add a plugin config
|
||||
pm.plugins["test-instance"] = &plugin.Plugin{
|
||||
Plugin: "test-plugin",
|
||||
Instance: "test-instance",
|
||||
}
|
||||
|
||||
rp, err := pm.addRunningPlugin("test-instance")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, rp)
|
||||
assert.Equal(t, "test-instance", rp.pluginInstance)
|
||||
assert.NotNil(t, rp.initialized)
|
||||
assert.NotNil(t, rp.failed)
|
||||
|
||||
// Verify it was added to the map
|
||||
pm.mut.RLock()
|
||||
stored := pm.runningPluginMap["test-instance"]
|
||||
pm.mut.RUnlock()
|
||||
assert.Equal(t, rp, stored)
|
||||
}
|
||||
|
||||
func TestPluginManager_AddRunningPlugin_AlreadyExists(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Add a plugin config
|
||||
pm.plugins["test-instance"] = &plugin.Plugin{
|
||||
Plugin: "test-plugin",
|
||||
Instance: "test-instance",
|
||||
}
|
||||
|
||||
// Add first time
|
||||
_, err := pm.addRunningPlugin("test-instance")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to add again - should return retryable error
|
||||
_, err = pm.addRunningPlugin("test-instance")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "already started")
|
||||
}
|
||||
|
||||
func TestPluginManager_AddRunningPlugin_NoConfig(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Don't add any plugin config
|
||||
|
||||
_, err := pm.addRunningPlugin("nonexistent-instance")
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "no config")
|
||||
}
|
||||
|
||||
// Test 8: Concurrent Plugin Operations
|
||||
|
||||
func TestPluginManager_ConcurrentAddRunningPlugin(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Add plugin config
|
||||
pm.plugins["test-instance"] = &plugin.Plugin{
|
||||
Plugin: "test-plugin",
|
||||
Instance: "test-instance",
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
successCount := 0
|
||||
errorCount := 0
|
||||
var mu sync.Mutex
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := pm.addRunningPlugin("test-instance")
|
||||
mu.Lock()
|
||||
if err == nil {
|
||||
successCount++
|
||||
} else {
|
||||
errorCount++
|
||||
}
|
||||
mu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Only one should succeed, the rest should get retryable errors
|
||||
assert.Equal(t, 1, successCount, "Only one goroutine should succeed")
|
||||
assert.Equal(t, numGoroutines-1, errorCount, "All other goroutines should fail")
|
||||
}
|
||||
|
||||
// Test 9: IsShuttingDown with Concurrent Access
|
||||
|
||||
func TestPluginManager_IsShuttingDown_Concurrent(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numReaders := 50
|
||||
|
||||
// Start many readers
|
||||
for i := 0; i < numReaders; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 100; j++ {
|
||||
_ = pm.isShuttingDown()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// One writer
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
pm.shutdownMut.Lock()
|
||||
pm.shuttingDown = !pm.shuttingDown
|
||||
pm.shutdownMut.Unlock()
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 10: Plugin Cache Size Map
|
||||
|
||||
func TestPluginManager_SetPluginCacheSizeMap_NoCacheLimit(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
|
||||
config2 := newTestConnectionConfig("plugin2", "instance2", "conn2")
|
||||
|
||||
pm.pluginConnectionConfigMap = map[string][]*sdkproto.ConnectionConfig{
|
||||
"instance1": {config1},
|
||||
"instance2": {config2},
|
||||
}
|
||||
|
||||
pm.setPluginCacheSizeMap()
|
||||
|
||||
// When no max size is set, all plugins should have size 0 (unlimited)
|
||||
assert.Equal(t, int64(0), pm.pluginCacheSizeMap["instance1"])
|
||||
assert.Equal(t, int64(0), pm.pluginCacheSizeMap["instance2"])
|
||||
}
|
||||
|
||||
// Test 11: NonAggregatorConnectionCount
|
||||
|
||||
func TestPluginManager_NonAggregatorConnectionCount(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Regular connection (no child connections)
|
||||
config1 := &sdkproto.ConnectionConfig{
|
||||
Plugin: "plugin1",
|
||||
PluginInstance: "instance1",
|
||||
Connection: "conn1",
|
||||
ChildConnections: []string{},
|
||||
}
|
||||
|
||||
// Aggregator connection (has child connections)
|
||||
config2 := &sdkproto.ConnectionConfig{
|
||||
Plugin: "plugin1",
|
||||
PluginInstance: "instance1",
|
||||
Connection: "conn2",
|
||||
ChildConnections: []string{"child1", "child2"},
|
||||
}
|
||||
|
||||
// Another regular connection
|
||||
config3 := &sdkproto.ConnectionConfig{
|
||||
Plugin: "plugin2",
|
||||
PluginInstance: "instance2",
|
||||
Connection: "conn3",
|
||||
ChildConnections: []string{},
|
||||
}
|
||||
|
||||
pm.pluginConnectionConfigMap = map[string][]*sdkproto.ConnectionConfig{
|
||||
"instance1": {config1, config2},
|
||||
"instance2": {config3},
|
||||
}
|
||||
|
||||
count := pm.nonAggregatorConnectionCount()
|
||||
|
||||
// Should count only non-aggregator connections (conn1 and conn3)
|
||||
assert.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
// Test 12: GetPluginExemplarConnections
|
||||
|
||||
func TestPluginManager_GetPluginExemplarConnections(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
config1 := newTestConnectionConfig("plugin1", "instance1", "conn1")
|
||||
config2 := newTestConnectionConfig("plugin1", "instance1", "conn2")
|
||||
config3 := newTestConnectionConfig("plugin2", "instance2", "conn3")
|
||||
|
||||
pm.connectionConfigMap = connection.ConnectionConfigMap{
|
||||
"conn1": config1,
|
||||
"conn2": config2,
|
||||
"conn3": config3,
|
||||
}
|
||||
|
||||
exemplars := pm.getPluginExemplarConnections()
|
||||
|
||||
assert.Len(t, exemplars, 2, "Should have 2 plugins")
|
||||
// Should have one exemplar for each plugin (might be any of the connections)
|
||||
assert.Contains(t, []string{"conn1", "conn2"}, exemplars["plugin1"])
|
||||
assert.Equal(t, "conn3", exemplars["plugin2"])
|
||||
}
|
||||
|
||||
// Test 13: Goroutine Leak Detection
|
||||
|
||||
func TestPluginManager_NoGoroutineLeak_OnError(t *testing.T) {
|
||||
before := runtime.NumGoroutine()
|
||||
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Add plugin config
|
||||
pm.plugins["test-instance"] = &plugin.Plugin{
|
||||
Plugin: "test-plugin",
|
||||
Instance: "test-instance",
|
||||
}
|
||||
|
||||
// Try to add running plugin
|
||||
_, err := pm.addRunningPlugin("test-instance")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Clean up
|
||||
pm.mut.Lock()
|
||||
delete(pm.runningPluginMap, "test-instance")
|
||||
pm.mut.Unlock()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
after := runtime.NumGoroutine()
|
||||
|
||||
// Allow some tolerance for background goroutines
|
||||
if after > before+5 {
|
||||
t.Errorf("Potential goroutine leak: before=%d, after=%d", before, after)
|
||||
}
|
||||
}
|
||||
|
||||
// Test 14: Pool Access
|
||||
|
||||
func TestPluginManager_Pool(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Initially nil
|
||||
assert.Nil(t, pm.Pool())
|
||||
}
|
||||
|
||||
// Test 15: RefreshConnections
|
||||
|
||||
func TestPluginManager_RefreshConnections(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
req := &pb.RefreshConnectionsRequest{}
|
||||
|
||||
resp, err := pm.RefreshConnections(req)
|
||||
|
||||
require.NoError(t, err, "RefreshConnections should not return error")
|
||||
assert.NotNil(t, resp, "Response should not be nil")
|
||||
}
|
||||
|
||||
// Test 16: GetConnectionConfig Concurrent Access
|
||||
|
||||
func TestPluginManager_GetConnectionConfig_Concurrent(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
config := newTestConnectionConfig("plugin1", "instance1", "conn1")
|
||||
pm.connectionConfigMap["conn1"] = config
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
cfg, err := pm.getConnectionConfig("conn1")
|
||||
if err == nil {
|
||||
assert.Equal(t, "conn1", cfg.Connection)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 17: Running Plugin Structure
|
||||
|
||||
func TestRunningPlugin_Initialization(t *testing.T) {
|
||||
rp := &runningPlugin{
|
||||
pluginInstance: "test",
|
||||
imageRef: "test-image",
|
||||
initialized: make(chan struct{}),
|
||||
failed: make(chan struct{}),
|
||||
}
|
||||
|
||||
assert.NotNil(t, rp.initialized, "initialized channel should not be nil")
|
||||
assert.NotNil(t, rp.failed, "failed channel should not be nil")
|
||||
|
||||
// Verify channels are not closed initially
|
||||
select {
|
||||
case <-rp.initialized:
|
||||
t.Fatal("initialized channel should not be closed initially")
|
||||
default:
|
||||
// Expected
|
||||
}
|
||||
|
||||
select {
|
||||
case <-rp.failed:
|
||||
t.Fatal("failed channel should not be closed initially")
|
||||
default:
|
||||
// Expected
|
||||
}
|
||||
}
|
||||
|
||||
// Test 18: Multiple Concurrent Refreshes
|
||||
|
||||
func TestPluginManager_ConcurrentRefreshConnections(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 10
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
req := &pb.RefreshConnectionsRequest{}
|
||||
_, _ = pm.RefreshConnections(req)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 19: NonAggregatorConnectionCount Helper
|
||||
|
||||
func TestNonAggregatorConnectionCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
connections []*sdkproto.ConnectionConfig
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
connections: []*sdkproto.ConnectionConfig{},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "all non-aggregators",
|
||||
connections: []*sdkproto.ConnectionConfig{
|
||||
{Connection: "conn1", ChildConnections: []string{}},
|
||||
{Connection: "conn2", ChildConnections: []string{}},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "all aggregators",
|
||||
connections: []*sdkproto.ConnectionConfig{
|
||||
{Connection: "conn1", ChildConnections: []string{"child1"}},
|
||||
{Connection: "conn2", ChildConnections: []string{"child2"}},
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed",
|
||||
connections: []*sdkproto.ConnectionConfig{
|
||||
{Connection: "conn1", ChildConnections: []string{}},
|
||||
{Connection: "conn2", ChildConnections: []string{"child1"}},
|
||||
{Connection: "conn3", ChildConnections: []string{}},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "nil child connections",
|
||||
connections: []*sdkproto.ConnectionConfig{
|
||||
{Connection: "conn1", ChildConnections: nil},
|
||||
{Connection: "conn2", ChildConnections: []string{"child1"}},
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
count := nonAggregatorConnectionCount(tt.connections)
|
||||
assert.Equal(t, tt.expected, count)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test 20: GetResponse Helper
|
||||
|
||||
func TestNewGetResponse(t *testing.T) {
|
||||
resp := newGetResponse()
|
||||
|
||||
assert.NotNil(t, resp)
|
||||
assert.NotNil(t, resp.GetResponse)
|
||||
assert.NotNil(t, resp.ReattachMap)
|
||||
assert.NotNil(t, resp.FailureMap)
|
||||
}
|
||||
|
||||
// Test 21: EnsurePlugin Early Exit When Shutting Down
|
||||
|
||||
func TestPluginManager_EnsurePlugin_ShuttingDown(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Set shutting down flag
|
||||
pm.shutdownMut.Lock()
|
||||
pm.shuttingDown = true
|
||||
pm.shutdownMut.Unlock()
|
||||
|
||||
config := newTestConnectionConfig("plugin1", "instance1", "conn1")
|
||||
req := &pb.GetRequest{Connections: []string{"conn1"}}
|
||||
|
||||
_, err := pm.ensurePlugin("instance1", []*sdkproto.ConnectionConfig{config}, req)
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "shutting down")
|
||||
}
|
||||
|
||||
// Test 22: KillPlugin with Nil Client
|
||||
|
||||
func TestPluginManager_KillPlugin_NilClient(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
rp := &runningPlugin{
|
||||
pluginInstance: "test",
|
||||
client: nil,
|
||||
}
|
||||
|
||||
// Should not panic
|
||||
pm.killPlugin(rp)
|
||||
}
|
||||
|
||||
// Test 23: Stress Test for Map Access
|
||||
|
||||
func TestPluginManager_StressConcurrentMapAccess(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping stress test in short mode")
|
||||
}
|
||||
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Add initial configs
|
||||
for i := 0; i < 100; i++ {
|
||||
connName := fmt.Sprintf("conn%d", i)
|
||||
config := newTestConnectionConfig("plugin1", "instance1", connName)
|
||||
pm.connectionConfigMap[connName] = config
|
||||
}
|
||||
pm.populatePluginConnectionConfigs()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
duration := 1 * time.Second
|
||||
stopCh := make(chan struct{})
|
||||
|
||||
// Start multiple readers
|
||||
for i := 0; i < 20; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stopCh:
|
||||
return
|
||||
default:
|
||||
connName := fmt.Sprintf("conn%d", idx%100)
|
||||
pm.mut.RLock()
|
||||
_ = pm.connectionConfigMap[connName]
|
||||
_ = pm.pluginConnectionConfigMap["instance1"]
|
||||
pm.mut.RUnlock()
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Run for duration
|
||||
time.Sleep(duration)
|
||||
close(stopCh)
|
||||
wg.Wait()
|
||||
}
|
||||
423
pkg/pluginmanager_service/rate_limiters_helpers_test.go
Normal file
423
pkg/pluginmanager_service/rate_limiters_helpers_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package pluginmanager_service
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/turbot/pipe-fittings/v2/plugin"
|
||||
"github.com/turbot/steampipe/v2/pkg/connection"
|
||||
)
|
||||
|
||||
// Test helpers for rate limiter tests
|
||||
|
||||
func newTestRateLimiter(pluginName, name string, source string) *plugin.RateLimiter {
|
||||
return &plugin.RateLimiter{
|
||||
Plugin: pluginName,
|
||||
Name: name,
|
||||
Source: source,
|
||||
Status: plugin.LimiterStatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
// Test 1: ShouldFetchRateLimiterDefs
|
||||
|
||||
func TestPluginManager_ShouldFetchRateLimiterDefs_Nil(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.pluginLimiters = nil
|
||||
|
||||
should := pm.ShouldFetchRateLimiterDefs()
|
||||
|
||||
assert.True(t, should, "Should fetch when pluginLimiters is nil")
|
||||
}
|
||||
|
||||
func TestPluginManager_ShouldFetchRateLimiterDefs_NotNil(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.pluginLimiters = make(connection.PluginLimiterMap)
|
||||
|
||||
should := pm.ShouldFetchRateLimiterDefs()
|
||||
|
||||
assert.False(t, should, "Should not fetch when pluginLimiters is initialized")
|
||||
}
|
||||
|
||||
// Test 2: GetPluginsWithChangedLimiters
|
||||
|
||||
func TestPluginManager_GetPluginsWithChangedLimiters_NoChanges(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
limiter1 := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig)
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": limiter1,
|
||||
},
|
||||
}
|
||||
|
||||
newLimiters := connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": limiter1,
|
||||
},
|
||||
}
|
||||
|
||||
changed := pm.getPluginsWithChangedLimiters(newLimiters)
|
||||
|
||||
assert.Len(t, changed, 0, "No plugins should have changed limiters")
|
||||
}
|
||||
|
||||
func TestPluginManager_GetPluginsWithChangedLimiters_NewPlugin(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.userLimiters = connection.PluginLimiterMap{}
|
||||
|
||||
newLimiters := connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
changed := pm.getPluginsWithChangedLimiters(newLimiters)
|
||||
|
||||
assert.Len(t, changed, 1, "Should detect new plugin")
|
||||
assert.Contains(t, changed, "plugin1")
|
||||
}
|
||||
|
||||
func TestPluginManager_GetPluginsWithChangedLimiters_RemovedPlugin(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
newLimiters := connection.PluginLimiterMap{}
|
||||
|
||||
changed := pm.getPluginsWithChangedLimiters(newLimiters)
|
||||
|
||||
assert.Len(t, changed, 1, "Should detect removed plugin")
|
||||
assert.Contains(t, changed, "plugin1")
|
||||
}
|
||||
|
||||
// Test 3: UpdateRateLimiterStatus
|
||||
|
||||
func TestPluginManager_UpdateRateLimiterStatus_NoOverride(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
pluginLimiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
|
||||
pluginLimiter.Status = plugin.LimiterStatusActive
|
||||
|
||||
pm.pluginLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": pluginLimiter,
|
||||
},
|
||||
}
|
||||
pm.userLimiters = connection.PluginLimiterMap{}
|
||||
|
||||
pm.updateRateLimiterStatus()
|
||||
|
||||
assert.Equal(t, plugin.LimiterStatusActive, pluginLimiter.Status)
|
||||
}
|
||||
|
||||
func TestPluginManager_UpdateRateLimiterStatus_WithOverride(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
pluginLimiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
|
||||
pluginLimiter.Status = plugin.LimiterStatusActive
|
||||
|
||||
userLimiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig)
|
||||
|
||||
pm.pluginLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": pluginLimiter,
|
||||
},
|
||||
}
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": userLimiter,
|
||||
},
|
||||
}
|
||||
|
||||
pm.updateRateLimiterStatus()
|
||||
|
||||
assert.Equal(t, plugin.LimiterStatusOverridden, pluginLimiter.Status)
|
||||
}
|
||||
|
||||
func TestPluginManager_UpdateRateLimiterStatus_MultiplePlugins(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
plugin1Limiter1 := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
|
||||
plugin1Limiter2 := newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourcePlugin)
|
||||
plugin2Limiter1 := newTestRateLimiter("plugin2", "limiter1", plugin.LimiterSourcePlugin)
|
||||
|
||||
pm.pluginLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": plugin1Limiter1,
|
||||
"limiter2": plugin1Limiter2,
|
||||
},
|
||||
"plugin2": connection.LimiterMap{
|
||||
"limiter1": plugin2Limiter1,
|
||||
},
|
||||
}
|
||||
|
||||
// Only override plugin1/limiter1
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
pm.updateRateLimiterStatus()
|
||||
|
||||
assert.Equal(t, plugin.LimiterStatusOverridden, plugin1Limiter1.Status)
|
||||
assert.Equal(t, plugin.LimiterStatusActive, plugin1Limiter2.Status)
|
||||
assert.Equal(t, plugin.LimiterStatusActive, plugin2Limiter1.Status)
|
||||
}
|
||||
|
||||
// Test 4: GetUserDefinedLimitersForPlugin
|
||||
|
||||
func TestPluginManager_GetUserDefinedLimitersForPlugin_Exists(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
limiter := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig)
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": limiter,
|
||||
},
|
||||
}
|
||||
|
||||
result := pm.getUserDefinedLimitersForPlugin("plugin1")
|
||||
|
||||
assert.Len(t, result, 1)
|
||||
assert.Equal(t, limiter, result["limiter1"])
|
||||
}
|
||||
|
||||
func TestPluginManager_GetUserDefinedLimitersForPlugin_NotExists(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.userLimiters = connection.PluginLimiterMap{}
|
||||
|
||||
result := pm.getUserDefinedLimitersForPlugin("plugin1")
|
||||
|
||||
assert.NotNil(t, result, "Should return empty map, not nil")
|
||||
assert.Len(t, result, 0)
|
||||
}
|
||||
|
||||
// Test 5: GetUserAndPluginLimitersFromTableResult
|
||||
|
||||
func TestPluginManager_GetUserAndPluginLimitersFromTableResult(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
rateLimiters := []*plugin.RateLimiter{
|
||||
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin),
|
||||
newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig),
|
||||
newTestRateLimiter("plugin2", "limiter1", plugin.LimiterSourcePlugin),
|
||||
}
|
||||
|
||||
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
|
||||
|
||||
// Check plugin limiters
|
||||
assert.Len(t, pluginLimiters, 2)
|
||||
assert.NotNil(t, pluginLimiters["plugin1"]["limiter1"])
|
||||
assert.NotNil(t, pluginLimiters["plugin2"]["limiter1"])
|
||||
|
||||
// Check user limiters
|
||||
assert.Len(t, userLimiters, 1)
|
||||
assert.NotNil(t, userLimiters["plugin1"]["limiter2"])
|
||||
}
|
||||
|
||||
func TestPluginManager_GetUserAndPluginLimitersFromTableResult_Empty(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
rateLimiters := []*plugin.RateLimiter{}
|
||||
|
||||
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
|
||||
|
||||
assert.NotNil(t, pluginLimiters)
|
||||
assert.NotNil(t, userLimiters)
|
||||
assert.Len(t, pluginLimiters, 0)
|
||||
assert.Len(t, userLimiters, 0)
|
||||
}
|
||||
|
||||
// Test 6: GetPluginsWithChangedLimiters Concurrent
|
||||
|
||||
func TestPluginManager_GetPluginsWithChangedLimiters_Concurrent(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
newLimiters := connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
if idx%2 == 0 {
|
||||
// Add a new limiter
|
||||
newLimiters["plugin1"]["limiter2"] = newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig)
|
||||
}
|
||||
|
||||
_ = pm.getPluginsWithChangedLimiters(newLimiters)
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 7: UpdateRateLimiterStatus with Multiple Limiters
|
||||
|
||||
func TestPluginManager_UpdateRateLimiterStatus_MultipleLimiters(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
limiter1 := newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin)
|
||||
limiter2 := newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourcePlugin)
|
||||
limiter3 := newTestRateLimiter("plugin1", "limiter3", plugin.LimiterSourcePlugin)
|
||||
|
||||
pm.pluginLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": limiter1,
|
||||
"limiter2": limiter2,
|
||||
"limiter3": limiter3,
|
||||
},
|
||||
}
|
||||
|
||||
// Override only limiter2
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter2": newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
pm.updateRateLimiterStatus()
|
||||
|
||||
assert.Equal(t, plugin.LimiterStatusActive, limiter1.Status)
|
||||
assert.Equal(t, plugin.LimiterStatusOverridden, limiter2.Status)
|
||||
assert.Equal(t, plugin.LimiterStatusActive, limiter3.Status)
|
||||
}
|
||||
|
||||
// Test 8: GetUserAndPluginLimitersFromTableResult with Duplicate Names
|
||||
|
||||
func TestPluginManager_GetUserAndPluginLimitersFromTableResult_DuplicateNames(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
// Same limiter name, different sources
|
||||
rateLimiters := []*plugin.RateLimiter{
|
||||
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin),
|
||||
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
}
|
||||
|
||||
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
|
||||
|
||||
assert.NotNil(t, pluginLimiters["plugin1"]["limiter1"])
|
||||
assert.NotNil(t, userLimiters["plugin1"]["limiter1"])
|
||||
assert.NotEqual(t, pluginLimiters["plugin1"]["limiter1"], userLimiters["plugin1"]["limiter1"])
|
||||
}
|
||||
|
||||
// Test 9: UpdateRateLimiterStatus with Empty Maps
|
||||
|
||||
func TestPluginManager_UpdateRateLimiterStatus_EmptyMaps(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.pluginLimiters = connection.PluginLimiterMap{}
|
||||
pm.userLimiters = connection.PluginLimiterMap{}
|
||||
|
||||
// Should not panic
|
||||
pm.updateRateLimiterStatus()
|
||||
}
|
||||
|
||||
// Test 10: GetPluginsWithChangedLimiters with Nil Comparison
|
||||
|
||||
func TestPluginManager_GetPluginsWithChangedLimiters_NilComparison(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": nil,
|
||||
}
|
||||
|
||||
newLimiters := connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
changed := pm.getPluginsWithChangedLimiters(newLimiters)
|
||||
|
||||
assert.Contains(t, changed, "plugin1", "Should detect change from nil to non-nil")
|
||||
}
|
||||
|
||||
// Test 11: ShouldFetchRateLimiterDefs Concurrent
|
||||
|
||||
func TestPluginManager_ShouldFetchRateLimiterDefs_Concurrent(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.pluginLimiters = nil
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = pm.ShouldFetchRateLimiterDefs()
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 12: GetUserDefinedLimitersForPlugin Concurrent
|
||||
|
||||
func TestPluginManager_GetUserDefinedLimitersForPlugin_Concurrent(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
pm.userLimiters = connection.PluginLimiterMap{
|
||||
"plugin1": connection.LimiterMap{
|
||||
"limiter1": newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourceConfig),
|
||||
},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 100
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result := pm.getUserDefinedLimitersForPlugin("plugin1")
|
||||
assert.NotNil(t, result)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Test 13: GetUserAndPluginLimitersFromTableResult Concurrent
|
||||
|
||||
func TestPluginManager_GetUserAndPluginLimitersFromTableResult_Concurrent(t *testing.T) {
|
||||
pm := newTestPluginManager(t)
|
||||
|
||||
rateLimiters := []*plugin.RateLimiter{
|
||||
newTestRateLimiter("plugin1", "limiter1", plugin.LimiterSourcePlugin),
|
||||
newTestRateLimiter("plugin1", "limiter2", plugin.LimiterSourceConfig),
|
||||
newTestRateLimiter("plugin2", "limiter1", plugin.LimiterSourcePlugin),
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
numGoroutines := 50
|
||||
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
pluginLimiters, userLimiters := pm.getUserAndPluginLimitersFromTableResult(rateLimiters)
|
||||
assert.NotNil(t, pluginLimiters)
|
||||
assert.NotNil(t, userLimiters)
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
253
pkg/query/queryexecute/execute_test.go
Normal file
253
pkg/query/queryexecute/execute_test.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package queryexecute
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/turbot/pipe-fittings/v2/modconfig"
|
||||
pqueryresult "github.com/turbot/pipe-fittings/v2/queryresult"
|
||||
"github.com/turbot/steampipe/v2/pkg/db/db_common"
|
||||
"github.com/turbot/steampipe/v2/pkg/export"
|
||||
"github.com/turbot/steampipe/v2/pkg/initialisation"
|
||||
"github.com/turbot/steampipe/v2/pkg/query"
|
||||
"github.com/turbot/steampipe/v2/pkg/query/queryresult"
|
||||
)
|
||||
|
||||
// Test Helpers
|
||||
|
||||
// createMockInitData creates a mock InitData for testing
|
||||
func createMockInitData(t *testing.T) *query.InitData {
|
||||
t.Helper()
|
||||
|
||||
initData := &query.InitData{
|
||||
InitData: initialisation.InitData{
|
||||
Result: &db_common.InitResult{},
|
||||
ExportManager: export.NewManager(),
|
||||
Client: &mockClient{}, // Add mock client to prevent nil pointer panics
|
||||
},
|
||||
Loaded: make(chan struct{}),
|
||||
StartTime: time.Now(),
|
||||
Queries: []*modconfig.ResolvedQuery{},
|
||||
}
|
||||
|
||||
return initData
|
||||
}
|
||||
|
||||
// closeInitDataLoaded closes the Loaded channel to simulate initialization completion
|
||||
func closeInitDataLoaded(initData *query.InitData) {
|
||||
select {
|
||||
case <-initData.Loaded:
|
||||
// already closed
|
||||
default:
|
||||
close(initData.Loaded)
|
||||
}
|
||||
}
|
||||
|
||||
// Test Suite: RunBatchSession
|
||||
|
||||
func TestRunBatchSession_EmptyQueries(t *testing.T) {
|
||||
// ARRANGE: Create initData with no queries
|
||||
ctx := context.Background()
|
||||
initData := createMockInitData(t)
|
||||
initData.Queries = []*modconfig.ResolvedQuery{} // explicitly empty
|
||||
|
||||
// Simulate successful initialization
|
||||
closeInitDataLoaded(initData)
|
||||
|
||||
// ACT: Run batch session
|
||||
failures, err := RunBatchSession(ctx, initData)
|
||||
|
||||
// ASSERT: Should return 0 failures and no error
|
||||
assert.NoError(t, err, "RunBatchSession should not error with empty queries")
|
||||
assert.Equal(t, 0, failures, "Should return 0 failures when no queries to execute")
|
||||
}
|
||||
|
||||
func TestRunBatchSession_InitError(t *testing.T) {
|
||||
// ARRANGE: Create initData with an initialization error
|
||||
ctx := context.Background()
|
||||
initData := createMockInitData(t)
|
||||
|
||||
// Simulate initialization error
|
||||
expectedErr := assert.AnError
|
||||
initData.Result.Error = expectedErr
|
||||
closeInitDataLoaded(initData)
|
||||
|
||||
// ACT: Run batch session
|
||||
failures, err := RunBatchSession(ctx, initData)
|
||||
|
||||
// ASSERT: Should return the init error immediately
|
||||
assert.Equal(t, expectedErr, err, "Should return initialization error")
|
||||
assert.Equal(t, 0, failures, "Should return 0 failures when init fails")
|
||||
}
|
||||
|
||||
// Test Suite: Helper Functions
|
||||
|
||||
func TestNeedSnapshot_DefaultValues(t *testing.T) {
|
||||
// This test verifies the needSnapshot function behavior with default config
|
||||
// Note: This is a simple test but ensures the function doesn't panic
|
||||
|
||||
// ACT: Call needSnapshot with default viper config
|
||||
result := needSnapshot()
|
||||
|
||||
// ASSERT: Should return false with default settings
|
||||
assert.False(t, result, "needSnapshot should return false with default settings")
|
||||
}
|
||||
|
||||
func TestShowBlankLineBetweenResults_DefaultValues(t *testing.T) {
|
||||
// This test verifies showBlankLineBetweenResults function with default config
|
||||
|
||||
// ACT: Call function with default viper config
|
||||
result := showBlankLineBetweenResults()
|
||||
|
||||
// ASSERT: Should return true with default settings (not CSV without header)
|
||||
assert.True(t, result, "Should show blank lines with default settings")
|
||||
}
|
||||
|
||||
func TestHandlePublishSnapshotError_PaymentRequired(t *testing.T) {
|
||||
// ARRANGE: Create a 402 Payment Required error
|
||||
err := assert.AnError
|
||||
err = &mockError{msg: "402 Payment Required"}
|
||||
|
||||
// ACT: Handle the error
|
||||
result := handlePublishSnapshotError(err)
|
||||
|
||||
// ASSERT: Should reword the error message
|
||||
assert.Error(t, result)
|
||||
assert.Contains(t, result.Error(), "maximum number of snapshots reached")
|
||||
}
|
||||
|
||||
func TestHandlePublishSnapshotError_OtherError(t *testing.T) {
|
||||
// ARRANGE: Create a different error
|
||||
err := assert.AnError
|
||||
|
||||
// ACT: Handle the error
|
||||
result := handlePublishSnapshotError(err)
|
||||
|
||||
// ASSERT: Should return the error unchanged
|
||||
assert.Equal(t, err, result)
|
||||
}
|
||||
|
||||
// Test Suite: Edge Cases and Resource Management
|
||||
|
||||
func TestExecuteQueries_EmptyQueriesList(t *testing.T) {
|
||||
// ARRANGE: InitData with empty queries list
|
||||
ctx := context.Background()
|
||||
initData := createMockInitData(t)
|
||||
initData.Queries = []*modconfig.ResolvedQuery{}
|
||||
|
||||
// ACT: Execute queries directly
|
||||
failures := executeQueries(ctx, initData)
|
||||
|
||||
// ASSERT: Should return 0 failures
|
||||
assert.Equal(t, 0, failures, "Should return 0 failures for empty queries list")
|
||||
}
|
||||
|
||||
// Test Suite: Context and Cancellation
|
||||
|
||||
func TestRunBatchSession_CancelHandlerSetup(t *testing.T) {
|
||||
// This test verifies that the cancel handler doesn't cause panics
|
||||
// We can't easily test the actual cancellation behavior without integration tests
|
||||
|
||||
// ARRANGE
|
||||
ctx := context.Background()
|
||||
initData := createMockInitData(t)
|
||||
closeInitDataLoaded(initData)
|
||||
|
||||
// ACT: Run batch session
|
||||
// Note: This test just verifies no panic occurs when setting up cancel handler
|
||||
assert.NotPanics(t, func() {
|
||||
_, _ = RunBatchSession(ctx, initData)
|
||||
}, "Should not panic when setting up cancel handler")
|
||||
}
|
||||
|
||||
// Test Suite: Result Wrapping
|
||||
|
||||
func TestWrapResult_NotNil(t *testing.T) {
|
||||
// This test ensures WrapResult doesn't panic and returns a valid wrapper
|
||||
|
||||
// ARRANGE: Create a basic result from pipe-fittings
|
||||
// Note: We need to use the pipe-fittings queryresult package
|
||||
// This test verifies the wrapper functionality exists and doesn't panic
|
||||
wrapped := queryresult.NewResult(nil)
|
||||
|
||||
// ASSERT: Should return a valid result
|
||||
assert.NotNil(t, wrapped, "NewResult should not return nil")
|
||||
}
|
||||
|
||||
// Mock Types
|
||||
|
||||
type mockError struct {
|
||||
msg string
|
||||
}
|
||||
|
||||
func (e *mockError) Error() string {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
// mockClient is a minimal mock implementation of db_common.Client for testing
|
||||
type mockClient struct {
|
||||
customSearchPath []string
|
||||
requiredSearchPath []string
|
||||
}
|
||||
|
||||
func (m *mockClient) Close(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockClient) LoadUserSearchPath(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockClient) SetRequiredSessionSearchPath(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockClient) GetRequiredSessionSearchPath() []string {
|
||||
return m.requiredSearchPath
|
||||
}
|
||||
|
||||
func (m *mockClient) GetCustomSearchPath() []string {
|
||||
return m.customSearchPath
|
||||
}
|
||||
|
||||
func (m *mockClient) AcquireManagementConnection(ctx context.Context) (*pgxpool.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) AcquireSession(ctx context.Context) *db_common.AcquireSessionResult {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockClient) ExecuteSync(ctx context.Context, query string, args ...any) (*pqueryresult.SyncQueryResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) Execute(ctx context.Context, query string, args ...any) (*queryresult.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) ExecuteSyncInSession(ctx context.Context, session *db_common.DatabaseSession, query string, args ...any) (*pqueryresult.SyncQueryResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) ExecuteInSession(ctx context.Context, session *db_common.DatabaseSession, onConnectionLost func(), query string, args ...any) (*queryresult.Result, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) ResetPools(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (m *mockClient) GetSchemaFromDB(ctx context.Context) (*db_common.SchemaMetadata, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockClient) ServerSettings() *db_common.ServerSettings {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockClient) RegisterNotificationListener(f func(notification *pgconn.Notification)) {
|
||||
}
|
||||
412
pkg/steampipeconfig/connection_state_map_test.go
Normal file
412
pkg/steampipeconfig/connection_state_map_test.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package steampipeconfig
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/turbot/steampipe/v2/pkg/constants"
|
||||
)
|
||||
|
||||
func TestConnectionStateMapGetSummary(t *testing.T) {
|
||||
stateMap := ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn3": &ConnectionState{
|
||||
ConnectionName: "conn3",
|
||||
State: constants.ConnectionStateError,
|
||||
},
|
||||
"conn4": &ConnectionState{
|
||||
ConnectionName: "conn4",
|
||||
State: constants.ConnectionStatePending,
|
||||
},
|
||||
}
|
||||
|
||||
summary := stateMap.GetSummary()
|
||||
|
||||
if summary[constants.ConnectionStateReady] != 2 {
|
||||
t.Errorf("Expected 2 ready connections, got %d", summary[constants.ConnectionStateReady])
|
||||
}
|
||||
|
||||
if summary[constants.ConnectionStateError] != 1 {
|
||||
t.Errorf("Expected 1 error connection, got %d", summary[constants.ConnectionStateError])
|
||||
}
|
||||
|
||||
if summary[constants.ConnectionStatePending] != 1 {
|
||||
t.Errorf("Expected 1 pending connection, got %d", summary[constants.ConnectionStatePending])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapPending(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
stateMap ConnectionStateMap
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has pending connections",
|
||||
stateMap: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStatePending,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has pending incomplete connections",
|
||||
stateMap: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStatePendingIncomplete,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no pending connections",
|
||||
stateMap: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStateError,
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
stateMap: ConnectionStateMap{},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.stateMap.Pending()
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapLoaded(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
stateMap ConnectionStateMap
|
||||
connections []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "all connections loaded",
|
||||
stateMap: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStateError,
|
||||
},
|
||||
},
|
||||
connections: []string{},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "some connections not loaded",
|
||||
stateMap: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStatePending,
|
||||
},
|
||||
},
|
||||
connections: []string{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "specific connections loaded",
|
||||
stateMap: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStatePending,
|
||||
},
|
||||
},
|
||||
connections: []string{"conn1"},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "disabled connections are loaded",
|
||||
stateMap: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateDisabled,
|
||||
},
|
||||
},
|
||||
connections: []string{},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.stateMap.Loaded(testCase.connections...)
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapConnectionsInState(t *testing.T) {
|
||||
stateMap := ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStateError,
|
||||
},
|
||||
"conn3": &ConnectionState{
|
||||
ConnectionName: "conn3",
|
||||
State: constants.ConnectionStatePending,
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
states []string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has ready connections",
|
||||
states: []string{constants.ConnectionStateReady},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has error or pending connections",
|
||||
states: []string{constants.ConnectionStateError, constants.ConnectionStatePending},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no updating connections",
|
||||
states: []string{constants.ConnectionStateUpdating},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no deleting connections",
|
||||
states: []string{constants.ConnectionStateDeleting},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := stateMap.ConnectionsInState(testCase.states...)
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapEquals(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
map1 ConnectionStateMap
|
||||
map2 ConnectionStateMap
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "equal maps",
|
||||
map1: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
Plugin: "plugin1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
map2: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
Plugin: "plugin1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different plugins",
|
||||
map1: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
Plugin: "plugin1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
map2: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
Plugin: "plugin2",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different keys",
|
||||
map1: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
Plugin: "plugin1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
map2: ConnectionStateMap{
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
Plugin: "plugin1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil vs non-nil",
|
||||
map1: nil,
|
||||
map2: ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
Plugin: "plugin1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.map1.Equals(testCase.map2)
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapConnectionModTime(t *testing.T) {
|
||||
now := time.Now()
|
||||
earlier := now.Add(-1 * time.Hour)
|
||||
later := now.Add(1 * time.Hour)
|
||||
|
||||
stateMap := ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
ConnectionModTime: earlier,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
ConnectionModTime: later,
|
||||
},
|
||||
"conn3": &ConnectionState{
|
||||
ConnectionName: "conn3",
|
||||
ConnectionModTime: now,
|
||||
},
|
||||
}
|
||||
|
||||
result := stateMap.ConnectionModTime()
|
||||
|
||||
if !result.Equal(later) {
|
||||
t.Errorf("Expected latest mod time %v, got %v", later, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapConnectionModTimeEmpty(t *testing.T) {
|
||||
stateMap := ConnectionStateMap{}
|
||||
|
||||
result := stateMap.ConnectionModTime()
|
||||
|
||||
if !result.IsZero() {
|
||||
t.Errorf("Expected zero time for empty map, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapGetPluginToConnectionMap(t *testing.T) {
|
||||
stateMap := ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
Plugin: "plugin1",
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
Plugin: "plugin1",
|
||||
},
|
||||
"conn3": &ConnectionState{
|
||||
ConnectionName: "conn3",
|
||||
Plugin: "plugin2",
|
||||
},
|
||||
}
|
||||
|
||||
result := stateMap.GetPluginToConnectionMap()
|
||||
|
||||
if len(result["plugin1"]) != 2 {
|
||||
t.Errorf("Expected 2 connections for plugin1, got %d", len(result["plugin1"]))
|
||||
}
|
||||
|
||||
if len(result["plugin2"]) != 1 {
|
||||
t.Errorf("Expected 1 connection for plugin2, got %d", len(result["plugin2"]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateMapSetConnectionsToPendingOrIncomplete(t *testing.T) {
|
||||
stateMap := ConnectionStateMap{
|
||||
"conn1": &ConnectionState{
|
||||
ConnectionName: "conn1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
"conn2": &ConnectionState{
|
||||
ConnectionName: "conn2",
|
||||
State: constants.ConnectionStateError,
|
||||
},
|
||||
"conn3": &ConnectionState{
|
||||
ConnectionName: "conn3",
|
||||
State: constants.ConnectionStateDisabled,
|
||||
},
|
||||
}
|
||||
|
||||
stateMap.SetConnectionsToPendingOrIncomplete()
|
||||
|
||||
if stateMap["conn1"].State != constants.ConnectionStatePending {
|
||||
t.Errorf("Expected conn1 to be pending, got %s", stateMap["conn1"].State)
|
||||
}
|
||||
|
||||
if stateMap["conn2"].State != constants.ConnectionStatePendingIncomplete {
|
||||
t.Errorf("Expected conn2 to be pending incomplete, got %s", stateMap["conn2"].State)
|
||||
}
|
||||
|
||||
if stateMap["conn3"].State != constants.ConnectionStateDisabled {
|
||||
t.Errorf("Expected conn3 to remain disabled, got %s", stateMap["conn3"].State)
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,10 @@ package steampipeconfig
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
typehelpers "github.com/turbot/go-kit/types"
|
||||
"github.com/turbot/steampipe/v2/pkg/constants"
|
||||
)
|
||||
|
||||
func TestConnectionsUpdateEqual(t *testing.T) {
|
||||
@@ -25,6 +29,84 @@ func TestConnectionsUpdateEqual(t *testing.T) {
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "different plugin",
|
||||
data1: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
State: "ready",
|
||||
},
|
||||
data2: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "different_plugin",
|
||||
State: "ready",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different type",
|
||||
data1: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
Type: typehelpers.String("aggregator"),
|
||||
State: "ready",
|
||||
},
|
||||
data2: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
Type: nil,
|
||||
State: "ready",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different import schema",
|
||||
data1: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
ImportSchema: "enabled",
|
||||
State: "ready",
|
||||
},
|
||||
data2: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
ImportSchema: "disabled",
|
||||
State: "ready",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "different error",
|
||||
data1: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
ConnectionError: typehelpers.String("error1"),
|
||||
State: "error",
|
||||
},
|
||||
data2: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
ConnectionError: typehelpers.String("error2"),
|
||||
State: "error",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "plugin mod time within tolerance",
|
||||
data1: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
PluginModTime: time.Now(),
|
||||
State: "ready",
|
||||
},
|
||||
data2: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Plugin: "test_plugin",
|
||||
PluginModTime: time.Now().Add(500 * time.Microsecond),
|
||||
State: "ready",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
@@ -36,3 +118,188 @@ func TestConnectionsUpdateEqual(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateLoaded(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
state *ConnectionState
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "ready state is loaded",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "error state is loaded",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateError,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "disabled state is loaded",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateDisabled,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "pending state is not loaded",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStatePending,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "updating state is not loaded",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateUpdating,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.state.Loaded()
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateDisabled(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
state *ConnectionState
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "disabled state",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateDisabled,
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "ready state is not disabled",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateReady,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error state is not disabled",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateError,
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.state.Disabled()
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateGetType(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
state *ConnectionState
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "aggregator type",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Type: typehelpers.String("aggregator"),
|
||||
},
|
||||
expected: "aggregator",
|
||||
},
|
||||
{
|
||||
name: "nil type returns empty string",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
Type: nil,
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.state.GetType()
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateError(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
state *ConnectionState
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "error message",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
ConnectionError: typehelpers.String("test error"),
|
||||
},
|
||||
expected: "test error",
|
||||
},
|
||||
{
|
||||
name: "nil error returns empty string",
|
||||
state: &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
ConnectionError: nil,
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.state.Error()
|
||||
if result != testCase.expected {
|
||||
t.Errorf("Expected %v, got %v", testCase.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionStateSetError(t *testing.T) {
|
||||
state := &ConnectionState{
|
||||
ConnectionName: "test1",
|
||||
State: constants.ConnectionStateReady,
|
||||
}
|
||||
|
||||
state.SetError("test error")
|
||||
|
||||
if state.State != constants.ConnectionStateError {
|
||||
t.Errorf("Expected state to be %s, got %s", constants.ConnectionStateError, state.State)
|
||||
}
|
||||
|
||||
if state.Error() != "test error" {
|
||||
t.Errorf("Expected error to be 'test error', got %s", state.Error())
|
||||
}
|
||||
}
|
||||
|
||||
106
pkg/steampipeconfig/validation_failure_test.go
Normal file
106
pkg/steampipeconfig/validation_failure_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package steampipeconfig
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidationFailureString(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
failure ValidationFailure
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "basic validation failure",
|
||||
failure: ValidationFailure{
|
||||
Plugin: "hub.steampipe.io/plugins/turbot/aws@latest",
|
||||
ConnectionName: "aws_prod",
|
||||
Message: "invalid configuration",
|
||||
ShouldDropIfExists: false,
|
||||
},
|
||||
expected: []string{
|
||||
"Connection: aws_prod",
|
||||
"Plugin: hub.steampipe.io/plugins/turbot/aws@latest",
|
||||
"Error: invalid configuration",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "validation failure with drop flag",
|
||||
failure: ValidationFailure{
|
||||
Plugin: "hub.steampipe.io/plugins/turbot/gcp@latest",
|
||||
ConnectionName: "gcp_dev",
|
||||
Message: "missing required field",
|
||||
ShouldDropIfExists: true,
|
||||
},
|
||||
expected: []string{
|
||||
"Connection: gcp_dev",
|
||||
"Plugin: hub.steampipe.io/plugins/turbot/gcp@latest",
|
||||
"Error: missing required field",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "validation failure with empty message",
|
||||
failure: ValidationFailure{
|
||||
Plugin: "test_plugin",
|
||||
ConnectionName: "test_conn",
|
||||
Message: "",
|
||||
ShouldDropIfExists: false,
|
||||
},
|
||||
expected: []string{
|
||||
"Connection: test_conn",
|
||||
"Plugin: test_plugin",
|
||||
"Error: ",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
result := testCase.failure.String()
|
||||
|
||||
for _, expected := range testCase.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected result to contain '%s', got: %s", expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidationFailureStringFormat(t *testing.T) {
|
||||
failure := ValidationFailure{
|
||||
Plugin: "test_plugin",
|
||||
ConnectionName: "test_connection",
|
||||
Message: "test error",
|
||||
ShouldDropIfExists: false,
|
||||
}
|
||||
|
||||
result := failure.String()
|
||||
|
||||
// Verify the format includes the expected labels
|
||||
if !strings.Contains(result, "Connection:") {
|
||||
t.Error("Expected result to contain 'Connection:' label")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Plugin:") {
|
||||
t.Error("Expected result to contain 'Plugin:' label")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "Error:") {
|
||||
t.Error("Expected result to contain 'Error:' label")
|
||||
}
|
||||
|
||||
// Verify the values are present
|
||||
if !strings.Contains(result, "test_connection") {
|
||||
t.Error("Expected result to contain connection name")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "test_plugin") {
|
||||
t.Error("Expected result to contain plugin name")
|
||||
}
|
||||
|
||||
if !strings.Contains(result, "test error") {
|
||||
t.Error("Expected result to contain error message")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user