diff --git a/internal/backend/local/backend.go b/internal/backend/local/backend.go index e4901000be..f83f151591 100644 --- a/internal/backend/local/backend.go +++ b/internal/backend/local/backend.go @@ -365,7 +365,7 @@ func (b *Local) opWait( // try to force a PersistState just in case the process is terminated // before we can complete. - if err := opStateMgr.PersistState(nil); err != nil { + if err := opStateMgr.PersistState(context.TODO(), nil); err != nil { // We can't error out from here, but warn the user if there was an error. // If this isn't transient, we will catch it again below, and // attempt to save the state another way. diff --git a/internal/backend/local/backend_apply.go b/internal/backend/local/backend_apply.go index e6f22e518a..9d74e87122 100644 --- a/internal/backend/local/backend_apply.go +++ b/internal/backend/local/backend_apply.go @@ -288,7 +288,7 @@ func (b *Local) opApply( // Store the final state runningOp.State = applyState - err := statemgr.WriteAndPersist(opState, applyState, schemas) + err := statemgr.WriteAndPersist(context.TODO(), opState, applyState, schemas) if err != nil { // Export the state file from the state manager and assign the new // state. This is needed to preserve the existing serial and lineage. diff --git a/internal/backend/local/backend_local.go b/internal/backend/local/backend_local.go index 41f09f8893..73793703c4 100644 --- a/internal/backend/local/backend_local.go +++ b/internal/backend/local/backend_local.go @@ -68,7 +68,7 @@ func (b *Local) localRun(ctx context.Context, op *backend.Operation) (*backend.L }() log.Printf("[TRACE] backend/local: reading remote state for workspace %q", op.Workspace) - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf("error loading state: %w", err)) return nil, nil, nil, diags } diff --git a/internal/backend/local/backend_local_test.go b/internal/backend/local/backend_local_test.go index 62d58a28c2..43021d39dc 100644 --- a/internal/backend/local/backend_local_test.go +++ b/internal/backend/local/backend_local_test.go @@ -142,7 +142,7 @@ func TestLocalRun_stalePlan(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := sm.RefreshState(); err != nil { + if err := sm.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error refreshing state: %s", err) } @@ -242,7 +242,7 @@ type stateStorageThatFailsRefresh struct { var _ statemgr.Full = (*stateStorageThatFailsRefresh)(nil) -func (s *stateStorageThatFailsRefresh) Lock(info *statemgr.LockInfo) (string, error) { +func (s *stateStorageThatFailsRefresh) Lock(_ context.Context, info *statemgr.LockInfo) (string, error) { if s.locked { return "", fmt.Errorf("already locked") } @@ -250,7 +250,7 @@ func (s *stateStorageThatFailsRefresh) Lock(info *statemgr.LockInfo) (string, er return "locked", nil } -func (s *stateStorageThatFailsRefresh) Unlock(id string) error { +func (s *stateStorageThatFailsRefresh) Unlock(_ context.Context, id string) error { if !s.locked { return fmt.Errorf("not locked") } @@ -262,7 +262,7 @@ func (s *stateStorageThatFailsRefresh) State() *states.State { return nil } -func (s *stateStorageThatFailsRefresh) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (s *stateStorageThatFailsRefresh) GetRootOutputValues(_ context.Context) (map[string]*states.OutputValue, error) { return nil, fmt.Errorf("unimplemented") } @@ -270,10 +270,10 @@ func (s *stateStorageThatFailsRefresh) WriteState(*states.State) error { return fmt.Errorf("unimplemented") } -func (s *stateStorageThatFailsRefresh) RefreshState() error { +func (s *stateStorageThatFailsRefresh) RefreshState(_ context.Context) error { return fmt.Errorf("intentionally failing for testing purposes") } -func (s *stateStorageThatFailsRefresh) PersistState(schemas *tofu.Schemas) error { +func (s *stateStorageThatFailsRefresh) PersistState(_ context.Context, schemas *tofu.Schemas) error { return fmt.Errorf("unimplemented") } diff --git a/internal/backend/local/backend_refresh.go b/internal/backend/local/backend_refresh.go index d3047738fd..786b5ef9c7 100644 --- a/internal/backend/local/backend_refresh.go +++ b/internal/backend/local/backend_refresh.go @@ -119,7 +119,7 @@ func (b *Local) opRefresh( return } - err := statemgr.WriteAndPersist(opState, newState, schemas) + err := statemgr.WriteAndPersist(context.TODO(), opState, newState, schemas) if err != nil { diags = diags.Append(fmt.Errorf("failed to write state: %w", err)) op.ReportResult(runningOp, diags) diff --git a/internal/backend/local/hook_state.go b/internal/backend/local/hook_state.go index 8c210fa164..e94fc86867 100644 --- a/internal/backend/local/hook_state.go +++ b/internal/backend/local/hook_state.go @@ -6,6 +6,7 @@ package local import ( + "context" "log" "sync" "time" @@ -81,7 +82,7 @@ func (h *StateHook) PostStateUpdate(new *states.State) (tofu.HookAction, error) } if mgrPersist, ok := h.StateMgr.(statemgr.Persister); ok && h.PersistInterval != 0 && h.Schemas != nil { if h.shouldPersist() { - err := mgrPersist.PersistState(h.Schemas) + err := mgrPersist.PersistState(context.TODO(), h.Schemas) if err != nil { return tofu.HookActionHalt, err } @@ -115,7 +116,7 @@ func (h *StateHook) Stopping() { h.intermediatePersist.ForcePersist = true if h.shouldPersist() { - err := mgrPersist.PersistState(h.Schemas) + err := mgrPersist.PersistState(context.TODO(), h.Schemas) if err != nil { // This hook can't affect OpenTofu Core's ongoing behavior, // but it's a best effort thing anyway, so we'll just emit a diff --git a/internal/backend/local/hook_state_test.go b/internal/backend/local/hook_state_test.go index 3ae87332a9..69d4f53ac3 100644 --- a/internal/backend/local/hook_state_test.go +++ b/internal/backend/local/hook_state_test.go @@ -6,6 +6,7 @@ package local import ( + "context" "fmt" "testing" "time" @@ -281,7 +282,7 @@ func (sm *testPersistentState) WriteState(state *states.State) error { return nil } -func (sm *testPersistentState) PersistState(schemas *tofu.Schemas) error { +func (sm *testPersistentState) PersistState(_ context.Context, schemas *tofu.Schemas) error { if schemas == nil { return fmt.Errorf("no schemas") } @@ -307,7 +308,7 @@ func (sm *testPersistentStateThatRefusesToPersist) WriteState(state *states.Stat return nil } -func (sm *testPersistentStateThatRefusesToPersist) PersistState(schemas *tofu.Schemas) error { +func (sm *testPersistentStateThatRefusesToPersist) PersistState(_ context.Context, schemas *tofu.Schemas) error { if schemas == nil { return fmt.Errorf("no schemas") } diff --git a/internal/backend/local/testing.go b/internal/backend/local/testing.go index a5cf30604b..91c6fe9e7c 100644 --- a/internal/backend/local/testing.go +++ b/internal/backend/local/testing.go @@ -184,7 +184,7 @@ func (b *TestLocalNoDefaultState) StateMgr(ctx context.Context, name string) (st func testStateFile(t *testing.T, path string, s *states.State) { t.Helper() - if err := statemgr.WriteAndPersist(statemgr.NewFilesystem(path, encryption.StateEncryptionDisabled()), s, nil); err != nil { + if err := statemgr.WriteAndPersist(t.Context(), statemgr.NewFilesystem(path, encryption.StateEncryptionDisabled()), s, nil); err != nil { t.Fatal(err) } } @@ -211,7 +211,7 @@ func mustResourceInstanceAddr(s string) addrs.AbsResourceInstance { func assertBackendStateUnlocked(t *testing.T, b *Local) bool { t.Helper() stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Errorf("state is already locked: %s", err.Error()) // lock was obtained return false @@ -226,7 +226,7 @@ func assertBackendStateUnlocked(t *testing.T, b *Local) bool { func assertBackendStateLocked(t *testing.T, b *Local) bool { t.Helper() stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { // lock was not obtained return true } diff --git a/internal/backend/remote-state/azure/backend_state.go b/internal/backend/remote-state/azure/backend_state.go index 10031dda8d..11ba1ab319 100644 --- a/internal/backend/remote-state/azure/backend_state.go +++ b/internal/backend/remote-state/azure/backend_state.go @@ -92,7 +92,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err stateMgr := remote.NewState(client, b.encryption) // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { return nil, err } //if this isn't the default state name, we need to create the object so @@ -101,21 +101,21 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err // take a lock on this state while we write it lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockId, err := client.Lock(lockInfo) + lockId, err := client.Lock(context.TODO(), lockInfo) if err != nil { return nil, fmt.Errorf("failed to lock azure state: %w", err) } // Local helper function so we can call it multiple places lockUnlock := func(parent error) error { - if err := stateMgr.Unlock(lockId); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockId); err != nil { return fmt.Errorf(strings.TrimSpace(errStateUnlock), lockId, err) } return parent } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { err = lockUnlock(err) return nil, err } @@ -127,7 +127,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(context.TODO(), nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/azure/client.go b/internal/backend/remote-state/azure/client.go index 1f5baa9928..5dfa04493d 100644 --- a/internal/backend/remote-state/azure/client.go +++ b/internal/backend/remote-state/azure/client.go @@ -36,9 +36,9 @@ type RemoteClient struct { timeoutSeconds int } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(ctx context.Context) (*remote.Payload, error) { // Get should time out after the timeoutSeconds - ctx, ctxCancel := c.getContextWithTimeout() + ctx, ctxCancel := c.getContextWithTimeout(ctx) defer ctxCancel() blob, err := c.giovanniBlobClient.Get(ctx, c.accountName, c.containerName, c.keyName, blobs.GetInput{LeaseID: c.leaseID}) if err != nil { @@ -60,8 +60,7 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { return payload, nil } -func (c *RemoteClient) Put(data []byte) error { - ctx := context.TODO() +func (c *RemoteClient) Put(ctx context.Context, data []byte) error { if c.snapshot { snapshotInput := blobs.SnapshotInput{LeaseID: c.leaseID} log.Printf("[DEBUG] Snapshotting existing Blob %q (Container %q / Account %q)", c.keyName, c.containerName, c.accountName) @@ -72,7 +71,7 @@ func (c *RemoteClient) Put(data []byte) error { log.Print("[DEBUG] Created blob snapshot") } - properties, err := c.getBlobProperties() + properties, err := c.getBlobProperties(ctx) if err != nil { if properties.StatusCode != http.StatusNotFound { return err @@ -91,8 +90,7 @@ func (c *RemoteClient) Put(data []byte) error { return err } -func (c *RemoteClient) Delete() error { - ctx := context.TODO() +func (c *RemoteClient) Delete(ctx context.Context) error { resp, err := c.giovanniBlobClient.Delete(ctx, c.accountName, c.containerName, c.keyName, blobs.DeleteInput{LeaseID: c.leaseID}) if err != nil { if !resp.IsHTTPStatus(http.StatusNotFound) { @@ -102,7 +100,7 @@ func (c *RemoteClient) Delete() error { return nil } -func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *RemoteClient) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { stateName := fmt.Sprintf("%s/%s", c.containerName, c.keyName) info.Path = stateName @@ -116,7 +114,7 @@ func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { } getLockInfoErr := func(err error) error { - lockInfo, infoErr := c.getLockInfo() + lockInfo, infoErr := c.getLockInfo(ctx) if infoErr != nil { err = multierror.Append(err, infoErr) } @@ -131,10 +129,9 @@ func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { ProposedLeaseID: &info.ID, LeaseDuration: -1, } - ctx := context.TODO() // obtain properties to see if the blob lease is already in use. If the blob doesn't exist, create it - properties, err := c.getBlobProperties() + properties, err := c.getBlobProperties(ctx) if err != nil { // error if we had issues getting the blob if !properties.Response.IsHTTPStatus(http.StatusNotFound) { @@ -165,15 +162,15 @@ func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { info.ID = leaseID.LeaseID c.setLeaseID(leaseID.LeaseID) - if err := c.writeLockInfo(info); err != nil { + if err := c.writeLockInfo(ctx, info); err != nil { return "", err } return info.ID, nil } -func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) { - properties, err := c.getBlobProperties() +func (c *RemoteClient) getLockInfo(ctx context.Context) (*statemgr.LockInfo, error) { + properties, err := c.getBlobProperties(ctx) if err != nil { return nil, err } @@ -198,9 +195,8 @@ func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) { } // writes info to blob meta data, deletes metadata entry if info is nil -func (c *RemoteClient) writeLockInfo(info *statemgr.LockInfo) error { - ctx := context.TODO() - properties, err := c.getBlobProperties() +func (c *RemoteClient) writeLockInfo(ctx context.Context, info *statemgr.LockInfo) error { + properties, err := c.getBlobProperties(ctx) if err != nil { return err } @@ -221,10 +217,10 @@ func (c *RemoteClient) writeLockInfo(info *statemgr.LockInfo) error { return err } -func (c *RemoteClient) Unlock(id string) error { +func (c *RemoteClient) Unlock(ctx context.Context, id string) error { lockErr := &statemgr.LockError{} - lockInfo, err := c.getLockInfo() + lockInfo, err := c.getLockInfo(ctx) if err != nil { lockErr.Err = fmt.Errorf("failed to retrieve lock info: %w", err) return lockErr @@ -237,12 +233,11 @@ func (c *RemoteClient) Unlock(id string) error { } c.setLeaseID(lockInfo.ID) - if err := c.writeLockInfo(nil); err != nil { + if err := c.writeLockInfo(ctx, nil); err != nil { lockErr.Err = fmt.Errorf("failed to delete lock info from metadata: %w", err) return lockErr } - ctx := context.TODO() _, err = c.giovanniBlobClient.ReleaseLease(ctx, c.accountName, c.containerName, c.keyName, id) if err != nil { lockErr.Err = err @@ -255,15 +250,15 @@ func (c *RemoteClient) Unlock(id string) error { } // getBlobProperties wraps the GetProperties method of the giovanniBlobClient with timeout -func (c *RemoteClient) getBlobProperties() (blobs.GetPropertiesResult, error) { - ctx, ctxCancel := c.getContextWithTimeout() +func (c *RemoteClient) getBlobProperties(ctx context.Context) (blobs.GetPropertiesResult, error) { + ctx, ctxCancel := c.getContextWithTimeout(ctx) defer ctxCancel() return c.giovanniBlobClient.GetProperties(ctx, c.accountName, c.containerName, c.keyName, blobs.GetPropertiesInput{LeaseID: c.leaseID}) } // getContextWithTimeout returns a context with timeout based on the timeoutSeconds -func (c *RemoteClient) getContextWithTimeout() (context.Context, context.CancelFunc) { - return context.WithTimeout(context.Background(), time.Duration(c.timeoutSeconds)*time.Second) +func (c *RemoteClient) getContextWithTimeout(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, time.Duration(c.timeoutSeconds)*time.Second) } // setLeaseID takes a string leaseID and sets the leaseID field of the RemoteClient diff --git a/internal/backend/remote-state/azure/client_test.go b/internal/backend/remote-state/azure/client_test.go index 7f6652c6d2..167a483e23 100644 --- a/internal/backend/remote-state/azure/client_test.go +++ b/internal/backend/remote-state/azure/client_test.go @@ -320,7 +320,7 @@ func TestPutMaintainsMetaData(t *testing.T) { } bytes := []byte(acctest.RandString(20)) - err = remoteClient.Put(bytes) + err = remoteClient.Put(t.Context(), bytes) if err != nil { t.Fatalf("Error putting data: %+v", err) } diff --git a/internal/backend/remote-state/consul/backend_state.go b/internal/backend/remote-state/consul/backend_state.go index e698e222fb..6d440c8a6c 100644 --- a/internal/backend/remote-state/consul/backend_state.go +++ b/internal/backend/remote-state/consul/backend_state.go @@ -101,14 +101,14 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error // so States() knows it exists. lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockId, err := stateMgr.Lock(lockInfo) + lockId, err := stateMgr.Lock(context.TODO(), lockInfo) if err != nil { return nil, fmt.Errorf("failed to lock state in Consul: %w", err) } // Local helper function so we can call it multiple places lockUnlock := func(parent error) error { - if err := stateMgr.Unlock(lockId); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockId); err != nil { return fmt.Errorf(strings.TrimSpace(errStateUnlock), lockId, err) } @@ -116,7 +116,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { err = lockUnlock(err) return nil, err } @@ -127,7 +127,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(context.TODO(), nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/consul/client.go b/internal/backend/remote-state/consul/client.go index 86ca2aa0bd..7b878b7c95 100644 --- a/internal/backend/remote-state/consul/client.go +++ b/internal/backend/remote-state/consul/client.go @@ -73,7 +73,7 @@ type RemoteClient struct { sessionCancel context.CancelFunc } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(_ context.Context) (*remote.Payload, error) { c.mu.Lock() defer c.mu.Unlock() @@ -125,7 +125,7 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { }, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(_ context.Context, data []byte) error { // The state can be stored in 4 different ways, based on the payload size // and whether the user enabled gzip: // - single entry mode with plain JSON: a single JSON is stored at @@ -295,7 +295,7 @@ func (c *RemoteClient) Put(data []byte) error { return store(payload) } -func (c *RemoteClient) Delete() error { +func (c *RemoteClient) Delete(_ context.Context) error { c.mu.Lock() defer c.mu.Unlock() @@ -361,7 +361,7 @@ func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) { return li, nil } -func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *RemoteClient) Lock(_ context.Context, info *statemgr.LockInfo) (string, error) { c.mu.Lock() defer c.mu.Unlock() @@ -551,7 +551,7 @@ func (c *RemoteClient) createSession() (string, error) { return id, nil } -func (c *RemoteClient) Unlock(id string) error { +func (c *RemoteClient) Unlock(_ context.Context, id string) error { c.mu.Lock() defer c.mu.Unlock() diff --git a/internal/backend/remote-state/consul/client_test.go b/internal/backend/remote-state/consul/client_test.go index f5c6227a9e..6dcdb97e3a 100644 --- a/internal/backend/remote-state/consul/client_test.go +++ b/internal/backend/remote-state/consul/client_test.go @@ -141,12 +141,12 @@ func TestConsul_largeState(t *testing.T) { if err != nil { t.Fatal(err) } - err = c.Put(payload) + err = c.Put(t.Context(), payload) if err != nil { t.Fatal("could not put payload", err) } - remote, err := c.Get() + remote, err := c.Get(t.Context()) if err != nil { t.Fatal(err) } @@ -235,7 +235,7 @@ func TestConsul_largeState(t *testing.T) { ) // Deleting the state should remove all chunks - err = c.Delete() + err = c.Delete(t.Context()) if err != nil { t.Fatal(err) } @@ -312,14 +312,14 @@ func TestConsul_destroyLock(t *testing.T) { clientA := s.(*remote.State).Client.(*RemoteClient) info := statemgr.NewLockInfo() - id, err := clientA.Lock(info) + id, err := clientA.Lock(t.Context(), info) if err != nil { t.Fatal(err) } lockPath := clientA.Path + lockSuffix - if err := clientA.Unlock(id); err != nil { + if err := clientA.Unlock(t.Context(), id); err != nil { t.Fatal(err) } @@ -335,18 +335,18 @@ func TestConsul_destroyLock(t *testing.T) { clientB := s.(*remote.State).Client.(*RemoteClient) info = statemgr.NewLockInfo() - id, err = clientA.Lock(info) + id, err = clientA.Lock(t.Context(), info) if err != nil { t.Fatal(err) } - if err := clientB.Unlock(id); err != nil { + if err := clientB.Unlock(t.Context(), id); err != nil { t.Fatal(err) } testLock(clientA, lockPath) - err = clientA.Unlock(id) + err = clientA.Unlock(t.Context(), id) if err == nil { t.Fatal("consul lock should have been lost") @@ -383,7 +383,7 @@ func TestConsul_lostLock(t *testing.T) { info := statemgr.NewLockInfo() info.Operation = "test-lost-lock" - id, err := sA.Lock(info) + id, err := sA.Lock(t.Context(), info) if err != nil { t.Fatal(err) } @@ -403,7 +403,7 @@ func TestConsul_lostLock(t *testing.T) { <-reLocked - if err := sA.Unlock(id); err != nil { + if err := sA.Unlock(t.Context(), id); err != nil { t.Fatal(err) } } @@ -435,7 +435,7 @@ func TestConsul_lostLockConnection(t *testing.T) { info := statemgr.NewLockInfo() info.Operation = "test-lost-lock-connection" - id, err := s.Lock(info) + id, err := s.Lock(t.Context(), info) if err != nil { t.Fatal(err) } @@ -449,7 +449,7 @@ func TestConsul_lostLockConnection(t *testing.T) { <-dialed } - if err := s.Unlock(id); err != nil { + if err := s.Unlock(t.Context(), id); err != nil { t.Fatal("unlock error:", err) } } diff --git a/internal/backend/remote-state/cos/backend.go b/internal/backend/remote-state/cos/backend.go index 1edc5c4513..a9aef282e5 100644 --- a/internal/backend/remote-state/cos/backend.go +++ b/internal/backend/remote-state/cos/backend.go @@ -42,10 +42,9 @@ type Backend struct { encryption encryption.StateEncryption credential *common.Credential - cosContext context.Context - cosClient *cos.Client - tagClient *tag.Client - stsClient *sts.Client + cosClient *cos.Client + tagClient *tag.Client + stsClient *sts.Client region string bucket string @@ -211,8 +210,7 @@ func (b *Backend) configure(ctx context.Context) error { return nil } - b.cosContext = ctx - data := schema.FromContextBackendConfig(b.cosContext) + data := schema.FromContextBackendConfig(ctx) b.region = data.Get("region").(string) b.bucket = data.Get("bucket").(string) diff --git a/internal/backend/remote-state/cos/backend_state.go b/internal/backend/remote-state/cos/backend_state.go index c365dea81b..ad54bd1b19 100644 --- a/internal/backend/remote-state/cos/backend_state.go +++ b/internal/backend/remote-state/cos/backend_state.go @@ -26,13 +26,13 @@ const ( ) // Workspaces returns a list of names for the workspaces -func (b *Backend) Workspaces(context.Context) ([]string, error) { +func (b *Backend) Workspaces(ctx context.Context) ([]string, error) { c, err := b.client("tencentcloud") if err != nil { return nil, err } - obs, err := c.getBucket(b.prefix) + obs, err := c.getBucket(ctx, b.prefix) log.Printf("[DEBUG] list all workspaces, objects: %v, error: %v", obs, err) if err != nil { return nil, err @@ -63,7 +63,7 @@ func (b *Backend) Workspaces(context.Context) ([]string, error) { } // DeleteWorkspace deletes the named workspaces. The "default" state cannot be deleted. -func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error { +func (b *Backend) DeleteWorkspace(ctx context.Context, name string, _ bool) error { log.Printf("[DEBUG] delete workspace, workspace: %v", name) if name == backend.DefaultStateName || name == "" { @@ -75,7 +75,7 @@ func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error return err } - return c.Delete() + return c.Delete(ctx) } // StateMgr manage the state, if the named state not exists, a new file will created @@ -107,21 +107,21 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err // take a lock on this state while we write it lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockId, err := c.Lock(lockInfo) + lockId, err := c.Lock(context.TODO(), lockInfo) if err != nil { return nil, fmt.Errorf("Failed to lock cos state: %w", err) } // Local helper function so we can call it multiple places lockUnlock := func(e error) error { - if err := stateMgr.Unlock(lockId); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockId); err != nil { return fmt.Errorf(unlockErrMsg, err, lockId) } return e } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { err = lockUnlock(err) return nil, err } @@ -132,7 +132,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(context.TODO(), nil); err != nil { err = lockUnlock(err) return nil, err } @@ -154,14 +154,13 @@ func (b *Backend) client(name string) (*remoteClient, error) { } return &remoteClient{ - cosContext: b.cosContext, - cosClient: b.cosClient, - tagClient: b.tagClient, - bucket: b.bucket, - stateFile: b.stateFile(name), - lockFile: b.lockFile(name), - encrypt: b.encrypt, - acl: b.acl, + cosClient: b.cosClient, + tagClient: b.tagClient, + bucket: b.bucket, + stateFile: b.stateFile(name), + lockFile: b.lockFile(name), + encrypt: b.encrypt, + acl: b.acl, }, nil } diff --git a/internal/backend/remote-state/cos/backend_test.go b/internal/backend/remote-state/cos/backend_test.go index 78d8fd731f..7cbfc11690 100644 --- a/internal/backend/remote-state/cos/backend_test.go +++ b/internal/backend/remote-state/cos/backend_test.go @@ -234,7 +234,7 @@ func setupBackend(t *testing.T, bucket, prefix, key string, encrypt bool) backen t.Fatalf("unexpected error: %s", err) } - err = c.putBucket() + err = c.putBucket(t.Context()) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -250,7 +250,7 @@ func teardownBackend(t *testing.T, b backend.Backend) { t.Fatalf("unexpected error: %s", err) } - err = c.deleteBucket(true) + err = c.deleteBucket(t.Context(), true) if err != nil { t.Fatalf("unexpected error: %s", err) } diff --git a/internal/backend/remote-state/cos/client.go b/internal/backend/remote-state/cos/client.go index e96f37508b..b3e8e8ed9c 100644 --- a/internal/backend/remote-state/cos/client.go +++ b/internal/backend/remote-state/cos/client.go @@ -31,9 +31,8 @@ const ( // RemoteClient implements the client of remote state type remoteClient struct { - cosContext context.Context - cosClient *cos.Client - tagClient *tag.Client + cosClient *cos.Client + tagClient *tag.Client bucket string stateFile string @@ -43,10 +42,10 @@ type remoteClient struct { } // Get returns remote state file -func (c *remoteClient) Get() (*remote.Payload, error) { +func (c *remoteClient) Get(ctx context.Context) (*remote.Payload, error) { log.Printf("[DEBUG] get remote state file %s", c.stateFile) - exists, data, checksum, err := c.getObject(c.stateFile) + exists, data, checksum, err := c.getObject(ctx, c.stateFile) if err != nil { return nil, err } @@ -64,97 +63,97 @@ func (c *remoteClient) Get() (*remote.Payload, error) { } // Put put state file to remote -func (c *remoteClient) Put(data []byte) error { +func (c *remoteClient) Put(ctx context.Context, data []byte) error { log.Printf("[DEBUG] put remote state file %s", c.stateFile) - return c.putObject(c.stateFile, data) + return c.putObject(ctx, c.stateFile, data) } // Delete delete remote state file -func (c *remoteClient) Delete() error { +func (c *remoteClient) Delete(ctx context.Context) error { log.Printf("[DEBUG] delete remote state file %s", c.stateFile) - return c.deleteObject(c.stateFile) + return c.deleteObject(ctx, c.stateFile) } // Lock lock remote state file for writing -func (c *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *remoteClient) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { log.Printf("[DEBUG] lock remote state file %s", c.lockFile) err := c.cosLock(c.bucket, c.lockFile) if err != nil { - return "", c.lockError(err) + return "", c.lockError(ctx, err) } // Local helper function so we can call it multiple places lockUnlock := func(parent error) error { if err := c.cosUnlock(c.bucket, c.lockFile); err != nil { return errors.Join( - fmt.Errorf("error unlocking cos state: %w", c.lockError(err)), + fmt.Errorf("error unlocking cos state: %w", c.lockError(ctx, err)), parent, ) } return parent } - exists, _, _, err := c.getObject(c.lockFile) + exists, _, _, err := c.getObject(ctx, c.lockFile) if err != nil { - return "", lockUnlock(c.lockError(err)) + return "", lockUnlock(c.lockError(ctx, err)) } if exists { - return "", lockUnlock(c.lockError(fmt.Errorf("lock file %s exists", c.lockFile))) + return "", lockUnlock(c.lockError(ctx, fmt.Errorf("lock file %s exists", c.lockFile))) } info.Path = c.lockFile data, err := json.Marshal(info) if err != nil { - return "", lockUnlock(c.lockError(err)) + return "", lockUnlock(c.lockError(ctx, err)) } check := fmt.Sprintf("%x", md5.Sum(data)) - err = c.putObject(c.lockFile, data) + err = c.putObject(ctx, c.lockFile, data) if err != nil { - return "", lockUnlock(c.lockError(err)) + return "", lockUnlock(c.lockError(ctx, err)) } return check, lockUnlock(nil) } // Unlock unlock remote state file -func (c *remoteClient) Unlock(check string) error { +func (c *remoteClient) Unlock(ctx context.Context, check string) error { log.Printf("[DEBUG] unlock remote state file %s", c.lockFile) - info, err := c.lockInfo() + info, err := c.lockInfo(ctx) if err != nil { - return c.lockError(err) + return c.lockError(ctx, err) } if info.ID != check { - return c.lockError(fmt.Errorf("lock id mismatch, %v != %v", info.ID, check)) + return c.lockError(ctx, fmt.Errorf("lock id mismatch, %v != %v", info.ID, check)) } - err = c.deleteObject(c.lockFile) + err = c.deleteObject(ctx, c.lockFile) if err != nil { - return c.lockError(err) + return c.lockError(ctx, err) } err = c.cosUnlock(c.bucket, c.lockFile) if err != nil { - return c.lockError(err) + return c.lockError(ctx, err) } return nil } // lockError returns statemgr.LockError -func (c *remoteClient) lockError(err error) *statemgr.LockError { +func (c *remoteClient) lockError(ctx context.Context, err error) *statemgr.LockError { log.Printf("[DEBUG] failed to lock or unlock %s: %v", c.lockFile, err) lockErr := &statemgr.LockError{ Err: err, } - info, infoErr := c.lockInfo() + info, infoErr := c.lockInfo(ctx) if infoErr != nil { lockErr.Err = multierror.Append(lockErr.Err, infoErr) } else { @@ -165,8 +164,8 @@ func (c *remoteClient) lockError(err error) *statemgr.LockError { } // lockInfo returns LockInfo from lock file -func (c *remoteClient) lockInfo() (*statemgr.LockInfo, error) { - exists, data, checksum, err := c.getObject(c.lockFile) +func (c *remoteClient) lockInfo(ctx context.Context) (*statemgr.LockInfo, error) { + exists, data, checksum, err := c.getObject(ctx, c.lockFile) if err != nil { return nil, err } @@ -186,8 +185,8 @@ func (c *remoteClient) lockInfo() (*statemgr.LockInfo, error) { } // getObject get remote object -func (c *remoteClient) getObject(cosFile string) (exists bool, data []byte, checksum string, err error) { - rsp, err := c.cosClient.Object.Get(c.cosContext, cosFile, nil) +func (c *remoteClient) getObject(ctx context.Context, cosFile string) (exists bool, data []byte, checksum string, err error) { + rsp, err := c.cosClient.Object.Get(ctx, cosFile, nil) if rsp == nil { log.Printf("[DEBUG] getObject %s: error: %v", cosFile, err) err = fmt.Errorf("failed to open file at %v: %w", cosFile, err) @@ -231,7 +230,7 @@ func (c *remoteClient) getObject(cosFile string) (exists bool, data []byte, chec } // putObject put object to remote -func (c *remoteClient) putObject(cosFile string, data []byte) error { +func (c *remoteClient) putObject(ctx context.Context, cosFile string, data []byte) error { opt := &cos.ObjectPutOptions{ ObjectPutHeaderOptions: &cos.ObjectPutHeaderOptions{ XCosMetaXXX: &http.Header{ @@ -248,7 +247,7 @@ func (c *remoteClient) putObject(cosFile string, data []byte) error { } r := bytes.NewReader(data) - rsp, err := c.cosClient.Object.Put(c.cosContext, cosFile, r, opt) + rsp, err := c.cosClient.Object.Put(ctx, cosFile, r, opt) if rsp == nil { log.Printf("[DEBUG] putObject %s: error: %v", cosFile, err) return fmt.Errorf("failed to save file to %v: %w", cosFile, err) @@ -264,8 +263,8 @@ func (c *remoteClient) putObject(cosFile string, data []byte) error { } // deleteObject delete remote object -func (c *remoteClient) deleteObject(cosFile string) error { - rsp, err := c.cosClient.Object.Delete(c.cosContext, cosFile) +func (c *remoteClient) deleteObject(ctx context.Context, cosFile string) error { + rsp, err := c.cosClient.Object.Delete(ctx, cosFile) if rsp == nil { log.Printf("[DEBUG] deleteObject %s: error: %v", cosFile, err) return fmt.Errorf("failed to delete file %v: %w", cosFile, err) @@ -285,8 +284,8 @@ func (c *remoteClient) deleteObject(cosFile string) error { } // getBucket list bucket by prefix -func (c *remoteClient) getBucket(prefix string) (obs []cos.Object, err error) { - fs, rsp, err := c.cosClient.Bucket.Get(c.cosContext, &cos.BucketGetOptions{Prefix: prefix}) +func (c *remoteClient) getBucket(ctx context.Context, prefix string) (obs []cos.Object, err error) { + fs, rsp, err := c.cosClient.Bucket.Get(ctx, &cos.BucketGetOptions{Prefix: prefix}) if rsp == nil { log.Printf("[DEBUG] getBucket %s/%s: error: %v", c.bucket, prefix, err) err = fmt.Errorf("bucket %s not exists", c.bucket) @@ -308,8 +307,8 @@ func (c *remoteClient) getBucket(prefix string) (obs []cos.Object, err error) { } // putBucket create cos bucket -func (c *remoteClient) putBucket() error { - rsp, err := c.cosClient.Bucket.Put(c.cosContext, nil) +func (c *remoteClient) putBucket(ctx context.Context) error { + rsp, err := c.cosClient.Bucket.Put(ctx, nil) if rsp == nil { log.Printf("[DEBUG] putBucket %s: error: %v", c.bucket, err) return fmt.Errorf("failed to create bucket %v: %w", c.bucket, err) @@ -329,9 +328,9 @@ func (c *remoteClient) putBucket() error { } // deleteBucket delete cos bucket -func (c *remoteClient) deleteBucket(recursive bool) error { +func (c *remoteClient) deleteBucket(ctx context.Context, recursive bool) error { if recursive { - obs, err := c.getBucket("") + obs, err := c.getBucket(ctx, "") if err != nil { if strings.Contains(err.Error(), "not exists") { return nil @@ -340,14 +339,14 @@ func (c *remoteClient) deleteBucket(recursive bool) error { return fmt.Errorf("failed to empty bucket %v: %w", c.bucket, err) } for _, v := range obs { - err := c.deleteObject(v.Key) + err := c.deleteObject(ctx, v.Key) if err != nil { return fmt.Errorf("failed to delete object with key %s: %w", v.Key, err) } } } - rsp, err := c.cosClient.Bucket.Delete(c.cosContext) + rsp, err := c.cosClient.Bucket.Delete(ctx) if rsp == nil { log.Printf("[DEBUG] deleteBucket %s: error: %v", c.bucket, err) return fmt.Errorf("failed to delete bucket %v: %w", c.bucket, err) diff --git a/internal/backend/remote-state/gcs/backend.go b/internal/backend/remote-state/gcs/backend.go index 58a1ae902b..2002f6a32b 100644 --- a/internal/backend/remote-state/gcs/backend.go +++ b/internal/backend/remote-state/gcs/backend.go @@ -32,8 +32,7 @@ type Backend struct { *schema.Backend encryption encryption.StateEncryption - storageClient *storage.Client - storageContext context.Context + storageClient *storage.Client bucketName string prefix string @@ -131,13 +130,7 @@ func (b *Backend) configure(ctx context.Context) error { return nil } - // ctx is a background context with the backend config added. - // Since no context is passed to remoteClient.Get(), .Lock(), etc. but - // one is required for calling the GCP API, we're holding on to this - // context here and re-use it later. - b.storageContext = ctx - - data := schema.FromContextBackendConfig(b.storageContext) + data := schema.FromContextBackendConfig(ctx) b.bucketName = data.Get("bucket").(string) b.prefix = strings.TrimLeft(data.Get("prefix").(string), "/") @@ -219,7 +212,7 @@ func (b *Backend) configure(ctx context.Context) error { endpoint := option.WithEndpoint(storageEndpoint.(string)) opts = append(opts, endpoint) } - client, err := storage.NewClient(b.storageContext, opts...) + client, err := storage.NewClient(ctx, opts...) if err != nil { return fmt.Errorf("storage.NewClient() failed: %w", err) } diff --git a/internal/backend/remote-state/gcs/backend_state.go b/internal/backend/remote-state/gcs/backend_state.go index a3c441a57f..44cb3f507c 100644 --- a/internal/backend/remote-state/gcs/backend_state.go +++ b/internal/backend/remote-state/gcs/backend_state.go @@ -28,11 +28,11 @@ const ( // Workspaces returns a list of names for the workspaces found on GCS. The default // state is always returned as the first element in the slice. -func (b *Backend) Workspaces(context.Context) ([]string, error) { +func (b *Backend) Workspaces(ctx context.Context) ([]string, error) { states := []string{backend.DefaultStateName} bucket := b.storageClient.Bucket(b.bucketName) - objs := bucket.Objects(b.storageContext, &storage.Query{ + objs := bucket.Objects(ctx, &storage.Query{ Delimiter: "/", Prefix: b.prefix, }) @@ -61,7 +61,7 @@ func (b *Backend) Workspaces(context.Context) ([]string, error) { } // DeleteWorkspace deletes the named workspaces. The "default" state cannot be deleted. -func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error { +func (b *Backend) DeleteWorkspace(ctx context.Context, name string, _ bool) error { if name == backend.DefaultStateName { return fmt.Errorf("cowardly refusing to delete the %q state", name) } @@ -71,7 +71,7 @@ func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error return err } - return c.Delete() + return c.Delete(ctx) } // client returns a remoteClient for the named state. @@ -81,19 +81,18 @@ func (b *Backend) client(name string) (*remoteClient, error) { } return &remoteClient{ - storageContext: b.storageContext, - storageClient: b.storageClient, - bucketName: b.bucketName, - stateFilePath: b.stateFile(name), - lockFilePath: b.lockFile(name), - encryptionKey: b.encryptionKey, - kmsKeyName: b.kmsKeyName, + storageClient: b.storageClient, + bucketName: b.bucketName, + stateFilePath: b.stateFile(name), + lockFilePath: b.lockFile(name), + encryptionKey: b.encryptionKey, + kmsKeyName: b.kmsKeyName, }, nil } // StateMgr reads and returns the named state from GCS. If the named state does // not yet exist, a new state file is created. -func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error) { +func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, error) { c, err := b.client(name) if err != nil { return nil, err @@ -102,7 +101,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error st := remote.NewState(c, b.encryption) // Grab the value - if err := st.RefreshState(); err != nil { + if err := st.RefreshState(ctx); err != nil { return nil, err } @@ -111,14 +110,14 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockID, err := st.Lock(lockInfo) + lockID, err := st.Lock(ctx, lockInfo) if err != nil { return nil, err } // Local helper function so we can call it multiple places unlock := func(baseErr error) error { - if err := st.Unlock(lockID); err != nil { + if err := st.Unlock(ctx, lockID); err != nil { const unlockErrMsg = `%v Additionally, unlocking the state file on Google Cloud Storage failed: @@ -138,7 +137,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error if err := st.WriteState(states.NewState()); err != nil { return nil, unlock(err) } - if err := st.PersistState(nil); err != nil { + if err := st.PersistState(ctx, nil); err != nil { return nil, unlock(err) } diff --git a/internal/backend/remote-state/gcs/backend_test.go b/internal/backend/remote-state/gcs/backend_test.go index 91b1e34198..9cd598c1b1 100644 --- a/internal/backend/remote-state/gcs/backend_test.go +++ b/internal/backend/remote-state/gcs/backend_test.go @@ -244,7 +244,7 @@ func setupBackend(t *testing.T, bucket, prefix, key, kmsName string) backend.Bac // create the bucket if it doesn't exist bkt := be.storageClient.Bucket(bucket) - _, err := bkt.Attrs(be.storageContext) + _, err := bkt.Attrs(t.Context()) if err != nil { if err != storage.ErrBucketNotExist { t.Fatal(err) @@ -253,7 +253,7 @@ func setupBackend(t *testing.T, bucket, prefix, key, kmsName string) backend.Bac attrs := &storage.BucketAttrs{ Location: os.Getenv("GOOGLE_REGION"), } - err := bkt.Create(be.storageContext, projectID, attrs) + err := bkt.Create(t.Context(), projectID, attrs) if err != nil { t.Fatal(err) } @@ -383,7 +383,7 @@ func teardownBackend(t *testing.T, be backend.Backend, prefix string) { if !ok { t.Fatalf("be is a %T, want a *gcsBackend", be) } - ctx := gcsBE.storageContext + ctx := t.Context() bucket := gcsBE.storageClient.Bucket(gcsBE.bucketName) objs := bucket.Objects(ctx, nil) diff --git a/internal/backend/remote-state/gcs/client.go b/internal/backend/remote-state/gcs/client.go index d454383cc9..e8347342ee 100644 --- a/internal/backend/remote-state/gcs/client.go +++ b/internal/backend/remote-state/gcs/client.go @@ -6,6 +6,7 @@ package gcs import ( + "context" "encoding/json" "errors" "fmt" @@ -16,24 +17,22 @@ import ( multierror "github.com/hashicorp/go-multierror" "github.com/opentofu/opentofu/internal/states/remote" "github.com/opentofu/opentofu/internal/states/statemgr" - "golang.org/x/net/context" ) // remoteClient is used by "state/remote".State to read and write // blobs representing state. // Implements "state/remote".ClientLocker type remoteClient struct { - storageContext context.Context - storageClient *storage.Client - bucketName string - stateFilePath string - lockFilePath string - encryptionKey []byte - kmsKeyName string + storageClient *storage.Client + bucketName string + stateFilePath string + lockFilePath string + encryptionKey []byte + kmsKeyName string } -func (c *remoteClient) Get() (payload *remote.Payload, err error) { - stateFileReader, err := c.stateFile().NewReader(c.storageContext) +func (c *remoteClient) Get(ctx context.Context) (payload *remote.Payload, err error) { + stateFileReader, err := c.stateFile().NewReader(ctx) if err != nil { if err == storage.ErrObjectNotExist { return nil, nil @@ -48,7 +47,7 @@ func (c *remoteClient) Get() (payload *remote.Payload, err error) { return nil, fmt.Errorf("Failed to read state file from %v: %w", c.stateFileURL(), err) } - stateFileAttrs, err := c.stateFile().Attrs(c.storageContext) + stateFileAttrs, err := c.stateFile().Attrs(ctx) if err != nil { return nil, fmt.Errorf("Failed to read state file attrs from %v: %w", c.stateFileURL(), err) } @@ -61,9 +60,9 @@ func (c *remoteClient) Get() (payload *remote.Payload, err error) { return result, nil } -func (c *remoteClient) Put(data []byte) error { +func (c *remoteClient) Put(ctx context.Context, data []byte) error { err := func() error { - stateFileWriter := c.stateFile().NewWriter(c.storageContext) + stateFileWriter := c.stateFile().NewWriter(ctx) if len(c.kmsKeyName) > 0 { stateFileWriter.KMSKeyName = c.kmsKeyName } @@ -79,8 +78,8 @@ func (c *remoteClient) Put(data []byte) error { return nil } -func (c *remoteClient) Delete() error { - if err := c.stateFile().Delete(c.storageContext); err != nil { +func (c *remoteClient) Delete(ctx context.Context) error { + if err := c.stateFile().Delete(ctx); err != nil { return fmt.Errorf("Failed to delete state file %v: %w", c.stateFileURL(), err) } @@ -89,7 +88,7 @@ func (c *remoteClient) Delete() error { // Lock writes to a lock file, ensuring file creation. Returns the generation // number, which must be passed to Unlock(). -func (c *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *remoteClient) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { // update the path we're using // we can't set the ID until the info is written info.Path = c.lockFileURL() @@ -100,7 +99,7 @@ func (c *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { } lockFile := c.lockFile() - w := lockFile.If(storage.Conditions{DoesNotExist: true}).NewWriter(c.storageContext) + w := lockFile.If(storage.Conditions{DoesNotExist: true}).NewWriter(ctx) err = func() error { if _, err := w.Write(infoJson); err != nil { return err @@ -109,7 +108,7 @@ func (c *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { }() if err != nil { - return "", c.lockError(fmt.Errorf("writing %q failed: %w", c.lockFileURL(), err)) + return "", c.lockError(ctx, fmt.Errorf("writing %q failed: %w", c.lockFileURL(), err)) } info.ID = strconv.FormatInt(w.Attrs().Generation, 10) @@ -117,25 +116,25 @@ func (c *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { return info.ID, nil } -func (c *remoteClient) Unlock(id string) error { +func (c *remoteClient) Unlock(ctx context.Context, id string) error { gen, err := strconv.ParseInt(id, 10, 64) if err != nil { return fmt.Errorf("Lock ID should be numerical value, got '%s'", id) } - if err := c.lockFile().If(storage.Conditions{GenerationMatch: gen}).Delete(c.storageContext); err != nil { - return c.lockError(err) + if err := c.lockFile().If(storage.Conditions{GenerationMatch: gen}).Delete(ctx); err != nil { + return c.lockError(ctx, err) } return nil } -func (c *remoteClient) lockError(err error) *statemgr.LockError { +func (c *remoteClient) lockError(ctx context.Context, err error) *statemgr.LockError { lockErr := &statemgr.LockError{ Err: err, } - info, infoErr := c.lockInfo() + info, infoErr := c.lockInfo(ctx) switch { case errors.Is(infoErr, storage.ErrObjectNotExist): // Race condition - file exists initially but then has been deleted by other process @@ -150,8 +149,8 @@ func (c *remoteClient) lockError(err error) *statemgr.LockError { // lockInfo reads the lock file, parses its contents and returns the parsed // LockInfo struct. -func (c *remoteClient) lockInfo() (*statemgr.LockInfo, error) { - r, err := c.lockFile().NewReader(c.storageContext) +func (c *remoteClient) lockInfo(ctx context.Context) (*statemgr.LockInfo, error) { + r, err := c.lockFile().NewReader(ctx) if err != nil { return nil, err } @@ -170,7 +169,7 @@ func (c *remoteClient) lockInfo() (*statemgr.LockInfo, error) { // We use the Generation as the ID, so overwrite the ID in the json. // This can't be written into the Info, since the generation isn't known // until it's written. - attrs, err := c.lockFile().Attrs(c.storageContext) + attrs, err := c.lockFile().Attrs(ctx) if err != nil { return nil, err } diff --git a/internal/backend/remote-state/http/client.go b/internal/backend/remote-state/http/client.go index 5ba755487b..7785778206 100644 --- a/internal/backend/remote-state/http/client.go +++ b/internal/backend/remote-state/http/client.go @@ -7,6 +7,7 @@ package http import ( "bytes" + "context" "crypto/md5" "encoding/base64" "encoding/json" @@ -43,7 +44,7 @@ type httpClient struct { jsonLockInfo []byte } -func (c *httpClient) httpRequest(method string, url *url.URL, data []byte, what string) (*http.Response, error) { +func (c *httpClient) httpRequest(ctx context.Context, method string, url *url.URL, data []byte, what string) (*http.Response, error) { var body interface{} if len(data) > 0 { body = data @@ -52,7 +53,7 @@ func (c *httpClient) httpRequest(method string, url *url.URL, data []byte, what log.Printf("[DEBUG] Executing HTTP remote state request for: %q", what) // Create the request - req, err := retryablehttp.NewRequest(method, url.String(), body) + req, err := retryablehttp.NewRequestWithContext(ctx, method, url.String(), body) if err != nil { return nil, fmt.Errorf("Failed to make %s HTTP request: %w", what, err) } @@ -89,14 +90,14 @@ func (c *httpClient) httpRequest(method string, url *url.URL, data []byte, what return resp, nil } -func (c *httpClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *httpClient) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { if c.LockURL == nil { return "", nil } c.lockID = "" jsonLockInfo := info.Marshal() - resp, err := c.httpRequest(c.LockMethod, c.LockURL, jsonLockInfo, "lock") + resp, err := c.httpRequest(ctx, c.LockMethod, c.LockURL, jsonLockInfo, "lock") if err != nil { return "", err } @@ -139,7 +140,7 @@ func (c *httpClient) Lock(info *statemgr.LockInfo) (string, error) { } } -func (c *httpClient) Unlock(id string) error { +func (c *httpClient) Unlock(ctx context.Context, id string) error { if c.UnlockURL == nil { return nil } @@ -162,7 +163,7 @@ func (c *httpClient) Unlock(id string) error { lockInfo.ID = id - resp, err := c.httpRequest(c.UnlockMethod, c.UnlockURL, lockInfo.Marshal(), "unlock") + resp, err := c.httpRequest(ctx, c.UnlockMethod, c.UnlockURL, lockInfo.Marshal(), "unlock") if err != nil { return err } @@ -177,8 +178,8 @@ func (c *httpClient) Unlock(id string) error { } } -func (c *httpClient) Get() (*remote.Payload, error) { - resp, err := c.httpRequest(http.MethodGet, c.URL, nil, "get state") +func (c *httpClient) Get(ctx context.Context) (*remote.Payload, error) { + resp, err := c.httpRequest(ctx, http.MethodGet, c.URL, nil, "get state") if err != nil { return nil, err } @@ -240,7 +241,7 @@ func (c *httpClient) Get() (*remote.Payload, error) { return payload, nil } -func (c *httpClient) Put(data []byte) error { +func (c *httpClient) Put(ctx context.Context, data []byte) error { // Copy the target URL base := *c.URL @@ -263,7 +264,7 @@ func (c *httpClient) Put(data []byte) error { if c.UpdateMethod != "" { method = c.UpdateMethod } - resp, err := c.httpRequest(method, &base, data, "upload state") + resp, err := c.httpRequest(ctx, method, &base, data, "upload state") if err != nil { return err } @@ -279,9 +280,9 @@ func (c *httpClient) Put(data []byte) error { } } -func (c *httpClient) Delete() error { +func (c *httpClient) Delete(ctx context.Context) error { // Make the request - resp, err := c.httpRequest(http.MethodDelete, c.URL, nil, "delete state") + resp, err := c.httpRequest(ctx, http.MethodDelete, c.URL, nil, "delete state") if err != nil { return err } diff --git a/internal/backend/remote-state/http/client_test.go b/internal/backend/remote-state/http/client_test.go index 60d4b93983..64aaa1b58b 100644 --- a/internal/backend/remote-state/http/client_test.go +++ b/internal/backend/remote-state/http/client_test.go @@ -370,7 +370,7 @@ func TestHttpClient_Unlock(t *testing.T) { jsonLockInfo: tt.jsonLockInfo, } - err = client.Unlock(tt.lockID) + err = client.Unlock(t.Context(), tt.lockID) if tt.expectedErrorMsg != nil && err == nil { // no expected error t.Errorf("UnLock() no expected error = %v", tt.expectedErrorMsg) @@ -474,7 +474,7 @@ func TestHttpClient_lock(t *testing.T) { Client: retryablehttp.NewClient(), } - lockID, err := client.Lock(tt.lockInfo) + lockID, err := client.Lock(t.Context(), tt.lockInfo) if tt.expectedErrorMsg != nil && err == nil { // no expected error t.Errorf("Lock() no expected error = %v", tt.expectedErrorMsg) diff --git a/internal/backend/remote-state/http/server_test.go b/internal/backend/remote-state/http/server_test.go index 97e3ffe37d..bcada3c5e3 100644 --- a/internal/backend/remote-state/http/server_test.go +++ b/internal/backend/remote-state/http/server_test.go @@ -289,7 +289,7 @@ func TestMTLSServer_NoCertFails(t *testing.T) { } opErr := new(net.OpError) - err = sm.RefreshState() + err = sm.RefreshState(t.Context()) if err == nil { t.Fatal("expected error when refreshing state without a client cert") } @@ -358,7 +358,7 @@ func TestMTLSServer_WithCertPasses(t *testing.T) { if err != nil { t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err) } - if err = sm.RefreshState(); err != nil { + if err = sm.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error calling RefreshState: %v", err) } state := sm.State() @@ -399,10 +399,10 @@ func TestMTLSServer_WithCertPasses(t *testing.T) { if err = sm.WriteState(state); err != nil { t.Errorf("error writing state: %v", err) } - if err = sm.PersistState(nil); err != nil { + if err = sm.PersistState(t.Context(), nil); err != nil { t.Errorf("error persisting state: %v", err) } - if err = sm.RefreshState(); err != nil { + if err = sm.RefreshState(t.Context()); err != nil { t.Errorf("error refreshing state: %v", err) } diff --git a/internal/backend/remote-state/inmem/backend.go b/internal/backend/remote-state/inmem/backend.go index f5553cc857..acdc4fccf3 100644 --- a/internal/backend/remote-state/inmem/backend.go +++ b/internal/backend/remote-state/inmem/backend.go @@ -139,14 +139,14 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error // take a lock and create a new state if it doesn't exist. lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockID, err := s.Lock(lockInfo) + lockID, err := s.Lock(context.TODO(), lockInfo) if err != nil { return nil, fmt.Errorf("failed to lock inmem state: %w", err) } // Local helper function so we can call it multiple places lockUnlock := func(parent error) error { - if err := s.Unlock(lockID); err != nil { + if err := s.Unlock(context.TODO(), lockID); err != nil { return errors.Join( fmt.Errorf("error unlocking inmem state: %w", err), parent, @@ -161,7 +161,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error err = lockUnlock(err) return nil, err } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(context.TODO(), nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/inmem/backend_test.go b/internal/backend/remote-state/inmem/backend_test.go index f6c77a1e92..67a4fa9d65 100644 --- a/internal/backend/remote-state/inmem/backend_test.go +++ b/internal/backend/remote-state/inmem/backend_test.go @@ -88,11 +88,11 @@ func TestRemoteState(t *testing.T) { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatal(err) } } diff --git a/internal/backend/remote-state/inmem/client.go b/internal/backend/remote-state/inmem/client.go index 7504ede0db..f494ff997f 100644 --- a/internal/backend/remote-state/inmem/client.go +++ b/internal/backend/remote-state/inmem/client.go @@ -6,6 +6,7 @@ package inmem import ( + "context" "crypto/md5" "github.com/opentofu/opentofu/internal/states/remote" @@ -19,7 +20,7 @@ type RemoteClient struct { Name string } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(_ context.Context) (*remote.Payload, error) { if c.Data == nil { return nil, nil } @@ -30,7 +31,7 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { }, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(_ context.Context, data []byte) error { md5 := md5.Sum(data) c.Data = data @@ -38,15 +39,15 @@ func (c *RemoteClient) Put(data []byte) error { return nil } -func (c *RemoteClient) Delete() error { +func (c *RemoteClient) Delete(_ context.Context) error { c.Data = nil c.MD5 = nil return nil } -func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *RemoteClient) Lock(_ context.Context, info *statemgr.LockInfo) (string, error) { return locks.lock(c.Name, info) } -func (c *RemoteClient) Unlock(id string) error { +func (c *RemoteClient) Unlock(_ context.Context, id string) error { return locks.unlock(c.Name, id) } diff --git a/internal/backend/remote-state/kubernetes/backend_state.go b/internal/backend/remote-state/kubernetes/backend_state.go index bc350dd0fa..a9f1a8af0a 100644 --- a/internal/backend/remote-state/kubernetes/backend_state.go +++ b/internal/backend/remote-state/kubernetes/backend_state.go @@ -65,7 +65,7 @@ func (b *Backend) Workspaces(ctx context.Context) ([]string, error) { return states, nil } -func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error { +func (b *Backend) DeleteWorkspace(ctx context.Context, name string, _ bool) error { if name == backend.DefaultStateName || name == "" { return fmt.Errorf("can't delete default state") } @@ -75,7 +75,7 @@ func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error return err } - return client.Delete() + return client.Delete(ctx) } func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error) { @@ -87,7 +87,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error stateMgr := remote.NewState(c, b.encryption) // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { return nil, err } @@ -96,7 +96,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockID, err := stateMgr.Lock(lockInfo) + lockID, err := stateMgr.Lock(context.TODO(), lockInfo) if err != nil { return nil, err } @@ -108,7 +108,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error // Local helper function so we can call it multiple places unlock := func(baseErr error) error { - if err := stateMgr.Unlock(lockID); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockID); err != nil { const unlockErrMsg = `%v Additionally, unlocking the state in Kubernetes failed: @@ -128,7 +128,7 @@ func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error if err := stateMgr.WriteState(states.NewState()); err != nil { return nil, unlock(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(context.TODO(), nil); err != nil { return nil, unlock(err) } diff --git a/internal/backend/remote-state/kubernetes/backend_test.go b/internal/backend/remote-state/kubernetes/backend_test.go index ff8268f54c..b3ea95099f 100644 --- a/internal/backend/remote-state/kubernetes/backend_test.go +++ b/internal/backend/remote-state/kubernetes/backend_test.go @@ -111,7 +111,7 @@ func TestBackendLocksSoak(t *testing.T) { li.Who = fmt.Sprintf("client-%v", n) for i := 0; i < lockAttempts; i++ { - id, err := locker.Lock(li) + id, err := locker.Lock(t.Context(), li) if err != nil { continue } @@ -119,7 +119,7 @@ func TestBackendLocksSoak(t *testing.T) { // hold onto the lock for a little bit time.Sleep(time.Duration(rand.Intn(10)) * time.Microsecond) - err = locker.Unlock(id) + err = locker.Unlock(t.Context(), id) if err != nil { t.Errorf("failed to unlock: %v", err) } diff --git a/internal/backend/remote-state/kubernetes/client.go b/internal/backend/remote-state/kubernetes/client.go index ccc9d2196e..7b41ef38d7 100644 --- a/internal/backend/remote-state/kubernetes/client.go +++ b/internal/backend/remote-state/kubernetes/client.go @@ -47,12 +47,12 @@ type RemoteClient struct { workspace string } -func (c *RemoteClient) Get() (payload *remote.Payload, err error) { +func (c *RemoteClient) Get(ctx context.Context) (payload *remote.Payload, err error) { secretName, err := c.createSecretName() if err != nil { return nil, err } - secret, err := c.kubernetesSecretClient.Get(context.Background(), secretName, metav1.GetOptions{}) + secret, err := c.kubernetesSecretClient.Get(ctx, secretName, metav1.GetOptions{}) if err != nil { if k8serrors.IsNotFound(err) { return nil, nil @@ -83,8 +83,7 @@ func (c *RemoteClient) Get() (payload *remote.Payload, err error) { return p, nil } -func (c *RemoteClient) Put(data []byte) error { - ctx := context.Background() +func (c *RemoteClient) Put(ctx context.Context, data []byte) error { secretName, err := c.createSecretName() if err != nil { return err @@ -95,7 +94,7 @@ func (c *RemoteClient) Put(data []byte) error { return err } - secret, err := c.getSecret(secretName) + secret, err := c.getSecret(ctx, secretName) if err != nil { if !k8serrors.IsNotFound(err) { return err @@ -124,13 +123,13 @@ func (c *RemoteClient) Put(data []byte) error { } // Delete the state secret -func (c *RemoteClient) Delete() error { +func (c *RemoteClient) Delete(ctx context.Context) error { secretName, err := c.createSecretName() if err != nil { return err } - err = c.deleteSecret(secretName) + err = c.deleteSecret(ctx, secretName) if err != nil { if !k8serrors.IsNotFound(err) { return err @@ -142,7 +141,7 @@ func (c *RemoteClient) Delete() error { return err } - err = c.deleteLease(leaseName) + err = c.deleteLease(ctx, leaseName) if err != nil { if !k8serrors.IsNotFound(err) { return err @@ -151,14 +150,13 @@ func (c *RemoteClient) Delete() error { return nil } -func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { - ctx := context.Background() +func (c *RemoteClient) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { leaseName, err := c.createLeaseName() if err != nil { return "", err } - lease, err := c.getLease(leaseName) + lease, err := c.getLease(ctx, leaseName) if err != nil { if !k8serrors.IsNotFound(err) { return "", err @@ -213,13 +211,13 @@ func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { return info.ID, err } -func (c *RemoteClient) Unlock(id string) error { +func (c *RemoteClient) Unlock(ctx context.Context, id string) error { leaseName, err := c.createLeaseName() if err != nil { return err } - lease, err := c.getLease(leaseName) + lease, err := c.getLease(ctx, leaseName) if err != nil { return err } @@ -242,7 +240,7 @@ func (c *RemoteClient) Unlock(id string) error { lease.Spec.HolderIdentity = nil removeLockInfo(lease) - _, err = c.kubernetesLeaseClient.Update(context.Background(), lease, metav1.UpdateOptions{}) + _, err = c.kubernetesLeaseClient.Update(ctx, lease, metav1.UpdateOptions{}) if err != nil { lockErr.Err = err return lockErr @@ -283,16 +281,16 @@ func (c *RemoteClient) getLabels() map[string]string { return l } -func (c *RemoteClient) getSecret(name string) (*unstructured.Unstructured, error) { - return c.kubernetesSecretClient.Get(context.Background(), name, metav1.GetOptions{}) +func (c *RemoteClient) getSecret(ctx context.Context, name string) (*unstructured.Unstructured, error) { + return c.kubernetesSecretClient.Get(ctx, name, metav1.GetOptions{}) } -func (c *RemoteClient) getLease(name string) (*coordinationv1.Lease, error) { - return c.kubernetesLeaseClient.Get(context.Background(), name, metav1.GetOptions{}) +func (c *RemoteClient) getLease(ctx context.Context, name string) (*coordinationv1.Lease, error) { + return c.kubernetesLeaseClient.Get(ctx, name, metav1.GetOptions{}) } -func (c *RemoteClient) deleteSecret(name string) error { - secret, err := c.getSecret(name) +func (c *RemoteClient) deleteSecret(ctx context.Context, name string) error { + secret, err := c.getSecret(ctx, name) if err != nil { return err } @@ -305,11 +303,11 @@ func (c *RemoteClient) deleteSecret(name string) error { delProp := metav1.DeletePropagationBackground delOps := metav1.DeleteOptions{PropagationPolicy: &delProp} - return c.kubernetesSecretClient.Delete(context.Background(), name, delOps) + return c.kubernetesSecretClient.Delete(ctx, name, delOps) } -func (c *RemoteClient) deleteLease(name string) error { - secret, err := c.getLease(name) +func (c *RemoteClient) deleteLease(ctx context.Context, name string) error { + secret, err := c.getLease(ctx, name) if err != nil { return err } @@ -322,7 +320,7 @@ func (c *RemoteClient) deleteLease(name string) error { delProp := metav1.DeletePropagationBackground delOps := metav1.DeleteOptions{PropagationPolicy: &delProp} - return c.kubernetesLeaseClient.Delete(context.Background(), name, delOps) + return c.kubernetesLeaseClient.Delete(ctx, name, delOps) } func (c *RemoteClient) createSecretName() (string, error) { diff --git a/internal/backend/remote-state/kubernetes/client_test.go b/internal/backend/remote-state/kubernetes/client_test.go index 536f1b8a24..19a0c84d78 100644 --- a/internal/backend/remote-state/kubernetes/client_test.go +++ b/internal/backend/remote-state/kubernetes/client_test.go @@ -82,7 +82,7 @@ func TestForceUnlock(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err := s1.Lock(info) + lockID, err := s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -93,7 +93,7 @@ func TestForceUnlock(t *testing.T) { t.Fatal("failed to get default state to force unlock:", err) } - if err := s2.Unlock(lockID); err != nil { + if err := s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock default state") } @@ -108,7 +108,7 @@ func TestForceUnlock(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err = s1.Lock(info) + lockID, err = s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -119,7 +119,7 @@ func TestForceUnlock(t *testing.T) { t.Fatal("failed to get named state to force unlock:", err) } - if err = s2.Unlock(lockID); err != nil { + if err = s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock named state") } } diff --git a/internal/backend/remote-state/oss/backend_state.go b/internal/backend/remote-state/oss/backend_state.go index ca0212ab57..070901e74b 100644 --- a/internal/backend/remote-state/oss/backend_state.go +++ b/internal/backend/remote-state/oss/backend_state.go @@ -102,7 +102,7 @@ func (b *Backend) Workspaces(context.Context) ([]string, error) { return result, nil } -func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error { +func (b *Backend) DeleteWorkspace(ctx context.Context, name string, _ bool) error { if name == backend.DefaultStateName || name == "" { return fmt.Errorf("can't delete default state") } @@ -111,7 +111,7 @@ func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error if err != nil { return err } - return client.Delete() + return client.Delete(ctx) } func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, error) { @@ -141,21 +141,21 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err // take a lock on this state while we write it lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockId, err := client.Lock(lockInfo) + lockId, err := client.Lock(context.TODO(), lockInfo) if err != nil { return nil, fmt.Errorf("failed to lock OSS state: %w", err) } // Local helper function so we can call it multiple places lockUnlock := func(e error) error { - if err := stateMgr.Unlock(lockId); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockId); err != nil { return fmt.Errorf(strings.TrimSpace(stateUnlockError), lockId, err) } return e } // Grab the value - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { err = lockUnlock(err) return nil, err } @@ -166,7 +166,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(context.TODO(), nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/oss/client.go b/internal/backend/remote-state/oss/client.go index 7788016dbe..db77e51bfb 100644 --- a/internal/backend/remote-state/oss/client.go +++ b/internal/backend/remote-state/oss/client.go @@ -7,6 +7,7 @@ package oss import ( "bytes" + "context" "crypto/md5" "encoding/hex" "encoding/json" @@ -55,7 +56,7 @@ type RemoteClient struct { otsTable string } -func (c *RemoteClient) Get() (payload *remote.Payload, err error) { +func (c *RemoteClient) Get(_ context.Context) (payload *remote.Payload, err error) { deadline := time.Now().Add(consistencyRetryTimeout) // If we have a checksum, and the returned payload doesn't match, we retry @@ -98,7 +99,7 @@ func (c *RemoteClient) Get() (payload *remote.Payload, err error) { return payload, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(_ context.Context, data []byte) error { bucket, err := c.ossClient.Bucket(c.bucketName) if err != nil { return fmt.Errorf("error getting bucket: %w", err) @@ -131,7 +132,7 @@ func (c *RemoteClient) Put(data []byte) error { return nil } -func (c *RemoteClient) Delete() error { +func (c *RemoteClient) Delete(_ context.Context) error { bucket, err := c.ossClient.Bucket(c.bucketName) if err != nil { return fmt.Errorf("error getting bucket %s: %w", c.bucketName, err) @@ -149,7 +150,7 @@ func (c *RemoteClient) Delete() error { return nil } -func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *RemoteClient) Lock(_ context.Context, info *statemgr.LockInfo) (string, error) { if c.otsTable == "" { return "", nil } @@ -360,7 +361,7 @@ func (c *RemoteClient) getLockInfo() (*statemgr.LockInfo, error) { } return lockInfo, nil } -func (c *RemoteClient) Unlock(id string) error { +func (c *RemoteClient) Unlock(_ context.Context, id string) error { if c.otsTable == "" { return nil } diff --git a/internal/backend/remote-state/oss/client_test.go b/internal/backend/remote-state/oss/client_test.go index 8414d3e97a..07dd26462f 100644 --- a/internal/backend/remote-state/oss/client_test.go +++ b/internal/backend/remote-state/oss/client_test.go @@ -123,7 +123,7 @@ func TestRemoteClientLocks_multipleStates(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := s1.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := s1.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatal("failed to get lock for s1:", err) } @@ -132,7 +132,7 @@ func TestRemoteClientLocks_multipleStates(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := s2.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := s2.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatal("failed to get lock for s2:", err) } } @@ -175,7 +175,7 @@ func TestRemoteForceUnlock(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err := s1.Lock(info) + lockID, err := s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -186,7 +186,7 @@ func TestRemoteForceUnlock(t *testing.T) { t.Fatal("failed to get default state to force unlock:", err) } - if err := s2.Unlock(lockID); err != nil { + if err := s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock default state") } @@ -201,7 +201,7 @@ func TestRemoteForceUnlock(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err = s1.Lock(info) + lockID, err = s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -212,7 +212,7 @@ func TestRemoteForceUnlock(t *testing.T) { t.Fatal("failed to get named state to force unlock:", err) } - if err = s2.Unlock(lockID); err != nil { + if err = s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock named state") } } @@ -318,22 +318,22 @@ func TestRemoteClient_stateChecksum(t *testing.T) { client2 := s2.(*remote.State).Client // write the new state through client2 so that there is no checksum yet - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(t.Context(), newState.Bytes()); err != nil { t.Fatal(err) } // verify that we can pull a state without a checksum - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(t.Context()); err != nil { t.Fatal(err) } // write the new state back with its checksum - if err := client1.Put(newState.Bytes()); err != nil { + if err := client1.Put(t.Context(), newState.Bytes()); err != nil { t.Fatal(err) } // put an empty state in place to check for panics during get - if err := client2.Put([]byte{}); err != nil { + if err := client2.Put(t.Context(), []byte{}); err != nil { t.Fatal(err) } @@ -349,24 +349,24 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // fetching an empty state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(t.Context()); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // put the old state in place of the new, without updating the checksum - if err := client2.Put(oldState.Bytes()); err != nil { + if err := client2.Put(t.Context(), oldState.Bytes()); err != nil { t.Fatal(err) } // fetching the wrong state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(t.Context()); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // update the state with the correct one after we Get again testChecksumHook = func() { - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(t.Context(), newState.Bytes()); err != nil { t.Fatal(err) } testChecksumHook = nil @@ -377,7 +377,7 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // this final Get will fail to fail the checksum verification, the above // callback will update the state with the correct version, and Get should // retry automatically. - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(t.Context()); err != nil { t.Fatal(err) } } diff --git a/internal/backend/remote-state/pg/backend_state.go b/internal/backend/remote-state/pg/backend_state.go index 93b48bc4d2..d8ea88a316 100644 --- a/internal/backend/remote-state/pg/backend_state.go +++ b/internal/backend/remote-state/pg/backend_state.go @@ -92,14 +92,14 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err if !exists { lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockId, err := stateMgr.Lock(lockInfo) + lockId, err := stateMgr.Lock(context.TODO(), lockInfo) if err != nil { return nil, fmt.Errorf("failed to lock state in Postgres: %w", err) } // Local helper function so we can call it multiple places lockUnlock := func(parent error) error { - if err := stateMgr.Unlock(lockId); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockId); err != nil { return fmt.Errorf("error unlocking Postgres state: %w", err) } return parent @@ -110,7 +110,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(context.TODO(), nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/pg/backend_test.go b/internal/backend/remote-state/pg/backend_test.go index 73eb8f1103..4fcd4e9bdd 100644 --- a/internal/backend/remote-state/pg/backend_test.go +++ b/internal/backend/remote-state/pg/backend_test.go @@ -630,48 +630,48 @@ func TestBackendConcurrentLock(t *testing.T) { // First we need to create the workspace as the lock for creating them is // global - lockID1, err := s1.Lock(i1) + lockID1, err := s1.Lock(t.Context(), i1) if err != nil { t.Fatalf("failed to lock first state: %v", err) } - if err = s1.PersistState(nil); err != nil { + if err = s1.PersistState(t.Context(), nil); err != nil { t.Fatalf("failed to persist state: %v", err) } - if err = s1.Unlock(lockID1); err != nil { + if err = s1.Unlock(t.Context(), lockID1); err != nil { t.Fatalf("failed to unlock first state: %v", err) } - lockID2, err := s2.Lock(i2) + lockID2, err := s2.Lock(t.Context(), i2) if err != nil { t.Fatalf("failed to lock second state: %v", err) } - if err = s2.PersistState(nil); err != nil { + if err = s2.PersistState(t.Context(), nil); err != nil { t.Fatalf("failed to persist state: %v", err) } - if err = s2.Unlock(lockID2); err != nil { + if err = s2.Unlock(t.Context(), lockID2); err != nil { t.Fatalf("failed to unlock first state: %v", err) } // Now we can test concurrent lock - lockID1, err = s1.Lock(i1) + lockID1, err = s1.Lock(t.Context(), i1) if err != nil { t.Fatalf("failed to lock first state: %v", err) } - lockID2, err = s2.Lock(i2) + lockID2, err = s2.Lock(t.Context(), i2) if err != nil { t.Fatalf("failed to lock second state: %v", err) } - if err = s1.Unlock(lockID1); err != nil { + if err = s1.Unlock(t.Context(), lockID1); err != nil { t.Fatalf("failed to unlock first state: %v", err) } - if err = s2.Unlock(lockID2); err != nil { + if err = s2.Unlock(t.Context(), lockID2); err != nil { t.Fatalf("failed to unlock first state: %v", err) } } diff --git a/internal/backend/remote-state/pg/client.go b/internal/backend/remote-state/pg/client.go index 4db12c15b3..fc548d7f1b 100644 --- a/internal/backend/remote-state/pg/client.go +++ b/internal/backend/remote-state/pg/client.go @@ -6,6 +6,7 @@ package pg import ( + "context" "crypto/md5" "database/sql" "fmt" @@ -29,7 +30,7 @@ type RemoteClient struct { info *statemgr.LockInfo } -func (c *RemoteClient) Get() (*remote.Payload, error) { +func (c *RemoteClient) Get(_ context.Context) (*remote.Payload, error) { query := fmt.Sprintf(`SELECT data FROM %s.%s WHERE name = $1`, pq.QuoteIdentifier(c.SchemaName), pq.QuoteIdentifier(c.TableName)) row := c.Client.QueryRow(query, c.Name) var data []byte @@ -49,7 +50,7 @@ func (c *RemoteClient) Get() (*remote.Payload, error) { } } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(_ context.Context, data []byte) error { query := fmt.Sprintf(`INSERT INTO %s.%s (name, data) VALUES ($1, $2) ON CONFLICT (name) DO UPDATE SET data = $2 WHERE %s.name = $1`, pq.QuoteIdentifier(c.SchemaName), pq.QuoteIdentifier(c.TableName), pq.QuoteIdentifier(c.TableName)) @@ -60,7 +61,7 @@ func (c *RemoteClient) Put(data []byte) error { return nil } -func (c *RemoteClient) Delete() error { +func (c *RemoteClient) Delete(_ context.Context) error { query := fmt.Sprintf(`DELETE FROM %s.%s WHERE name = $1`, pq.QuoteIdentifier(c.SchemaName), pq.QuoteIdentifier(c.TableName)) _, err := c.Client.Exec(query, c.Name) if err != nil { @@ -69,7 +70,7 @@ func (c *RemoteClient) Delete() error { return nil } -func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *RemoteClient) Lock(_ context.Context, info *statemgr.LockInfo) (string, error) { var err error var lockID string @@ -137,7 +138,7 @@ func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { return info.ID, nil } -func (c *RemoteClient) Unlock(id string) error { +func (c *RemoteClient) Unlock(_ context.Context, id string) error { if c.info != nil && c.info.Path != "" { query := `SELECT pg_advisory_unlock($1)` row := c.Client.QueryRow(query, c.info.Path) diff --git a/internal/backend/remote-state/pg/client_test.go b/internal/backend/remote-state/pg/client_test.go index 065a53859e..7eea047283 100644 --- a/internal/backend/remote-state/pg/client_test.go +++ b/internal/backend/remote-state/pg/client_test.go @@ -175,10 +175,10 @@ func TestConcurrentCreationLocksInDifferentSchemas(t *testing.T) { // Those calls with empty database must think they are locking // for workspace creation, both of them must succeed since they // are operating on different schemas. - if _, err = firstClient.Lock(lock); err != nil { + if _, err = firstClient.Lock(t.Context(), lock); err != nil { t.Fatal(err) } - if _, err = secondClient.Lock(lock); err != nil { + if _, err = secondClient.Lock(t.Context(), lock); err != nil { t.Fatal(err) } @@ -186,7 +186,7 @@ func TestConcurrentCreationLocksInDifferentSchemas(t *testing.T) { // lock as the first client. We need to make this call from a // separate session, since advisory locks are okay to be re-acquired // during the same session. - if _, err = thirdClient.Lock(lock); err == nil { + if _, err = thirdClient.Lock(t.Context(), lock); err == nil { t.Fatal("Expected an error to be thrown on a second lock attempt") } else if lockErr := err.(*statemgr.LockError); lockErr.Info != lock && //nolint:errcheck,errorlint // this is a test, I am fine with panic here lockErr.Err.Error() != "Already locked for workspace creation: default" { @@ -277,10 +277,10 @@ func TestConcurrentCreationLocksInDifferentTables(t *testing.T) { // Those calls with empty database must think they are locking // for workspace creation, both of them must succeed since they // are operating on different schemas. - if _, err = firstClient.Lock(lock); err != nil { + if _, err = firstClient.Lock(t.Context(), lock); err != nil { t.Fatal(err) } - if _, err = secondClient.Lock(lock); err != nil { + if _, err = secondClient.Lock(t.Context(), lock); err != nil { t.Fatal(err) } @@ -288,7 +288,7 @@ func TestConcurrentCreationLocksInDifferentTables(t *testing.T) { // lock as the first client. We need to make this call from a // separate session, since advisory locks are okay to be re-acquired // during the same session. - if _, err = thirdClient.Lock(lock); err == nil { + if _, err = thirdClient.Lock(t.Context(), lock); err == nil { t.Fatal("Expected an error to be thrown on a second lock attempt") } else if lockErr := err.(*statemgr.LockError); lockErr.Info != lock && //nolint:errcheck // this is a test, I am fine with panic here lockErr.Err.Error() != "Already locked for workspace creation: default" { diff --git a/internal/backend/remote-state/s3/backend_state.go b/internal/backend/remote-state/s3/backend_state.go index ae0c20f16e..bd1b2b3a1d 100644 --- a/internal/backend/remote-state/s3/backend_state.go +++ b/internal/backend/remote-state/s3/backend_state.go @@ -113,7 +113,7 @@ func (b *Backend) keyEnv(key string) string { return parts[0] } -func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error { +func (b *Backend) DeleteWorkspace(ctx context.Context, name string, _ bool) error { if name == backend.DefaultStateName || name == "" { return fmt.Errorf("can't delete default state") } @@ -123,7 +123,7 @@ func (b *Backend) DeleteWorkspace(_ context.Context, name string, _ bool) error return err } - return client.Delete() + return client.Delete(ctx) } // get a remote client configured for this state @@ -182,14 +182,14 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err // take a lock on this state while we write it lockInfo := statemgr.NewLockInfo() lockInfo.Operation = "init" - lockId, err := client.Lock(lockInfo) + lockId, err := client.Lock(context.TODO(), lockInfo) if err != nil { return nil, fmt.Errorf("failed to lock s3 state: %w", err) } // Local helper function so we can call it multiple places lockUnlock := func(parent error) error { - if err := stateMgr.Unlock(lockId); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockId); err != nil { return fmt.Errorf(strings.TrimSpace(errStateUnlock), lockId, err) } return parent @@ -198,7 +198,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err // Grab the value // This is to ensure that no one beat us to writing a state between // the `exists` check and taking the lock. - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { err = lockUnlock(err) return nil, err } @@ -209,7 +209,7 @@ func (b *Backend) StateMgr(ctx context.Context, name string) (statemgr.Full, err err = lockUnlock(err) return nil, err } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(context.TODO(), nil); err != nil { err = lockUnlock(err) return nil, err } diff --git a/internal/backend/remote-state/s3/backend_test.go b/internal/backend/remote-state/s3/backend_test.go index 9a4bada552..88e7036c21 100644 --- a/internal/backend/remote-state/s3/backend_test.go +++ b/internal/backend/remote-state/s3/backend_test.go @@ -1283,7 +1283,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(s1); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } @@ -1295,7 +1295,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr2.WriteState(s2); err != nil { t.Fatal(err) } - if err := stateMgr2.PersistState(nil); err != nil { + if err := stateMgr2.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } @@ -1310,7 +1310,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } if err := checkStateList(t.Context(), b, []string{"default", "s1", "s2"}); err != nil { @@ -1322,7 +1322,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } if err := checkStateList(t.Context(), b, []string{"default", "s1", "s2"}); err != nil { @@ -1330,7 +1330,7 @@ func TestBackendExtraPaths(t *testing.T) { } // remove the state with extra subkey - if err := client.Delete(); err != nil { + if err := client.Delete(t.Context()); err != nil { t.Fatal(err) } @@ -1348,7 +1348,7 @@ func TestBackendExtraPaths(t *testing.T) { if err != nil { t.Fatal(err) } - if err := s2Mgr.RefreshState(); err != nil { + if err := s2Mgr.RefreshState(t.Context()); err != nil { t.Fatal(err) } @@ -1363,7 +1363,7 @@ func TestBackendExtraPaths(t *testing.T) { if err := stateMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := stateMgr.PersistState(nil); err != nil { + if err := stateMgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } @@ -1372,7 +1372,7 @@ func TestBackendExtraPaths(t *testing.T) { if err != nil { t.Fatal(err) } - if err := s2Mgr.RefreshState(); err != nil { + if err := s2Mgr.RefreshState(t.Context()); err != nil { t.Fatal(err) } @@ -1405,7 +1405,7 @@ func TestBackendPrefixInWorkspace(t *testing.T) { if err != nil { t.Fatal(err) } - if err := sMgr.RefreshState(); err != nil { + if err := sMgr.RefreshState(t.Context()); err != nil { t.Fatal(err) } diff --git a/internal/backend/remote-state/s3/client.go b/internal/backend/remote-state/s3/client.go index da54b1b5b4..371510abd9 100644 --- a/internal/backend/remote-state/s3/client.go +++ b/internal/backend/remote-state/s3/client.go @@ -69,8 +69,7 @@ var ( // test hook called when checksums don't match var testChecksumHook func() -func (c *RemoteClient) Get() (payload *remote.Payload, err error) { - ctx := context.TODO() +func (c *RemoteClient) Get(ctx context.Context) (payload *remote.Payload, err error) { deadline := time.Now().Add(consistencyRetryTimeout) // If we have a checksum, and the returned payload doesn't match, we retry @@ -194,7 +193,7 @@ func (c *RemoteClient) get(ctx context.Context) (*remote.Payload, error) { return payload, nil } -func (c *RemoteClient) Put(data []byte) error { +func (c *RemoteClient) Put(ctx context.Context, data []byte) error { contentLength := int64(len(data)) i := &s3.PutObjectInput{ @@ -208,7 +207,7 @@ func (c *RemoteClient) Put(data []byte) error { c.configurePutObjectChecksum(data, i) c.configurePutObjectEncryption(i) c.configurePutObjectACL(i) - ctx := context.TODO() + ctx, _ = attachLoggerToContext(ctx) log.Printf("[DEBUG] Uploading remote state to S3: %#v", i) @@ -227,8 +226,7 @@ func (c *RemoteClient) Put(data []byte) error { return nil } -func (c *RemoteClient) Delete() error { - ctx := context.TODO() +func (c *RemoteClient) Delete(ctx context.Context) error { ctx, _ = attachLoggerToContext(ctx) _, err := c.s3Client.DeleteObject(ctx, &s3.DeleteObjectInput{ @@ -247,7 +245,7 @@ func (c *RemoteClient) Delete() error { return nil } -func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { +func (c *RemoteClient) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { if !c.IsLockingEnabled() { return "", nil } @@ -261,12 +259,12 @@ func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { } info.Path = c.lockPath() - if err := c.s3Lock(info); err != nil { + if err := c.s3Lock(ctx, info); err != nil { return "", err } - if err := c.dynamoDBLock(info); err != nil { + if err := c.dynamoDBLock(ctx, info); err != nil { // when the second lock fails from getting acquired, release the initially acquired one - if uErr := c.s3Unlock(info.ID); uErr != nil { + if uErr := c.s3Unlock(ctx, info.ID); uErr != nil { log.Printf("[WARN] failed to release the S3 lock after failed to acquire the dynamoDD lock: %v", uErr) } return "", err @@ -275,7 +273,7 @@ func (c *RemoteClient) Lock(info *statemgr.LockInfo) (string, error) { } // dynamoDBLock expects the statemgr.LockInfo#ID to be filled already -func (c *RemoteClient) dynamoDBLock(info *statemgr.LockInfo) error { +func (c *RemoteClient) dynamoDBLock(ctx context.Context, info *statemgr.LockInfo) error { if c.ddbTable == "" { return nil } @@ -289,7 +287,6 @@ func (c *RemoteClient) dynamoDBLock(info *statemgr.LockInfo) error { ConditionExpression: aws.String("attribute_not_exists(LockID)"), } - ctx := context.TODO() _, err := c.dynClient.PutItem(ctx, putParams) if err != nil { lockInfo, infoErr := c.getLockInfoFromDynamoDB(ctx) @@ -308,7 +305,7 @@ func (c *RemoteClient) dynamoDBLock(info *statemgr.LockInfo) error { } // s3Lock expects the statemgr.LockInfo#ID to be filled already -func (c *RemoteClient) s3Lock(info *statemgr.LockInfo) error { +func (c *RemoteClient) s3Lock(ctx context.Context, info *statemgr.LockInfo) error { if !c.useLockfile { return nil } @@ -325,7 +322,7 @@ func (c *RemoteClient) s3Lock(info *statemgr.LockInfo) error { c.configurePutObjectChecksum(lInfo, putParams) c.configurePutObjectEncryption(putParams) c.configurePutObjectACL(putParams) - ctx := context.TODO() + ctx, _ = attachLoggerToContext(ctx) log.Printf("[DEBUG] Uploading s3 locking object: %#v", putParams) @@ -488,11 +485,11 @@ func (c *RemoteClient) getLockInfoFromS3(ctx context.Context) (*statemgr.LockInf return lockInfo, nil } -func (c *RemoteClient) Unlock(id string) error { +func (c *RemoteClient) Unlock(ctx context.Context, id string) error { // Attempt to release the lock from both sources. // We want to do so to be sure that we are leaving no locks unhandled - s3Err := c.s3Unlock(id) - dynamoDBErr := c.dynamoDBUnlock(id) + s3Err := c.s3Unlock(ctx, id) + dynamoDBErr := c.dynamoDBUnlock(ctx, id) switch { case s3Err != nil && dynamoDBErr != nil: s3Err.Err = multierror.Append(s3Err.Err, dynamoDBErr.Err) @@ -511,12 +508,11 @@ func (c *RemoteClient) Unlock(id string) error { return nil } -func (c *RemoteClient) s3Unlock(id string) *statemgr.LockError { +func (c *RemoteClient) s3Unlock(ctx context.Context, id string) *statemgr.LockError { if !c.useLockfile { return nil } lockErr := &statemgr.LockError{} - ctx := context.TODO() ctx, _ = attachLoggerToContext(ctx) lockInfo, err := c.getLockInfoFromS3(ctx) @@ -544,13 +540,12 @@ func (c *RemoteClient) s3Unlock(id string) *statemgr.LockError { return nil } -func (c *RemoteClient) dynamoDBUnlock(id string) *statemgr.LockError { +func (c *RemoteClient) dynamoDBUnlock(ctx context.Context, id string) *statemgr.LockError { if c.ddbTable == "" { return nil } lockErr := &statemgr.LockError{} - ctx := context.TODO() lockInfo, err := c.getLockInfoFromDynamoDB(ctx) if err != nil { diff --git a/internal/backend/remote-state/s3/client_test.go b/internal/backend/remote-state/s3/client_test.go index af4fed6366..9e72570268 100644 --- a/internal/backend/remote-state/s3/client_test.go +++ b/internal/backend/remote-state/s3/client_test.go @@ -215,7 +215,7 @@ func TestRemoteS3AndDynamoDBClientLocksWithNoDBInstance(t *testing.T) { infoA.Operation = "test" infoA.Who = "clientA" - if _, err := s1.Lock(infoA); err == nil { + if _, err := s1.Lock(t.Context(), infoA); err == nil { t.Fatal("unexpected successful lock: ", err) } @@ -260,7 +260,7 @@ func TestForceUnlock(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err := s1.Lock(info) + lockID, err := s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -271,7 +271,7 @@ func TestForceUnlock(t *testing.T) { t.Fatal("failed to get default state to force unlock:", err) } - if err := s2.Unlock(lockID); err != nil { + if err := s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock default state") } @@ -286,7 +286,7 @@ func TestForceUnlock(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err = s1.Lock(info) + lockID, err = s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -297,7 +297,7 @@ func TestForceUnlock(t *testing.T) { t.Fatal("failed to get named state to force unlock:", err) } - if err = s2.Unlock(lockID); err != nil { + if err = s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock named state") } @@ -307,7 +307,7 @@ func TestForceUnlock(t *testing.T) { if err != nil { t.Fatal(err) } - err = s2.Unlock(lockID) + err = s2.Unlock(t.Context(), lockID) if err == nil { t.Fatal("expected an error to occur:", err) } @@ -352,7 +352,7 @@ func TestForceUnlockS3Only(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err := s1.Lock(info) + lockID, err := s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -363,7 +363,7 @@ func TestForceUnlockS3Only(t *testing.T) { t.Fatal("failed to get default state to force unlock:", err) } - if err = s2.Unlock(lockID); err != nil { + if err = s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock default state") } @@ -378,7 +378,7 @@ func TestForceUnlockS3Only(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err = s1.Lock(info) + lockID, err = s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -389,7 +389,7 @@ func TestForceUnlockS3Only(t *testing.T) { t.Fatal("failed to get named state to force unlock:", err) } - if err = s2.Unlock(lockID); err != nil { + if err = s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock named state") } @@ -399,7 +399,7 @@ func TestForceUnlockS3Only(t *testing.T) { if err != nil { t.Fatal(err) } - err = s2.Unlock(lockID) + err = s2.Unlock(t.Context(), lockID) if err == nil { t.Fatal("expected an error to occur:", err) } @@ -455,7 +455,7 @@ func TestForceUnlockS3AndDynamo(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err := s1.Lock(info) + lockID, err := s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -466,7 +466,7 @@ func TestForceUnlockS3AndDynamo(t *testing.T) { t.Fatal("failed to get default state to force unlock:", err) } - if err = s2.Unlock(lockID); err != nil { + if err = s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock default state") } @@ -481,7 +481,7 @@ func TestForceUnlockS3AndDynamo(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err = s1.Lock(info) + lockID, err = s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -492,7 +492,7 @@ func TestForceUnlockS3AndDynamo(t *testing.T) { t.Fatal("failed to get named state to force unlock:", err) } - if err = s2.Unlock(lockID); err != nil { + if err = s2.Unlock(t.Context(), lockID); err != nil { t.Fatal("failed to force-unlock named state") } @@ -502,7 +502,7 @@ func TestForceUnlockS3AndDynamo(t *testing.T) { if err != nil { t.Fatal(err) } - err = s2.Unlock(lockID) + err = s2.Unlock(t.Context(), lockID) if err == nil { t.Fatal("expected an error to occur:", err) } @@ -546,7 +546,7 @@ func TestForceUnlockS3WithAndDynamoWithout(t *testing.T) { info.Operation = "test" info.Who = "clientA" - lockID, err := s1.Lock(info) + lockID, err := s1.Lock(t.Context(), info) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -554,7 +554,7 @@ func TestForceUnlockS3WithAndDynamoWithout(t *testing.T) { // Remove the dynamo lock to simulate that the lock in s3 was acquired, dynamo failed but s3 release failed in the end. // Therefore, the user is left in the situation with s3 lock existing and dynamo missing. deleteDynamoEntry(t.Context(), t, b1.dynClient, bucketName, info.Path) - err = s1.Unlock(lockID) + err = s1.Unlock(t.Context(), lockID) if err == nil { t.Fatal("expected to get an error but got nil") } @@ -564,7 +564,7 @@ func TestForceUnlockS3WithAndDynamoWithout(t *testing.T) { } // Now, unlocking should fail with error on both locks - err = s1.Unlock(lockID) + err = s1.Unlock(t.Context(), lockID) if err == nil { t.Fatal("expected to get an error but got nil") } @@ -676,22 +676,22 @@ func TestRemoteClient_stateChecksum(t *testing.T) { client2 := s2.(*remote.State).Client // write the new state through client2 so that there is no checksum yet - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(t.Context(), newState.Bytes()); err != nil { t.Fatal(err) } // verify that we can pull a state without a checksum - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(t.Context()); err != nil { t.Fatal(err) } // write the new state back with its checksum - if err := client1.Put(newState.Bytes()); err != nil { + if err := client1.Put(t.Context(), newState.Bytes()); err != nil { t.Fatal(err) } // put an empty state in place to check for panics during get - if err := client2.Put([]byte{}); err != nil { + if err := client2.Put(t.Context(), []byte{}); err != nil { t.Fatal(err) } @@ -707,24 +707,24 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // fetching an empty state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(t.Context()); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // put the old state in place of the new, without updating the checksum - if err := client2.Put(oldState.Bytes()); err != nil { + if err := client2.Put(t.Context(), oldState.Bytes()); err != nil { t.Fatal(err) } // fetching the wrong state through client1 should now error out due to a // mismatched checksum. - if _, err := client1.Get(); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { + if _, err := client1.Get(t.Context()); !strings.HasPrefix(err.Error(), errBadChecksumFmt[:80]) { t.Fatalf("expected state checksum error: got %s", err) } // update the state with the correct one after we Get again testChecksumHook = func() { - if err := client2.Put(newState.Bytes()); err != nil { + if err := client2.Put(t.Context(), newState.Bytes()); err != nil { t.Fatal(err) } testChecksumHook = nil @@ -735,7 +735,7 @@ func TestRemoteClient_stateChecksum(t *testing.T) { // this final Get will fail to fail the checksum verification, the above // callback will update the state with the correct version, and Get should // retry automatically. - if _, err := client1.Get(); err != nil { + if _, err := client1.Get(t.Context()); err != nil { t.Fatal(err) } } @@ -821,7 +821,7 @@ func TestS3ChecksumsHeaders(t *testing.T) { name: "s3.Put with included checksum", skipChecksum: false, action: func(cl *RemoteClient) error { - return cl.Put([]byte("test")) + return cl.Put(t.Context(), []byte("test")) }, wantMissingHeaders: []string{ "X-Amz-Checksum-Mode", @@ -837,7 +837,7 @@ func TestS3ChecksumsHeaders(t *testing.T) { name: "s3.Put with skipped checksum", skipChecksum: true, action: func(cl *RemoteClient) error { - return cl.Put([]byte("test")) + return cl.Put(t.Context(), []byte("test")) }, wantMissingHeaders: []string{ "X-Amz-Checksum-Mode", @@ -854,7 +854,7 @@ func TestS3ChecksumsHeaders(t *testing.T) { name: "s3.HeadObject and s3.GetObject with included checksum", skipChecksum: false, action: func(cl *RemoteClient) error { - _, err := cl.Get() + _, err := cl.Get(t.Context()) return err }, wantMissingHeaders: []string{ @@ -871,7 +871,7 @@ func TestS3ChecksumsHeaders(t *testing.T) { name: "s3.HeadObject and s3.GetObject with skipped checksum", skipChecksum: true, action: func(cl *RemoteClient) error { - _, err := cl.Get() + _, err := cl.Get(t.Context()) return err }, wantMissingHeaders: []string{ @@ -939,7 +939,7 @@ func TestS3LockingWritingHeaders(t *testing.T) { ) // get the request from state writing { - err := rc.Put([]byte("test")) + err := rc.Put(t.Context(), []byte("test")) if err != nil { t.Fatalf("expected to have no error writing the state object but got one: %s", err) } @@ -950,7 +950,7 @@ func TestS3LockingWritingHeaders(t *testing.T) { } // get the request from lock object writing { - err := rc.s3Lock(&statemgr.LockInfo{Info: "test"}) + err := rc.s3Lock(t.Context(), &statemgr.LockInfo{Info: "test"}) if err != nil { t.Fatalf("expected to have no error writing the lock object but got one: %s", err) } diff --git a/internal/backend/remote/backend.go b/internal/backend/remote/backend.go index bc6e0b5f6d..23d84dc4a5 100644 --- a/internal/backend/remote/backend.go +++ b/internal/backend/remote/backend.go @@ -557,7 +557,7 @@ func (b *Remote) DeleteWorkspace(_ context.Context, name string, _ bool) error { encryption: b.encryption, } - return client.Delete() + return client.Delete(context.TODO()) } // StateMgr implements backend.Enhanced. diff --git a/internal/backend/remote/backend_apply_test.go b/internal/backend/remote/backend_apply_test.go index 0dea0eb592..bf7c29c289 100644 --- a/internal/backend/remote/backend_apply_test.go +++ b/internal/backend/remote/backend_apply_test.go @@ -112,7 +112,7 @@ func TestRemote_applyBasic(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after apply: %s", err.Error()) } } @@ -140,7 +140,7 @@ func TestRemote_applyCanceled(t *testing.T) { } stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after cancelling apply: %s", err.Error()) } } @@ -488,7 +488,7 @@ func TestRemote_applyWithExclude(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after failed apply: %s", err.Error()) } } @@ -654,7 +654,7 @@ func TestRemote_applyNoConfig(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after failed apply: %s", err.Error()) } } diff --git a/internal/backend/remote/backend_context.go b/internal/backend/remote/backend_context.go index bfb0edebc3..9ed8afbbd5 100644 --- a/internal/backend/remote/backend_context.go +++ b/internal/backend/remote/backend_context.go @@ -64,7 +64,7 @@ func (b *Remote) LocalRun(ctx context.Context, op *backend.Operation) (*backend. }() log.Printf("[TRACE] backend/remote: reading remote state for workspace %q", remoteWorkspaceName) - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf("error loading state: %w", err)) return nil, nil, diags } diff --git a/internal/backend/remote/backend_context_test.go b/internal/backend/remote/backend_context_test.go index 022e0749e7..a4599f391b 100644 --- a/internal/backend/remote/backend_context_test.go +++ b/internal/backend/remote/backend_context_test.go @@ -228,7 +228,7 @@ func TestRemoteContextWithVars(t *testing.T) { // When Context() returns an error, it should unlock the state, // so re-locking it is expected to succeed. stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state: %s", err.Error()) } } else { @@ -237,7 +237,7 @@ func TestRemoteContextWithVars(t *testing.T) { } // When Context() succeeds, this should fail w/ "workspace already locked" stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err == nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err == nil { t.Fatal("unexpected success locking state after Context") } } @@ -445,7 +445,7 @@ func TestRemoteVariablesDoNotOverride(t *testing.T) { } // When Context() succeeds, this should fail w/ "workspace already locked" stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err == nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err == nil { t.Fatal("unexpected success locking state after Context") } diff --git a/internal/backend/remote/backend_plan_test.go b/internal/backend/remote/backend_plan_test.go index 07428b9c28..bd44c9fe66 100644 --- a/internal/backend/remote/backend_plan_test.go +++ b/internal/backend/remote/backend_plan_test.go @@ -97,7 +97,7 @@ func TestRemote_planBasic(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) // An error suggests that the state was not unlocked after the operation finished - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after successful plan: %s", err.Error()) } } @@ -126,7 +126,7 @@ func TestRemote_planCanceled(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), backend.DefaultStateName) // An error suggests that the state was not unlocked after the operation finished - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after cancelled plan: %s", err.Error()) } } diff --git a/internal/backend/remote/backend_state.go b/internal/backend/remote/backend_state.go index fe2c81a6c8..69caf71f82 100644 --- a/internal/backend/remote/backend_state.go +++ b/internal/backend/remote/backend_state.go @@ -36,9 +36,7 @@ type remoteClient struct { } // Get the remote state. -func (r *remoteClient) Get() (*remote.Payload, error) { - ctx := context.Background() - +func (r *remoteClient) Get(ctx context.Context) (*remote.Payload, error) { sv, err := r.client.StateVersions.ReadCurrent(ctx, r.workspace.ID) if err != nil { if err == tfe.ErrResourceNotFound { @@ -93,9 +91,7 @@ func (r *remoteClient) uploadStateFallback(ctx context.Context, stateFile *state } // Put the remote state. -func (r *remoteClient) Put(state []byte) error { - ctx := context.Background() - +func (r *remoteClient) Put(ctx context.Context, state []byte) error { // Read the raw state into a OpenTofu state. stateFile, err := statefile.Read(bytes.NewReader(state), r.encryption) if err != nil { @@ -145,8 +141,8 @@ func (r *remoteClient) Put(state []byte) error { } // Delete the remote state. -func (r *remoteClient) Delete() error { - err := r.client.Workspaces.Delete(context.Background(), r.organization, r.workspace.Name) +func (r *remoteClient) Delete(ctx context.Context) error { + err := r.client.Workspaces.Delete(ctx, r.organization, r.workspace.Name) if err != nil && err != tfe.ErrResourceNotFound { return fmt.Errorf("error deleting workspace %s: %w", r.workspace.Name, err) } @@ -161,9 +157,7 @@ func (r *remoteClient) EnableForcePush() { } // Lock the remote state. -func (r *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { - ctx := context.Background() - +func (r *remoteClient) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { lockErr := &statemgr.LockError{Info: r.lockInfo} // Lock the workspace. @@ -185,9 +179,7 @@ func (r *remoteClient) Lock(info *statemgr.LockInfo) (string, error) { } // Unlock the remote state. -func (r *remoteClient) Unlock(id string) error { - ctx := context.Background() - +func (r *remoteClient) Unlock(ctx context.Context, id string) error { // We first check if there was an error while uploading the latest // state. If so, we will not unlock the workspace to prevent any // changes from being applied until the correct state is uploaded. diff --git a/internal/backend/remote/backend_state_test.go b/internal/backend/remote/backend_state_test.go index 4a11eb7935..a264f0d1fd 100644 --- a/internal/backend/remote/backend_state_test.go +++ b/internal/backend/remote/backend_state_test.go @@ -60,7 +60,7 @@ func TestRemoteClient_Put_withRunID(t *testing.T) { // Store the new state to verify (this will be done // by the mock that is used) that the run ID is set. - if err := client.Put(buf.Bytes()); err != nil { + if err := client.Put(t.Context(), buf.Bytes()); err != nil { t.Fatalf("expected no error, got %v", err) } } diff --git a/internal/backend/remote/backend_test.go b/internal/backend/remote/backend_test.go index efd8d4711c..8743b598bb 100644 --- a/internal/backend/remote/backend_test.go +++ b/internal/backend/remote/backend_test.go @@ -382,13 +382,13 @@ func TestRemote_Unlock_ignoreVersion(t *testing.T) { t.Fatalf("error: %v", err) } - lockID, err := state.Lock(statemgr.NewLockInfo()) + lockID, err := state.Lock(t.Context(), statemgr.NewLockInfo()) if err != nil { t.Fatalf("error: %v", err) } // this should succeed since the version conflict is ignored - if err = state.Unlock(lockID); err != nil { + if err = state.Unlock(t.Context(), lockID); err != nil { t.Fatalf("error: %v", err) } } diff --git a/internal/backend/testing.go b/internal/backend/testing.go index b1dfd71273..68e8ff57be 100644 --- a/internal/backend/testing.go +++ b/internal/backend/testing.go @@ -165,7 +165,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatalf("error: %s", err) } - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(t.Context()); err != nil { t.Fatalf("bad: %s", err) } if v := foo.State(); v.HasManagedResourceInstanceObjects() { @@ -176,7 +176,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatalf("error: %s", err) } - if err := bar.RefreshState(); err != nil { + if err := bar.RefreshState(t.Context()); err != nil { t.Fatalf("bad: %s", err) } if v := bar.State(); v.HasManagedResourceInstanceObjects() { @@ -194,7 +194,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err := foo.WriteState(fooState); err != nil { t.Fatal("error writing foo state:", err) } - if err := foo.PersistState(nil); err != nil { + if err := foo.PersistState(t.Context(), nil); err != nil { t.Fatal("error persisting foo state:", err) } @@ -223,12 +223,12 @@ func TestBackendStates(t *testing.T, b Backend) { if err := bar.WriteState(barState); err != nil { t.Fatalf("bad: %s", err) } - if err := bar.PersistState(nil); err != nil { + if err := bar.PersistState(t.Context(), nil); err != nil { t.Fatalf("bad: %s", err) } // verify that foo is unchanged with the existing state manager - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(t.Context()); err != nil { t.Fatal("error refreshing foo:", err) } fooState = foo.State() @@ -241,7 +241,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatal("error re-fetching state:", err) } - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(t.Context()); err != nil { t.Fatal("error refreshing foo:", err) } fooState = foo.State() @@ -254,7 +254,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatal("error re-fetching state:", err) } - if err := bar.RefreshState(); err != nil { + if err := bar.RefreshState(t.Context()); err != nil { t.Fatal("error refreshing bar:", err) } barState = bar.State() @@ -298,7 +298,7 @@ func TestBackendStates(t *testing.T, b Backend) { if err != nil { t.Fatalf("error: %s", err) } - if err := foo.RefreshState(); err != nil { + if err := foo.RefreshState(t.Context()); err != nil { t.Fatalf("bad: %s", err) } if v := foo.State(); v.HasManagedResourceInstanceObjects() { @@ -371,7 +371,7 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo if err != nil { t.Fatalf("error: %s", err) } - if err := b1StateMgr.RefreshState(); err != nil { + if err := b1StateMgr.RefreshState(t.Context()); err != nil { t.Fatalf("bad: %s", err) } @@ -387,7 +387,7 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo if err != nil { t.Fatalf("error: %s", err) } - if err := b2StateMgr.RefreshState(); err != nil { + if err := b2StateMgr.RefreshState(t.Context()); err != nil { t.Fatalf("bad: %s", err) } @@ -403,7 +403,7 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo infoB.Operation = "test" infoB.Who = "clientB" - lockIDA, err := lockerA.Lock(infoA) + lockIDA, err := lockerA.Lock(t.Context(), infoA) if err != nil { t.Fatal("unable to get initial lock:", err) } @@ -422,17 +422,17 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo return } - _, err = lockerB.Lock(infoB) + _, err = lockerB.Lock(t.Context(), infoB) if err == nil { - _ = lockerA.Unlock(lockIDA) // test already failed, no need to check err further + _ = lockerA.Unlock(t.Context(), lockIDA) // test already failed, no need to check err further t.Fatal("client B obtained lock while held by client A") } - if err := lockerA.Unlock(lockIDA); err != nil { + if err := lockerA.Unlock(t.Context(), lockIDA); err != nil { t.Fatal("error unlocking client A", err) } - lockIDB, err := lockerB.Lock(infoB) + lockIDB, err := lockerB.Lock(t.Context(), infoB) if err != nil { t.Fatal("unable to obtain lock from client B") } @@ -441,7 +441,7 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo t.Errorf("duplicate lock IDs: %q", lockIDB) } - if err = lockerB.Unlock(lockIDB); err != nil { + if err = lockerB.Unlock(t.Context(), lockIDB); err != nil { t.Fatal("error unlocking client B:", err) } @@ -457,18 +457,18 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo panic(err) } - lockIDA, err = lockerA.Lock(infoA) + lockIDA, err = lockerA.Lock(t.Context(), infoA) if err != nil { t.Fatal("unable to get re lock A:", err) } unlock := func() { - err := lockerA.Unlock(lockIDA) + err := lockerA.Unlock(t.Context(), lockIDA) if err != nil { t.Fatal(err) } } - _, err = lockerB.Lock(infoB) + _, err = lockerB.Lock(t.Context(), infoB) if err == nil { unlock() t.Fatal("client B obtained lock while held by client A") @@ -481,7 +481,7 @@ func testLocksInWorkspace(t *testing.T, b1, b2 Backend, testForceUnlock bool, wo } // try to unlock with the second unlocker, using the ID from the error - if err := lockerB.Unlock(infoErr.Info.ID); err != nil { + if err := lockerB.Unlock(t.Context(), infoErr.Info.ID); err != nil { unlock() t.Fatalf("could not unlock with the reported ID %q: %s", infoErr.Info.ID, err) } diff --git a/internal/builtin/providers/tf/data_source_state.go b/internal/builtin/providers/tf/data_source_state.go index 586def1f3b..2fc542568b 100644 --- a/internal/builtin/providers/tf/data_source_state.go +++ b/internal/builtin/providers/tf/data_source_state.go @@ -146,7 +146,7 @@ func dataSourceRemoteStateRead(ctx context.Context, d cty.Value, enc encryption. return cty.NilVal, diags } - if err := state.RefreshState(); err != nil { + if err := state.RefreshState(ctx); err != nil { diags = diags.Append(err) return cty.NilVal, diags } diff --git a/internal/cloud/backend_apply_test.go b/internal/cloud/backend_apply_test.go index a79180048a..7e75d43f78 100644 --- a/internal/cloud/backend_apply_test.go +++ b/internal/cloud/backend_apply_test.go @@ -116,7 +116,7 @@ func TestCloud_applyBasic(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after apply: %s", err.Error()) } } @@ -174,7 +174,7 @@ func TestCloud_applyJSONBasic(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after apply: %s", err.Error()) } } @@ -261,7 +261,7 @@ func TestCloud_applyJSONWithOutputs(t *testing.T) { } stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after apply: %s", err.Error()) } } @@ -289,7 +289,7 @@ func TestCloud_applyCanceled(t *testing.T) { } stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after cancelling apply: %s", err.Error()) } } @@ -496,7 +496,7 @@ func TestCloud_applyWithCloudPlan(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after apply: %s", err.Error()) } } @@ -644,7 +644,7 @@ func TestCloud_applyWithExclude(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after failed apply: %s", err.Error()) } } @@ -744,7 +744,7 @@ func TestCloud_applyNoConfig(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after failed apply: %s", err.Error()) } } @@ -1405,7 +1405,7 @@ func TestCloud_applyJSONWithProvisioner(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after apply: %s", err.Error()) } } diff --git a/internal/cloud/backend_context.go b/internal/cloud/backend_context.go index 1bf012b881..82c365afbf 100644 --- a/internal/cloud/backend_context.go +++ b/internal/cloud/backend_context.go @@ -64,7 +64,7 @@ func (b *Cloud) LocalRun(ctx context.Context, op *backend.Operation) (*backend.L }() log.Printf("[TRACE] cloud: reading remote state for workspace %q", remoteWorkspaceName) - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf("error loading state: %w", err)) return nil, nil, diags } diff --git a/internal/cloud/backend_context_test.go b/internal/cloud/backend_context_test.go index 422c64d081..d7b87f2055 100644 --- a/internal/cloud/backend_context_test.go +++ b/internal/cloud/backend_context_test.go @@ -227,7 +227,7 @@ func TestRemoteContextWithVars(t *testing.T) { // When Context() returns an error, it should unlock the state, // so re-locking it is expected to succeed. stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state: %s", err.Error()) } } else { @@ -236,7 +236,7 @@ func TestRemoteContextWithVars(t *testing.T) { } // When Context() succeeds, this should fail w/ "workspace already locked" stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err == nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err == nil { t.Fatal("unexpected success locking state after Context") } } @@ -444,7 +444,7 @@ func TestRemoteVariablesDoNotOverride(t *testing.T) { } // When Context() succeeds, this should fail w/ "workspace already locked" stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err == nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err == nil { t.Fatal("unexpected success locking state after Context") } diff --git a/internal/cloud/backend_plan_test.go b/internal/cloud/backend_plan_test.go index e982f59c6b..2e0b96f5f3 100644 --- a/internal/cloud/backend_plan_test.go +++ b/internal/cloud/backend_plan_test.go @@ -100,7 +100,7 @@ func TestCloud_planBasic(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after the operation finished - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after successful plan: %s", err.Error()) } } @@ -145,7 +145,7 @@ func TestCloud_planJSONBasic(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after the operation finished - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after successful plan: %s", err.Error()) } } @@ -174,7 +174,7 @@ func TestCloud_planCanceled(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after the operation finished - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after cancelled plan: %s", err.Error()) } } @@ -254,7 +254,7 @@ func TestCloud_planJSONFull(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after the operation finished - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after successful plan: %s", err.Error()) } } @@ -1322,7 +1322,7 @@ func TestCloud_planImportConfigGeneration(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after the operation finished - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after successful plan: %s", err.Error()) } diff --git a/internal/cloud/backend_refresh_test.go b/internal/cloud/backend_refresh_test.go index 4577570787..9eda72baa2 100644 --- a/internal/cloud/backend_refresh_test.go +++ b/internal/cloud/backend_refresh_test.go @@ -78,7 +78,7 @@ func TestCloud_refreshBasicActuallyRunsApplyRefresh(t *testing.T) { stateMgr, _ := b.StateMgr(t.Context(), testBackendSingleWorkspaceName) // An error suggests that the state was not unlocked after apply - if _, err := stateMgr.Lock(statemgr.NewLockInfo()); err != nil { + if _, err := stateMgr.Lock(t.Context(), statemgr.NewLockInfo()); err != nil { t.Fatalf("unexpected error locking state after apply: %s", err.Error()) } } diff --git a/internal/cloud/state.go b/internal/cloud/state.go index 3b98be30db..5461664e5d 100644 --- a/internal/cloud/state.go +++ b/internal/cloud/state.go @@ -167,7 +167,7 @@ func (s *State) WriteState(state *states.State) error { } // PersistState uploads a snapshot of the latest state as a StateVersion to Terraform Cloud -func (s *State) PersistState(schemas *tofu.Schemas) error { +func (s *State) PersistState(ctx context.Context, schemas *tofu.Schemas) error { s.mu.Lock() defer s.mu.Unlock() @@ -187,7 +187,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { // We might be writing a new state altogether, but before we do that // we'll check to make sure there isn't already a snapshot present // that we ought to be updating. - err := s.refreshState() + err := s.refreshState(ctx) if err != nil { return fmt.Errorf("failed checking for existing remote state: %w", err) } @@ -234,7 +234,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { return fmt.Errorf("failed to marshal outputs to json: %w", err) } - err = s.uploadState(s.lineage, s.serial, s.forcePush, buf.Bytes(), jsonState, jsonStateOutputs) + err = s.uploadState(ctx, s.lineage, s.serial, s.forcePush, buf.Bytes(), jsonState, jsonStateOutputs) if err != nil { s.stateUploadErr = true return fmt.Errorf("error uploading state: %w", err) @@ -298,9 +298,7 @@ func (s *State) uploadStateFallback(ctx context.Context, lineage string, serial return err } -func (s *State) uploadState(lineage string, serial uint64, isForcePush bool, state, jsonState, jsonStateOutputs []byte) error { - ctx := context.Background() - +func (s *State) uploadState(ctx context.Context, lineage string, serial uint64, isForcePush bool, state, jsonState, jsonStateOutputs []byte) error { options := tfe.StateVersionUploadOptions{ StateVersionCreateOptions: tfe.StateVersionCreateOptions{ Lineage: tfe.String(lineage), @@ -337,14 +335,13 @@ func (s *State) uploadState(lineage string, serial uint64, isForcePush bool, sta } // Lock calls the Client's Lock method if it's implemented. -func (s *State) Lock(info *statemgr.LockInfo) (string, error) { +func (s *State) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { s.mu.Lock() defer s.mu.Unlock() if s.disableLocks { return "", nil } - ctx := context.Background() lockErr := &statemgr.LockError{Info: s.lockInfo} @@ -367,16 +364,16 @@ func (s *State) Lock(info *statemgr.LockInfo) (string, error) { } // statemgr.Refresher impl. -func (s *State) RefreshState() error { +func (s *State) RefreshState(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() - return s.refreshState() + return s.refreshState(ctx) } // refreshState is the main implementation of RefreshState, but split out so // that we can make internal calls to it from methods that are already holding // the s.mu lock. -func (s *State) refreshState() error { +func (s *State) refreshState(ctx context.Context) error { payload, err := s.getStatePayload() if err != nil { return err @@ -443,7 +440,7 @@ func (s *State) getStatePayload() (*remote.Payload, error) { } // Unlock calls the Client's Unlock method if it's implemented. -func (s *State) Unlock(id string) error { +func (s *State) Unlock(ctx context.Context, id string) error { s.mu.Lock() defer s.mu.Unlock() @@ -451,8 +448,6 @@ func (s *State) Unlock(id string) error { return nil } - ctx := context.Background() - // We first check if there was an error while uploading the latest // state. If so, we will not unlock the workspace to prevent any // changes from being applied until the correct state is uploaded. @@ -521,9 +516,7 @@ func (s *State) Delete(force bool) error { } // GetRootOutputValues fetches output values from Terraform Cloud -func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { - ctx := context.Background() - +func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { so, err := s.tfeClient.StateVersionOutputs.ReadCurrent(ctx, s.workspace.ID) if err != nil { @@ -540,7 +533,7 @@ func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { // requires a higher level of authorization. log.Printf("[DEBUG] falling back to reading full state") - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(ctx); err != nil { return nil, fmt.Errorf("failed to load state: %w", err) } diff --git a/internal/cloud/state_test.go b/internal/cloud/state_test.go index e3981e92a4..3d315da152 100644 --- a/internal/cloud/state_test.go +++ b/internal/cloud/state_test.go @@ -45,7 +45,7 @@ func TestState_GetRootOutputValues(t *testing.T) { state := &State{tfeClient: b.client, organization: b.organization, workspace: &tfe.Workspace{ ID: "ws-abcd", }, encryption: encryption.StateEncryptionDisabled()} - outputs, err := state.GetRootOutputValues() + outputs, err := state.GetRootOutputValues(t.Context()) if err != nil { t.Fatalf("error returned from GetRootOutputValues: %s", err) @@ -119,7 +119,7 @@ func TestState(t *testing.T) { } }`) - if err := state.uploadState(state.lineage, state.serial, state.forcePush, data, jsonState, jsonStateOutputs); err != nil { + if err := state.uploadState(t.Context(), state.lineage, state.serial, state.forcePush, data, jsonState, jsonStateOutputs); err != nil { t.Fatalf("put: %s", err) } @@ -175,25 +175,25 @@ func TestCloudLocks(t *testing.T) { infoB.Operation = "test" infoB.Who = "clientB" - lockIDA, err := lockerA.Lock(infoA) + lockIDA, err := lockerA.Lock(t.Context(), infoA) if err != nil { t.Fatal("unable to get initial lock:", err) } - _, err = lockerB.Lock(infoB) + _, err = lockerB.Lock(t.Context(), infoB) if err == nil { - _ = lockerA.Unlock(lockIDA) // test already failed, no need to check err further + _ = lockerA.Unlock(t.Context(), lockIDA) // test already failed, no need to check err further t.Fatal("client B obtained lock while held by client A") } if _, ok := err.(*statemgr.LockError); !ok { t.Errorf("expected a LockError, but was %t: %s", err, err) } - if err := lockerA.Unlock(lockIDA); err != nil { + if err := lockerA.Unlock(t.Context(), lockIDA); err != nil { t.Fatal("error unlocking client A", err) } - lockIDB, err := lockerB.Lock(infoB) + lockIDB, err := lockerB.Lock(t.Context(), infoB) if err != nil { t.Fatal("unable to obtain lock from client B") } @@ -202,7 +202,7 @@ func TestCloudLocks(t *testing.T) { t.Fatalf("duplicate lock IDs: %q", lockIDB) } - if err = lockerB.Unlock(lockIDB); err != nil { + if err = lockerB.Unlock(t.Context(), lockIDB); err != nil { t.Fatal("error unlocking client B:", err) } } @@ -271,7 +271,7 @@ func TestState_PersistState(t *testing.T) { t.Fatal("expected nil initial readState") } - err := cloudState.PersistState(nil) + err := cloudState.PersistState(t.Context(), nil) if err != nil { t.Fatalf("expected no error, got %q", err) } @@ -331,7 +331,7 @@ func TestState_PersistState(t *testing.T) { } cloudState.tfeClient = client - err = cloudState.RefreshState() + err = cloudState.RefreshState(t.Context()) if err != nil { t.Fatal(err) } @@ -345,7 +345,7 @@ func TestState_PersistState(t *testing.T) { t.Fatal(err) } - err = cloudState.PersistState(nil) + err = cloudState.PersistState(t.Context(), nil) if err != nil { t.Fatal(err) } diff --git a/internal/command/apply_test.go b/internal/command/apply_test.go index 32368ac34d..bb8a2b5dce 100644 --- a/internal/command/apply_test.go +++ b/internal/command/apply_test.go @@ -443,7 +443,7 @@ func TestApply_defaultState(t *testing.T) { } // create an existing state file - if err := statemgr.WriteAndPersist(statemgr.NewFilesystem(statePath, encryption.StateEncryptionDisabled()), states.NewState(), nil); err != nil { + if err := statemgr.WriteAndPersist(t.Context(), statemgr.NewFilesystem(statePath, encryption.StateEncryptionDisabled()), states.NewState(), nil); err != nil { t.Fatal(err) } @@ -733,7 +733,7 @@ func TestApply_plan_backup(t *testing.T) { // create a state file that needs to be backed up fs := statemgr.NewFilesystem(statePath, encryption.StateEncryptionDisabled()) fs.StateSnapshotMeta() - if err := statemgr.WriteAndPersist(fs, states.NewState(), nil); err != nil { + if err := statemgr.WriteAndPersist(t.Context(), fs, states.NewState(), nil); err != nil { t.Fatal(err) } diff --git a/internal/command/clistate/local_state.go b/internal/command/clistate/local_state.go index 39fbacdd42..e72f86b4bf 100644 --- a/internal/command/clistate/local_state.go +++ b/internal/command/clistate/local_state.go @@ -7,6 +7,7 @@ package clistate import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -123,12 +124,12 @@ func (s *LocalState) writeState(state *tofu.State) error { // PersistState for LocalState is a no-op since WriteState always persists. // // StatePersister impl. -func (s *LocalState) PersistState() error { +func (s *LocalState) PersistState(_ context.Context) error { return nil } // StateRefresher impl. -func (s *LocalState) RefreshState() error { +func (s *LocalState) RefreshState(_ context.Context) error { s.mu.Lock() defer s.mu.Unlock() @@ -187,7 +188,7 @@ func (s *LocalState) RefreshState() error { } // Lock implements a local filesystem state.Locker. -func (s *LocalState) Lock(info *statemgr.LockInfo) (string, error) { +func (s *LocalState) Lock(_ context.Context, info *statemgr.LockInfo) (string, error) { s.mu.Lock() defer s.mu.Unlock() @@ -219,7 +220,7 @@ func (s *LocalState) Lock(info *statemgr.LockInfo) (string, error) { return s.lockID, s.writeLockInfo(info) } -func (s *LocalState) Unlock(id string) error { +func (s *LocalState) Unlock(_ context.Context, id string) error { s.mu.Lock() defer s.mu.Unlock() diff --git a/internal/command/clistate/state.go b/internal/command/clistate/state.go index 9d9235c899..4c7cebc554 100644 --- a/internal/command/clistate/state.go +++ b/internal/command/clistate/state.go @@ -150,7 +150,7 @@ func (l *locker) Unlock() tfdiags.Diagnostics { } err := slowmessage.Do(LockThreshold, func() error { - return l.state.Unlock(l.lockID) + return l.state.Unlock(l.ctx, l.lockID) }, l.view.Unlocking) if err != nil { diff --git a/internal/command/import.go b/internal/command/import.go index 08159aea2c..e640ad223c 100644 --- a/internal/command/import.go +++ b/internal/command/import.go @@ -6,6 +6,7 @@ package command import ( + "context" "errors" "fmt" "log" @@ -281,7 +282,7 @@ func (c *ImportCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } - if err := state.PersistState(schemas); err != nil { + if err := state.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } diff --git a/internal/command/init.go b/internal/command/init.go index e094e99a0c..353bcbba8a 100644 --- a/internal/command/init.go +++ b/internal/command/init.go @@ -262,7 +262,7 @@ func (c *InitCommand) Run(args []string) int { return 1 } - if err := sMgr.RefreshState(); err != nil { + if err := sMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Error refreshing state: %s", err)) return 1 } diff --git a/internal/command/init_test.go b/internal/command/init_test.go index 462e90816b..13b4a3ef17 100644 --- a/internal/command/init_test.go +++ b/internal/command/init_test.go @@ -1257,7 +1257,7 @@ func TestInit_inputFalse(t *testing.T) { "", ) }) - if err := statemgr.WriteAndPersist(statemgr.NewFilesystem("foo", encryption.StateEncryptionDisabled()), fooState, nil); err != nil { + if err := statemgr.WriteAndPersist(t.Context(), statemgr.NewFilesystem("foo", encryption.StateEncryptionDisabled()), fooState, nil); err != nil { t.Fatal(err) } barState := states.BuildState(func(s *states.SyncState) { @@ -1268,7 +1268,7 @@ func TestInit_inputFalse(t *testing.T) { "", ) }) - if err := statemgr.WriteAndPersist(statemgr.NewFilesystem("bar", encryption.StateEncryptionDisabled()), barState, nil); err != nil { + if err := statemgr.WriteAndPersist(t.Context(), statemgr.NewFilesystem("bar", encryption.StateEncryptionDisabled()), barState, nil); err != nil { t.Fatal(err) } diff --git a/internal/command/meta_backend.go b/internal/command/meta_backend.go index 7165c2b561..31ad141b0a 100644 --- a/internal/command/meta_backend.go +++ b/internal/command/meta_backend.go @@ -565,7 +565,7 @@ func (m *Meta) backendFromConfig(ctx context.Context, opts *BackendOpts, enc enc // we haven't used a non-local backend before. That is okay. statePath := filepath.Join(m.DataDir(), DefaultStateFilename) sMgr := &clistate.LocalState{Path: statePath} - if err := sMgr.RefreshState(); err != nil { + if err := sMgr.RefreshState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf("Failed to load state: %w", err)) return nil, diags } @@ -793,7 +793,7 @@ func (m *Meta) backendFromState(ctx context.Context, enc encryption.StateEncrypt // we haven't used a non-local backend before. That is okay. statePath := filepath.Join(m.DataDir(), DefaultStateFilename) sMgr := &clistate.LocalState{Path: statePath} - if err := sMgr.RefreshState(); err != nil { + if err := sMgr.RefreshState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf("Failed to load state: %w", err)) return nil, diags } @@ -943,7 +943,7 @@ func (m *Meta) backend_c_r_S( diags = diags.Append(fmt.Errorf(strings.TrimSpace(errBackendClearSaved), err)) return nil, diags } - if err := sMgr.PersistState(); err != nil { + if err := sMgr.PersistState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf(strings.TrimSpace(errBackendClearSaved), err)) return nil, diags } @@ -989,7 +989,7 @@ func (m *Meta) backend_C_r_s(ctx context.Context, c *configs.Backend, cHash int, diags = diags.Append(fmt.Errorf(errBackendLocalRead, err)) return nil, diags } - if err := localState.RefreshState(); err != nil { + if err := localState.RefreshState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf(errBackendLocalRead, err)) return nil, diags } @@ -1052,7 +1052,7 @@ func (m *Meta) backend_C_r_s(ctx context.Context, c *configs.Backend, cHash int, diags = diags.Append(fmt.Errorf(errBackendMigrateLocalDelete, err)) return nil, diags } - if err := localState.PersistState(nil); err != nil { + if err := localState.PersistState(context.TODO(), nil); err != nil { diags = diags.Append(fmt.Errorf(errBackendMigrateLocalDelete, err)) return nil, diags } @@ -1112,7 +1112,7 @@ func (m *Meta) backend_C_r_s(ctx context.Context, c *configs.Backend, cHash int, diags = diags.Append(fmt.Errorf(errBackendWriteSaved, err)) return nil, diags } - if err := sMgr.PersistState(); err != nil { + if err := sMgr.PersistState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf(errBackendWriteSaved, err)) return nil, diags } @@ -1244,7 +1244,7 @@ func (m *Meta) backend_C_r_S_changed(ctx context.Context, c *configs.Backend, cH diags = diags.Append(fmt.Errorf(errBackendWriteSaved, err)) return nil, diags } - if err := sMgr.PersistState(); err != nil { + if err := sMgr.PersistState(context.TODO()); err != nil { diags = diags.Append(fmt.Errorf(errBackendWriteSaved, err)) return nil, diags } diff --git a/internal/command/meta_backend_migrate.go b/internal/command/meta_backend_migrate.go index 80bde01f6b..b08b826f68 100644 --- a/internal/command/meta_backend_migrate.go +++ b/internal/command/meta_backend_migrate.go @@ -270,7 +270,7 @@ func (m *Meta) backendMigrateState_s_s(ctx context.Context, opts *backendMigrate return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(context.TODO()); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } @@ -318,7 +318,7 @@ func (m *Meta) backendMigrateState_s_s(ctx context.Context, opts *backendMigrate return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.DestinationType, err) } - if err := destinationState.RefreshState(); err != nil { + if err := destinationState.RefreshState(context.TODO()); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.DestinationType, err) } @@ -371,12 +371,12 @@ func (m *Meta) backendMigrateState_s_s(ctx context.Context, opts *backendMigrate // We now own a lock, so double check that we have the version // corresponding to the lock. log.Print("[TRACE] backendMigrateState: refreshing source workspace state") - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(context.TODO()); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } log.Print("[TRACE] backendMigrateState: refreshing destination workspace state") - if err := destinationState.RefreshState(); err != nil { + if err := destinationState.RefreshState(context.TODO()); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } @@ -455,7 +455,7 @@ func (m *Meta) backendMigrateState_s_s(ctx context.Context, opts *backendMigrate // so requiring schemas here could lead to a catch-22 where it requires some manual // intervention to proceed far enough for provider installation. To avoid this, // when migrating to TFC backend, the initial JSON variant of state won't be generated and stored. - if err := destinationState.PersistState(nil); err != nil { + if err := destinationState.PersistState(context.TODO(), nil); err != nil { return fmt.Errorf(strings.TrimSpace(errBackendStateCopy), opts.SourceType, opts.DestinationType, err) } @@ -500,7 +500,7 @@ func (m *Meta) backendMigrateNonEmptyConfirm( // Helper to write the state saveHelper := func(n, path string, s *states.State) error { - return statemgr.WriteAndPersist(statemgr.NewFilesystem(path, encryption.StateEncryptionDisabled()), s, nil) + return statemgr.WriteAndPersist(context.TODO(), statemgr.NewFilesystem(path, encryption.StateEncryptionDisabled()), s, nil) } // Write the states @@ -602,7 +602,7 @@ func (m *Meta) backendMigrateTFC(ctx context.Context, opts *backendMigrateOpts) if err != nil { return err } - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(context.TODO()); err != nil { return err } if sourceState.State().Empty() { @@ -690,7 +690,7 @@ func (m *Meta) backendMigrateState_S_TFC(ctx context.Context, opts *backendMigra errMigrateSingleLoadDefault), opts.SourceType, err) } // RefreshState is what actually pulls the state to be evaluated. - if err := sourceState.RefreshState(); err != nil { + if err := sourceState.RefreshState(context.TODO()); err != nil { return fmt.Errorf(strings.TrimSpace( errMigrateSingleLoadDefault), opts.SourceType, err) } diff --git a/internal/command/meta_backend_test.go b/internal/command/meta_backend_test.go index 45681cc07d..ea344c257e 100644 --- a/internal/command/meta_backend_test.go +++ b/internal/command/meta_backend_test.go @@ -56,7 +56,7 @@ func TestMetaBackend_emptyDir(t *testing.T) { if err := s.WriteState(testState()); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -123,7 +123,7 @@ func TestMetaBackend_emptyWithDefaultState(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("err: %s", err) } if actual := s.State().String(); actual != testState().String() { @@ -146,7 +146,7 @@ func TestMetaBackend_emptyWithDefaultState(t *testing.T) { if err := s.WriteState(next); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -194,7 +194,7 @@ func TestMetaBackend_emptyWithExplicitState(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("err: %s", err) } if actual := s.State().String(); actual != testState().String() { @@ -217,7 +217,7 @@ func TestMetaBackend_emptyWithExplicitState(t *testing.T) { if err := s.WriteState(next); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -264,7 +264,7 @@ func TestMetaBackend_configureNew(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -279,7 +279,7 @@ func TestMetaBackend_configureNew(t *testing.T) { if err := s.WriteState(state); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -337,7 +337,7 @@ func TestMetaBackend_configureNewWithState(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - state, err := statemgr.RefreshAndRead(s) + state, err := statemgr.RefreshAndRead(t.Context(), s) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -353,7 +353,7 @@ func TestMetaBackend_configureNewWithState(t *testing.T) { state = states.NewState() mark := markStateForMatching(state, "changing") - if err := statemgr.WriteAndPersist(s, state, nil); err != nil { + if err := statemgr.WriteAndPersist(t.Context(), s, state, nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -460,7 +460,7 @@ func TestMetaBackend_configureNewWithStateNoMigrate(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } if state := s.State(); state != nil { @@ -503,7 +503,7 @@ func TestMetaBackend_configureNewWithStateExisting(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -521,7 +521,7 @@ func TestMetaBackend_configureNewWithStateExisting(t *testing.T) { if err := s.WriteState(state); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -577,7 +577,7 @@ func TestMetaBackend_configureNewWithStateExistingNoMigrate(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -594,7 +594,7 @@ func TestMetaBackend_configureNewWithStateExistingNoMigrate(t *testing.T) { if err := s.WriteState(state); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -644,7 +644,7 @@ func TestMetaBackend_configuredUnchanged(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -788,7 +788,7 @@ func TestMetaBackend_configuredUnchangedWithStaticEvalVars(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -834,7 +834,7 @@ func TestMetaBackend_configuredChange(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -859,7 +859,7 @@ func TestMetaBackend_configuredChange(t *testing.T) { if err := s.WriteState(state); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -922,7 +922,7 @@ func TestMetaBackend_reconfigureChange(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } newState := s.State() @@ -932,7 +932,7 @@ func TestMetaBackend_reconfigureChange(t *testing.T) { // verify that the old state is still there s = statemgr.NewFilesystem("local-state.tfstate", encryption.StateEncryptionDisabled()) - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatal(err) } oldState := s.State() @@ -1062,7 +1062,7 @@ func TestMetaBackend_configuredChangeCopy(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1115,7 +1115,7 @@ func TestMetaBackend_configuredChangeCopy_singleState(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1169,7 +1169,7 @@ func TestMetaBackend_configuredChangeCopy_multiToSingleDefault(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1223,7 +1223,7 @@ func TestMetaBackend_configuredChangeCopy_multiToSingle(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1297,7 +1297,7 @@ func TestMetaBackend_configuredChangeCopy_multiToSingleCurrentEnv(t *testing.T) if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1365,7 +1365,7 @@ func TestMetaBackend_configuredChangeCopy_multiToMulti(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1383,7 +1383,7 @@ func TestMetaBackend_configuredChangeCopy_multiToMulti(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1463,7 +1463,7 @@ func TestMetaBackend_configuredChangeCopy_multiToNoDefaultWithDefault(t *testing if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1537,7 +1537,7 @@ func TestMetaBackend_configuredChangeCopy_multiToNoDefaultWithoutDefault(t *test if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1590,7 +1590,7 @@ func TestMetaBackend_configuredUnset(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1614,7 +1614,7 @@ func TestMetaBackend_configuredUnset(t *testing.T) { if err := s.WriteState(testState()); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1654,7 +1654,7 @@ func TestMetaBackend_configuredUnsetCopy(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1674,7 +1674,7 @@ func TestMetaBackend_configuredUnsetCopy(t *testing.T) { if err := s.WriteState(testState()); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1724,7 +1724,7 @@ func TestMetaBackend_planLocal(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1755,7 +1755,7 @@ func TestMetaBackend_planLocal(t *testing.T) { if err := s.WriteState(state); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1807,7 +1807,7 @@ func TestMetaBackend_planLocalStatePath(t *testing.T) { statePath := "foo.tfstate" // put an initial state there that needs to be backed up - err = statemgr.WriteAndPersist(statemgr.NewFilesystem(statePath, encryption.StateEncryptionDisabled()), original, nil) + err = statemgr.WriteAndPersist(t.Context(), statemgr.NewFilesystem(statePath, encryption.StateEncryptionDisabled()), original, nil) if err != nil { t.Fatal(err) } @@ -1827,7 +1827,7 @@ func TestMetaBackend_planLocalStatePath(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1858,7 +1858,7 @@ func TestMetaBackend_planLocalStatePath(t *testing.T) { if err := s.WriteState(state); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } @@ -1918,7 +1918,7 @@ func TestMetaBackend_planLocalMatch(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %s", err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("unexpected error: %s", err) } state := s.State() @@ -1947,7 +1947,7 @@ func TestMetaBackend_planLocalMatch(t *testing.T) { if err := s.WriteState(state); err != nil { t.Fatal(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("unexpected error: %s", err) } diff --git a/internal/command/output.go b/internal/command/output.go index 169bd6aa27..dc008e023b 100644 --- a/internal/command/output.go +++ b/internal/command/output.go @@ -103,7 +103,7 @@ func (c *OutputCommand) Outputs(ctx context.Context, statePath string, enc encry return nil, diags } - output, err := stateStore.GetRootOutputValues() + output, err := stateStore.GetRootOutputValues(context.TODO()) if err != nil { return nil, diags.Append(err) } diff --git a/internal/command/providers.go b/internal/command/providers.go index 20151f50c5..b3876d5732 100644 --- a/internal/command/providers.go +++ b/internal/command/providers.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "path/filepath" "strings" @@ -117,7 +118,7 @@ func (c *ProvidersCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } diff --git a/internal/command/refresh_test.go b/internal/command/refresh_test.go index b2ccc260e8..1df61dd7b1 100644 --- a/internal/command/refresh_test.go +++ b/internal/command/refresh_test.go @@ -238,7 +238,7 @@ func TestRefresh_defaultState(t *testing.T) { statePath := testStateFile(t, originalState) localState := statemgr.NewFilesystem(statePath, encryption.StateEncryptionDisabled()) - if err := localState.RefreshState(); err != nil { + if err := localState.RefreshState(t.Context()); err != nil { t.Fatal(err) } s := localState.State() diff --git a/internal/command/show.go b/internal/command/show.go index 0c4a0cf324..b8863de30e 100644 --- a/internal/command/show.go +++ b/internal/command/show.go @@ -521,7 +521,7 @@ func getStateFromBackend(ctx context.Context, b backend.Backend, workspace strin } // Refresh the state store with the latest state snapshot from persistent storage - if err := stateStore.RefreshState(); err != nil { + if err := stateStore.RefreshState(context.TODO()); err != nil { tracing.SetSpanError(span, err) return nil, fmt.Errorf("failed to load state: %w", err) } diff --git a/internal/command/state_list.go b/internal/command/state_list.go index 59738c8471..c60e1e3ceb 100644 --- a/internal/command/state_list.go +++ b/internal/command/state_list.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "strings" @@ -70,7 +71,7 @@ func (c *StateListCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateLoadingState, err)) return 1 } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } diff --git a/internal/command/state_mv.go b/internal/command/state_mv.go index 4c570f6860..3b21a3eb31 100644 --- a/internal/command/state_mv.go +++ b/internal/command/state_mv.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "strings" @@ -120,7 +121,7 @@ func (c *StateMvCommand) Run(args []string) int { }() } - if err := stateFromMgr.RefreshState(); err != nil { + if err := stateFromMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh source state: %s", err)) return 1 } @@ -158,7 +159,7 @@ func (c *StateMvCommand) Run(args []string) int { }() } - if err := stateToMgr.RefreshState(); err != nil { + if err := stateToMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh destination state: %s", err)) return 1 } @@ -420,7 +421,7 @@ func (c *StateMvCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateToMgr.PersistState(schemas); err != nil { + if err := stateToMgr.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } @@ -431,7 +432,7 @@ func (c *StateMvCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateFromMgr.PersistState(schemas); err != nil { + if err := stateFromMgr.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } diff --git a/internal/command/state_pull.go b/internal/command/state_pull.go index cbe50109c8..4e3059151f 100644 --- a/internal/command/state_pull.go +++ b/internal/command/state_pull.go @@ -7,6 +7,7 @@ package command import ( "bytes" + "context" "fmt" "strings" @@ -65,7 +66,7 @@ func (c *StatePullCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateLoadingState, err)) return 1 } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh state: %s", err)) return 1 } diff --git a/internal/command/state_push.go b/internal/command/state_push.go index b33abc38fa..7bf4918143 100644 --- a/internal/command/state_push.go +++ b/internal/command/state_push.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "io" "os" @@ -128,7 +129,7 @@ func (c *StatePushCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh destination state: %s", err)) return 1 } @@ -155,7 +156,7 @@ func (c *StatePushCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Failed to write state: %s", err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf("Failed to persist state: %s", err)) return 1 } diff --git a/internal/command/state_push_test.go b/internal/command/state_push_test.go index ea361dafbf..de31f3826d 100644 --- a/internal/command/state_push_test.go +++ b/internal/command/state_push_test.go @@ -273,7 +273,7 @@ func TestStatePush_forceRemoteState(t *testing.T) { if err := sMgr.WriteState(states.NewState()); err != nil { t.Fatal(err) } - if err := sMgr.PersistState(nil); err != nil { + if err := sMgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } diff --git a/internal/command/state_replace_provider.go b/internal/command/state_replace_provider.go index 0ee0f9924e..491cc57e6c 100644 --- a/internal/command/state_replace_provider.go +++ b/internal/command/state_replace_provider.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "strings" @@ -110,7 +111,7 @@ func (c *StateReplaceProviderCommand) Run(args []string) int { } // Refresh and load state - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh source state: %s", err)) return 1 } @@ -197,7 +198,7 @@ func (c *StateReplaceProviderCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } diff --git a/internal/command/state_rm.go b/internal/command/state_rm.go index 424485216f..0e39fddbab 100644 --- a/internal/command/state_rm.go +++ b/internal/command/state_rm.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "strings" @@ -77,7 +78,7 @@ func (c *StateRmCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to refresh state: %s", err)) return 1 } @@ -144,7 +145,7 @@ func (c *StateRmCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf(errStateRmPersist, err)) return 1 } diff --git a/internal/command/state_show.go b/internal/command/state_show.go index ed7a00e79c..93b44d5dbd 100644 --- a/internal/command/state_show.go +++ b/internal/command/state_show.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "os" "strings" @@ -138,7 +139,7 @@ func (c *StateShowCommand) Run(args []string) int { c.Streams.Eprintln(fmt.Sprintf(errStateLoadingState, err)) return 1 } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Streams.Eprintf("Failed to refresh state: %s\n", err) return 1 } diff --git a/internal/command/taint.go b/internal/command/taint.go index c3ecf704ca..247b8bea99 100644 --- a/internal/command/taint.go +++ b/internal/command/taint.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "strings" @@ -118,7 +119,7 @@ func (c *TaintCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } @@ -193,7 +194,7 @@ func (c *TaintCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } diff --git a/internal/command/testdata/statelocker.go b/internal/command/testdata/statelocker.go index b8114b2a6f..8ef7b0564a 100644 --- a/internal/command/testdata/statelocker.go +++ b/internal/command/testdata/statelocker.go @@ -4,6 +4,7 @@ package main import ( + "context" "io" "log" "os" @@ -28,7 +29,7 @@ func main() { info.Operation = "test" info.Info = "state locker" - lockID, err := s.Lock(info) + lockID, err := s.Lock(context.Background(), info) if err != nil { io.WriteString(os.Stderr, err.Error()) return @@ -38,7 +39,7 @@ func main() { io.WriteString(os.Stdout, "LOCKID "+lockID) defer func() { - if err := s.Unlock(lockID); err != nil { + if err := s.Unlock(context.Background(), lockID); err != nil { io.WriteString(os.Stderr, err.Error()) } }() diff --git a/internal/command/unlock.go b/internal/command/unlock.go index fcbfe08abf..a6c6e958e1 100644 --- a/internal/command/unlock.go +++ b/internal/command/unlock.go @@ -132,7 +132,7 @@ func (c *UnlockCommand) Run(args []string) int { } } - if err := stateMgr.Unlock(lockID); err != nil { + if err := stateMgr.Unlock(context.TODO(), lockID); err != nil { c.Ui.Error(fmt.Sprintf("Failed to unlock state: %s", err)) return 1 } diff --git a/internal/command/untaint.go b/internal/command/untaint.go index dedf8508cf..53c6fc2844 100644 --- a/internal/command/untaint.go +++ b/internal/command/untaint.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "strings" @@ -109,7 +110,7 @@ func (c *UntaintCommand) Run(args []string) int { }() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { c.Ui.Error(fmt.Sprintf("Failed to load state: %s", err)) return 1 } @@ -194,7 +195,7 @@ func (c *UntaintCommand) Run(args []string) int { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } - if err := stateMgr.PersistState(schemas); err != nil { + if err := stateMgr.PersistState(context.TODO(), schemas); err != nil { c.Ui.Error(fmt.Sprintf("Error writing state file: %s", err)) return 1 } diff --git a/internal/command/workspace_command_test.go b/internal/command/workspace_command_test.go index 8d1bf3d157..761859f480 100644 --- a/internal/command/workspace_command_test.go +++ b/internal/command/workspace_command_test.go @@ -252,7 +252,7 @@ func TestWorkspace_createWithState(t *testing.T) { ) }) - err := statemgr.WriteAndPersist(statemgr.NewFilesystem("test.tfstate", encryption.StateEncryptionDisabled()), originalState, nil) + err := statemgr.WriteAndPersist(t.Context(), statemgr.NewFilesystem("test.tfstate", encryption.StateEncryptionDisabled()), originalState, nil) if err != nil { t.Fatal(err) } @@ -270,7 +270,7 @@ func TestWorkspace_createWithState(t *testing.T) { newPath := filepath.Join(local.DefaultWorkspaceDir, "test", DefaultStateFilename) envState := statemgr.NewFilesystem(newPath, encryption.StateEncryptionDisabled()) - err = envState.RefreshState() + err = envState.RefreshState(t.Context()) if err != nil { t.Fatal(err) } diff --git a/internal/command/workspace_delete.go b/internal/command/workspace_delete.go index a9f0f44aea..9b17a6746d 100644 --- a/internal/command/workspace_delete.go +++ b/internal/command/workspace_delete.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "strings" "time" @@ -134,7 +135,7 @@ func (c *WorkspaceDeleteCommand) Run(args []string) int { stateLocker = clistate.NewNoopLocker() } - if err := stateMgr.RefreshState(); err != nil { + if err := stateMgr.RefreshState(context.TODO()); err != nil { // We need to release the lock before exit stateLocker.Unlock() c.Ui.Error(err.Error()) diff --git a/internal/command/workspace_new.go b/internal/command/workspace_new.go index 4261182e8a..126290df42 100644 --- a/internal/command/workspace_new.go +++ b/internal/command/workspace_new.go @@ -6,6 +6,7 @@ package command import ( + "context" "fmt" "os" "strings" @@ -173,7 +174,7 @@ func (c *WorkspaceNewCommand) Run(args []string) int { c.Ui.Error(err.Error()) return 1 } - err = stateMgr.PersistState(nil) + err = stateMgr.PersistState(context.TODO(), nil) if err != nil { c.Ui.Error(err.Error()) return 1 diff --git a/internal/states/remote/remote.go b/internal/states/remote/remote.go index 6fe46954e3..61c3806e54 100644 --- a/internal/states/remote/remote.go +++ b/internal/states/remote/remote.go @@ -6,6 +6,8 @@ package remote import ( + "context" + "github.com/opentofu/opentofu/internal/states/statemgr" ) @@ -13,9 +15,9 @@ import ( // driver. It supports dumb put/get/delete, and the higher level structs // handle persisting the state properly here. type Client interface { - Get() (*Payload, error) - Put([]byte) error - Delete() error + Get(context.Context) (*Payload, error) + Put(context.Context, []byte) error + Delete(context.Context) error } // ClientForcePusher is an optional interface that allows a remote diff --git a/internal/states/remote/remote_test.go b/internal/states/remote/remote_test.go index 75859eaf3b..58e225e2ca 100644 --- a/internal/states/remote/remote_test.go +++ b/internal/states/remote/remote_test.go @@ -6,6 +6,7 @@ package remote import ( + "context" "crypto/md5" "encoding/json" "testing" @@ -15,7 +16,7 @@ func TestRemoteClient_noPayload(t *testing.T) { s := &State{ Client: nilClient{}, } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatal("error refreshing empty remote state") } } @@ -23,11 +24,11 @@ func TestRemoteClient_noPayload(t *testing.T) { // nilClient returns nil for everything type nilClient struct{} -func (nilClient) Get() (*Payload, error) { return nil, nil } +func (nilClient) Get(context.Context) (*Payload, error) { return nil, nil } -func (c nilClient) Put([]byte) error { return nil } +func (c nilClient) Put(context.Context, []byte) error { return nil } -func (c nilClient) Delete() error { return nil } +func (c nilClient) Delete(context.Context) error { return nil } // mockClient is a client that tracks persisted state snapshots only in // memory and also logs what it has been asked to do for use in test @@ -42,7 +43,7 @@ type mockClientRequest struct { Content map[string]interface{} } -func (c *mockClient) Get() (*Payload, error) { +func (c *mockClient) Get(_ context.Context) (*Payload, error) { c.appendLog("Get", c.current) if c.current == nil { return nil, nil @@ -54,13 +55,13 @@ func (c *mockClient) Get() (*Payload, error) { }, nil } -func (c *mockClient) Put(data []byte) error { +func (c *mockClient) Put(_ context.Context, data []byte) error { c.appendLog("Put", data) c.current = data return nil } -func (c *mockClient) Delete() error { +func (c *mockClient) Delete(_ context.Context) error { c.appendLog("Delete", c.current) c.current = nil return nil @@ -91,7 +92,7 @@ type mockClientForcePusher struct { log []mockClientRequest } -func (c *mockClientForcePusher) Get() (*Payload, error) { +func (c *mockClientForcePusher) Get(_ context.Context) (*Payload, error) { c.appendLog("Get", c.current) if c.current == nil { return nil, nil @@ -103,7 +104,7 @@ func (c *mockClientForcePusher) Get() (*Payload, error) { }, nil } -func (c *mockClientForcePusher) Put(data []byte) error { +func (c *mockClientForcePusher) Put(_ context.Context, data []byte) error { if c.force { c.appendLog("Force Put", data) } else { @@ -118,7 +119,7 @@ func (c *mockClientForcePusher) EnableForcePush() { c.force = true } -func (c *mockClientForcePusher) Delete() error { +func (c *mockClientForcePusher) Delete(_ context.Context) error { c.appendLog("Delete", c.current) c.current = nil return nil diff --git a/internal/states/remote/state.go b/internal/states/remote/state.go index b93e0a9084..9dae13f348 100644 --- a/internal/states/remote/state.go +++ b/internal/states/remote/state.go @@ -7,6 +7,7 @@ package remote import ( "bytes" + "context" "fmt" "log" "sync" @@ -77,8 +78,8 @@ func (s *State) State() *states.State { return s.state.DeepCopy() } -func (s *State) GetRootOutputValues() (map[string]*states.OutputValue, error) { - if err := s.RefreshState(); err != nil { +func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + if err := s.RefreshState(ctx); err != nil { return nil, fmt.Errorf("Failed to load state: %w", err) } @@ -142,17 +143,17 @@ func (s *State) WriteStateForMigration(f *statefile.File, force bool) error { } // statemgr.Refresher impl. -func (s *State) RefreshState() error { +func (s *State) RefreshState(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() - return s.refreshState() + return s.refreshState(ctx) } // refreshState is the main implementation of RefreshState, but split out so // that we can make internal calls to it from methods that are already holding // the s.mu lock. -func (s *State) refreshState() error { - payload, err := s.Client.Get() +func (s *State) refreshState(ctx context.Context) error { + payload, err := s.Client.Get(ctx) if err != nil { return err } @@ -184,7 +185,7 @@ func (s *State) refreshState() error { } // statemgr.Persister impl. -func (s *State) PersistState(schemas *tofu.Schemas) error { +func (s *State) PersistState(ctx context.Context, schemas *tofu.Schemas) error { s.mu.Lock() defer s.mu.Unlock() @@ -204,7 +205,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { // We might be writing a new state altogether, but before we do that // we'll check to make sure there isn't already a snapshot present // that we ought to be updating. - err := s.refreshState() + err := s.refreshState(ctx) if err != nil { return fmt.Errorf("failed checking for existing remote state: %w", err) } @@ -228,7 +229,7 @@ func (s *State) PersistState(schemas *tofu.Schemas) error { return err } - err = s.Client.Put(buf.Bytes()) + err = s.Client.Put(ctx, buf.Bytes()) if err != nil { return err } @@ -255,7 +256,7 @@ func (s *State) ShouldPersistIntermediateState(info *local.IntermediateStatePers } // Lock calls the Client's Lock method if it's implemented. -func (s *State) Lock(info *statemgr.LockInfo) (string, error) { +func (s *State) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) { s.mu.Lock() defer s.mu.Unlock() @@ -264,13 +265,13 @@ func (s *State) Lock(info *statemgr.LockInfo) (string, error) { } if c, ok := s.Client.(ClientLocker); ok { - return c.Lock(info) + return c.Lock(ctx, info) } return "", nil } // Unlock calls the Client's Unlock method if it's implemented. -func (s *State) Unlock(id string) error { +func (s *State) Unlock(ctx context.Context, id string) error { s.mu.Lock() defer s.mu.Unlock() @@ -279,7 +280,7 @@ func (s *State) Unlock(id string) error { } if c, ok := s.Client.(ClientLocker); ok { - return c.Unlock(id) + return c.Unlock(ctx, id) } return nil } diff --git a/internal/states/remote/state_test.go b/internal/states/remote/state_test.go index 0799afef4e..2132fb299b 100644 --- a/internal/states/remote/state_test.go +++ b/internal/states/remote/state_test.go @@ -6,6 +6,7 @@ package remote import ( + "context" "log" "sync" "testing" @@ -46,10 +47,10 @@ func TestStateRace(t *testing.T) { if err := s.WriteState(current); err != nil { panic(err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { panic(err) } - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { panic(err) } }() @@ -342,7 +343,7 @@ func TestStatePersist(t *testing.T) { // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this - if err := mgr.RefreshState(); err != nil { + if err := mgr.RefreshState(t.Context()); err != nil { t.Fatalf("failed to RefreshState: %s", err) } @@ -363,7 +364,7 @@ func TestStatePersist(t *testing.T) { if err := mgr.WriteState(s); err != nil { t.Fatalf("failed to WriteState for %q: %s", tc.name, err) } - if err := mgr.PersistState(nil); err != nil { + if err := mgr.PersistState(t.Context(), nil); err != nil { t.Fatalf("failed to PersistState for %q: %s", tc.name, err) } @@ -412,7 +413,7 @@ func TestState_GetRootOutputValues(t *testing.T) { encryption.StateEncryptionDisabled(), ) - outputs, err := mgr.GetRootOutputValues() + outputs, err := mgr.GetRootOutputValues(t.Context()) if err != nil { t.Errorf("Expected GetRootOutputValues to not return an error, but it returned %v", err) } @@ -528,7 +529,7 @@ func TestWriteStateForMigration(t *testing.T) { // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this - if err := mgr.RefreshState(); err != nil { + if err := mgr.RefreshState(t.Context()); err != nil { t.Fatalf("failed to RefreshState: %s", err) } @@ -570,7 +571,7 @@ func TestWriteStateForMigration(t *testing.T) { if err := mgr.WriteState(mgr.State()); err != nil { t.Fatal(err) } - if err := mgr.PersistState(nil); err != nil { + if err := mgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } @@ -689,7 +690,7 @@ func TestWriteStateForMigrationWithForcePushClient(t *testing.T) { // before any writes would happen, so we'll mimic that here for realism. // NB This causes a GET to be logged so the first item in the test cases // must account for this - if err := mgr.RefreshState(); err != nil { + if err := mgr.RefreshState(t.Context()); err != nil { t.Fatalf("failed to RefreshState: %s", err) } @@ -741,7 +742,7 @@ func TestWriteStateForMigrationWithForcePushClient(t *testing.T) { if err := mgr.WriteState(mgr.State()); err != nil { t.Fatal(err) } - if err := mgr.PersistState(nil); err != nil { + if err := mgr.PersistState(t.Context(), nil); err != nil { t.Fatal(err) } @@ -773,12 +774,12 @@ type mockClientLocker struct { } // Implement the mock Lock method for mockOptionalClientLocker -func (c *mockOptionalClientLocker) Lock(_ *statemgr.LockInfo) (string, error) { +func (c *mockOptionalClientLocker) Lock(_ context.Context, _ *statemgr.LockInfo) (string, error) { return "", nil } // Implement the mock Unlock method for mockOptionalClientLocker -func (c *mockOptionalClientLocker) Unlock(_ string) error { +func (c *mockOptionalClientLocker) Unlock(_ context.Context, _ string) error { // Provide a simple implementation return nil } @@ -789,12 +790,12 @@ func (c *mockOptionalClientLocker) IsLockingEnabled() bool { } // Implement the mock Lock method for mockClientLocker -func (c *mockClientLocker) Lock(_ *statemgr.LockInfo) (string, error) { +func (c *mockClientLocker) Lock(_ context.Context, _ *statemgr.LockInfo) (string, error) { return "", nil } // Implement the mock Unlock method for mockClientLocker -func (c *mockClientLocker) Unlock(_ string) error { +func (c *mockClientLocker) Unlock(_ context.Context, _ string) error { return nil } diff --git a/internal/states/remote/testing.go b/internal/states/remote/testing.go index 8153ccc01c..ffa77eddd3 100644 --- a/internal/states/remote/testing.go +++ b/internal/states/remote/testing.go @@ -25,11 +25,11 @@ func TestClient(t *testing.T, c Client) { } data := buf.Bytes() - if err := c.Put(data); err != nil { + if err := c.Put(t.Context(), data); err != nil { t.Fatalf("put: %s", err) } - p, err := c.Get() + p, err := c.Get(t.Context()) if err != nil { t.Fatalf("get: %s", err) } @@ -37,11 +37,11 @@ func TestClient(t *testing.T, c Client) { t.Fatalf("expected full state %q\n\ngot: %q", string(p.Data), string(data)) } - if err := c.Delete(); err != nil { + if err := c.Delete(t.Context()); err != nil { t.Fatalf("delete: %s", err) } - p, err = c.Get() + p, err = c.Get(t.Context()) if err != nil { t.Fatalf("get: %s", err) } @@ -73,14 +73,14 @@ func TestRemoteLocks(t *testing.T, a, b Client) { infoB.Operation = "test" infoB.Who = "clientB" - lockIDA, err := lockerA.Lock(infoA) + lockIDA, err := lockerA.Lock(t.Context(), infoA) if err != nil { t.Fatal("unable to get initial lock:", err) } - _, err = lockerB.Lock(infoB) + _, err = lockerB.Lock(t.Context(), infoB) if err == nil { - if err := lockerA.Unlock(lockIDA); err != nil { + if err := lockerA.Unlock(t.Context(), lockIDA); err != nil { t.Error(err) } t.Fatal("client B obtained lock while held by client A") @@ -89,11 +89,11 @@ func TestRemoteLocks(t *testing.T, a, b Client) { t.Errorf("expected a LockError, but was %t: %s", err, err) } - if err := lockerA.Unlock(lockIDA); err != nil { + if err := lockerA.Unlock(t.Context(), lockIDA); err != nil { t.Fatal("error unlocking client A", err) } - lockIDB, err := lockerB.Lock(infoB) + lockIDB, err := lockerB.Lock(t.Context(), infoB) if err != nil { t.Fatal("unable to obtain lock from client B") } @@ -102,7 +102,7 @@ func TestRemoteLocks(t *testing.T, a, b Client) { t.Fatalf("duplicate lock IDs: %q", lockIDB) } - if err = lockerB.Unlock(lockIDB); err != nil { + if err = lockerB.Unlock(t.Context(), lockIDB); err != nil { t.Fatal("error unlocking client B:", err) } diff --git a/internal/states/statemgr/filesystem.go b/internal/states/statemgr/filesystem.go index 8c347a700b..5bb6c6d140 100644 --- a/internal/states/statemgr/filesystem.go +++ b/internal/states/statemgr/filesystem.go @@ -7,6 +7,7 @@ package statemgr import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -158,7 +159,7 @@ func (s *Filesystem) writeState(state *states.State, meta *SnapshotMeta) error { } // PersistState writes state to a tfstate file. -func (s *Filesystem) PersistState(schemas *tofu.Schemas) error { +func (s *Filesystem) PersistState(_ context.Context, schemas *tofu.Schemas) error { defer s.mutex()() return s.persistState(schemas) @@ -249,13 +250,13 @@ func (s *Filesystem) persistState(schemas *tofu.Schemas) error { } // RefreshState is an implementation of Refresher. -func (s *Filesystem) RefreshState() error { +func (s *Filesystem) RefreshState(_ context.Context) error { defer s.mutex()() return s.refreshState() } -func (s *Filesystem) GetRootOutputValues() (map[string]*states.OutputValue, error) { - err := s.RefreshState() +func (s *Filesystem) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + err := s.RefreshState(ctx) if err != nil { return nil, err } @@ -331,7 +332,7 @@ func (s *Filesystem) refreshState() error { } // Lock implements Locker using filesystem discretionary locks. -func (s *Filesystem) Lock(info *LockInfo) (string, error) { +func (s *Filesystem) Lock(_ context.Context, info *LockInfo) (string, error) { defer s.mutex()() if s.stateFileOut == nil { @@ -364,7 +365,7 @@ func (s *Filesystem) Lock(info *LockInfo) (string, error) { } // Unlock is the companion to Lock, completing the implementation of Locker. -func (s *Filesystem) Unlock(id string) error { +func (s *Filesystem) Unlock(_ context.Context, id string) error { defer s.mutex()() if s.lockID == "" { diff --git a/internal/states/statemgr/filesystem_test.go b/internal/states/statemgr/filesystem_test.go index 0b8b082b07..a76571021b 100644 --- a/internal/states/statemgr/filesystem_test.go +++ b/internal/states/statemgr/filesystem_test.go @@ -59,7 +59,7 @@ func TestFilesystemLocks(t *testing.T) { // lock first info := NewLockInfo() info.Operation = "test" - lockID, err := s.Lock(info) + lockID, err := s.Lock(t.Context(), info) if err != nil { t.Fatal(err) } @@ -84,22 +84,22 @@ func TestFilesystemLocks(t *testing.T) { } // a noop, since we unlock on exit - if err := s.Unlock(lockID); err != nil { + if err := s.Unlock(t.Context(), lockID); err != nil { t.Fatal(err) } // local locks can re-lock - lockID, err = s.Lock(info) + lockID, err = s.Lock(t.Context(), info) if err != nil { t.Fatal(err) } - if err := s.Unlock(lockID); err != nil { + if err := s.Unlock(t.Context(), lockID); err != nil { t.Fatal(err) } // we should not be able to unlock the same lock twice - if err := s.Unlock(lockID); err == nil { + if err := s.Unlock(t.Context(), lockID); err == nil { t.Fatal("unlocking an unlocked state should fail") } @@ -120,12 +120,12 @@ func TestFilesystem_writeWhileLocked(t *testing.T) { // lock first info := NewLockInfo() info.Operation = "test" - lockID, err := s.Lock(info) + lockID, err := s.Lock(t.Context(), info) if err != nil { t.Fatal(err) } defer func() { - if err := s.Unlock(lockID); err != nil { + if err := s.Unlock(t.Context(), lockID); err != nil { t.Fatal(err) } }() @@ -256,17 +256,17 @@ func TestFilesystem_backupAndReadPath(t *testing.T) { "", ) }) - err = WriteAndPersist(ls, newState, nil) + err = WriteAndPersist(t.Context(), ls, newState, nil) if err != nil { t.Fatalf("failed to write new state: %s", err) } - lockID, err := ls.Lock(info) + lockID, err := ls.Lock(t.Context(), info) if err != nil { t.Fatal(err) } - if err := ls.Unlock(lockID); err != nil { + if err := ls.Unlock(t.Context(), lockID); err != nil { t.Fatal(err) } @@ -308,7 +308,7 @@ func TestFilesystem_backupAndReadPath(t *testing.T) { func TestFilesystem_nonExist(t *testing.T) { defer testOverrideVersion(t, "1.2.3")() ls := NewFilesystem("ishouldntexist", encryption.StateEncryptionDisabled()) - if err := ls.RefreshState(); err != nil { + if err := ls.RefreshState(t.Context()); err != nil { t.Fatalf("err: %s", err) } @@ -327,7 +327,7 @@ func TestFilesystem_lockUnlockWithoutWrite(t *testing.T) { os.Remove(ls.path) // Lock the state, and in doing so recreate the tempfile - lockID, err := ls.Lock(info) + lockID, err := ls.Lock(t.Context(), info) if err != nil { t.Fatal(err) } @@ -336,7 +336,7 @@ func TestFilesystem_lockUnlockWithoutWrite(t *testing.T) { t.Fatal("should have marked state as created") } - if err := ls.Unlock(lockID); err != nil { + if err := ls.Unlock(t.Context(), lockID); err != nil { t.Fatal(err) } @@ -381,7 +381,7 @@ func testFilesystem(t *testing.T) *Filesystem { f.Close() ls := NewFilesystem(f.Name(), encryption.StateEncryptionDisabled()) - if err := ls.RefreshState(); err != nil { + if err := ls.RefreshState(t.Context()); err != nil { t.Fatalf("initial refresh failed: %s", err) } @@ -413,17 +413,17 @@ func TestFilesystem_refreshWhileLocked(t *testing.T) { // lock first info := NewLockInfo() info.Operation = "test" - lockID, err := s.Lock(info) + lockID, err := s.Lock(t.Context(), info) if err != nil { t.Fatal(err) } defer func() { - if err := s.Unlock(lockID); err != nil { + if err := s.Unlock(t.Context(), lockID); err != nil { t.Fatal(err) } }() - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatal(err) } @@ -436,7 +436,7 @@ func TestFilesystem_refreshWhileLocked(t *testing.T) { func TestFilesystem_GetRootOutputValues(t *testing.T) { fs := testFilesystem(t) - outputs, err := fs.GetRootOutputValues() + outputs, err := fs.GetRootOutputValues(t.Context()) if err != nil { t.Errorf("Expected GetRootOutputValues to not return an error, but it returned %v", err) } diff --git a/internal/states/statemgr/helper.go b/internal/states/statemgr/helper.go index 7943850c1f..f235016d37 100644 --- a/internal/states/statemgr/helper.go +++ b/internal/states/statemgr/helper.go @@ -9,6 +9,8 @@ package statemgr // operations done against full state managers. import ( + "context" + "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/states/statefile" "github.com/opentofu/opentofu/internal/tofu" @@ -30,8 +32,8 @@ func NewStateFile() *statefile.File { // // This is a wrapper around calling RefreshState and then State on the given // manager. -func RefreshAndRead(mgr Storage) (*states.State, error) { - err := mgr.RefreshState() +func RefreshAndRead(ctx context.Context, mgr Storage) (*states.State, error) { + err := mgr.RefreshState(ctx) if err != nil { return nil, err } @@ -50,10 +52,10 @@ func RefreshAndRead(mgr Storage) (*states.State, error) { // out quickly with a user-facing error. In situations where more control // is required, call WriteState and PersistState on the state manager directly // and handle their errors. -func WriteAndPersist(mgr Storage, state *states.State, schemas *tofu.Schemas) error { +func WriteAndPersist(ctx context.Context, mgr Storage, state *states.State, schemas *tofu.Schemas) error { err := mgr.WriteState(state) if err != nil { return err } - return mgr.PersistState(schemas) + return mgr.PersistState(ctx, schemas) } diff --git a/internal/states/statemgr/lock.go b/internal/states/statemgr/lock.go index 0c289eb84a..b69d595a7c 100644 --- a/internal/states/statemgr/lock.go +++ b/internal/states/statemgr/lock.go @@ -6,6 +6,8 @@ package statemgr import ( + "context" + "github.com/opentofu/opentofu/internal/states" "github.com/opentofu/opentofu/internal/tofu" ) @@ -25,26 +27,26 @@ func (s *LockDisabled) State() *states.State { return s.Inner.State() } -func (s *LockDisabled) GetRootOutputValues() (map[string]*states.OutputValue, error) { - return s.Inner.GetRootOutputValues() +func (s *LockDisabled) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) { + return s.Inner.GetRootOutputValues(ctx) } func (s *LockDisabled) WriteState(v *states.State) error { return s.Inner.WriteState(v) } -func (s *LockDisabled) RefreshState() error { - return s.Inner.RefreshState() +func (s *LockDisabled) RefreshState(ctx context.Context) error { + return s.Inner.RefreshState(ctx) } -func (s *LockDisabled) PersistState(schemas *tofu.Schemas) error { - return s.Inner.PersistState(schemas) +func (s *LockDisabled) PersistState(ctx context.Context, schemas *tofu.Schemas) error { + return s.Inner.PersistState(ctx, schemas) } -func (s *LockDisabled) Lock(info *LockInfo) (string, error) { +func (s *LockDisabled) Lock(_ context.Context, info *LockInfo) (string, error) { return "", nil } -func (s *LockDisabled) Unlock(id string) error { +func (s *LockDisabled) Unlock(_ context.Context, id string) error { return nil } diff --git a/internal/states/statemgr/locker.go b/internal/states/statemgr/locker.go index d530761818..a88273aa9d 100644 --- a/internal/states/statemgr/locker.go +++ b/internal/states/statemgr/locker.go @@ -55,7 +55,7 @@ type Locker interface { // an instance of LockError immediately if the lock is already held, // and the helper function LockWithContext uses this to automatically // retry lock acquisition periodically until a timeout is reached. - Lock(info *LockInfo) (string, error) + Lock(ctx context.Context, info *LockInfo) (string, error) // Unlock releases a lock previously acquired by Lock. // @@ -63,7 +63,7 @@ type Locker interface { // another user with some sort of administrative override privilege -- // then an error is returned explaining the situation in a way that // is suitable for returning to an end-user. - Unlock(id string) error + Unlock(ctx context.Context, id string) error } // OptionalLocker extends Locker interface to allow callers @@ -88,7 +88,10 @@ func LockWithContext(ctx context.Context, s Locker, info *LockInfo) (string, err delay := time.Second maxDelay := 16 * time.Second for { - id, err := s.Lock(info) + // We disable cancellation on the context passed to s.Lock + // because we want it to run to completion if possible and then + // we'll check context cancellation explicitly below. + id, err := s.Lock(context.WithoutCancel(ctx), info) if err == nil { return id, nil } diff --git a/internal/states/statemgr/persistent.go b/internal/states/statemgr/persistent.go index 6dc84ff1e9..23241a026d 100644 --- a/internal/states/statemgr/persistent.go +++ b/internal/states/statemgr/persistent.go @@ -6,6 +6,8 @@ package statemgr import ( + "context" + version "github.com/hashicorp/go-version" "github.com/opentofu/opentofu/internal/states" @@ -33,7 +35,7 @@ type Persistent interface { // to differentiate reading the state and reading the outputs within the state. type OutputReader interface { // GetRootOutputValues fetches the root module output values from state or another source - GetRootOutputValues() (map[string]*states.OutputValue, error) + GetRootOutputValues(context.Context) (map[string]*states.OutputValue, error) } // Refresher is the interface for managers that can read snapshots from @@ -63,7 +65,7 @@ type Refresher interface { // return only a subset of what was written. Callers must assume that // ephemeral portions of the state may be unpopulated after calling // RefreshState. - RefreshState() error + RefreshState(context.Context) error } // Persister is the interface for managers that can write snapshots to @@ -83,7 +85,7 @@ type Refresher interface { // state. For example, when representing state in an external JSON // representation. type Persister interface { - PersistState(*tofu.Schemas) error + PersistState(context.Context, *tofu.Schemas) error } // PersistentMeta is an optional extension to Persistent that allows inspecting diff --git a/internal/states/statemgr/statemgr_fake.go b/internal/states/statemgr/statemgr_fake.go index 1a0ec73064..649b1ea746 100644 --- a/internal/states/statemgr/statemgr_fake.go +++ b/internal/states/statemgr/statemgr_fake.go @@ -6,6 +6,7 @@ package statemgr import ( + "context" "errors" "sync" @@ -63,19 +64,19 @@ func (m *fakeFull) WriteState(s *states.State) error { return m.t.WriteState(s) } -func (m *fakeFull) RefreshState() error { +func (m *fakeFull) RefreshState(_ context.Context) error { return m.t.WriteState(m.fakeP.State()) } -func (m *fakeFull) PersistState(schemas *tofu.Schemas) error { +func (m *fakeFull) PersistState(_ context.Context, schemas *tofu.Schemas) error { return m.fakeP.WriteState(m.t.State()) } -func (m *fakeFull) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (m *fakeFull) GetRootOutputValues(_ context.Context) (map[string]*states.OutputValue, error) { return m.State().RootModule().OutputValues, nil } -func (m *fakeFull) Lock(info *LockInfo) (string, error) { +func (m *fakeFull) Lock(_ context.Context, info *LockInfo) (string, error) { m.lockLock.Lock() defer m.lockLock.Unlock() @@ -90,7 +91,7 @@ func (m *fakeFull) Lock(info *LockInfo) (string, error) { return "placeholder", nil } -func (m *fakeFull) Unlock(id string) error { +func (m *fakeFull) Unlock(_ context.Context, id string) error { m.lockLock.Lock() defer m.lockLock.Unlock() @@ -121,7 +122,7 @@ func (m *fakeErrorFull) State() *states.State { return nil } -func (m *fakeErrorFull) GetRootOutputValues() (map[string]*states.OutputValue, error) { +func (m *fakeErrorFull) GetRootOutputValues(_ context.Context) (map[string]*states.OutputValue, error) { return nil, errors.New("fake state manager error") } @@ -129,18 +130,18 @@ func (m *fakeErrorFull) WriteState(s *states.State) error { return errors.New("fake state manager error") } -func (m *fakeErrorFull) RefreshState() error { +func (m *fakeErrorFull) RefreshState(_ context.Context) error { return errors.New("fake state manager error") } -func (m *fakeErrorFull) PersistState(schemas *tofu.Schemas) error { +func (m *fakeErrorFull) PersistState(_ context.Context, schemas *tofu.Schemas) error { return errors.New("fake state manager error") } -func (m *fakeErrorFull) Lock(info *LockInfo) (string, error) { +func (m *fakeErrorFull) Lock(_ context.Context, info *LockInfo) (string, error) { return "placeholder", nil } -func (m *fakeErrorFull) Unlock(id string) error { +func (m *fakeErrorFull) Unlock(_ context.Context, id string) error { return errors.New("fake state manager error") } diff --git a/internal/states/statemgr/statemgr_test.go b/internal/states/statemgr/statemgr_test.go index 5bc81945a8..c95872d32d 100644 --- a/internal/states/statemgr/statemgr_test.go +++ b/internal/states/statemgr/statemgr_test.go @@ -47,7 +47,7 @@ func TestNewLockInfo(t *testing.T) { func TestLockWithContext(t *testing.T) { s := NewFullFake(nil, TestFullInitialState()) - id, err := s.Lock(NewLockInfo()) + id, err := s.Lock(t.Context(), NewLockInfo()) if err != nil { t.Fatal(err) } @@ -76,7 +76,7 @@ func TestLockWithContext(t *testing.T) { go func() { defer close(unlocked) <-attempted - unlockErr = s.Unlock(id) + unlockErr = s.Unlock(t.Context(), id) }() ctx, cancel = context.WithTimeout(context.Background(), 2*time.Second) diff --git a/internal/states/statemgr/testdata/lockstate.go b/internal/states/statemgr/testdata/lockstate.go index 0811535023..2e3256906e 100644 --- a/internal/states/statemgr/testdata/lockstate.go +++ b/internal/states/statemgr/testdata/lockstate.go @@ -1,6 +1,7 @@ package main import ( + "context" "io" "log" "os" @@ -22,7 +23,7 @@ func main() { info.Operation = "test" info.Info = "state locker" - _, err := s.Lock(info) + _, err := s.Lock(context.Background(), info) if err != nil { io.WriteString(os.Stderr, "lock failed") } diff --git a/internal/states/statemgr/testing.go b/internal/states/statemgr/testing.go index 0753f133aa..823e8d24eb 100644 --- a/internal/states/statemgr/testing.go +++ b/internal/states/statemgr/testing.go @@ -27,7 +27,7 @@ import ( func TestFull(t *testing.T, s Full) { t.Helper() - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("err: %s", err) } @@ -61,12 +61,12 @@ func TestFull(t *testing.T, s Full) { } // Test persistence - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("err: %s", err) } // Refresh if we got it - if err := s.RefreshState(); err != nil { + if err := s.RefreshState(t.Context()); err != nil { t.Fatalf("err: %s", err) } @@ -86,7 +86,7 @@ func TestFull(t *testing.T, s Full) { if err := s.WriteState(current); err != nil { t.Fatalf("err: %s", err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("err: %s", err) } @@ -109,7 +109,7 @@ func TestFull(t *testing.T, s Full) { if err := s.WriteState(current); err != nil { t.Fatalf("err: %s", err) } - if err := s.PersistState(nil); err != nil { + if err := s.PersistState(t.Context(), nil); err != nil { t.Fatalf("err: %s", err) }