diff --git a/pkg/db/db_client/db_client.go b/pkg/db/db_client/db_client.go index b21a9fa2f..ef066eda3 100644 --- a/pkg/db/db_client/db_client.go +++ b/pkg/db/db_client/db_client.go @@ -169,7 +169,10 @@ func (c *DbClient) Close(context.Context) error { c.closePools() // nullify active sessions, since with the closing of the pools // none of the sessions will be valid anymore + // Acquire mutex to prevent concurrent access to sessions map + c.sessionsMutex.Lock() c.sessions = nil + c.sessionsMutex.Unlock() return nil } diff --git a/pkg/db/db_client/db_client_session.go b/pkg/db/db_client/db_client_session.go index 435e0073d..d877482ff 100644 --- a/pkg/db/db_client/db_client_session.go +++ b/pkg/db/db_client/db_client_session.go @@ -38,6 +38,12 @@ func (c *DbClient) AcquireSession(ctx context.Context) (sessionResult *db_common backendPid := databaseConnection.Conn().PgConn().PID() c.sessionsMutex.Lock() + // Check if client has been closed (sessions set to nil) + if c.sessions == nil { + c.sessionsMutex.Unlock() + sessionResult.Error = fmt.Errorf("client has been closed") + return sessionResult + } session, found := c.sessions[backendPid] if !found { session = db_common.NewDBSession(backendPid) diff --git a/pkg/db/db_client/db_client_test.go b/pkg/db/db_client/db_client_test.go index fa039eecc..7496993a0 100644 --- a/pkg/db/db_client/db_client_test.go +++ b/pkg/db/db_client/db_client_test.go @@ -161,6 +161,88 @@ func TestDbClient_Close_ClearsSessionsMap(t *testing.T) { assert.Nil(t, client.sessions, "Sessions map should be nil after Close()") } +// TestDbClient_ConcurrentCloseAndRead verifies that concurrent reads don't panic +// when Close() sets sessions to nil +// Reference: https://github.com/turbot/steampipe/issues/4793 +func TestDbClient_ConcurrentCloseAndRead(t *testing.T) { + + // This test simulates the race condition where: + // 1. A goroutine enters AcquireSession, locks the mutex, reads c.sessions + // 2. Close() sets c.sessions = nil WITHOUT holding the mutex + // 3. The goroutine tries to write to c.sessions which is now nil + // This causes a nil map panic or data race + + // Run the test multiple times to increase chance of catching the race + for i := 0; i < 50; i++ { + client := &DbClient{ + sessions: make(map[uint32]*db_common.DatabaseSession), + sessionsMutex: &sync.Mutex{}, + } + + done := make(chan bool, 2) + + // Goroutine 1: Simulates AcquireSession behavior + go func() { + defer func() { done <- true }() + + client.sessionsMutex.Lock() + // After the fix, code should check if sessions is nil + if client.sessions != nil { + _, found := client.sessions[12345] + if !found { + client.sessions[12345] = db_common.NewDBSession(12345) + } + } + client.sessionsMutex.Unlock() + }() + + // Goroutine 2: Calls Close() + go func() { + defer func() { done <- true }() + // Without the fix, Close() sets sessions to nil without mutex protection + // This is the bug - it should acquire the mutex first + client.Close(nil) + }() + + // Wait for both goroutines + <-done + <-done + } + + // With the bug present, running with -race will detect the data race + // After the fix, this test should pass cleanly +} + +// TestDbClient_SessionsMapNilAfterClose verifies that accessing sessions after Close +// doesn't cause a nil pointer panic +// Reference: https://github.com/turbot/steampipe/issues/4793 +func TestDbClient_SessionsMapNilAfterClose(t *testing.T) { + + client := &DbClient{ + sessions: make(map[uint32]*db_common.DatabaseSession), + sessionsMutex: &sync.Mutex{}, + } + + // Add a session + client.sessionsMutex.Lock() + client.sessions[12345] = db_common.NewDBSession(12345) + client.sessionsMutex.Unlock() + + // Close sets sessions to nil (without mutex protection - this is the bug) + client.Close(nil) + + // Attempt to access sessions like AcquireSession does + // After the fix, this should not panic + client.sessionsMutex.Lock() + defer client.sessionsMutex.Unlock() + + // With the bug: this panics because sessions is nil + // After fix: sessions should either not be nil, or code checks for nil + if client.sessions != nil { + client.sessions[67890] = db_common.NewDBSession(67890) + } +} + // TestDbClient_SessionsMutexProtectsMap verifies that sessionsMutex protects all map operations func TestDbClient_SessionsMutexProtectsMap(t *testing.T) { // This is a structural test to verify the sessions map is never accessed without the mutex