// Copyright (c) The OpenTofu Authors // SPDX-License-Identifier: MPL-2.0 // Copyright (c) 2023 HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package remote import ( "context" "log" "sync" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" regaddr "github.com/opentofu/registry-address/v2" "github.com/zclconf/go-cty/cty" "github.com/opentofu/opentofu/internal/addrs" "github.com/opentofu/opentofu/internal/encryption" "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/states/statefile" "github.com/opentofu/opentofu/internal/states/statemgr" "github.com/opentofu/opentofu/version" ) func TestState_impl(t *testing.T) { var _ statemgr.Reader = new(State) var _ statemgr.Writer = new(State) var _ statemgr.Persister = new(State) var _ statemgr.Refresher = new(State) var _ statemgr.OutputReader = new(State) var _ statemgr.Locker = new(State) } func TestStateRace(t *testing.T) { s := NewState(nilClient{}, encryption.StateEncryptionDisabled()) current := states.NewState() var wg sync.WaitGroup for i := 0; i < 100; i++ { wg.Add(1) go func() { defer wg.Done() if err := s.WriteState(current); err != nil { panic(err) } if err := s.PersistState(t.Context(), nil); err != nil { panic(err) } if err := s.RefreshState(t.Context()); err != nil { panic(err) } }() } wg.Wait() } // testCase encapsulates a test state test type testCase struct { name string // A function to mutate state and return a cleanup function mutationFunc func(*State) (*states.State, func()) // The expected requests to have taken place expectedRequests []mockClientRequest // Mark this case as not having a request noRequest bool } // isRequested ensures a test that is specified as not having // a request doesn't have one by checking if a method exists // on the expectedRequest. func (tc testCase) isRequested(t *testing.T) bool { for _, expectedMethod := range tc.expectedRequests { hasMethod := expectedMethod.Method != "" if tc.noRequest && hasMethod { t.Fatalf("expected no content for %q but got: %v", tc.name, expectedMethod) } } return !tc.noRequest } func TestStatePersist(t *testing.T) { testCases := []testCase{ { name: "first state persistence", mutationFunc: func(mgr *State) (*states.State, func()) { mgr.state = &states.State{ Modules: map[string]*states.Module{"": {}}, } s := mgr.State() s.RootModule().SetResourceInstanceCurrent( addrs.Resource{ Mode: addrs.ManagedResourceMode, Name: "myfile", Type: "local_file", }.Instance(addrs.NoKey), &states.ResourceInstanceObjectSrc{ AttrsFlat: map[string]string{ "filename": "file.txt", }, Status: states.ObjectReady, }, addrs.AbsProviderConfig{ Provider: regaddr.Provider{Namespace: "local"}, }, addrs.NoKey, ) return s, func() {} }, expectedRequests: []mockClientRequest{ // Expect an initial refresh, which returns nothing since there is no remote state. { Method: "Get", Content: nil, }, // Expect a second refresh, since the read state is nil { Method: "Get", Content: nil, }, // Expect an initial push with values and a serial of 1 { Method: "Put", Content: map[string]interface{}{ "version": 4.0, // encoding/json decodes this as float64 by default "lineage": "some meaningless value", "serial": 1.0, // encoding/json decodes this as float64 by default "terraform_version": version.Version, "outputs": map[string]interface{}{}, "resources": []interface{}{ map[string]interface{}{ "instances": []interface{}{ map[string]interface{}{ "attributes_flat": map[string]interface{}{ "filename": "file.txt", }, "schema_version": 0.0, "sensitive_attributes": []interface{}{}, }, }, "mode": "managed", "name": "myfile", "provider": `provider["/local/"]`, "type": "local_file", }, }, "check_results": nil, }, }, }, }, // If lineage changes, expect the serial to increment { name: "change lineage", mutationFunc: func(mgr *State) (*states.State, func()) { mgr.lineage = "mock-lineage" return mgr.State(), func() {} }, expectedRequests: []mockClientRequest{ { Method: "Put", Content: map[string]interface{}{ "version": 4.0, // encoding/json decodes this as float64 by default "lineage": "mock-lineage", "serial": 2.0, // encoding/json decodes this as float64 by default "terraform_version": version.Version, "outputs": map[string]interface{}{}, "resources": []interface{}{ map[string]interface{}{ "instances": []interface{}{ map[string]interface{}{ "attributes_flat": map[string]interface{}{ "filename": "file.txt", }, "schema_version": 0.0, "sensitive_attributes": []interface{}{}, }, }, "mode": "managed", "name": "myfile", "provider": `provider["/local/"]`, "type": "local_file", }, }, "check_results": nil, }, }, }, }, // removing resources should increment the serial { name: "remove resources", mutationFunc: func(mgr *State) (*states.State, func()) { mgr.state.RootModule().Resources = map[string]*states.Resource{} return mgr.State(), func() {} }, expectedRequests: []mockClientRequest{ { Method: "Put", Content: map[string]interface{}{ "version": 4.0, // encoding/json decodes this as float64 by default "lineage": "mock-lineage", "serial": 3.0, // encoding/json decodes this as float64 by default "terraform_version": version.Version, "outputs": map[string]interface{}{}, "resources": []interface{}{}, "check_results": nil, }, }, }, }, // If the remote serial is incremented, then we increment it once more. { name: "change serial", mutationFunc: func(mgr *State) (*states.State, func()) { originalSerial := mgr.serial mgr.serial++ return mgr.State(), func() { mgr.serial = originalSerial } }, expectedRequests: []mockClientRequest{ { Method: "Put", Content: map[string]interface{}{ "version": 4.0, // encoding/json decodes this as float64 by default "lineage": "mock-lineage", "serial": 5.0, // encoding/json decodes this as float64 by default "terraform_version": version.Version, "outputs": map[string]interface{}{}, "resources": []interface{}{}, "check_results": nil, }, }, }, }, // Adding an output should cause the serial to increment as well. { name: "add output to state", mutationFunc: func(mgr *State) (*states.State, func()) { s := mgr.State() s.RootModule().SetOutputValue("foo", cty.StringVal("bar"), false, "") return s, func() {} }, expectedRequests: []mockClientRequest{ { Method: "Put", Content: map[string]interface{}{ "version": 4.0, // encoding/json decodes this as float64 by default "lineage": "mock-lineage", "serial": 4.0, // encoding/json decodes this as float64 by default "terraform_version": version.Version, "outputs": map[string]interface{}{ "foo": map[string]interface{}{ "type": "string", "value": "bar", }, }, "resources": []interface{}{}, "check_results": nil, }, }, }, }, // ...as should changing an output { name: "mutate state bar -> baz", mutationFunc: func(mgr *State) (*states.State, func()) { s := mgr.State() s.RootModule().SetOutputValue("foo", cty.StringVal("baz"), false, "") return s, func() {} }, expectedRequests: []mockClientRequest{ { Method: "Put", Content: map[string]interface{}{ "version": 4.0, // encoding/json decodes this as float64 by default "lineage": "mock-lineage", "serial": 5.0, // encoding/json decodes this as float64 by default "terraform_version": version.Version, "outputs": map[string]interface{}{ "foo": map[string]interface{}{ "type": "string", "value": "baz", }, }, "resources": []interface{}{}, "check_results": nil, }, }, }, }, { name: "nothing changed", mutationFunc: func(mgr *State) (*states.State, func()) { s := mgr.State() return s, func() {} }, noRequest: true, }, // If the remote state's serial is less (force push), then we // increment it once from there. { name: "reset serial (force push style)", mutationFunc: func(mgr *State) (*states.State, func()) { mgr.serial = 2 return mgr.State(), func() {} }, expectedRequests: []mockClientRequest{ { Method: "Put", Content: map[string]interface{}{ "version": 4.0, // encoding/json decodes this as float64 by default "lineage": "mock-lineage", "serial": 3.0, // encoding/json decodes this as float64 by default "terraform_version": version.Version, "outputs": map[string]interface{}{ "foo": map[string]interface{}{ "type": "string", "value": "baz", }, }, "resources": []interface{}{}, "check_results": nil, }, }, }, }, } // Initial setup of state just to give us a fixed starting point for our // test assertions below, or else we'd need to deal with // random lineage. mgr := NewState( &mockClient{}, encryption.StateEncryptionDisabled(), ) // In normal use (during a OpenTofu operation) we always refresh and read // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this if err := mgr.RefreshState(t.Context()); err != nil { t.Fatalf("failed to RefreshState: %s", err) } // Our client is a mockClient which has a log we // use to check that operations generate expected requests mockClient := mgr.Client.(*mockClient) // logIdx tracks the current index of the log separate from // the loop iteration so we can check operations that don't // cause any requests to be generated logIdx := 0 // Run tests in order. for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { s, cleanup := tc.mutationFunc(mgr) if err := mgr.WriteState(s); err != nil { t.Fatalf("failed to WriteState for %q: %s", tc.name, err) } if err := mgr.PersistState(t.Context(), nil); err != nil { t.Fatalf("failed to PersistState for %q: %s", tc.name, err) } if tc.isRequested(t) { // Get captured request from the mock client log // based on the index of the current test if logIdx >= len(mockClient.log) { t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) } for expectedRequestIdx := 0; expectedRequestIdx < len(tc.expectedRequests); expectedRequestIdx++ { loggedRequest := mockClient.log[logIdx] logIdx++ if diff := cmp.Diff(tc.expectedRequests[expectedRequestIdx], loggedRequest, cmpopts.IgnoreMapEntries(func(key string, value interface{}) bool { // This is required since the initial state creation causes the lineage to be a UUID that is not known at test time. return tc.name == "first state persistence" && key == "lineage" })); len(diff) > 0 { t.Logf("incorrect client requests for %q:\n%s", tc.name, diff) t.Fail() } } } cleanup() }) } logCnt := len(mockClient.log) if logIdx != logCnt { t.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx) } } func TestState_GetRootOutputValues(t *testing.T) { // Initial setup of state with outputs already defined mgr := NewState( &mockClient{ current: []byte(` { "version": 4, "lineage": "mock-lineage", "serial": 1, "terraform_version":"0.0.0", "outputs": {"foo": {"value":"bar", "type": "string"}}, "resources": [] } `), }, encryption.StateEncryptionDisabled(), ) outputs, err := mgr.GetRootOutputValues(t.Context()) if err != nil { t.Errorf("Expected GetRootOutputValues to not return an error, but it returned %v", err) } if len(outputs) != 1 { t.Errorf("Expected %d outputs, but received %d", 1, len(outputs)) } } type migrationTestCase struct { name string // A function to generate a statefile stateFile func(*State) *statefile.File // The expected request to have taken place expectedRequest mockClientRequest // Mark this case as not having a request expectedError string // force flag passed to client force bool } func TestWriteStateForMigration(t *testing.T) { mgr := NewState( &mockClient{ current: []byte(` { "version": 4, "lineage": "mock-lineage", "serial": 3, "terraform_version":"0.0.0", "outputs": {"foo": {"value":"bar", "type": "string"}}, "resources": [] } `), }, encryption.StateEncryptionDisabled(), ) testCases := []migrationTestCase{ // Refreshing state before we run the test loop causes a GET { name: "refresh state", stateFile: func(mgr *State) *statefile.File { return mgr.StateForMigration() }, expectedRequest: mockClientRequest{ Method: "Get", Content: map[string]interface{}{ "version": 4.0, "lineage": "mock-lineage", "serial": 3.0, "terraform_version": "0.0.0", "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, "resources": []interface{}{}, }, }, }, { name: "cannot import lesser serial without force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, mgr.lineage, 1) }, expectedError: "cannot import state with serial 1 over newer state with serial 3", }, { name: "cannot import differing lineage without force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, "different-lineage", mgr.serial) }, expectedError: `cannot import state with lineage "different-lineage" over unrelated state with lineage "mock-lineage"`, }, { name: "can import lesser serial with force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, mgr.lineage, 1) }, expectedRequest: mockClientRequest{ Method: "Put", Content: map[string]interface{}{ "version": 4.0, "lineage": "mock-lineage", "serial": 2.0, "terraform_version": version.Version, "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, "resources": []interface{}{}, "check_results": nil, }, }, force: true, }, { name: "cannot import differing lineage without force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, "different-lineage", mgr.serial) }, expectedRequest: mockClientRequest{ Method: "Put", Content: map[string]interface{}{ "version": 4.0, "lineage": "different-lineage", "serial": 3.0, "terraform_version": version.Version, "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, "resources": []interface{}{}, "check_results": nil, }, }, force: true, }, } // In normal use (during a OpenTofu operation) we always refresh and read // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this if err := mgr.RefreshState(t.Context()); err != nil { t.Fatalf("failed to RefreshState: %s", err) } if err := mgr.WriteState(mgr.State()); err != nil { t.Fatalf("failed to write initial state: %s", err) } // Our client is a mockClient which has a log we // use to check that operations generate expected requests mockClient := mgr.Client.(*mockClient) // logIdx tracks the current index of the log separate from // the loop iteration so we can check operations that don't // cause any requests to be generated logIdx := 0 for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { sf := tc.stateFile(mgr) err := mgr.WriteStateForMigration(sf, tc.force) shouldError := tc.expectedError != "" // If we are expecting and error check it and move on if shouldError { if err == nil { t.Fatalf("test case %q should have failed with error %q", tc.name, tc.expectedError) } else if err.Error() != tc.expectedError { t.Fatalf("test case %q expected error %q but got %q", tc.name, tc.expectedError, err) } return } if err != nil { t.Fatalf("test case %q failed: %v", tc.name, err) } // At this point we should just do a normal write and persist // as would happen from the CLI if err := mgr.WriteState(mgr.State()); err != nil { t.Fatal(err) } if err := mgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } if logIdx >= len(mockClient.log) { t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) } loggedRequest := mockClient.log[logIdx] logIdx++ if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 { t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff) } }) } logCnt := len(mockClient.log) if logIdx != logCnt { log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx) } } // This test runs the same test cases as above, but with // a client that implements EnableForcePush -- this allows // us to test that -force continues to work for backends without // this interface, but that this interface works for those that do. func TestWriteStateForMigrationWithForcePushClient(t *testing.T) { mgr := NewState( &mockClientForcePusher{ current: []byte(` { "version": 4, "lineage": "mock-lineage", "serial": 3, "terraform_version":"0.0.0", "outputs": {"foo": {"value":"bar", "type": "string"}}, "resources": [] } `), }, encryption.StateEncryptionDisabled(), ) testCases := []migrationTestCase{ // Refreshing state before we run the test loop causes a GET { name: "refresh state", stateFile: func(mgr *State) *statefile.File { return mgr.StateForMigration() }, expectedRequest: mockClientRequest{ Method: "Get", Content: map[string]interface{}{ "version": 4.0, "lineage": "mock-lineage", "serial": 3.0, "terraform_version": "0.0.0", "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, "resources": []interface{}{}, }, }, }, { name: "cannot import lesser serial without force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, mgr.lineage, 1) }, expectedError: "cannot import state with serial 1 over newer state with serial 3", }, { name: "cannot import differing lineage without force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, "different-lineage", mgr.serial) }, expectedError: `cannot import state with lineage "different-lineage" over unrelated state with lineage "mock-lineage"`, }, { name: "can import lesser serial with force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, mgr.lineage, 1) }, expectedRequest: mockClientRequest{ Method: "Force Put", Content: map[string]interface{}{ "version": 4.0, "lineage": "mock-lineage", "serial": 2.0, "terraform_version": version.Version, "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, "resources": []interface{}{}, "check_results": nil, }, }, force: true, }, { name: "cannot import differing lineage without force", stateFile: func(mgr *State) *statefile.File { return statefile.New(mgr.state, "different-lineage", mgr.serial) }, expectedRequest: mockClientRequest{ Method: "Force Put", Content: map[string]interface{}{ "version": 4.0, "lineage": "different-lineage", "serial": 3.0, "terraform_version": version.Version, "outputs": map[string]interface{}{"foo": map[string]interface{}{"type": string("string"), "value": string("bar")}}, "resources": []interface{}{}, "check_results": nil, }, }, force: true, }, } // In normal use (during a OpenTofu operation) we always refresh and read // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this if err := mgr.RefreshState(t.Context()); err != nil { t.Fatalf("failed to RefreshState: %s", err) } if err := mgr.WriteState(mgr.State()); err != nil { t.Fatalf("failed to write initial state: %s", err) } // Our client is a mockClientForcePusher which has a log we // use to check that operations generate expected requests mockClient := mgr.Client.(*mockClientForcePusher) if mockClient.force { t.Fatalf("client should not default to force") } // logIdx tracks the current index of the log separate from // the loop iteration so we can check operations that don't // cause any requests to be generated logIdx := 0 for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Always reset client to not be force pushing mockClient.force = false sf := tc.stateFile(mgr) err := mgr.WriteStateForMigration(sf, tc.force) shouldError := tc.expectedError != "" // If we are expecting and error check it and move on if shouldError { if err == nil { t.Fatalf("test case %q should have failed with error %q", tc.name, tc.expectedError) } else if err.Error() != tc.expectedError { t.Fatalf("test case %q expected error %q but got %q", tc.name, tc.expectedError, err) } return } if err != nil { t.Fatalf("test case %q failed: %v", tc.name, err) } if tc.force && !mockClient.force { t.Fatalf("test case %q should have enabled force push", tc.name) } // At this point we should just do a normal write and persist // as would happen from the CLI if err := mgr.WriteState(mgr.State()); err != nil { t.Fatal(err) } if err := mgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } if logIdx >= len(mockClient.log) { t.Fatalf("request lock and index are out of sync on %q: idx=%d len=%d", tc.name, logIdx, len(mockClient.log)) } loggedRequest := mockClient.log[logIdx] logIdx++ if diff := cmp.Diff(tc.expectedRequest, loggedRequest); len(diff) > 0 { t.Fatalf("incorrect client requests for %q:\n%s", tc.name, diff) } }) } logCnt := len(mockClient.log) if logIdx != logCnt { log.Fatalf("not all requests were read. Expected logIdx to be %d but got %d", logCnt, logIdx) } } // mockOptionalClientLocker is a mock implementation of a client that supports optional locking. type mockOptionalClientLocker struct { *mockClient // Embedded mock client that simulates basic client behavior. lockingEnabled bool // A flag indicating whether locking is enabled or disabled. } type mockClientLocker struct { *mockClient // Embedded mock client that simulates basic client behavior. } // Implement the mock Lock method for mockOptionalClientLocker func (c *mockOptionalClientLocker) Lock(_ context.Context, _ *statemgr.LockInfo) (string, error) { return "", nil } // Implement the mock Unlock method for mockOptionalClientLocker func (c *mockOptionalClientLocker) Unlock(_ context.Context, _ string) error { // Provide a simple implementation return nil } // Implement the mock IsLockingEnabled method for mockOptionalClientLocker func (c *mockOptionalClientLocker) IsLockingEnabled() bool { return c.lockingEnabled } // Implement the mock Lock method for mockClientLocker func (c *mockClientLocker) Lock(_ context.Context, _ *statemgr.LockInfo) (string, error) { return "", nil } // Implement the mock Unlock method for mockClientLocker func (c *mockClientLocker) Unlock(_ context.Context, _ string) error { return nil } // Check for interface compliance var _ OptionalClientLocker = &mockOptionalClientLocker{} var _ ClientLocker = &mockClientLocker{} // Tests whether the IsLockingEnabled method returns the expected values based on the backend. func TestState_IsLockingEnabled(t *testing.T) { tests := []struct { name string disableLocks bool client Client wantResult bool }{ { name: "disableLocks is true", disableLocks: true, client: &mockClient{}, wantResult: false, }, { name: "OptionalClientLocker with IsLockingEnabled() == true", disableLocks: false, client: &mockOptionalClientLocker{ mockClient: &mockClient{}, lockingEnabled: true, }, wantResult: true, }, { name: "OptionalClientLocker with IsLockingEnabled() == false", disableLocks: false, client: &mockOptionalClientLocker{ mockClient: &mockClient{}, lockingEnabled: false, }, wantResult: false, }, { name: "ClientLocker without OptionalClientLocker", disableLocks: false, client: &mockClientLocker{ mockClient: &mockClient{}, }, wantResult: true, }, { name: "Client without any locking", disableLocks: false, client: &mockClient{}, wantResult: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewState(tt.client, encryption.StateEncryptionDisabled()) s.disableLocks = tt.disableLocks gotResult := s.IsLockingEnabled() if gotResult != tt.wantResult { t.Errorf("IsLockingEnabled() = %v; want %v", gotResult, tt.wantResult) } }) } }