mirror of
https://github.com/opentffoundation/opentf.git
synced 2025-12-19 17:59:05 -05:00
This extends statemgr.Persistent, statemgr.Locker and remote.Client to all expect context.Context parameters, and then updates all of the existing implementations of those interfaces to support them. All of the calls to statemgr.Persistent and statemgr.Locker methods outside of tests are consistently context.TODO() for now, because the caller landscape of these interfaces has some complications: 1. statemgr.Locker is also used by the clistate package for its state implementation that was derived from statemgr.Filesystem's predecessor, even though what clistate manages is not actually "state" in the sense of package statemgr. The callers of that are not yet ready to provide real contexts. In a future commit we'll either need to plumb context through to all of the clistate callers, or continue the effort to separate statemgr from clistate by introducing a clistate-specific "locker" API for it to use instead. 2. We call statemgr.Persistent and statemgr.Locker methods in situations where the active context might have already been cancelled, and so we'll need to make sure to ignore cancellation when calling those. This is mainly limited to PersistState and Unlock, since both need to be able to complete after a cancellation, but there are various codepaths that perform a Lock, Refresh, Persist, Unlock sequence and so it isn't yet clear where is the best place to enforce the invariant that Persist and Unlock must not be called with a cancelable context. We'll deal with that more in subsequent commits. Within the various state manager and remote client implementations the contexts _are_ wired together as best as possible with how these subsystems are already laid out, and so once we deal with the problems above and make callers provide suitable contexts they should be able to reach all of the leaf API clients that might want to generate OpenTelemetry traces. Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
441 lines
12 KiB
Go
441 lines
12 KiB
Go
// Copyright (c) The OpenTofu Authors
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
// Copyright (c) 2023 HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package http
|
|
|
|
//go:generate go run go.uber.org/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"os/signal"
|
|
"path/filepath"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"testing"
|
|
|
|
"github.com/opentofu/opentofu/internal/addrs"
|
|
"github.com/opentofu/opentofu/internal/backend"
|
|
"github.com/opentofu/opentofu/internal/configs"
|
|
"github.com/opentofu/opentofu/internal/encryption"
|
|
"github.com/opentofu/opentofu/internal/states"
|
|
"github.com/zclconf/go-cty/cty"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
const sampleState = `
|
|
{
|
|
"version": 4,
|
|
"serial": 0,
|
|
"lineage": "666f9301-7e65-4b19-ae23-71184bb19b03",
|
|
"remote": {
|
|
"type": "http",
|
|
"config": {
|
|
"path": "local-state.tfstate"
|
|
}
|
|
}
|
|
}
|
|
`
|
|
|
|
type (
|
|
HttpServerCallback interface {
|
|
StateGET(req *http.Request)
|
|
StatePOST(req *http.Request)
|
|
StateDELETE(req *http.Request)
|
|
StateLOCK(req *http.Request)
|
|
StateUNLOCK(req *http.Request)
|
|
}
|
|
httpServer struct {
|
|
r *http.ServeMux
|
|
data map[string]string
|
|
locks map[string]string
|
|
lock sync.RWMutex
|
|
|
|
httpServerCallback HttpServerCallback
|
|
}
|
|
httpServerOpt func(*httpServer)
|
|
)
|
|
|
|
func withHttpServerCallback(callback HttpServerCallback) httpServerOpt {
|
|
return func(s *httpServer) {
|
|
s.httpServerCallback = callback
|
|
}
|
|
}
|
|
|
|
func newHttpServer(opts ...httpServerOpt) *httpServer {
|
|
r := http.NewServeMux()
|
|
s := &httpServer{
|
|
r: r,
|
|
data: make(map[string]string),
|
|
locks: make(map[string]string),
|
|
}
|
|
for _, opt := range opts {
|
|
opt(s)
|
|
}
|
|
s.data["sample"] = sampleState
|
|
r.HandleFunc("/state/", s.handleState)
|
|
return s
|
|
}
|
|
|
|
func (h *httpServer) getResource(req *http.Request) string {
|
|
switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) {
|
|
case 3:
|
|
return pathParts[2]
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) {
|
|
switch req.Method {
|
|
case "GET":
|
|
h.handleStateGET(writer, req)
|
|
case "POST":
|
|
h.handleStatePOST(writer, req)
|
|
case "DELETE":
|
|
h.handleStateDELETE(writer, req)
|
|
case "LOCK":
|
|
h.handleStateLOCK(writer, req)
|
|
case "UNLOCK":
|
|
h.handleStateUNLOCK(writer, req)
|
|
}
|
|
}
|
|
|
|
func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) {
|
|
if h.httpServerCallback != nil {
|
|
defer h.httpServerCallback.StateGET(req)
|
|
}
|
|
resource := h.getResource(req)
|
|
|
|
h.lock.RLock()
|
|
defer h.lock.RUnlock()
|
|
|
|
if state, ok := h.data[resource]; ok {
|
|
_, _ = io.WriteString(writer, state)
|
|
} else {
|
|
writer.WriteHeader(http.StatusNotFound)
|
|
}
|
|
}
|
|
|
|
func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) {
|
|
if h.httpServerCallback != nil {
|
|
defer h.httpServerCallback.StatePOST(req)
|
|
}
|
|
defer req.Body.Close()
|
|
resource := h.getResource(req)
|
|
|
|
data, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
writer.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
h.lock.Lock()
|
|
defer h.lock.Unlock()
|
|
|
|
h.data[resource] = string(data)
|
|
writer.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) {
|
|
if h.httpServerCallback != nil {
|
|
defer h.httpServerCallback.StateDELETE(req)
|
|
}
|
|
resource := h.getResource(req)
|
|
|
|
h.lock.Lock()
|
|
defer h.lock.Unlock()
|
|
|
|
delete(h.data, resource)
|
|
writer.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) {
|
|
if h.httpServerCallback != nil {
|
|
defer h.httpServerCallback.StateLOCK(req)
|
|
}
|
|
defer req.Body.Close()
|
|
resource := h.getResource(req)
|
|
|
|
data, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
writer.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
h.lock.Lock()
|
|
defer h.lock.Unlock()
|
|
|
|
if existingLock, ok := h.locks[resource]; ok {
|
|
writer.WriteHeader(http.StatusLocked)
|
|
_, _ = io.WriteString(writer, existingLock)
|
|
} else {
|
|
h.locks[resource] = string(data)
|
|
_, _ = io.WriteString(writer, existingLock)
|
|
}
|
|
}
|
|
|
|
func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) {
|
|
if h.httpServerCallback != nil {
|
|
defer h.httpServerCallback.StateUNLOCK(req)
|
|
}
|
|
defer req.Body.Close()
|
|
resource := h.getResource(req)
|
|
|
|
data, err := io.ReadAll(req.Body)
|
|
if err != nil {
|
|
writer.WriteHeader(http.StatusBadRequest)
|
|
return
|
|
}
|
|
var lockInfo map[string]interface{}
|
|
if err = json.Unmarshal(data, &lockInfo); err != nil {
|
|
writer.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
h.lock.Lock()
|
|
defer h.lock.Unlock()
|
|
|
|
if existingLock, ok := h.locks[resource]; ok {
|
|
var existingLockInfo map[string]interface{}
|
|
if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil {
|
|
writer.WriteHeader(http.StatusInternalServerError)
|
|
return
|
|
}
|
|
lockID := lockInfo["ID"].(string)
|
|
existingID := existingLockInfo["ID"].(string)
|
|
if lockID != existingID {
|
|
writer.WriteHeader(http.StatusConflict)
|
|
_, _ = io.WriteString(writer, existingLock)
|
|
} else {
|
|
delete(h.locks, resource)
|
|
_, _ = io.WriteString(writer, existingLock)
|
|
}
|
|
} else {
|
|
writer.WriteHeader(http.StatusConflict)
|
|
}
|
|
}
|
|
|
|
func (h *httpServer) handler() http.Handler {
|
|
return h.r
|
|
}
|
|
|
|
func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) {
|
|
clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clientCAs := x509.NewCertPool()
|
|
clientCAs.AppendCertsFromPEM(clientCAData)
|
|
|
|
cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
h := newHttpServer(opts...)
|
|
s := httptest.NewUnstartedServer(h.handler())
|
|
s.TLS = &tls.Config{
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: clientCAs,
|
|
Certificates: []tls.Certificate{cert},
|
|
}
|
|
|
|
s.StartTLS()
|
|
return s, nil
|
|
}
|
|
|
|
func TestMTLSServer_NoCertFails(t *testing.T) {
|
|
// Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
mockCallback := NewMockHttpServerCallback(ctrl)
|
|
|
|
// Fire up a test server
|
|
ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error creating test server: %v", err)
|
|
}
|
|
defer ts.Close()
|
|
|
|
// Configure the backend to the pre-populated sample state
|
|
url := ts.URL + "/state/sample"
|
|
conf := map[string]cty.Value{
|
|
"address": cty.StringVal(url),
|
|
"skip_cert_verification": cty.BoolVal(true),
|
|
}
|
|
b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
|
|
if nil == b {
|
|
t.Fatal("nil backend")
|
|
}
|
|
|
|
// Now get a state manager and check that it fails to refresh the state
|
|
sm, err := b.StateMgr(t.Context(), backend.DefaultStateName)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
|
|
}
|
|
|
|
opErr := new(net.OpError)
|
|
err = sm.RefreshState(t.Context())
|
|
if err == nil {
|
|
t.Fatal("expected error when refreshing state without a client cert")
|
|
}
|
|
if errors.As(err, &opErr) {
|
|
errType := fmt.Sprintf("%T", opErr.Err)
|
|
expected := "tls.alert"
|
|
if errType != expected {
|
|
t.Fatalf("expected net.OpError.Err type: %q got: %q error:%s", expected, errType, err)
|
|
}
|
|
} else {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestMTLSServer_WithCertPasses(t *testing.T) {
|
|
// Ensure that the expected amount of calls is made to the server
|
|
ctrl := gomock.NewController(t)
|
|
defer ctrl.Finish()
|
|
mockCallback := NewMockHttpServerCallback(ctrl)
|
|
|
|
// Two or three (not testing the caching here) calls to GET
|
|
mockCallback.EXPECT().
|
|
StateGET(gomock.Any()).
|
|
MinTimes(2).
|
|
MaxTimes(3)
|
|
// One call to the POST to write the data
|
|
mockCallback.EXPECT().
|
|
StatePOST(gomock.Any())
|
|
|
|
// Fire up a test server
|
|
ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
|
|
if err != nil {
|
|
t.Fatalf("unexpected error creating test server: %v", err)
|
|
}
|
|
defer ts.Close()
|
|
|
|
// Configure the backend to the pre-populated sample state, and with all the test certs lined up
|
|
url := ts.URL + "/state/sample"
|
|
caData, err := os.ReadFile("testdata/certs/ca.cert.pem")
|
|
if err != nil {
|
|
t.Fatalf("error reading ca certs: %v", err)
|
|
}
|
|
clientCertData, err := os.ReadFile("testdata/certs/client.crt")
|
|
if err != nil {
|
|
t.Fatalf("error reading client cert: %v", err)
|
|
}
|
|
clientKeyData, err := os.ReadFile("testdata/certs/client.key")
|
|
if err != nil {
|
|
t.Fatalf("error reading client key: %v", err)
|
|
}
|
|
conf := map[string]cty.Value{
|
|
"address": cty.StringVal(url),
|
|
"lock_address": cty.StringVal(url),
|
|
"unlock_address": cty.StringVal(url),
|
|
"client_ca_certificate_pem": cty.StringVal(string(caData)),
|
|
"client_certificate_pem": cty.StringVal(string(clientCertData)),
|
|
"client_private_key_pem": cty.StringVal(string(clientKeyData)),
|
|
}
|
|
b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
|
|
if nil == b {
|
|
t.Fatal("nil backend")
|
|
}
|
|
|
|
// Now get a state manager, fetch the state, and ensure that the "foo" output is not set
|
|
sm, err := b.StateMgr(t.Context(), backend.DefaultStateName)
|
|
if err != nil {
|
|
t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
|
|
}
|
|
if err = sm.RefreshState(t.Context()); err != nil {
|
|
t.Fatalf("unexpected error calling RefreshState: %v", err)
|
|
}
|
|
state := sm.State()
|
|
if nil == state {
|
|
t.Fatal("nil state")
|
|
}
|
|
stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
|
|
if stateFoo != nil {
|
|
t.Errorf("expected nil foo from state; got %v", stateFoo)
|
|
}
|
|
|
|
// Create a new state that has "foo" set to "bar" and ensure that state is as expected
|
|
state = states.BuildState(func(ss *states.SyncState) {
|
|
ss.SetOutputValue(
|
|
addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance),
|
|
cty.StringVal("bar"),
|
|
false, "")
|
|
})
|
|
stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
|
|
if nil == stateFoo {
|
|
t.Fatal("nil foo after building state with foo populated")
|
|
}
|
|
if foo := stateFoo.Value.AsString(); foo != "bar" {
|
|
t.Errorf("Expected built state foo value to be bar; got %s", foo)
|
|
}
|
|
|
|
// Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states
|
|
curState := sm.State()
|
|
curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
|
|
if curStateFoo != nil {
|
|
t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo)
|
|
}
|
|
if reflect.DeepEqual(state, curState) {
|
|
t.Errorf("expected %v != %v; but they were equal", state, curState)
|
|
}
|
|
|
|
// Write the new state, persist, and refresh
|
|
if err = sm.WriteState(state); err != nil {
|
|
t.Errorf("error writing state: %v", err)
|
|
}
|
|
if err = sm.PersistState(t.Context(), nil); err != nil {
|
|
t.Errorf("error persisting state: %v", err)
|
|
}
|
|
if err = sm.RefreshState(t.Context()); err != nil {
|
|
t.Errorf("error refreshing state: %v", err)
|
|
}
|
|
|
|
// Get the state again and verify that is now the same as state and has the "foo" value set to "bar"
|
|
curState = sm.State()
|
|
if !reflect.DeepEqual(state, curState) {
|
|
t.Errorf("expected %v == %v; but they were unequal", state, curState)
|
|
}
|
|
curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
|
|
if nil == curStateFoo {
|
|
t.Fatal("nil foo")
|
|
}
|
|
if foo := curStateFoo.Value.AsString(); foo != "bar" {
|
|
t.Errorf("expected foo to be bar, but got: %s", foo)
|
|
}
|
|
}
|
|
|
|
// TestRunServer allows running the server for local debugging; it runs until ctl-c is received
|
|
func TestRunServer(t *testing.T) {
|
|
if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok {
|
|
t.Skip("TEST_RUN_SERVER not set")
|
|
}
|
|
s, err := NewHttpTestServer()
|
|
if err != nil {
|
|
t.Fatalf("unexpected error creating test server: %v", err)
|
|
}
|
|
defer s.Close()
|
|
|
|
t.Log(s.URL)
|
|
|
|
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
|
|
defer cancel()
|
|
// wait until signal
|
|
<-ctx.Done()
|
|
}
|