Files
opentf/internal/backend/remote-state/http/server_test.go
Martin Atkins 67a5cd0911 statemgr+remote: context.Context parameters
This extends statemgr.Persistent, statemgr.Locker and remote.Client to
all expect context.Context parameters, and then updates all of the existing
implementations of those interfaces to support them.

All of the calls to statemgr.Persistent and statemgr.Locker methods outside
of tests are consistently context.TODO() for now, because the caller
landscape of these interfaces has some complications:

1. statemgr.Locker is also used by the clistate package for its state
   implementation that was derived from statemgr.Filesystem's predecessor,
   even though what clistate manages is not actually "state" in the sense
   of package statemgr. The callers of that are not yet ready to provide
   real contexts.

   In a future commit we'll either need to plumb context through to all of
   the clistate callers, or continue the effort to separate statemgr from
   clistate by introducing a clistate-specific "locker" API for it
   to use instead.

2. We call statemgr.Persistent and statemgr.Locker methods in situations
   where the active context might have already been cancelled, and so we'll
   need to make sure to ignore cancellation when calling those.

   This is mainly limited to PersistState and Unlock, since both need to
   be able to complete after a cancellation, but there are various
   codepaths that perform a Lock, Refresh, Persist, Unlock sequence and so
   it isn't yet clear where is the best place to enforce the invariant that
   Persist and Unlock must not be called with a cancelable context. We'll
   deal with that more in subsequent commits.

Within the various state manager and remote client implementations the
contexts _are_ wired together as best as possible with how these subsystems
are already laid out, and so once we deal with the problems above and make
callers provide suitable contexts they should be able to reach all of the
leaf API clients that might want to generate OpenTelemetry traces.

Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
2025-07-10 08:11:39 -07:00

