Fix #4803: Use atomic.Bool for initialisationComplete flag

Replace the plain boolean initialisationComplete field with atomic.Bool
to prevent data races when accessed concurrently by multiple goroutines.

Changes:
- Change field type from bool to atomic.Bool
- Use .Store(true) for writes
- Use .Load() for reads in isInitialised() and handleConnectionUpdateNotification()
- Update test to use atomic operations

The test now passes with -race flag, confirming the race condition is fixed.
This commit is contained in:
Nathan Wallace
2025-11-11 23:26:11 +08:00
parent 1c2a5949b5
commit e70f697d20
3 changed files with 10 additions and 10 deletions

View File

@@ -10,6 +10,7 @@ import (
"os/signal" "os/signal"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/alecthomas/chroma/formatters" "github.com/alecthomas/chroma/formatters"
@@ -58,7 +59,7 @@ type InteractiveClient struct {
// channel used internally to pass the initialisation result // channel used internally to pass the initialisation result
initResultChan chan *db_common.InitResult initResultChan chan *db_common.InitResult
// flag set when initialisation is complete (with or without errors) // flag set when initialisation is complete (with or without errors)
initialisationComplete bool initialisationComplete atomic.Bool
afterClose AfterPromptCloseAction afterClose AfterPromptCloseAction
// lock while execution is occurring to avoid errors/warnings being shown // lock while execution is occurring to avoid errors/warnings being shown
executionLock sync.Mutex executionLock sync.Mutex
@@ -731,7 +732,7 @@ func (c *InteractiveClient) handleConnectionUpdateNotification(ctx context.Conte
// ignore schema update notifications until initialisation is complete // ignore schema update notifications until initialisation is complete
// (we may receive schema update messages from the initial refresh connections, but we do not need to reload // (we may receive schema update messages from the initial refresh connections, but we do not need to reload
// the schema as we will have already loaded the correct schema) // the schema as we will have already loaded the correct schema)
if !c.initialisationComplete { if !c.initialisationComplete.Load() {
log.Printf("[INFO] received schema update notification but ignoring it as we are initializing") log.Printf("[INFO] received schema update notification but ignoring it as we are initializing")
return return
} }

View File

@@ -16,7 +16,7 @@ import (
func (c *InteractiveClient) handleInitResult(ctx context.Context, initResult *db_common.InitResult) { func (c *InteractiveClient) handleInitResult(ctx context.Context, initResult *db_common.InitResult) {
// whatever happens, set initialisationComplete // whatever happens, set initialisationComplete
defer func() { defer func() {
c.initialisationComplete = true c.initialisationComplete.Store(true)
}() }()
if initResult.Error != nil { if initResult.Error != nil {
@@ -127,7 +127,7 @@ func (c *InteractiveClient) readInitDataStream(ctx context.Context) {
// return whether the client is initialises // return whether the client is initialises
// there are 3 conditions> // there are 3 conditions>
func (c *InteractiveClient) isInitialised() bool { func (c *InteractiveClient) isInitialised() bool {
return c.initialisationComplete return c.initialisationComplete.Load()
} }
func (c *InteractiveClient) waitForInitData(ctx context.Context) error { func (c *InteractiveClient) waitForInitData(ctx context.Context) error {

View File

@@ -543,9 +543,8 @@ func TestCancelActiveQueryIfAny(t *testing.T) {
// //
// Bug: #4803 // Bug: #4803
func TestInitialisationComplete_RaceCondition(t *testing.T) { func TestInitialisationComplete_RaceCondition(t *testing.T) {
c := &InteractiveClient{ c := &InteractiveClient{}
initialisationComplete: false, c.initialisationComplete.Store(false)
}
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -554,8 +553,8 @@ func TestInitialisationComplete_RaceCondition(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
c.initialisationComplete = true c.initialisationComplete.Store(true)
c.initialisationComplete = false c.initialisationComplete.Store(false)
} }
}() }()
@@ -574,7 +573,7 @@ func TestInitialisationComplete_RaceCondition(t *testing.T) {
defer wg.Done() defer wg.Done()
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
// Check the flag directly (as handleConnectionUpdateNotification does) // Check the flag directly (as handleConnectionUpdateNotification does)
if !c.initialisationComplete { if !c.initialisationComplete.Load() {
continue continue
} }
} }