Files
opentf/internal/cloud/state.go
Martin Atkins 67a5cd0911 statemgr+remote: context.Context parameters
This extends statemgr.Persistent, statemgr.Locker and remote.Client to
all expect context.Context parameters, and then updates all of the existing
implementations of those interfaces to support them.

All of the calls to statemgr.Persistent and statemgr.Locker methods outside
of tests are consistently context.TODO() for now, because the caller
landscape of these interfaces has some complications:

1. statemgr.Locker is also used by the clistate package for its state
   implementation that was derived from statemgr.Filesystem's predecessor,
   even though what clistate manages is not actually "state" in the sense
   of package statemgr. The callers of that are not yet ready to provide
   real contexts.

   In a future commit we'll either need to plumb context through to all of
   the clistate callers, or continue the effort to separate statemgr from
   clistate by introducing a clistate-specific "locker" API for it
   to use instead.

2. We call statemgr.Persistent and statemgr.Locker methods in situations
   where the active context might have already been cancelled, and so we'll
   need to make sure to ignore cancellation when calling those.

   This is mainly limited to PersistState and Unlock, since both need to
   be able to complete after a cancellation, but there are various
   codepaths that perform a Lock, Refresh, Persist, Unlock sequence and so
   it isn't yet clear where is the best place to enforce the invariant that
   Persist and Unlock must not be called with a cancelable context. We'll
   deal with that more in subsequent commits.

Within the various state manager and remote client implementations the
contexts _are_ wired together as best as possible with how these subsystems
are already laid out, and so once we deal with the problems above and make
callers provide suitable contexts they should be able to reach all of the
leaf API clients that might want to generate OpenTelemetry traces.

Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
2025-07-10 08:11:39 -07:00

