// Copyright (c) The OpenTofu Authors // SPDX-License-Identifier: MPL-2.0 // Copyright (c) 2023 HashiCorp, Inc. // SPDX-License-Identifier: MPL-2.0 package local import ( "context" "fmt" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/opentofu/opentofu/internal/addrs" "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/states/statemgr" "github.com/opentofu/opentofu/internal/tofu" "github.com/zclconf/go-cty/cty" ) func TestStateHook_impl(t *testing.T) { var _ tofu.Hook = new(StateHook) } func stateHookExpected() *states.State { expected := states.NewState() expected.RootModule().SetOutputValue("sensitive_output", cty.StringVal("it's a secret"), true, "") return expected } func stateHookMutator(state *states.SyncState) { state.SetOutputValue(addrs.AbsOutputValue{OutputValue: addrs.OutputValue{Name: "sensitive_output"}}, cty.StringVal("it's a secret"), true, "") } func TestStateHook(t *testing.T) { is := statemgr.NewTransientInMemory(nil) var hook tofu.Hook = &StateHook{ StateMgr: is, } action, err := hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("err: %s", err) } if action != tofu.HookActionContinue { t.Fatalf("bad: %v", action) } if !is.State().Equal(stateHookExpected()) { t.Fatalf("bad state: %#v", is.State()) } } func TestStateHookStopping(t *testing.T) { is := &testPersistentState{} hook := &StateHook{ StateMgr: is, Schemas: &tofu.Schemas{}, PersistInterval: 4 * time.Hour, intermediatePersist: IntermediateStatePersistInfo{ LastPersist: time.Now(), }, } s := stateHookExpected() action, err := hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if got, want := action, tofu.HookActionContinue; got != want { t.Fatalf("wrong hookaction %#v; want %#v", got, want) } if is.Written == nil || !is.Written.Equal(s) { t.Fatalf("mismatching state written") } if is.Persisted != nil { t.Fatalf("persisted too soon") } // We'll now force lastPersist to be long enough ago that persisting // should be due on the next call. hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour) _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Written == nil || !is.Written.Equal(s) { t.Fatalf("mismatching state written") } if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Written == nil || !is.Written.Equal(s) { t.Fatalf("mismatching state written") } if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } gotLog := is.CallLog wantLog := []string{ // Initial call before we reset lastPersist "MutateState", // Write and then persist after we reset lastPersist "MutateState", "PersistState", // Final call when persisting wasn't due yet. "MutateState", } if diff := cmp.Diff(wantLog, gotLog); diff != "" { t.Fatalf("wrong call log so far\n%s", diff) } // We'll reset the log now before we try seeing what happens after // we use "Stopped". is.CallLog = is.CallLog[:0] is.Persisted = nil hook.Stopping() if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } is.Persisted = nil _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } is.Persisted = nil _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } gotLog = is.CallLog wantLog = []string{ // "Stopping" immediately persisted "PersistState", // PostStateUpdate then writes and persists on every call, // on the assumption that we're now bailing out after // being cancelled and trying to save as much state as we can. "MutateState", "PersistState", "MutateState", "PersistState", } if diff := cmp.Diff(wantLog, gotLog); diff != "" { t.Fatalf("wrong call log once in stopping mode\n%s", diff) } } func TestStateHookCustomPersistRule(t *testing.T) { is := &testPersistentStateThatRefusesToPersist{} hook := &StateHook{ StateMgr: is, Schemas: &tofu.Schemas{}, PersistInterval: 4 * time.Hour, intermediatePersist: IntermediateStatePersistInfo{ LastPersist: time.Now(), }, } s := stateHookExpected() action, err := hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if got, want := action, tofu.HookActionContinue; got != want { t.Fatalf("wrong hookaction %#v; want %#v", got, want) } if is.Written == nil || !is.Written.Equal(s) { t.Fatalf("mismatching state written") } if is.Persisted != nil { t.Fatalf("persisted too soon") } // We'll now force lastPersist to be long enough ago that persisting // should be due on the next call. hook.intermediatePersist.LastPersist = time.Now().Add(-5 * time.Hour) _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Written == nil || !is.Written.Equal(s) { t.Fatalf("mismatching state written") } if is.Persisted != nil { t.Fatalf("has a persisted state, but shouldn't") } _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Written == nil || !is.Written.Equal(s) { t.Fatalf("mismatching state written") } if is.Persisted != nil { t.Fatalf("has a persisted state, but shouldn't") } gotLog := is.CallLog wantLog := []string{ // Initial call before we reset lastPersist "MutateState", "ShouldPersistIntermediateState", // Previous call should return false, preventing a "PersistState" call // Write and then decline to persist "MutateState", "ShouldPersistIntermediateState", // Previous call should return false, preventing a "PersistState" call // Final call before we start "stopping". "MutateState", "ShouldPersistIntermediateState", // Previous call should return false, preventing a "PersistState" call } if diff := cmp.Diff(wantLog, gotLog); diff != "" { t.Fatalf("wrong call log so far\n%s", diff) } // We'll reset the log now before we try seeing what happens after // we use "Stopped". is.CallLog = is.CallLog[:0] is.Persisted = nil hook.Stopping() if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } is.Persisted = nil _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } is.Persisted = nil _, err = hook.PostStateUpdate(stateHookMutator) if err != nil { t.Fatalf("unexpected error from PostStateUpdate: %s", err) } if is.Persisted == nil || !is.Persisted.Equal(s) { t.Fatalf("mismatching state persisted") } gotLog = is.CallLog wantLog = []string{ "ShouldPersistIntermediateState", // Previous call should return true, allowing the following "PersistState" call "PersistState", "MutateState", "ShouldPersistIntermediateState", // Previous call should return true, allowing the following "PersistState" call "PersistState", "MutateState", "ShouldPersistIntermediateState", // Previous call should return true, allowing the following "PersistState" call "PersistState", } if diff := cmp.Diff(wantLog, gotLog); diff != "" { t.Fatalf("wrong call log once in stopping mode\n%s", diff) } } type testPersistentState struct { CallLog []string Written *states.State Persisted *states.State } var _ statemgr.Writer = (*testPersistentState)(nil) var _ statemgr.Persister = (*testPersistentState)(nil) func (sm *testPersistentState) WriteState(state *states.State) error { sm.CallLog = append(sm.CallLog, "WriteState") sm.Written = state return nil } func (sm *testPersistentState) MutateState(fn func(*states.State) *states.State) error { sm.CallLog = append(sm.CallLog, "MutateState") sm.Written = fn(sm.Written) return nil } func (sm *testPersistentState) PersistState(_ context.Context, schemas *tofu.Schemas) error { if schemas == nil { return fmt.Errorf("no schemas") } sm.CallLog = append(sm.CallLog, "PersistState") sm.Persisted = sm.Written return nil } type testPersistentStateThatRefusesToPersist struct { CallLog []string Written *states.State Persisted *states.State } var _ statemgr.Writer = (*testPersistentStateThatRefusesToPersist)(nil) var _ statemgr.Persister = (*testPersistentStateThatRefusesToPersist)(nil) var _ IntermediateStateConditionalPersister = (*testPersistentStateThatRefusesToPersist)(nil) func (sm *testPersistentStateThatRefusesToPersist) WriteState(state *states.State) error { sm.CallLog = append(sm.CallLog, "WriteState") sm.Written = state return nil } func (sm *testPersistentStateThatRefusesToPersist) MutateState(fn func(*states.State) *states.State) error { sm.CallLog = append(sm.CallLog, "MutateState") sm.Written = fn(sm.Written) return nil } func (sm *testPersistentStateThatRefusesToPersist) PersistState(_ context.Context, schemas *tofu.Schemas) error { if schemas == nil { return fmt.Errorf("no schemas") } sm.CallLog = append(sm.CallLog, "PersistState") sm.Persisted = sm.Written return nil } // ShouldPersistIntermediateState implements IntermediateStateConditionalPersister func (sm *testPersistentStateThatRefusesToPersist) ShouldPersistIntermediateState(info *IntermediateStatePersistInfo) bool { sm.CallLog = append(sm.CallLog, "ShouldPersistIntermediateState") return info.ForcePersist }