diff --git a/pkg/interactive/interactive_client.go b/pkg/interactive/interactive_client.go index 6111a6dac..98bf979ed 100644 --- a/pkg/interactive/interactive_client.go +++ b/pkg/interactive/interactive_client.go @@ -55,6 +55,8 @@ type InteractiveClient struct { // NOTE: should ONLY be called by cancelActiveQueryIfAny cancelActiveQuery context.CancelFunc cancelPrompt context.CancelFunc + // mutex to protect concurrent access to cancelActiveQuery + cancelMutex sync.Mutex // channel used internally to pass the initialisation result initResultChan chan *db_common.InitResult diff --git a/pkg/interactive/interactive_client_cancel.go b/pkg/interactive/interactive_client_cancel.go index 51661596f..5ac253313 100644 --- a/pkg/interactive/interactive_client_cancel.go +++ b/pkg/interactive/interactive_client_cancel.go @@ -18,11 +18,16 @@ func (c *InteractiveClient) createPromptContext(parentContext context.Context) c func (c *InteractiveClient) createQueryContext(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) + c.cancelMutex.Lock() c.cancelActiveQuery = cancel + c.cancelMutex.Unlock() return ctx } func (c *InteractiveClient) cancelActiveQueryIfAny() { + c.cancelMutex.Lock() + defer c.cancelMutex.Unlock() + if c.cancelActiveQuery != nil { log.Println("[INFO] cancelActiveQueryIfAny CALLING cancelActiveQuery") c.cancelActiveQuery()