Race on cancelActiveQuery without synchronization closes #4802 (#4844)

* Add test for #4802: cancelActiveQuery should be safe for concurrent access

* Fix #4802: Add mutex protection to cancelActiveQuery
This commit is contained in:
Nathan Wallace
2025-11-16 13:46:48 -05:00
committed by GitHub
parent bf3092396c
commit 57712dc4df
3 changed files with 75 additions and 0 deletions

View File

@@ -2,6 +2,7 @@ package interactive
import (
"context"
"sync"
"testing"
"time"
@@ -450,3 +451,70 @@ func TestNoGoroutineLeaks(t *testing.T) {
}
}
}
// TestConcurrentCancellation tests that cancelActiveQuery can be accessed
// concurrently without triggering data races.
// This test reproduces the race condition reported in issue #4802.
func TestConcurrentCancellation(t *testing.T) {
// Create a minimal InteractiveClient
client := &InteractiveClient{}
// Simulate concurrent access to cancelActiveQuery from multiple goroutines
// This mirrors real-world usage where:
// - createQueryContext() sets cancelActiveQuery
// - cancelActiveQueryIfAny() reads and clears it
// - signal handlers may also call cancelActiveQueryIfAny()
var wg sync.WaitGroup
numGoroutines := 10
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
// Simulate creating a query context (writes cancelActiveQuery)
ctx := client.createQueryContext(context.Background())
_ = ctx
}()
wg.Add(1)
go func() {
defer wg.Done()
// Simulate cancelling the active query (reads and writes cancelActiveQuery)
client.cancelActiveQueryIfAny()
}()
}
// Wait for all goroutines to complete
wg.Wait()
// If we get here without panicking or race detector errors, the test passes
// Note: This test will fail when run with -race flag if cancelActiveQuery access is not synchronized
}
// TestMultipleConcurrentCancellations tests rapid concurrent cancellations
// to stress test the synchronization.
func TestMultipleConcurrentCancellations(t *testing.T) {
client := &InteractiveClient{}
var wg sync.WaitGroup
numIterations := 100
// Create a query context first
_ = client.createQueryContext(context.Background())
// Now try to cancel it from multiple goroutines simultaneously
for i := 0; i < numIterations; i++ {
wg.Add(1)
go func() {
defer wg.Done()
client.cancelActiveQueryIfAny()
}()
}
wg.Wait()
// Verify the client is in a consistent state
if client.cancelActiveQuery != nil {
t.Error("Expected cancelActiveQuery to be nil after all cancellations")
}
}

View File

@@ -55,6 +55,8 @@ type InteractiveClient struct {
// NOTE: should ONLY be called by cancelActiveQueryIfAny
cancelActiveQuery context.CancelFunc
cancelPrompt context.CancelFunc
// mutex to protect concurrent access to cancelActiveQuery
cancelMutex sync.Mutex
// channel used internally to pass the initialisation result
initResultChan chan *db_common.InitResult

View File

@@ -18,11 +18,16 @@ func (c *InteractiveClient) createPromptContext(parentContext context.Context) c
func (c *InteractiveClient) createQueryContext(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
c.cancelMutex.Lock()
c.cancelActiveQuery = cancel
c.cancelMutex.Unlock()
return ctx
}
func (c *InteractiveClient) cancelActiveQueryIfAny() {
c.cancelMutex.Lock()
defer c.cancelMutex.Unlock()
if c.cancelActiveQuery != nil {
log.Println("[INFO] cancelActiveQueryIfAny CALLING cancelActiveQuery")
c.cancelActiveQuery()