Files
opentf/internal/backend/remote-state/http/server_test.go
Martin Atkins e74bf2d0a1 go.mod: Use the new "tool" directive
Previously the Go toolchain had no explicit support for "tools" and so we
used the typical Go community workaround of adding "tools.go" files (two,
for some reason) that existed only to trick the Go toolchain into
considering the tools as dependencies we could track in go.mod.

Go 1.24 introduced explicit support for tracking tools as part of go.mod,
and the ability to run those using "go tool" instead of "go run", and so
this commit switches us over to using that strategy for everything we were
previously managing in tools.go.

There are some intentional exceptions here:

- The protobuf-compile script can't use "go tool" or "go run" because the
  tools in question are run only indirectly through protoc. However, we
  do still use the "tool" directive in go.mod to tell the Go toolchain that
  we depend on those tools, so that it'll track which versions we are
  currently using as part of go.mod.
- Our golangci-lint Makefile target uses "go run" to run a specific
  version of golangci-lint. We _intentionally_ don't consider that tool
  to be a direct dependency of OpenTofu because it has a lot of indirect
  dependencies that would pollute our go.mod file. Therefore that continues
  to use "go run" after this commit.
- Both of our tools.go files previously referred to
  github.com/nishanths/exhaustive , but nothing actually appears to be
  using that tool in the current OpenTofu tree, so it's no longer a
  dependency after this commit.

All of the dependencies we have _only_ for tools are now classified as
"indirect" in the go.mod file. This is the default behavior of the Go
toolchain and appears to be motivated by making it clearer that these
modules do not contribute anything to the runtime behavior of OpenTofu.
This also corrected a historical oddity in our go.mod where for some reason
the "indirect" dependencies had been split across two different "require"
directives; they are now all grouped together in a single directive.

Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
2025-10-10 07:06:56 -03:00