634 lines
20 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package cloud
import (
"bytes"
"context"
"crypto/md5"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/gocty"
tfe "github.com/hashicorp/go-tfe"
uuid "github.com/hashicorp/go-uuid"
"github.com/opentofu/opentofu/internal/backend/local"
"github.com/opentofu/opentofu/internal/command/jsonstate"
"github.com/opentofu/opentofu/internal/encryption"
"github.com/opentofu/opentofu/internal/states"
"github.com/opentofu/opentofu/internal/states/remote"
"github.com/opentofu/opentofu/internal/states/statefile"
"github.com/opentofu/opentofu/internal/states/statemgr"
"github.com/opentofu/opentofu/internal/tofu"
)
const (
// HeaderSnapshotInterval is the header key that controls the snapshot interval
HeaderSnapshotInterval = "x-terraform-snapshot-interval"
)
// State implements the State interfaces in the state package to handle
// reading and writing the remote state to TFC. This State on its own does no
// local caching so every persist will go to the remote storage and local
// writes will go to memory.
type State struct {
// We track two pieces of meta data in addition to the state itself:
//
// lineage - the state's unique ID
// serial - the monotonic counter of "versions" of the state
//
// Both of these (along with state) have a sister field
// that represents the values read in from an existing source.
// All three of these values are used to determine if the new
// state has changed from an existing state we read in.
lineage, readLineage string
serial, readSerial uint64
mu sync.Mutex
state, readState *states.State
disableLocks bool
tfeClient *tfe.Client
organization string
workspace *tfe.Workspace
stateUploadErr bool
forcePush bool
lockInfo *statemgr.LockInfo
// The server can optionally return an X-Terraform-Snapshot-Interval header
// in its response to the "Create State Version" operation, which specifies
// a number of seconds the server would prefer us to wait before trying
// to write a new snapshot. If this is non-zero then we'll wait at least
// this long before allowing another intermediate snapshot. This does
// not effect final snapshots after an operation, which will always
// be written to the remote API.
stateSnapshotInterval time.Duration
// If the header X-Terraform-Snapshot-Interval is present then
// we will enable snapshots
enableIntermediateSnapshots bool
encryption encryption.StateEncryption
}
var ErrStateVersionUnauthorizedUpgradeState = errors.New(strings.TrimSpace(`
You are not authorized to read the full state version containing outputs.
State versions created by tofu v1.3.0 and newer do not require this level
of authorization and therefore this error can usually be fixed by upgrading the
remote state version.
`))
var _ statemgr.Full = (*State)(nil)
var _ statemgr.Migrator = (*State)(nil)
var _ statemgr.PersistentMeta = (*State)(nil)
var _ local.IntermediateStateConditionalPersister = (*State)(nil)
// statemgr.Reader impl.
func (s *State) State() *states.State {
s.mu.Lock()
defer s.mu.Unlock()
return s.state.DeepCopy()
}
// StateForMigration is part of our implementation of statemgr.Migrator.
func (s *State) StateForMigration() *statefile.File {
s.mu.Lock()
defer s.mu.Unlock()
return statefile.New(s.state.DeepCopy(), s.lineage, s.serial)
}
// WriteStateForMigration is part of our implementation of statemgr.Migrator.
func (s *State) WriteStateForMigration(f *statefile.File, force bool) error {
s.mu.Lock()
defer s.mu.Unlock()
if !force {
checkFile := statefile.New(s.state, s.lineage, s.serial)
if err := statemgr.CheckValidImport(f, checkFile); err != nil {
return err
}
}
// We create a deep copy of the state here, because the caller also has
// a reference to the given object and can potentially go on to mutate
// it after we return, but we want the snapshot at this point in time.
s.state = f.State.DeepCopy()
s.lineage = f.Lineage
s.serial = f.Serial
s.forcePush = force
return nil
}
// DisableLocks turns the Lock and Unlock methods into no-ops. This is intended
// to be called during initialization of a state manager and should not be
// called after any of the statemgr.Full interface methods have been called.
func (s *State) DisableLocks() {
s.disableLocks = true
}
// StateSnapshotMeta returns the metadata from the most recently persisted
// or refreshed persistent state snapshot.
//
// This is an implementation of statemgr.PersistentMeta.
func (s *State) StateSnapshotMeta() statemgr.SnapshotMeta {
return statemgr.SnapshotMeta{
Lineage: s.lineage,
Serial: s.serial,
}
}
// statemgr.Writer impl.
func (s *State) WriteState(state *states.State) error {
s.mu.Lock()
defer s.mu.Unlock()
// We create a deep copy of the state here, because the caller also has
// a reference to the given object and can potentially go on to mutate
// it after we return, but we want the snapshot at this point in time.
s.state = state.DeepCopy()
s.forcePush = false
return nil
}
// PersistState uploads a snapshot of the latest state as a StateVersion to Terraform Cloud
func (s *State) PersistState(ctx context.Context, schemas *tofu.Schemas) error {
s.mu.Lock()
defer s.mu.Unlock()
log.Printf("[DEBUG] cloud/state: state read serial is: %d; serial is: %d", s.readSerial, s.serial)
log.Printf("[DEBUG] cloud/state: state read lineage is: %s; lineage is: %s", s.readLineage, s.lineage)
if s.readState != nil {
lineageUnchanged := s.readLineage != "" && s.lineage == s.readLineage
serialUnchanged := s.readSerial != 0 && s.serial == s.readSerial
stateUnchanged := statefile.StatesMarshalEqual(s.state, s.readState)
if stateUnchanged && lineageUnchanged && serialUnchanged {
// If the state, lineage or serial haven't changed at all then we have nothing to do.
return nil
}
s.serial++
} else {
// We might be writing a new state altogether, but before we do that
// we'll check to make sure there isn't already a snapshot present
// that we ought to be updating.
err := s.refreshState(ctx)
if err != nil {
return fmt.Errorf("failed checking for existing remote state: %w", err)
}
log.Printf("[DEBUG] cloud/state: after refresh, state read serial is: %d; serial is: %d", s.readSerial, s.serial)
log.Printf("[DEBUG] cloud/state: after refresh, state read lineage is: %s; lineage is: %s", s.readLineage, s.lineage)
if s.lineage == "" { // indicates that no state snapshot is present yet
lineage, err := uuid.GenerateUUID()
if err != nil {
return fmt.Errorf("failed to generate initial lineage: %w", err)
}
s.lineage = lineage
s.serial++
}
}
f := statefile.New(s.state, s.lineage, s.serial)
var buf bytes.Buffer
err := statefile.Write(f, &buf, s.encryption)
if err != nil {
return err
}
var jsonState []byte
if schemas != nil {
jsonState, err = jsonstate.Marshal(f, schemas)
if err != nil {
return err
}
}
stateFile, err := statefile.Read(bytes.NewReader(buf.Bytes()), s.encryption)
if err != nil {
return fmt.Errorf("failed to read state: %w", err)
}
ov, err := jsonstate.MarshalOutputs(stateFile.State.RootModule().OutputValues)
if err != nil {
return fmt.Errorf("failed to translate outputs: %w", err)
}
jsonStateOutputs, err := json.Marshal(ov)
if err != nil {
return fmt.Errorf("failed to marshal outputs to json: %w", err)
}
err = s.uploadState(ctx, s.lineage, s.serial, s.forcePush, buf.Bytes(), jsonState, jsonStateOutputs)
if err != nil {
s.stateUploadErr = true
return fmt.Errorf("error uploading state: %w", err)
}
// After we've successfully persisted, what we just wrote is our new
// reference state until someone calls RefreshState again.
// We've potentially overwritten (via force) the state, lineage
// and / or serial (and serial was incremented) so we copy over all
// three fields so everything matches the new state and a subsequent
// operation would correctly detect no changes to the lineage, serial or state.
s.readState = s.state.DeepCopy()
s.readLineage = s.lineage
s.readSerial = s.serial
return nil
}
// ShouldPersistIntermediateState implements local.IntermediateStateConditionalPersister
func (s *State) ShouldPersistIntermediateState(info *local.IntermediateStatePersistInfo) bool {
if info.ForcePersist {
return true
}
// This value is controlled by a x-terraform-snapshot-interval header intercepted during
// state-versions API responses
if !s.enableIntermediateSnapshots {
return false
}
// Our persist interval is the largest of either the caller's requested
// interval or the server's requested interval.
wantInterval := info.RequestedPersistInterval
if s.stateSnapshotInterval > wantInterval {
wantInterval = s.stateSnapshotInterval
}
currentInterval := time.Since(info.LastPersist)
return currentInterval >= wantInterval
}
func (s *State) uploadStateFallback(ctx context.Context, lineage string, serial uint64, isForcePush bool, state, jsonState, jsonStateOutputs []byte) error {
options := tfe.StateVersionCreateOptions{
Lineage: tfe.String(lineage),
Serial: tfe.Int64(int64(serial)),
MD5: tfe.String(fmt.Sprintf("%x", md5.Sum(state))),
Force: tfe.Bool(isForcePush),
State: tfe.String(base64.StdEncoding.EncodeToString(state)),
JSONState: tfe.String(base64.StdEncoding.EncodeToString(jsonState)),
JSONStateOutputs: tfe.String(base64.StdEncoding.EncodeToString(jsonStateOutputs)),
}
// If we have a run ID, make sure to add it to the options
// so the state will be properly associated with the run.
runID := os.Getenv("TFE_RUN_ID")
if runID != "" {
options.Run = &tfe.Run{ID: runID}
}
// Create the new state.
_, err := s.tfeClient.StateVersions.Create(ctx, s.workspace.ID, options)
return err
}
func (s *State) uploadState(ctx context.Context, lineage string, serial uint64, isForcePush bool, state, jsonState, jsonStateOutputs []byte) error {
options := tfe.StateVersionUploadOptions{
StateVersionCreateOptions: tfe.StateVersionCreateOptions{
Lineage: tfe.String(lineage),
Serial: tfe.Int64(int64(serial)),
MD5: tfe.String(fmt.Sprintf("%x", md5.Sum(state))),
Force: tfe.Bool(isForcePush),
JSONStateOutputs: tfe.String(base64.StdEncoding.EncodeToString(jsonStateOutputs)),
},
RawState: state,
RawJSONState: jsonState,
}
// If we have a run ID, make sure to add it to the options
// so the state will be properly associated with the run.
runID := os.Getenv("TFE_RUN_ID")
if runID != "" {
options.StateVersionCreateOptions.Run = &tfe.Run{ID: runID}
}
// The server is allowed to dynamically request a different time interval
// than we'd normally use, for example if it's currently under heavy load
// and needs clients to backoff for a while.
ctx = tfe.ContextWithResponseHeaderHook(ctx, s.readSnapshotIntervalHeader)
// Create the new state.
_, err := s.tfeClient.StateVersions.Upload(ctx, s.workspace.ID, options)
if errors.Is(err, tfe.ErrStateVersionUploadNotSupported) {
// Create the new state with content included in the request (Terraform Enterprise v202306-1 and below)
log.Println("[INFO] Detected that state version upload is not supported. Retrying using compatibility state upload.")
return s.uploadStateFallback(ctx, lineage, serial, isForcePush, state, jsonState, jsonStateOutputs)
}
return err
}
// Lock calls the Client's Lock method if it's implemented.
func (s *State) Lock(ctx context.Context, info *statemgr.LockInfo) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.disableLocks {
return "", nil
}
lockErr := &statemgr.LockError{Info: s.lockInfo}
// Lock the workspace.
_, err := s.tfeClient.Workspaces.Lock(ctx, s.workspace.ID, tfe.WorkspaceLockOptions{
Reason: tfe.String("Locked by OpenTofu"),
})
if err != nil {
if err == tfe.ErrWorkspaceLocked {
lockErr.Info = info
err = fmt.Errorf("%w (lock ID: \"%s/%s\")", err, s.organization, s.workspace.Name)
}
lockErr.Err = err
return "", lockErr
}
s.lockInfo = info
return s.lockInfo.ID, nil
}
// statemgr.Refresher impl.
func (s *State) RefreshState(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
return s.refreshState(ctx)
}
// refreshState is the main implementation of RefreshState, but split out so
// that we can make internal calls to it from methods that are already holding
// the s.mu lock.
func (s *State) refreshState(ctx context.Context) error {
payload, err := s.getStatePayload()
if err != nil {
return err
}
// no remote state is OK
if payload == nil {
s.readState = nil
s.lineage = ""
s.serial = 0
return nil
}
stateFile, err := statefile.Read(bytes.NewReader(payload.Data), s.encryption)
if err != nil {
return err
}
s.lineage = stateFile.Lineage
s.serial = stateFile.Serial
s.state = stateFile.State
// Properties from the remote must be separate so we can
// track changes as lineage, serial and/or state are mutated
s.readLineage = stateFile.Lineage
s.readSerial = stateFile.Serial
s.readState = s.state.DeepCopy()
return nil
}
func (s *State) getStatePayload() (*remote.Payload, error) {
ctx := context.Background()
// Check the x-terraform-snapshot-interval header to see if it has a non-empty
// value which would indicate snapshots are enabled
ctx = tfe.ContextWithResponseHeaderHook(ctx, s.readSnapshotIntervalHeader)
sv, err := s.tfeClient.StateVersions.ReadCurrent(ctx, s.workspace.ID)
if err != nil {
if err == tfe.ErrResourceNotFound {
// If no state exists, then return nil.
return nil, nil
}
return nil, fmt.Errorf("error retrieving state: %w", err)
}
state, err := s.tfeClient.StateVersions.Download(ctx, sv.DownloadURL)
if err != nil {
return nil, fmt.Errorf("error downloading state: %w", err)
}
// If the state is empty, then return nil.
if len(state) == 0 {
return nil, nil
}
// Get the MD5 checksum of the state.
sum := md5.Sum(state)
return &remote.Payload{
Data: state,
MD5: sum[:],
}, nil
}
// Unlock calls the Client's Unlock method if it's implemented.
func (s *State) Unlock(ctx context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.disableLocks {
return nil
}
// We first check if there was an error while uploading the latest
// state. If so, we will not unlock the workspace to prevent any
// changes from being applied until the correct state is uploaded.
if s.stateUploadErr {
return nil
}
lockErr := &statemgr.LockError{Info: s.lockInfo}
// With lock info this should be treated as a normal unlock.
if s.lockInfo != nil {
// Verify the expected lock ID.
if s.lockInfo.ID != id {
lockErr.Err = fmt.Errorf("lock ID does not match existing lock")
return lockErr
}
// Unlock the workspace.
_, err := s.tfeClient.Workspaces.Unlock(ctx, s.workspace.ID)
if err != nil {
lockErr.Err = err
return lockErr
}
return nil
}
// Verify the optional force-unlock lock ID.
if s.organization+"/"+s.workspace.Name != id {
lockErr.Err = fmt.Errorf(
"lock ID %q does not match existing lock ID \"%s/%s\"",
id,
s.organization,
s.workspace.Name,
)
return lockErr
}
// Force unlock the workspace.
_, err := s.tfeClient.Workspaces.ForceUnlock(ctx, s.workspace.ID)
if err != nil {
lockErr.Err = err
return lockErr
}
return nil
}
// Delete the remote state.
func (s *State) Delete(force bool) error {
var err error
isSafeDeleteSupported := s.workspace.Permissions.CanForceDelete != nil
if force || !isSafeDeleteSupported {
err = s.tfeClient.Workspaces.Delete(context.Background(), s.organization, s.workspace.Name)
} else {
err = s.tfeClient.Workspaces.SafeDelete(context.Background(), s.organization, s.workspace.Name)
}
if err != nil && err != tfe.ErrResourceNotFound {
return fmt.Errorf("error deleting workspace %s: %w", s.workspace.Name, err)
}
return nil
}
// GetRootOutputValues fetches output values from Terraform Cloud
func (s *State) GetRootOutputValues(ctx context.Context) (map[string]*states.OutputValue, error) {
so, err := s.tfeClient.StateVersionOutputs.ReadCurrent(ctx, s.workspace.ID)
if err != nil {
return nil, fmt.Errorf("could not read state version outputs: %w", err)
}
result := make(map[string]*states.OutputValue)
for _, output := range so.Items {
if output.DetailedType == nil {
// If there is no detailed type information available, this state was probably created
// with a version of terraform < 1.3.0. In this case, we'll eject completely from this
// function and fall back to the old behavior of reading the entire state file, which
// requires a higher level of authorization.
log.Printf("[DEBUG] falling back to reading full state")
if err := s.RefreshState(ctx); err != nil {
return nil, fmt.Errorf("failed to load state: %w", err)
}
state := s.State()
if state == nil {
// We know that there is supposed to be state (and this is not simply a new workspace
// without state) because the fallback is only invoked when outputs are present but
// detailed types are not available.
return nil, ErrStateVersionUnauthorizedUpgradeState
}
return state.RootModule().OutputValues, nil
}
if output.Sensitive {
// Since this is a sensitive value, the output must be requested explicitly in order to
// read its value, which is assumed to be present by callers
sensitiveOutput, err := s.tfeClient.StateVersionOutputs.Read(ctx, output.ID)
if err != nil {
return nil, fmt.Errorf("could not read state version output %s: %w", output.ID, err)
}
output.Value = sensitiveOutput.Value
}
cval, err := tfeOutputToCtyValue(*output)
if err != nil {
return nil, fmt.Errorf("could not decode output %s (ID %s)", output.Name, output.ID)
}
result[output.Name] = &states.OutputValue{
Value: cval,
Sensitive: output.Sensitive,
}
}
return result, nil
}
func clamp(val, min, max int64) int64 {
if val < min {
return min
} else if val > max {
return max
}
return val
}
func (s *State) readSnapshotIntervalHeader(status int, header http.Header) {
// Only proceed if this came from tfe.v2 API
contentType := header.Get("Content-Type")
if !strings.Contains(contentType, tfe.ContentTypeJSONAPI) {
log.Printf("[TRACE] Skipping intermediate state interval because Content-Type was %q", contentType)
return
}
intervalStr := header.Get(HeaderSnapshotInterval)
if intervalSecs, err := strconv.ParseInt(intervalStr, 10, 64); err == nil {
// More than an hour is an unreasonable delay, so we'll just
// limit to one hour max.
intervalSecs = clamp(intervalSecs, 0, 3600)
s.stateSnapshotInterval = time.Duration(intervalSecs) * time.Second
} else {
// If the header field is either absent or invalid then we'll
// just choose zero, which effectively means that we'll just use
// the caller's requested interval instead. If the caller has no
// requested interval, or it is zero, then we will disable snapshots.
s.stateSnapshotInterval = time.Duration(0)
}
// We will only enable snapshots for intervals greater than zero
log.Printf("[TRACE] Intermediate state interval is set by header to %v", s.stateSnapshotInterval)
s.enableIntermediateSnapshots = s.stateSnapshotInterval > 0
}
// tfeOutputToCtyValue decodes a combination of TFE output value and detailed-type to create a
// cty value that is suitable for use in tofu.
func tfeOutputToCtyValue(output tfe.StateVersionOutput) (cty.Value, error) {
var result cty.Value
bufType, err := json.Marshal(output.DetailedType)
if err != nil {
return result, fmt.Errorf("could not marshal output %s type: %w", output.ID, err)
}
var ctype cty.Type
err = ctype.UnmarshalJSON(bufType)
if err != nil {
return result, fmt.Errorf("could not interpret output %s type: %w", output.ID, err)
}
result, err = gocty.ToCtyValue(output.Value, ctype)
if err != nil {
return result, fmt.Errorf("could not interpret value %v as type %s for output %s: %w", result, ctype.FriendlyName(), output.ID, err)
}
return result, nil
}