Files
opentf/internal/registry/client_test.go
Martin Atkins 65a0f7a656 registry+getproviders: Registry client policy centralized in main
The primary reason for this change is that registry.NewClient was
originally imposing its own decision about service discovery request
policy on every other user of the shared disco.Disco object by modifying
it directly.

We have been moving towards using a dependency inversion style where
package main is responsible for deciding how everything should be
configured based on global CLI arguments, environment variables, and the
CLI configuration, and so this commit moves to using that model for the
HTTP clients used by the module and provider registry client code.

This also makes explicit what was previously hidden away: that all service
discovery requests are made using the same HTTP client policy as for
requests to module registries, even if the service being discovered is not
a registry. This doesn't seem to have been the intention of the code as
previously written, but was still its ultimate effect: there is only one
disco.Disco object shared across all discovery callers and so changing its
configuration in any way changes it for everyone.

This initial rework is certainly not perfect: these components were not
originally designed to work in this way and there are lots of existing
test cases relying on them working the old way, and so this is a compromise
to get the behavior we now need (using consistent HTTP client settings
across all callers) without disrupting too much existing code.

Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
2025-05-12 10:50:17 -07:00

489 lines
14 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package registry
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/hashicorp/go-retryablehttp"
version "github.com/hashicorp/go-version"
"github.com/hashicorp/terraform-svchost/disco"
"github.com/opentofu/opentofu/internal/httpclient"
"github.com/opentofu/opentofu/internal/registry/regsrc"
"github.com/opentofu/opentofu/internal/registry/response"
"github.com/opentofu/opentofu/internal/registry/test"
tfversion "github.com/opentofu/opentofu/version"
)
func TestLookupModuleVersions(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
// test with and without a hostname
for _, src := range []string{
"example.com/test-versions/name/provider",
"test-versions/name/provider",
} {
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err != nil {
t.Fatal(err)
}
if len(resp.Modules) != 1 {
t.Fatal("expected 1 module, got", len(resp.Modules))
}
mod := resp.Modules[0]
name := "test-versions/name/provider"
if mod.Source != name {
t.Fatalf("expected module name %q, got %q", name, mod.Source)
}
if len(mod.Versions) != 4 {
t.Fatal("expected 4 versions, got", len(mod.Versions))
}
for _, v := range mod.Versions {
_, err := version.NewVersion(v.Version)
if err != nil {
t.Fatalf("invalid version %q: %s", v.Version, err)
}
}
}
}
func TestInvalidRegistry(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "non-existent.localhost.localdomain/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil {
t.Fatal("expected error")
}
}
func TestRegistryAuth(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "private/name/provider"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
_, err = client.ModuleVersions(context.Background(), mod)
if err != nil {
t.Fatal(err)
}
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err != nil {
t.Fatal(err)
}
// Also test without a credentials source
client.services.SetCredentialsSource(nil)
// both should fail without auth
_, err = client.ModuleVersions(context.Background(), mod)
if err == nil {
t.Fatal("expected error")
}
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err == nil {
t.Fatal("expected error")
}
}
func TestLookupModuleLocationRelative(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "relative/foo/bar"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
if err != nil {
t.Fatal(err)
}
want := server.URL + "/relative-path"
if got != want {
t.Errorf("wrong location %s; want %s", got, want)
}
}
func TestAccLookupModuleVersions(t *testing.T) {
if os.Getenv("TF_ACC") == "" {
t.Skip()
}
regDisco := disco.New()
regDisco.SetUserAgent(httpclient.OpenTofuUserAgent(tfversion.String()))
// test with and without a hostname
for _, src := range []string{
"terraform-aws-modules/vpc/aws",
regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws",
} {
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
s := NewClient(t.Context(), regDisco, nil)
resp, err := s.ModuleVersions(context.Background(), modsrc)
if err != nil {
t.Fatal(err)
}
if len(resp.Modules) != 1 {
t.Fatal("expected 1 module, got", len(resp.Modules))
}
mod := resp.Modules[0]
name := "terraform-aws-modules/vpc/aws"
if mod.Source != name {
t.Fatalf("expected module name %q, got %q", name, mod.Source)
}
if len(mod.Versions) == 0 {
t.Fatal("expected multiple versions, got 0")
}
for _, v := range mod.Versions {
_, err := version.NewVersion(v.Version)
if err != nil {
t.Fatalf("invalid version %q: %s", v.Version, err)
}
}
}
}
// the error should reference the config source exactly, not the discovered path.
func TestLookupLookupModuleError(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
// this should not be found in the registry
src := "bad/local/path"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
// Instrument CheckRetry to make sure 404s are not retried
retries := 0
oldCheck := client.client.CheckRetry
client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
if retries > 0 {
t.Fatal("retried after module not found")
}
retries++
return oldCheck(ctx, resp, err)
}
_, err = client.ModuleLocation(context.Background(), mod, "0.2.0")
if err == nil {
t.Fatal("expected error")
}
// check for the exact quoted string to ensure we didn't prepend a hostname.
if !strings.Contains(err.Error(), `"bad/local/path"`) {
t.Fatal("error should not include the hostname. got:", err)
}
}
func TestLookupModuleRetryError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected requests to exceed retry", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the error
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
t.Fatal("unexpected error, got:", err)
}
}
func TestLookupModuleNoRetryError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
defer server.Close()
client := NewClient(
t.Context(), test.Disco(server),
// Retries are disabled by the second argument to this function
httpclient.NewForRegistryRequests(t.Context(), 0, 10*time.Second),
)
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected request to fail", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the error
if !strings.Contains(err.Error(), "request failed:") {
t.Fatal("unexpected error, got:", err)
}
}
func TestLookupModuleNetworkError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
client := NewClient(t.Context(), test.Disco(server), nil)
// Shut down the server to simulate network failure
server.Close()
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected request to fail", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the correct error
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
t.Fatal("unexpected error, got:", err)
}
}
func TestModuleLocation_readRegistryResponse(t *testing.T) {
cases := map[string]struct {
src string
handlerFunc func(w http.ResponseWriter, r *http.Request)
registryFlags []uint8
want string
wantErrorStr string
wantToReadFromHeader bool
wantStatusCode int
}{
"shall find the module location in the registry response body": {
src: "exists-in-registry/identifier/provider",
want: "file:///registry/exists",
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
},
},
"shall find the module location in the registry response header": {
src: "exists-in-registry/identifier/provider",
registryFlags: []uint8{test.WithModuleLocationInHeader},
want: "file:///registry/exists",
wantToReadFromHeader: true,
wantStatusCode: http.StatusNoContent,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Terraform-Get", "file:///registry/exists")
w.WriteHeader(http.StatusNoContent)
},
},
"shall read location from the registry response body even if the header with location address is also set": {
src: "exists-in-registry/identifier/provider",
want: "file:///registry/exists",
wantStatusCode: http.StatusOK,
wantToReadFromHeader: false,
registryFlags: []uint8{test.WithModuleLocationInBody, test.WithModuleLocationInHeader},
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Terraform-Get", "file:///registry/exists-header")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
},
},
"shall fail to find the module": {
src: "not-exist/identifier/provider",
// note that the version is fixed in the mock
// see: /internal/registry/test/mock_registry.go:testMods
wantErrorStr: `module "not-exist/identifier/provider" version "0.2.0" not found`,
wantStatusCode: http.StatusNotFound,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
},
},
"shall fail because of reading response body error": {
src: "foo/bar/baz",
wantErrorStr: "error reading response body from registry",
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "1000") // Set incorrect content length
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{")) // Only write a partial response
// The connection will close after handler returns, but client will expect more data
},
},
"shall fail to deserialize JSON response": {
src: "foo/bar/baz",
wantErrorStr: `module "foo/bar/baz" version "0.2.0" failed to deserialize response body {: unexpected end of JSON input`,
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{"))
},
},
"shall fail because of unexpected protocol change - 422 http status": {
src: "foo/bar/baz",
wantErrorStr: `error getting download location for "foo/bar/baz": 422 Unprocessable Entity resp:bar`,
wantStatusCode: http.StatusUnprocessableEntity,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_, _ = w.Write([]byte("bar"))
},
},
"shall fail because location is not found in the response": {
src: "foo/bar/baz",
wantErrorStr: `failed to get download URL for "foo/bar/baz": 200 OK resp:{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`,
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// note that the response emulates a contract change
_, _ = w.Write([]byte(`{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`))
},
},
}
t.Parallel()
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(tc.handlerFunc))
defer mockServer.Close()
registryServer := test.Registry(tc.registryFlags...)
defer registryServer.Close()
transport := &testTransport{
mockURL: mockServer.URL,
}
httpClient := retryablehttp.NewClient()
httpClient.HTTPClient.Transport = transport
client := NewClient(t.Context(), test.Disco(registryServer), httpClient)
mod, err := regsrc.ParseModuleSource(tc.src)
if err != nil {
t.Fatal(err)
}
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
// Validate the results
if err != nil && tc.wantErrorStr == "" {
t.Fatalf("unexpected error: %v", err)
}
if err != nil && !strings.Contains(err.Error(), tc.wantErrorStr) {
t.Fatalf("unexpected error content: want=%s, got=%v", tc.wantErrorStr, err)
}
if got != tc.want {
t.Fatalf("unexpected location: want=%s, got=%v", tc.want, got)
}
// Verify status code if we have a successful response
if transport.lastResponse != nil {
gotStatusCode := transport.lastResponse.StatusCode
if tc.wantStatusCode != gotStatusCode {
t.Fatalf("unexpected response status code: want=%d, got=%d", tc.wantStatusCode, gotStatusCode)
}
// Check if we expected to read from header
if tc.wantToReadFromHeader && err == nil {
headerVal := transport.lastResponse.Header.Get("X-Terraform-Get")
if headerVal == "" {
t.Fatalf("expected to read location from header but X-Terraform-Get header was not set")
}
}
}
})
}
}
// testTransport is a custom http.RoundTripper that redirects requests to the mock server
// and captures the response for inspection
type testTransport struct {
mockURL string
// Store the last response received from the mock server
lastResponse *http.Response
}
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Create a new request to the mock server with the same path, method, body, etc.
mockReq := &http.Request{
Method: req.Method,
URL: &url.URL{
Scheme: "http",
Host: strings.TrimPrefix(t.mockURL, "http://"),
Path: req.URL.Path,
},
Header: req.Header,
Body: req.Body,
Host: req.Host,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
}
// Send the request to the mock server
resp, err := http.DefaultTransport.RoundTrip(mockReq)
if err == nil {
t.lastResponse = resp
}
return resp, err
}