440 lines
12 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package http
//go:generate go tool go.uber.org/mock/mockgen -package $GOPACKAGE -source $GOFILE -destination mock_$GOFILE
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"os/signal"
"reflect"
"strings"
"sync"
"syscall"
"testing"
"github.com/opentofu/opentofu/internal/addrs"
"github.com/opentofu/opentofu/internal/backend"
"github.com/opentofu/opentofu/internal/configs"
"github.com/opentofu/opentofu/internal/encryption"
"github.com/opentofu/opentofu/internal/states"
"github.com/zclconf/go-cty/cty"
"go.uber.org/mock/gomock"
)
const sampleState = `
{
"version": 4,
"serial": 0,
"lineage": "666f9301-7e65-4b19-ae23-71184bb19b03",
"remote": {
"type": "http",
"config": {
"path": "local-state.tfstate"
}
}
}
`
type (
HttpServerCallback interface {
StateGET(req *http.Request)
StatePOST(req *http.Request)
StateDELETE(req *http.Request)
StateLOCK(req *http.Request)
StateUNLOCK(req *http.Request)
}
httpServer struct {
r *http.ServeMux
data map[string]string
locks map[string]string
lock sync.RWMutex
httpServerCallback HttpServerCallback
}
httpServerOpt func(*httpServer)
)
func withHttpServerCallback(callback HttpServerCallback) httpServerOpt {
return func(s *httpServer) {
s.httpServerCallback = callback
}
}
func newHttpServer(opts ...httpServerOpt) *httpServer {
r := http.NewServeMux()
s := &httpServer{
r: r,
data: make(map[string]string),
locks: make(map[string]string),
}
for _, opt := range opts {
opt(s)
}
s.data["sample"] = sampleState
r.HandleFunc("/state/", s.handleState)
return s
}
func (h *httpServer) getResource(req *http.Request) string {
switch pathParts := strings.SplitN(req.URL.Path, "/", 3); len(pathParts) {
case 3:
return pathParts[2]
default:
return ""
}
}
func (h *httpServer) handleState(writer http.ResponseWriter, req *http.Request) {
switch req.Method {
case "GET":
h.handleStateGET(writer, req)
case "POST":
h.handleStatePOST(writer, req)
case "DELETE":
h.handleStateDELETE(writer, req)
case "LOCK":
h.handleStateLOCK(writer, req)
case "UNLOCK":
h.handleStateUNLOCK(writer, req)
}
}
func (h *httpServer) handleStateGET(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateGET(req)
}
resource := h.getResource(req)
h.lock.RLock()
defer h.lock.RUnlock()
if state, ok := h.data[resource]; ok {
_, _ = io.WriteString(writer, state)
} else {
writer.WriteHeader(http.StatusNotFound)
}
}
func (h *httpServer) handleStatePOST(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StatePOST(req)
}
defer req.Body.Close()
resource := h.getResource(req)
data, err := io.ReadAll(req.Body)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
h.lock.Lock()
defer h.lock.Unlock()
h.data[resource] = string(data)
writer.WriteHeader(http.StatusOK)
}
func (h *httpServer) handleStateDELETE(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateDELETE(req)
}
resource := h.getResource(req)
h.lock.Lock()
defer h.lock.Unlock()
delete(h.data, resource)
writer.WriteHeader(http.StatusOK)
}
func (h *httpServer) handleStateLOCK(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateLOCK(req)
}
defer req.Body.Close()
resource := h.getResource(req)
data, err := io.ReadAll(req.Body)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
h.lock.Lock()
defer h.lock.Unlock()
if existingLock, ok := h.locks[resource]; ok {
writer.WriteHeader(http.StatusLocked)
_, _ = io.WriteString(writer, existingLock)
} else {
h.locks[resource] = string(data)
_, _ = io.WriteString(writer, existingLock)
}
}
func (h *httpServer) handleStateUNLOCK(writer http.ResponseWriter, req *http.Request) {
if h.httpServerCallback != nil {
defer h.httpServerCallback.StateUNLOCK(req)
}
defer req.Body.Close()
resource := h.getResource(req)
data, err := io.ReadAll(req.Body)
if err != nil {
writer.WriteHeader(http.StatusBadRequest)
return
}
var lockInfo map[string]interface{}
if err = json.Unmarshal(data, &lockInfo); err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
h.lock.Lock()
defer h.lock.Unlock()
if existingLock, ok := h.locks[resource]; ok {
var existingLockInfo map[string]interface{}
if err = json.Unmarshal([]byte(existingLock), &existingLockInfo); err != nil {
writer.WriteHeader(http.StatusInternalServerError)
return
}
lockID := lockInfo["ID"].(string)
existingID := existingLockInfo["ID"].(string)
if lockID != existingID {
writer.WriteHeader(http.StatusConflict)
_, _ = io.WriteString(writer, existingLock)
} else {
delete(h.locks, resource)
_, _ = io.WriteString(writer, existingLock)
}
} else {
writer.WriteHeader(http.StatusConflict)
}
}
func (h *httpServer) handler() http.Handler {
return h.r
}
func NewHttpTestServer(opts ...httpServerOpt) (*httptest.Server, error) {
clientCAData, err := os.ReadFile("testdata/certs/ca.cert.pem")
if err != nil {
return nil, err
}
clientCAs := x509.NewCertPool()
clientCAs.AppendCertsFromPEM(clientCAData)
cert, err := tls.LoadX509KeyPair("testdata/certs/server.crt", "testdata/certs/server.key")
if err != nil {
return nil, err
}
h := newHttpServer(opts...)
s := httptest.NewUnstartedServer(h.handler())
s.TLS = &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCAs,
Certificates: []tls.Certificate{cert},
}
s.StartTLS()
return s, nil
}
func TestMTLSServer_NoCertFails(t *testing.T) {
// Ensure that no calls are made to the server - everything is blocked by the tls.RequireAndVerifyClientCert
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCallback := NewMockHttpServerCallback(ctrl)
// Fire up a test server
ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
if err != nil {
t.Fatalf("unexpected error creating test server: %v", err)
}
defer ts.Close()
// Configure the backend to the pre-populated sample state
url := ts.URL + "/state/sample"
conf := map[string]cty.Value{
"address": cty.StringVal(url),
"skip_cert_verification": cty.BoolVal(true),
}
b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
if nil == b {
t.Fatal("nil backend")
}
// Now get a state manager and check that it fails to refresh the state
sm, err := b.StateMgr(t.Context(), backend.DefaultStateName)
if err != nil {
t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
}
opErr := new(net.OpError)
err = sm.RefreshState(t.Context())
if err == nil {
t.Fatal("expected error when refreshing state without a client cert")
}
if errors.As(err, &opErr) {
errType := fmt.Sprintf("%T", opErr.Err)
expected := "tls.alert"
if errType != expected {
t.Fatalf("expected net.OpError.Err type: %q got: %q error:%s", expected, errType, err)
}
} else {
t.Fatalf("unexpected error: %v", err)
}
}
func TestMTLSServer_WithCertPasses(t *testing.T) {
// Ensure that the expected amount of calls is made to the server
ctrl := gomock.NewController(t)
defer ctrl.Finish()
mockCallback := NewMockHttpServerCallback(ctrl)
// Two or three (not testing the caching here) calls to GET
mockCallback.EXPECT().
StateGET(gomock.Any()).
MinTimes(2).
MaxTimes(3)
// One call to the POST to write the data
mockCallback.EXPECT().
StatePOST(gomock.Any())
// Fire up a test server
ts, err := NewHttpTestServer(withHttpServerCallback(mockCallback))
if err != nil {
t.Fatalf("unexpected error creating test server: %v", err)
}
defer ts.Close()
// Configure the backend to the pre-populated sample state, and with all the test certs lined up
url := ts.URL + "/state/sample"
caData, err := os.ReadFile("testdata/certs/ca.cert.pem")
if err != nil {
t.Fatalf("error reading ca certs: %v", err)
}
clientCertData, err := os.ReadFile("testdata/certs/client.crt")
if err != nil {
t.Fatalf("error reading client cert: %v", err)
}
clientKeyData, err := os.ReadFile("testdata/certs/client.key")
if err != nil {
t.Fatalf("error reading client key: %v", err)
}
conf := map[string]cty.Value{
"address": cty.StringVal(url),
"lock_address": cty.StringVal(url),
"unlock_address": cty.StringVal(url),
"client_ca_certificate_pem": cty.StringVal(string(caData)),
"client_certificate_pem": cty.StringVal(string(clientCertData)),
"client_private_key_pem": cty.StringVal(string(clientKeyData)),
}
b := backend.TestBackendConfig(t, New(encryption.StateEncryptionDisabled()), configs.SynthBody("synth", conf)).(*Backend)
if nil == b {
t.Fatal("nil backend")
}
// Now get a state manager, fetch the state, and ensure that the "foo" output is not set
sm, err := b.StateMgr(t.Context(), backend.DefaultStateName)
if err != nil {
t.Fatalf("unexpected error fetching StateMgr with %s: %v", backend.DefaultStateName, err)
}
if err = sm.RefreshState(t.Context()); err != nil {
t.Fatalf("unexpected error calling RefreshState: %v", err)
}
state := sm.State()
if nil == state {
t.Fatal("nil state")
}
stateFoo := state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if stateFoo != nil {
t.Errorf("expected nil foo from state; got %v", stateFoo)
}
// Create a new state that has "foo" set to "bar" and ensure that state is as expected
state = states.BuildState(func(ss *states.SyncState) {
ss.SetOutputValue(
addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance),
cty.StringVal("bar"),
false, "")
})
stateFoo = state.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if nil == stateFoo {
t.Fatal("nil foo after building state with foo populated")
}
if foo := stateFoo.Value.AsString(); foo != "bar" {
t.Errorf("Expected built state foo value to be bar; got %s", foo)
}
// Ensure the change hasn't altered the current state manager state by checking "foo" and comparing states
curState := sm.State()
curStateFoo := curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if curStateFoo != nil {
t.Errorf("expected session manager state to be unaltered and still nil, but got: %v", curStateFoo)
}
if reflect.DeepEqual(state, curState) {
t.Errorf("expected %v != %v; but they were equal", state, curState)
}
// Write the new state, persist, and refresh
if err = sm.WriteState(state); err != nil {
t.Errorf("error writing state: %v", err)
}
if err = sm.PersistState(t.Context(), nil); err != nil {
t.Errorf("error persisting state: %v", err)
}
if err = sm.RefreshState(t.Context()); err != nil {
t.Errorf("error refreshing state: %v", err)
}
// Get the state again and verify that is now the same as state and has the "foo" value set to "bar"
curState = sm.State()
if !reflect.DeepEqual(state, curState) {
t.Errorf("expected %v == %v; but they were unequal", state, curState)
}
curStateFoo = curState.OutputValue(addrs.OutputValue{Name: "foo"}.Absolute(addrs.RootModuleInstance))
if nil == curStateFoo {
t.Fatal("nil foo")
}
if foo := curStateFoo.Value.AsString(); foo != "bar" {
t.Errorf("expected foo to be bar, but got: %s", foo)
}
}
// TestRunServer allows running the server for local debugging; it runs until ctl-c is received
func TestRunServer(t *testing.T) {
if _, ok := os.LookupEnv("TEST_RUN_SERVER"); !ok {
t.Skip("TEST_RUN_SERVER not set")
}
s, err := NewHttpTestServer()
if err != nil {
t.Fatalf("unexpected error creating test server: %v", err)
}
defer s.Close()
t.Log(s.URL)
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
// wait until signal
<-ctx.Done()
}