mirror of
https://github.com/opentffoundation/opentf.git
synced 2025-12-20 02:09:26 -05:00
The primary reason for this change is that registry.NewClient was originally imposing its own decision about service discovery request policy on every other user of the shared disco.Disco object by modifying it directly. We have been moving towards using a dependency inversion style where package main is responsible for deciding how everything should be configured based on global CLI arguments, environment variables, and the CLI configuration, and so this commit moves to using that model for the HTTP clients used by the module and provider registry client code. This also makes explicit what was previously hidden away: that all service discovery requests are made using the same HTTP client policy as for requests to module registries, even if the service being discovered is not a registry. This doesn't seem to have been the intention of the code as previously written, but was still its ultimate effect: there is only one disco.Disco object shared across all discovery callers and so changing its configuration in any way changes it for everyone. This initial rework is certainly not perfect: these components were not originally designed to work in this way and there are lots of existing test cases relying on them working the old way, and so this is a compromise to get the behavior we now need (using consistent HTTP client settings across all callers) without disrupting too much existing code. Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
489 lines
14 KiB
Go
489 lines
14 KiB
Go
// Copyright (c) The OpenTofu Authors
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
// Copyright (c) 2023 HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package registry
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/go-retryablehttp"
|
|
version "github.com/hashicorp/go-version"
|
|
"github.com/hashicorp/terraform-svchost/disco"
|
|
|
|
"github.com/opentofu/opentofu/internal/httpclient"
|
|
"github.com/opentofu/opentofu/internal/registry/regsrc"
|
|
"github.com/opentofu/opentofu/internal/registry/response"
|
|
"github.com/opentofu/opentofu/internal/registry/test"
|
|
tfversion "github.com/opentofu/opentofu/version"
|
|
)
|
|
|
|
func TestLookupModuleVersions(t *testing.T) {
|
|
server := test.Registry()
|
|
defer server.Close()
|
|
|
|
client := NewClient(t.Context(), test.Disco(server), nil)
|
|
|
|
// test with and without a hostname
|
|
for _, src := range []string{
|
|
"example.com/test-versions/name/provider",
|
|
"test-versions/name/provider",
|
|
} {
|
|
modsrc, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
resp, err := client.ModuleVersions(context.Background(), modsrc)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(resp.Modules) != 1 {
|
|
t.Fatal("expected 1 module, got", len(resp.Modules))
|
|
}
|
|
|
|
mod := resp.Modules[0]
|
|
name := "test-versions/name/provider"
|
|
if mod.Source != name {
|
|
t.Fatalf("expected module name %q, got %q", name, mod.Source)
|
|
}
|
|
|
|
if len(mod.Versions) != 4 {
|
|
t.Fatal("expected 4 versions, got", len(mod.Versions))
|
|
}
|
|
|
|
for _, v := range mod.Versions {
|
|
_, err := version.NewVersion(v.Version)
|
|
if err != nil {
|
|
t.Fatalf("invalid version %q: %s", v.Version, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestInvalidRegistry(t *testing.T) {
|
|
server := test.Registry()
|
|
defer server.Close()
|
|
|
|
client := NewClient(t.Context(), test.Disco(server), nil)
|
|
|
|
src := "non-existent.localhost.localdomain/test-versions/name/provider"
|
|
modsrc, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
}
|
|
|
|
func TestRegistryAuth(t *testing.T) {
|
|
server := test.Registry()
|
|
defer server.Close()
|
|
|
|
client := NewClient(t.Context(), test.Disco(server), nil)
|
|
|
|
src := "private/name/provider"
|
|
mod, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
_, err = client.ModuleVersions(context.Background(), mod)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Also test without a credentials source
|
|
client.services.SetCredentialsSource(nil)
|
|
|
|
// both should fail without auth
|
|
_, err = client.ModuleVersions(context.Background(), mod)
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
}
|
|
|
|
func TestLookupModuleLocationRelative(t *testing.T) {
|
|
server := test.Registry()
|
|
defer server.Close()
|
|
|
|
client := NewClient(t.Context(), test.Disco(server), nil)
|
|
|
|
src := "relative/foo/bar"
|
|
mod, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
want := server.URL + "/relative-path"
|
|
if got != want {
|
|
t.Errorf("wrong location %s; want %s", got, want)
|
|
}
|
|
}
|
|
|
|
func TestAccLookupModuleVersions(t *testing.T) {
|
|
if os.Getenv("TF_ACC") == "" {
|
|
t.Skip()
|
|
}
|
|
regDisco := disco.New()
|
|
regDisco.SetUserAgent(httpclient.OpenTofuUserAgent(tfversion.String()))
|
|
|
|
// test with and without a hostname
|
|
for _, src := range []string{
|
|
"terraform-aws-modules/vpc/aws",
|
|
regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws",
|
|
} {
|
|
modsrc, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
s := NewClient(t.Context(), regDisco, nil)
|
|
resp, err := s.ModuleVersions(context.Background(), modsrc)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if len(resp.Modules) != 1 {
|
|
t.Fatal("expected 1 module, got", len(resp.Modules))
|
|
}
|
|
|
|
mod := resp.Modules[0]
|
|
name := "terraform-aws-modules/vpc/aws"
|
|
if mod.Source != name {
|
|
t.Fatalf("expected module name %q, got %q", name, mod.Source)
|
|
}
|
|
|
|
if len(mod.Versions) == 0 {
|
|
t.Fatal("expected multiple versions, got 0")
|
|
}
|
|
|
|
for _, v := range mod.Versions {
|
|
_, err := version.NewVersion(v.Version)
|
|
if err != nil {
|
|
t.Fatalf("invalid version %q: %s", v.Version, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// the error should reference the config source exactly, not the discovered path.
|
|
func TestLookupLookupModuleError(t *testing.T) {
|
|
server := test.Registry()
|
|
defer server.Close()
|
|
|
|
client := NewClient(t.Context(), test.Disco(server), nil)
|
|
|
|
// this should not be found in the registry
|
|
src := "bad/local/path"
|
|
mod, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Instrument CheckRetry to make sure 404s are not retried
|
|
retries := 0
|
|
oldCheck := client.client.CheckRetry
|
|
client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
|
|
if retries > 0 {
|
|
t.Fatal("retried after module not found")
|
|
}
|
|
retries++
|
|
return oldCheck(ctx, resp, err)
|
|
}
|
|
|
|
_, err = client.ModuleLocation(context.Background(), mod, "0.2.0")
|
|
if err == nil {
|
|
t.Fatal("expected error")
|
|
}
|
|
|
|
// check for the exact quoted string to ensure we didn't prepend a hostname.
|
|
if !strings.Contains(err.Error(), `"bad/local/path"`) {
|
|
t.Fatal("error should not include the hostname. got:", err)
|
|
}
|
|
}
|
|
|
|
func TestLookupModuleRetryError(t *testing.T) {
|
|
server := test.RegistryRetryableErrorsServer()
|
|
defer server.Close()
|
|
|
|
client := NewClient(t.Context(), test.Disco(server), nil)
|
|
|
|
src := "example.com/test-versions/name/provider"
|
|
modsrc, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp, err := client.ModuleVersions(context.Background(), modsrc)
|
|
if err == nil {
|
|
t.Fatal("expected requests to exceed retry", err)
|
|
}
|
|
if resp != nil {
|
|
t.Fatal("unexpected response", *resp)
|
|
}
|
|
|
|
// verify maxRetryErrorHandler handler returned the error
|
|
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
|
|
t.Fatal("unexpected error, got:", err)
|
|
}
|
|
}
|
|
|
|
func TestLookupModuleNoRetryError(t *testing.T) {
|
|
server := test.RegistryRetryableErrorsServer()
|
|
defer server.Close()
|
|
|
|
client := NewClient(
|
|
t.Context(), test.Disco(server),
|
|
// Retries are disabled by the second argument to this function
|
|
httpclient.NewForRegistryRequests(t.Context(), 0, 10*time.Second),
|
|
)
|
|
|
|
src := "example.com/test-versions/name/provider"
|
|
modsrc, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp, err := client.ModuleVersions(context.Background(), modsrc)
|
|
if err == nil {
|
|
t.Fatal("expected request to fail", err)
|
|
}
|
|
if resp != nil {
|
|
t.Fatal("unexpected response", *resp)
|
|
}
|
|
|
|
// verify maxRetryErrorHandler handler returned the error
|
|
if !strings.Contains(err.Error(), "request failed:") {
|
|
t.Fatal("unexpected error, got:", err)
|
|
}
|
|
}
|
|
|
|
func TestLookupModuleNetworkError(t *testing.T) {
|
|
server := test.RegistryRetryableErrorsServer()
|
|
client := NewClient(t.Context(), test.Disco(server), nil)
|
|
|
|
// Shut down the server to simulate network failure
|
|
server.Close()
|
|
|
|
src := "example.com/test-versions/name/provider"
|
|
modsrc, err := regsrc.ParseModuleSource(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
resp, err := client.ModuleVersions(context.Background(), modsrc)
|
|
if err == nil {
|
|
t.Fatal("expected request to fail", err)
|
|
}
|
|
if resp != nil {
|
|
t.Fatal("unexpected response", *resp)
|
|
}
|
|
|
|
// verify maxRetryErrorHandler handler returned the correct error
|
|
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
|
|
t.Fatal("unexpected error, got:", err)
|
|
}
|
|
}
|
|
|
|
func TestModuleLocation_readRegistryResponse(t *testing.T) {
|
|
cases := map[string]struct {
|
|
src string
|
|
handlerFunc func(w http.ResponseWriter, r *http.Request)
|
|
registryFlags []uint8
|
|
want string
|
|
wantErrorStr string
|
|
wantToReadFromHeader bool
|
|
wantStatusCode int
|
|
}{
|
|
"shall find the module location in the registry response body": {
|
|
src: "exists-in-registry/identifier/provider",
|
|
want: "file:///registry/exists",
|
|
wantStatusCode: http.StatusOK,
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
|
|
},
|
|
},
|
|
"shall find the module location in the registry response header": {
|
|
src: "exists-in-registry/identifier/provider",
|
|
registryFlags: []uint8{test.WithModuleLocationInHeader},
|
|
want: "file:///registry/exists",
|
|
wantToReadFromHeader: true,
|
|
wantStatusCode: http.StatusNoContent,
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Terraform-Get", "file:///registry/exists")
|
|
w.WriteHeader(http.StatusNoContent)
|
|
},
|
|
},
|
|
"shall read location from the registry response body even if the header with location address is also set": {
|
|
src: "exists-in-registry/identifier/provider",
|
|
want: "file:///registry/exists",
|
|
wantStatusCode: http.StatusOK,
|
|
wantToReadFromHeader: false,
|
|
registryFlags: []uint8{test.WithModuleLocationInBody, test.WithModuleLocationInHeader},
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Terraform-Get", "file:///registry/exists-header")
|
|
w.WriteHeader(http.StatusOK)
|
|
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
|
|
},
|
|
},
|
|
"shall fail to find the module": {
|
|
src: "not-exist/identifier/provider",
|
|
// note that the version is fixed in the mock
|
|
// see: /internal/registry/test/mock_registry.go:testMods
|
|
wantErrorStr: `module "not-exist/identifier/provider" version "0.2.0" not found`,
|
|
wantStatusCode: http.StatusNotFound,
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusNotFound)
|
|
},
|
|
},
|
|
"shall fail because of reading response body error": {
|
|
src: "foo/bar/baz",
|
|
wantErrorStr: "error reading response body from registry",
|
|
wantStatusCode: http.StatusOK,
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Length", "1000") // Set incorrect content length
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("{")) // Only write a partial response
|
|
// The connection will close after handler returns, but client will expect more data
|
|
},
|
|
},
|
|
"shall fail to deserialize JSON response": {
|
|
src: "foo/bar/baz",
|
|
wantErrorStr: `module "foo/bar/baz" version "0.2.0" failed to deserialize response body {: unexpected end of JSON input`,
|
|
wantStatusCode: http.StatusOK,
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("{"))
|
|
},
|
|
},
|
|
"shall fail because of unexpected protocol change - 422 http status": {
|
|
src: "foo/bar/baz",
|
|
wantErrorStr: `error getting download location for "foo/bar/baz": 422 Unprocessable Entity resp:bar`,
|
|
wantStatusCode: http.StatusUnprocessableEntity,
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusUnprocessableEntity)
|
|
_, _ = w.Write([]byte("bar"))
|
|
},
|
|
},
|
|
"shall fail because location is not found in the response": {
|
|
src: "foo/bar/baz",
|
|
wantErrorStr: `failed to get download URL for "foo/bar/baz": 200 OK resp:{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`,
|
|
wantStatusCode: http.StatusOK,
|
|
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
// note that the response emulates a contract change
|
|
_, _ = w.Write([]byte(`{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`))
|
|
},
|
|
},
|
|
}
|
|
|
|
t.Parallel()
|
|
for name, tc := range cases {
|
|
t.Run(name, func(t *testing.T) {
|
|
mockServer := httptest.NewServer(http.HandlerFunc(tc.handlerFunc))
|
|
defer mockServer.Close()
|
|
|
|
registryServer := test.Registry(tc.registryFlags...)
|
|
defer registryServer.Close()
|
|
|
|
transport := &testTransport{
|
|
mockURL: mockServer.URL,
|
|
}
|
|
httpClient := retryablehttp.NewClient()
|
|
httpClient.HTTPClient.Transport = transport
|
|
client := NewClient(t.Context(), test.Disco(registryServer), httpClient)
|
|
|
|
mod, err := regsrc.ParseModuleSource(tc.src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
|
|
|
|
// Validate the results
|
|
if err != nil && tc.wantErrorStr == "" {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if err != nil && !strings.Contains(err.Error(), tc.wantErrorStr) {
|
|
t.Fatalf("unexpected error content: want=%s, got=%v", tc.wantErrorStr, err)
|
|
}
|
|
if got != tc.want {
|
|
t.Fatalf("unexpected location: want=%s, got=%v", tc.want, got)
|
|
}
|
|
|
|
// Verify status code if we have a successful response
|
|
if transport.lastResponse != nil {
|
|
gotStatusCode := transport.lastResponse.StatusCode
|
|
if tc.wantStatusCode != gotStatusCode {
|
|
t.Fatalf("unexpected response status code: want=%d, got=%d", tc.wantStatusCode, gotStatusCode)
|
|
}
|
|
|
|
// Check if we expected to read from header
|
|
if tc.wantToReadFromHeader && err == nil {
|
|
headerVal := transport.lastResponse.Header.Get("X-Terraform-Get")
|
|
if headerVal == "" {
|
|
t.Fatalf("expected to read location from header but X-Terraform-Get header was not set")
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// testTransport is a custom http.RoundTripper that redirects requests to the mock server
|
|
// and captures the response for inspection
|
|
type testTransport struct {
|
|
mockURL string
|
|
// Store the last response received from the mock server
|
|
lastResponse *http.Response
|
|
}
|
|
|
|
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
// Create a new request to the mock server with the same path, method, body, etc.
|
|
mockReq := &http.Request{
|
|
Method: req.Method,
|
|
URL: &url.URL{
|
|
Scheme: "http",
|
|
Host: strings.TrimPrefix(t.mockURL, "http://"),
|
|
Path: req.URL.Path,
|
|
},
|
|
Header: req.Header,
|
|
Body: req.Body,
|
|
Host: req.Host,
|
|
Proto: req.Proto,
|
|
ProtoMajor: req.ProtoMajor,
|
|
ProtoMinor: req.ProtoMinor,
|
|
}
|
|
|
|
// Send the request to the mock server
|
|
resp, err := http.DefaultTransport.RoundTrip(mockReq)
|
|
if err == nil {
|
|
t.lastResponse = resp
|
|
}
|
|
return resp, err
|
|
}
|