mirror of
https://github.com/opentffoundation/opentf.git
synced 2025-12-22 03:07:51 -05:00
176 lines
4.4 KiB
Go
176 lines
4.4 KiB
Go
// Copyright (c) The OpenTofu Authors
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
// Copyright (c) 2023 HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package azure
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
|
|
"github.com/opentofu/opentofu/internal/backend"
|
|
"github.com/opentofu/opentofu/internal/states"
|
|
"github.com/opentofu/opentofu/internal/states/remote"
|
|
"github.com/opentofu/opentofu/internal/states/statemgr"
|
|
)
|
|
|
|
const (
|
|
// This will be used as directory name, the odd looking colon is simply to
|
|
// reduce the chance of name conflicts with existing objects.
|
|
keyEnvPrefix = "env:"
|
|
)
|
|
|
|
// getContextWithTimeout returns a context with timeout based on the timeoutSeconds
|
|
func (b *Backend) getContextWithTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
|
|
return context.WithTimeout(ctx, b.timeout)
|
|
}
|
|
|
|
func (b *Backend) Workspaces(ctx context.Context) ([]string, error) {
|
|
ctx, cancel := b.getContextWithTimeout(ctx)
|
|
defer cancel()
|
|
|
|
prefix := b.keyName + keyEnvPrefix
|
|
result, err := getPaginatedResults(ctx, b.containerClient, prefix)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (b *Backend) DeleteWorkspace(ctx context.Context, name string, _ bool) error {
|
|
if name == backend.DefaultStateName || name == "" {
|
|
return fmt.Errorf("can't delete default state")
|
|
}
|
|
|
|
ctx, cancel := b.getContextWithTimeout(ctx)
|
|
defer cancel()
|
|
blobClient := b.containerClient.NewBlockBlobClient(b.path(name))
|
|
|
|
if _, err := blobClient.Delete(ctx, nil); err != nil {
|
|
if !notFoundError(err) {
|
|
return fmt.Errorf("error deleting blob: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (b *Backend) StateMgr(_ context.Context, name string) (statemgr.Full, error) {
|
|
blobClient := b.containerClient.NewBlockBlobClient(b.path(name))
|
|
|
|
client := &RemoteClient{
|
|
blobClient: blobClient,
|
|
snapshot: b.snapshot,
|
|
timeout: b.timeout,
|
|
}
|
|
|
|
stateMgr := remote.NewState(client, b.encryption)
|
|
|
|
// Grab the value
|
|
if err := stateMgr.RefreshState(context.TODO()); err != nil {
|
|
return nil, err
|
|
}
|
|
//if this isn't the default state name, we need to create the object so
|
|
//it's listed by States.
|
|
if v := stateMgr.State(); v == nil {
|
|
// take a lock on this state while we write it
|
|
lockInfo := statemgr.NewLockInfo()
|
|
lockInfo.Operation = "init"
|
|
lockId, err := client.Lock(context.TODO(), lockInfo)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to lock azure state: %w", err)
|
|
}
|
|
|
|
// Local helper function so we can call it multiple places
|
|
lockUnlock := func(parent error) error {
|
|
if err := stateMgr.Unlock(context.TODO(), lockId); err != nil {
|
|
return fmt.Errorf(strings.TrimSpace(errStateUnlock), lockId, err)
|
|
}
|
|
return parent
|
|
}
|
|
|
|
if err := stateMgr.WriteState(states.NewState()); err != nil {
|
|
err = lockUnlock(err)
|
|
return nil, err
|
|
}
|
|
if err := stateMgr.PersistState(context.TODO(), nil); err != nil {
|
|
err = lockUnlock(err)
|
|
return nil, err
|
|
}
|
|
|
|
// Unlock, the state should now be initialized
|
|
if err := lockUnlock(nil); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return stateMgr, nil
|
|
}
|
|
|
|
func (b *Backend) path(name string) string {
|
|
if name == backend.DefaultStateName {
|
|
return b.keyName
|
|
}
|
|
|
|
return b.keyName + keyEnvPrefix + name
|
|
}
|
|
|
|
const errStateUnlock = `
|
|
Error unlocking Azure state. Lock ID: %s
|
|
|
|
Error: %w
|
|
|
|
You may have to force-unlock this state in order to use it again.
|
|
`
|
|
|
|
type azureClient interface {
|
|
NewListBlobsFlatPager(o *container.ListBlobsFlatOptions) *runtime.Pager[container.ListBlobsFlatResponse]
|
|
}
|
|
|
|
func getPaginatedResults(ctx context.Context, client azureClient, prefix string) ([]string, error) {
|
|
count := 1
|
|
initialMarker := ""
|
|
|
|
params := container.ListBlobsFlatOptions{
|
|
Prefix: &prefix,
|
|
Marker: &initialMarker,
|
|
}
|
|
result := []string{backend.DefaultStateName}
|
|
pager := client.NewListBlobsFlatPager(¶ms)
|
|
|
|
for pager.More() {
|
|
log.Printf("[TRACE] Getting page %d of blob results", count)
|
|
|
|
resp, err := pager.NextPage(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error listing blobs: %w", err)
|
|
}
|
|
|
|
for _, obj := range resp.Segment.BlobItems {
|
|
key := obj.Name
|
|
if !strings.HasPrefix(*key, prefix) {
|
|
continue
|
|
}
|
|
|
|
name := strings.TrimPrefix(*key, prefix)
|
|
// we store the state in a key, not a directory
|
|
if strings.Contains(name, "/") {
|
|
continue
|
|
}
|
|
result = append(result, name)
|
|
}
|
|
|
|
count++
|
|
}
|
|
|
|
sort.Strings(result[1:])
|
|
return result, nil
|
|
}
|