Files
opentf/internal/registry/client_test.go
2025-12-04 13:40:05 -08:00

547 lines
16 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package registry
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/go-retryablehttp"
version "github.com/hashicorp/go-version"
"github.com/opentofu/svchost/disco"
"github.com/opentofu/opentofu/internal/addrs"
"github.com/opentofu/opentofu/internal/httpclient"
"github.com/opentofu/opentofu/internal/registry/regsrc"
"github.com/opentofu/opentofu/internal/registry/response"
"github.com/opentofu/opentofu/internal/registry/test"
)
func TestLookupModuleVersions(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
// test with and without a hostname
for _, src := range []string{
"example.com/test-versions/name/provider",
"test-versions/name/provider",
} {
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err != nil {
t.Fatal(err)
}
if len(resp.Modules) != 1 {
t.Fatal("expected 1 module, got", len(resp.Modules))
}
mod := resp.Modules[0]
name := "test-versions/name/provider"
if mod.Source != name {
t.Fatalf("expected module name %q, got %q", name, mod.Source)
}
if len(mod.Versions) != 4 {
t.Fatal("expected 4 versions, got", len(mod.Versions))
}
for _, v := range mod.Versions {
_, err := version.NewVersion(v.Version)
if err != nil {
t.Fatalf("invalid version %q: %s", v.Version, err)
}
}
}
}
func TestInvalidRegistry(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "non-existent.localhost.localdomain/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil {
t.Fatal("expected error")
}
}
func TestRegistryAuth(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "private/name/provider"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
_, err = client.ModuleVersions(context.Background(), mod)
if err != nil {
t.Fatal(err)
}
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err != nil {
t.Fatal(err)
}
// Also test without a credentials source
client.services.SetCredentialsSource(nil)
// both should fail without auth
_, err = client.ModuleVersions(context.Background(), mod)
if err == nil {
t.Fatal("expected error")
}
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err == nil {
t.Fatal("expected error")
}
}
func TestLookupModuleLocationRelative(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "relative/foo/bar"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
if err != nil {
t.Fatal(err)
}
want := PackageLocationIndirect{
SourceAddr: addrs.ModuleSourceRemote{
Package: addrs.ModulePackage(server.URL + "/relative-path"),
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Error("wrong location\n" + diff)
}
}
func TestAccLookupModuleVersions(t *testing.T) {
if os.Getenv("TF_ACC") == "" {
t.Skip()
}
regDisco := disco.New(
disco.WithHTTPClient(httpclient.New(t.Context())),
)
// test with and without a hostname
for _, src := range []string{
"terraform-aws-modules/vpc/aws",
regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws",
} {
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
s := NewClient(t.Context(), regDisco, nil)
resp, err := s.ModuleVersions(context.Background(), modsrc)
if err != nil {
t.Fatal(err)
}
if len(resp.Modules) != 1 {
t.Fatal("expected 1 module, got", len(resp.Modules))
}
mod := resp.Modules[0]
name := "terraform-aws-modules/vpc/aws"
if mod.Source != name {
t.Fatalf("expected module name %q, got %q", name, mod.Source)
}
if len(mod.Versions) == 0 {
t.Fatal("expected multiple versions, got 0")
}
for _, v := range mod.Versions {
_, err := version.NewVersion(v.Version)
if err != nil {
t.Fatalf("invalid version %q: %s", v.Version, err)
}
}
}
}
// the error should reference the config source exactly, not the discovered path.
func TestLookupLookupModuleError(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
// this should not be found in the registry
src := "bad/local/path"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
// Instrument CheckRetry to make sure 404s are not retried
retries := 0
oldCheck := client.client.CheckRetry
client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
if retries > 0 {
t.Fatal("retried after module not found")
}
retries++
return oldCheck(ctx, resp, err)
}
_, err = client.ModuleLocation(context.Background(), mod, "0.2.0")
if err == nil {
t.Fatal("expected error")
}
// check for the exact quoted string to ensure we didn't prepend a hostname.
if !strings.Contains(err.Error(), `"bad/local/path"`) {
t.Fatal("error should not include the hostname. got:", err)
}
}
func TestLookupModuleRetryError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected requests to exceed retry", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the error
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
t.Fatal("unexpected error, got:", err)
}
}
func TestLookupModuleNoRetryError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
defer server.Close()
client := NewClient(
t.Context(), test.Disco(server),
// Retries are disabled by the second argument to this function
httpclient.NewForRegistryRequests(t.Context(), 0, 10*time.Second),
)
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected request to fail", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the error
if !strings.Contains(err.Error(), "request failed:") {
t.Fatal("unexpected error, got:", err)
}
}
func TestLookupModuleNetworkError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
client := NewClient(t.Context(), test.Disco(server), nil)
// Shut down the server to simulate network failure
server.Close()
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected request to fail", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the correct error
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
t.Fatal("unexpected error, got:", err)
}
}
func TestModuleLocation_readRegistryResponse(t *testing.T) {
makeIndirectLocation := func(packageAddr string, subDir string) PackageLocationIndirect {
return PackageLocationIndirect{
SourceAddr: addrs.ModuleSourceRemote{
Package: addrs.ModulePackage(packageAddr),
Subdir: subDir,
},
}
}
mustParseURL := func(s string) *url.URL {
ret, err := url.Parse(s)
if err != nil {
t.Fatal(err)
}
return ret
}
cases := map[string]struct {
src string
handlerFunc func(w http.ResponseWriter, r *http.Request)
registryFlags []uint8
want PackageLocation
wantErrorStr string
wantToReadFromHeader bool
wantStatusCode int
}{
"shall find direct module location in the registry response body, opting to use the registry's credentials": {
src: "exists-in-registry/identifier/provider",
want: PackageLocationDirect{
module: &regsrc.Module{
RawHost: &regsrc.FriendlyHost{Raw: "registry.opentofu.org"},
RawNamespace: "exists-in-registry",
RawName: "identifier",
RawProvider: "provider",
},
packageURL: mustParseURL("https://example.com/package.zip"),
useRegistryCredentials: true,
},
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"location":"https://example.com/package.zip","use_registry_credentials":true}`))
},
},
"shall find direct module location in the registry response body, not opting to use the registry's credentials": {
src: "exists-in-registry/identifier/provider",
want: PackageLocationDirect{
module: &regsrc.Module{
RawHost: &regsrc.FriendlyHost{Raw: "registry.opentofu.org"},
RawNamespace: "exists-in-registry",
RawName: "identifier",
RawProvider: "provider",
},
packageURL: mustParseURL("https://example.com/package.zip"),
useRegistryCredentials: false,
},
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"location":"https://example.com/package.zip","use_registry_credentials":false}`))
},
},
"shall find indirect module location in the registry response body": {
src: "exists-in-registry/identifier/provider",
want: makeIndirectLocation("file:///registry/exists", ""),
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
},
},
"shall find indirect module location in the registry response header": {
src: "exists-in-registry/identifier/provider",
registryFlags: []uint8{test.WithModuleLocationInHeader},
want: makeIndirectLocation("file:///registry/exists", ""),
wantToReadFromHeader: true,
wantStatusCode: http.StatusNoContent,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Terraform-Get", "file:///registry/exists")
w.WriteHeader(http.StatusNoContent)
},
},
"shall read indirect location from the registry response body even if the header with location address is also set": {
src: "exists-in-registry/identifier/provider",
want: makeIndirectLocation("file:///registry/exists", ""),
wantStatusCode: http.StatusOK,
wantToReadFromHeader: false,
registryFlags: []uint8{test.WithModuleLocationInBody, test.WithModuleLocationInHeader},
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Terraform-Get", "file:///registry/exists-header")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
},
},
"shall fail to find the module": {
src: "not-exist/identifier/provider",
// note that the version is fixed in the mock
// see: /internal/registry/test/mock_registry.go:testMods
wantErrorStr: `module "not-exist/identifier/provider" version "0.2.0" not found`,
wantStatusCode: http.StatusNotFound,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
},
},
"shall fail because of reading response body error": {
src: "foo/bar/baz",
wantErrorStr: "error reading response body from registry",
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "1000") // Set incorrect content length
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{")) // Only write a partial response
// The connection will close after handler returns, but client will expect more data
},
},
"shall fail to deserialize JSON response": {
src: "foo/bar/baz",
wantErrorStr: `module "foo/bar/baz" version "0.2.0" failed to deserialize response body {: unexpected end of JSON input`,
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{"))
},
},
"shall fail because of unexpected protocol change - 422 http status": {
src: "foo/bar/baz",
wantErrorStr: `error getting download location for "foo/bar/baz": 422 Unprocessable Entity resp:bar`,
wantStatusCode: http.StatusUnprocessableEntity,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_, _ = w.Write([]byte("bar"))
},
},
"shall fail because location is not found in the response": {
src: "foo/bar/baz",
wantErrorStr: `registry did not return a location for this package`,
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// note that the response emulates a contract change
_, _ = w.Write([]byte(`{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`))
},
},
}
t.Parallel()
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(tc.handlerFunc))
defer mockServer.Close()
registryServer := test.Registry(tc.registryFlags...)
defer registryServer.Close()
transport := &testTransport{
mockURL: mockServer.URL,
}
httpClient := retryablehttp.NewClient()
httpClient.HTTPClient.Transport = transport
client := NewClient(t.Context(), test.Disco(registryServer), httpClient)
mod, err := regsrc.ParseModuleSource(tc.src)
if err != nil {
t.Fatal(err)
}
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
// Validate the results
if err != nil && tc.wantErrorStr == "" {
t.Fatalf("unexpected error: %v", err)
}
if err != nil && !strings.Contains(err.Error(), tc.wantErrorStr) {
t.Fatalf("unexpected error content: want=%s, got=%v", tc.wantErrorStr, err)
}
if diff := cmp.Diff(tc.want, got, cmp.AllowUnexported(PackageLocationDirect{})); diff != "" {
t.Fatal("unexpected location\n" + diff)
}
// Verify status code if we have a successful response
if transport.lastResponse != nil {
gotStatusCode := transport.lastResponse.StatusCode
if tc.wantStatusCode != gotStatusCode {
t.Fatalf("unexpected response status code: want=%d, got=%d", tc.wantStatusCode, gotStatusCode)
}
// Check if we expected to read from header
if tc.wantToReadFromHeader && err == nil {
headerVal := transport.lastResponse.Header.Get("X-Terraform-Get")
if headerVal == "" {
t.Fatalf("expected to read location from header but X-Terraform-Get header was not set")
}
}
}
})
}
}
// testTransport is a custom http.RoundTripper that redirects requests to the mock server
// and captures the response for inspection
type testTransport struct {
mockURL string
// Store the last response received from the mock server
lastResponse *http.Response
}
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Create a new request to the mock server with the same path, method, body, etc.
mockReq := &http.Request{
Method: req.Method,
URL: &url.URL{
Scheme: "http",
Host: strings.TrimPrefix(t.mockURL, "http://"),
Path: req.URL.Path,
},
Header: req.Header,
Body: req.Body,
Host: req.Host,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
}
// Send the request to the mock server
resp, err := http.DefaultTransport.RoundTrip(mockReq)
if err == nil {
t.lastResponse = resp
}
return resp, err
}