diff --git a/pkg/pluginmanager/state.go b/pkg/pluginmanager/state.go index 745805669..b321d66df 100644 --- a/pkg/pluginmanager/state.go +++ b/pkg/pluginmanager/state.go @@ -4,7 +4,9 @@ import ( "encoding/json" "log" "os" + "path/filepath" "strings" + "sync" "syscall" "github.com/hashicorp/go-plugin" @@ -16,6 +18,9 @@ import ( const PluginManagerStructVersion = 20220411 +// stateMutex protects concurrent writes to the state file +var stateMutex sync.Mutex + type State struct { Protocol plugin.Protocol `json:"protocol"` ProtocolVersion int `json:"protocol_version"` @@ -75,6 +80,10 @@ func LoadState() (*State, error) { } func (s *State) Save() error { + // Protect concurrent writes with a mutex + stateMutex.Lock() + defer stateMutex.Unlock() + // set struct version s.StructVersion = PluginManagerStructVersion @@ -82,7 +91,26 @@ func (s *State) Save() error { if err != nil { return err } - return os.WriteFile(filepaths.PluginManagerStateFilePath(), content, 0644) + + // Use atomic write to prevent file corruption from concurrent writes + // Write to a temporary file first, then atomically rename it + stateFilePath := filepaths.PluginManagerStateFilePath() + + // Ensure the directory exists + if err := os.MkdirAll(filepath.Dir(stateFilePath), 0755); err != nil { + return err + } + + tempFile := stateFilePath + ".tmp" + + // Write to temporary file + if err := os.WriteFile(tempFile, content, 0644); err != nil { + return err + } + + // Atomically rename the temp file to the final location + // This ensures that the state file is never partially written + return os.Rename(tempFile, stateFilePath) } func (s *State) reattachConfig() *plugin.ReattachConfig { diff --git a/pkg/pluginmanager/state_test.go b/pkg/pluginmanager/state_test.go index 877174699..618c3c0ac 100644 --- a/pkg/pluginmanager/state_test.go +++ b/pkg/pluginmanager/state_test.go @@ -1,9 +1,17 @@ package pluginmanager import ( + "encoding/json" + "net" + "os" + "path/filepath" + "sync" "testing" "github.com/hashicorp/go-plugin" + "github.com/turbot/pipe-fittings/v2/app_specific" + "github.com/turbot/steampipe/v2/pkg/filepaths" + pb "github.com/turbot/steampipe/v2/pkg/pluginmanager_service/grpc/proto" ) // TestStateWithNilAddr tests that reattachConfig handles nil Addr gracefully @@ -25,3 +33,80 @@ func TestStateWithNilAddr(t *testing.T) { t.Error("Expected nil reattach config when Addr is nil") } } + +func TestStateFileRaceCondition(t *testing.T) { + // This test demonstrates the race condition in State.Save() + // When multiple goroutines call Save() concurrently, they can corrupt the JSON file + + // Setup: Create a temporary directory for testing + tempDir, err := os.MkdirTemp("", "steampipe-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) + + // Initialize app_specific.InstallDir for the test + app_specific.InstallDir = filepath.Join(tempDir, ".steampipe") + + // Create multiple states with different data + concurrency := 50 + iterations := 20 + var wg sync.WaitGroup + wg.Add(concurrency) + + // Channel to collect errors from goroutines + errors := make(chan error, concurrency*iterations) + + // Launch concurrent Save() operations to the same file + for i := 0; i < concurrency; i++ { + go func(id int) { + defer wg.Done() + + // Create a new state with unique data + addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 8080 + id} + reattach := &plugin.ReattachConfig{ + Protocol: plugin.ProtocolGRPC, + ProtocolVersion: 1, + Addr: pb.NewSimpleAddr(addr), + Pid: 1000 + id, + } + + state := NewState("/test/executable", reattach) + + // Perform multiple saves to increase race window + for j := 0; j < iterations; j++ { + if err := state.Save(); err != nil { + errors <- err + } + } + }(i) + } + + wg.Wait() + close(errors) + + // Check for any errors during save + for err := range errors { + t.Errorf("Failed to save state: %v", err) + } + + // Verify that the state file is valid JSON + stateFilePath := filepaths.PluginManagerStateFilePath() + content, err := os.ReadFile(stateFilePath) + if err != nil { + t.Fatalf("Failed to read state file: %v", err) + } + + // The main test: Can we unmarshal the file without error? + var state State + err = json.Unmarshal(content, &state) + if err != nil { + t.Fatalf("State file is corrupted (invalid JSON): %v\nContent: %s", err, string(content)) + } + + // Additional validation: ensure required fields are present + if state.StructVersion != PluginManagerStructVersion { + t.Errorf("State file missing or has incorrect struct version: got %d, want %d", + state.StructVersion, PluginManagerStructVersion) + } +}