441 lines
12 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package http
//go:generate go run go.uber.org/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"os/signal"
"path/filepath"
"reflect"
"strings"
"sync"
"syscall"
"testing"
"github.com/opentofu/opentofu/internal/addrs"
"github.com/opentofu/opentofu/internal/backend"
"github.com/opentofu/opentofu/internal/configs"
"github.com/opentofu/opentofu/internal/encryption"
"github.com/opentofu/opentofu/internal/states"
"github.com/zclconf/go-cty/cty"
"go.uber.org/mock/gomock"
)
const sampleState = `
{
"version": 4,
"serial": 0,
"lineage": "666f9301-7e65-4b19-ae23-71184bb19b03",
"remote": {
"type": "http",
"config": {
"path": "local-state.tfstate"
}
}
}
`
type (
HttpServerCallback interface {
StateGET(req *http.Request)
StatePOST(req *http.Request)
StateDELETE(req *http.Request)
StateLOCK(req *http.Request)
StateUNLOCK(req *http.Request)
}
httpServer struct {
r *http.ServeMux
data map[string]string
locks map[string]string
lock sync.RWMutex
httpServerCallback HttpServerCallback
}
httpServerOpt func(*httpServer)
)
func withHttpServerCallback(callback HttpServerCallback) httpServerOpt {
return func(s *httpServer) {
s.httpServerCallback = callback
}
}
func newHttpServer(opts ...httpServerOpt) *httpServer {
r := http.NewServeMux()
s := &httpServer{
r: r,
data: make(map[string]string),
locks: make(map[string]string),
}
for _, opt := range opts {
opt(s)
}
s.data["sample"] = sampleState
r.HandleFunc("/state/", s.handleState)
return s
}
func (h *httpServer) getResource(req *http.Request) string {
switch pathParts := strings.SplitN(req.URL.Path, string(filepath.Separator), 3); len(pathParts) {
case 3:
return pathParts[2]
default:
return ""
}
}
func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) {
switch req.Method {
case "GET":
h.handleStateGET(writer, req)
case "POST":
h.handleStatePOST(writer, req)
case "DELETE":
h.handleStateDELETE(writer, req)
case "LOCK":
h.handleStateLOCK(writer, req)
case "UNLOCK":
h.handleStateUNLOCK(writer, req)
}
}
func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateGET(req)
}
resource := h.getResource(req)
h.lock.RLock()
defer h.lock.RUnlock()
if state, ok := h.data[resource]; ok {
_, _ = io.WriteString(writer, state)
} else {
writer.WriteHeader(http.StatusNotFound)
}
}
func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StatePOST(req)
}
defer req.Body.Close()
resource := h.getResource(req)
data, err := io.ReadAll(req.Body)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
h.lock.Lock()
defer h.lock.Unlock()
h.data[resource] = string(data)
writer.WriteHeader(http.StatusOK)
}
func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateDELETE(req)
}
resource := h.getResource(req)
h.lock.Lock()
defer h.lock.Unlock()
delete(h.data, resource)
writer.WriteHeader(http.StatusOK)
}
func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateLOCK(req)
}
defer req.Body.Close()
resource := h.getResource(req)
data, err := io.ReadAll(req.Body)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
h.lock.Lock()
defer h.lock.Unlock()
if existingLock, ok := h.locks[resource]; ok {
writer.WriteHeader(http.StatusLocked)
_, _ = io.WriteString(writer, existingLock)
} else {
h.locks[resource] = string(data)
_, _ = io.WriteString(writer, existingLock)
}
}
func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateUNLOCK(req)
}
defer req.Body.Close()
resource := h.getResource(req)
data, err := io.ReadAll(req.Body)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
var lockInfo map[string]interface{}
if err = json.Unmarshal(data, &lockInfo); err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
h.lock.Lock()
defer h.lock.Unlock()
if existingLock, ok := h.locks[resource]; ok {
var existingLockInfo map[string]interface{}
if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
lockID := lockInfo["ID"].(string)
existingID := existingLockInfo["ID"].(string)
if lockID != existingID {
writer.WriteHeader(http.StatusConflict)
_, _ = io.WriteString(writer, existingLock)
} else {
delete(h.locks, resource)
_, _ = io.WriteString(writer, existingLock)
}
} else {
writer.WriteHeader(http.StatusConflict)
}
}
func (h *httpServer) handler() http.Handler {
return h.r
}
func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) {
clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem")
if err != nil {
return nil, err
}
clientCAs := x509.NewCertPool()
clientCAs.AppendCertsFromPEM(clientCAData)
cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key")
if err != nil {
return nil, err
}
h := newHttpServer(opts...)
s := httptest.NewUnstartedServer(h.handler())
s.TLS = &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCAs,
Certificates: []tls.Certificate{cert},
}
s.StartTLS()
return s, nil
}
func TestMTLSServer_NoCertFails(t *testing.T) {
// Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCallback := NewMockHttpServerCallback(ctrl)
// Fire up a test server
ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
if err != nil {
t.Fatalf("unexpected error creating test server: %v", err)
}
defer ts.Close()
// Configure the backend to the pre-populated sample state
url := ts.URL + "/state/sample"
conf := map[string]cty.Value{
"address": cty.StringVal(url),
"skip_cert_verification": cty.BoolVal(true),
}
b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
if nil == b {
t.Fatal("nil backend")
}
// Now get a state manager and check that it fails to refresh the state
sm, err := b.StateMgr(t.Context(), backend.DefaultStateName)
if err != nil {
t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
}
opErr := new(net.OpError)
err = sm.RefreshState(t.Context())
if err == nil {
t.Fatal("expected error when refreshing state without a client cert")
}
if errors.As(err, &opErr) {
errType := fmt.Sprintf("%T", opErr.Err)
expected := "tls.alert"
if errType != expected {
t.Fatalf("expected net.OpError.Err type: %q got: %q error:%s", expected, errType, err)
}
} else {
t.Fatalf("unexpected error: %v", err)
}
}
func TestMTLSServer_WithCertPasses(t *testing.T) {
// Ensure that the expected amount of calls is made to the server
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCallback := NewMockHttpServerCallback(ctrl)
// Two or three (not testing the caching here) calls to GET
mockCallback.EXPECT().
StateGET(gomock.Any()).
MinTimes(2).
MaxTimes(3)
// One call to the POST to write the data
mockCallback.EXPECT().
StatePOST(gomock.Any())
// Fire up a test server
ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
if err != nil {
t.Fatalf("unexpected error creating test server: %v", err)
}
defer ts.Close()
// Configure the backend to the pre-populated sample state, and with all the test certs lined up
url := ts.URL + "/state/sample"
caData, err := os.ReadFile("testdata/certs/ca.cert.pem")
if err != nil {
t.Fatalf("error reading ca certs: %v", err)
}
clientCertData, err := os.ReadFile("testdata/certs/client.crt")
if err != nil {
t.Fatalf("error reading client cert: %v", err)
}
clientKeyData, err := os.ReadFile("testdata/certs/client.key")
if err != nil {
t.Fatalf("error reading client key: %v", err)
}
conf := map[string]cty.Value{
"address": cty.StringVal(url),
"lock_address": cty.StringVal(url),
"unlock_address": cty.StringVal(url),
"client_ca_certificate_pem": cty.StringVal(string(caData)),
"client_certificate_pem": cty.StringVal(string(clientCertData)),
"client_private_key_pem": cty.StringVal(string(clientKeyData)),
}
b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
if nil == b {
t.Fatal("nil backend")
}
// Now get a state manager, fetch the state, and ensure that the "foo" output is not set
sm, err := b.StateMgr(t.Context(), backend.DefaultStateName)
if err != nil {
t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
}
if err = sm.RefreshState(t.Context()); err != nil {
t.Fatalf("unexpected error calling RefreshState: %v", err)
}
state := sm.State()
if nil == state {
t.Fatal("nil state")
}
stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if stateFoo != nil {
t.Errorf("expected nil foo from state; got %v", stateFoo)
}
// Create a new state that has "foo" set to "bar" and ensure that state is as expected
state = states.BuildState(func(ss *states.SyncState) {
ss.SetOutputValue(
addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance),
cty.StringVal("bar"),
false, "")
})
stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if nil == stateFoo {
t.Fatal("nil foo after building state with foo populated")
}
if foo := stateFoo.Value.AsString(); foo != "bar" {
t.Errorf("Expected built state foo value to be bar; got %s", foo)
}
// Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states
curState := sm.State()
curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if curStateFoo != nil {
t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo)
}
if reflect.DeepEqual(state, curState) {
t.Errorf("expected %v != %v; but they were equal", state, curState)
}
// Write the new state, persist, and refresh
if err = sm.WriteState(state); err != nil {
t.Errorf("error writing state: %v", err)
}
if err = sm.PersistState(t.Context(), nil); err != nil {
t.Errorf("error persisting state: %v", err)
}
if err = sm.RefreshState(t.Context()); err != nil {
t.Errorf("error refreshing state: %v", err)
}
// Get the state again and verify that is now the same as state and has the "foo" value set to "bar"
curState = sm.State()
if !reflect.DeepEqual(state, curState) {
t.Errorf("expected %v == %v; but they were unequal", state, curState)
}
curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if nil == curStateFoo {
t.Fatal("nil foo")
}
if foo := curStateFoo.Value.AsString(); foo != "bar" {
t.Errorf("expected foo to be bar, but got: %s", foo)
}
}
// TestRunServer allows running the server for local debugging; it runs until ctl-c is received
func TestRunServer(t *testing.T) {
if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok {
t.Skip("TEST_RUN_SERVER not set")
}
s, err := NewHttpTestServer()
if err != nil {
t.Fatalf("unexpected error creating test server: %v", err)
}
defer s.Close()
t.Log(s.URL)
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
// wait until signal
<-ctx.Done()
}