Files
opentf/internal/registry/client_test.go
Martin Atkins d2bef1fd47 Adopt OpenTofu's own "svchost" module
Previously we were using a third-party library, but that doesn't have any
support for passing context.Context through its API and so isn't suitable
for our goals of adding OpenTelemetry tracing for all outgoing network
requests.

We now have our own fork that is updated to use context.Context. It also
has a slightly reduced scope no longer including various details that
are tightly-coupled to our cliconfig mechanism and so better placed in the
main OpenTofu codebase so we can evolve it in future without making
lockstep library releases.

The "registry-address" library also uses svchost and uses some of its types
in its public API, so this also incorporates v2 of that library that is
updated to use our own svchost module.

Unfortunately this commit is a mix of mechanical updates to the new
libraries and some new code dealing with the functionality that is removed
in our fork of svchost. The new code is primarily in the "svcauthconfig"
package, which is similar in purpose "ociauthconfig" but for OpenTofu's
own auth mechanism instead of the OCI Distribution protocol's auth
mechanism.

This includes some additional plumbing of context.Context where it was
possible to do so without broad changes to files that would not otherwise
have been included in this commit, but there are a few leftover spots that
are context.TODO() which we'll address separately in later commits.

This removes the temporary workaround from d079da6e9e, since we are now
able to plumb the OpenTelemetry span tree all the way to the service
discovery requests.

Signed-off-by: Martin Atkins <mart@degeneration.co.uk>
2025-06-12 09:37:59 -07:00

489 lines
14 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package registry
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/hashicorp/go-retryablehttp"
version "github.com/hashicorp/go-version"
"github.com/opentofu/svchost/disco"
"github.com/opentofu/opentofu/internal/httpclient"
"github.com/opentofu/opentofu/internal/registry/regsrc"
"github.com/opentofu/opentofu/internal/registry/response"
"github.com/opentofu/opentofu/internal/registry/test"
)
func TestLookupModuleVersions(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
// test with and without a hostname
for _, src := range []string{
"example.com/test-versions/name/provider",
"test-versions/name/provider",
} {
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err != nil {
t.Fatal(err)
}
if len(resp.Modules) != 1 {
t.Fatal("expected 1 module, got", len(resp.Modules))
}
mod := resp.Modules[0]
name := "test-versions/name/provider"
if mod.Source != name {
t.Fatalf("expected module name %q, got %q", name, mod.Source)
}
if len(mod.Versions) != 4 {
t.Fatal("expected 4 versions, got", len(mod.Versions))
}
for _, v := range mod.Versions {
_, err := version.NewVersion(v.Version)
if err != nil {
t.Fatalf("invalid version %q: %s", v.Version, err)
}
}
}
}
func TestInvalidRegistry(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "non-existent.localhost.localdomain/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil {
t.Fatal("expected error")
}
}
func TestRegistryAuth(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "private/name/provider"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
_, err = client.ModuleVersions(context.Background(), mod)
if err != nil {
t.Fatal(err)
}
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err != nil {
t.Fatal(err)
}
// Also test without a credentials source
client.services.SetCredentialsSource(nil)
// both should fail without auth
_, err = client.ModuleVersions(context.Background(), mod)
if err == nil {
t.Fatal("expected error")
}
_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
if err == nil {
t.Fatal("expected error")
}
}
func TestLookupModuleLocationRelative(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "relative/foo/bar"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
if err != nil {
t.Fatal(err)
}
want := server.URL + "/relative-path"
if got != want {
t.Errorf("wrong location %s; want %s", got, want)
}
}
func TestAccLookupModuleVersions(t *testing.T) {
if os.Getenv("TF_ACC") == "" {
t.Skip()
}
regDisco := disco.New(
disco.WithHTTPClient(httpclient.New(t.Context())),
)
// test with and without a hostname
for _, src := range []string{
"terraform-aws-modules/vpc/aws",
regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws",
} {
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
s := NewClient(t.Context(), regDisco, nil)
resp, err := s.ModuleVersions(context.Background(), modsrc)
if err != nil {
t.Fatal(err)
}
if len(resp.Modules) != 1 {
t.Fatal("expected 1 module, got", len(resp.Modules))
}
mod := resp.Modules[0]
name := "terraform-aws-modules/vpc/aws"
if mod.Source != name {
t.Fatalf("expected module name %q, got %q", name, mod.Source)
}
if len(mod.Versions) == 0 {
t.Fatal("expected multiple versions, got 0")
}
for _, v := range mod.Versions {
_, err := version.NewVersion(v.Version)
if err != nil {
t.Fatalf("invalid version %q: %s", v.Version, err)
}
}
}
}
// the error should reference the config source exactly, not the discovered path.
func TestLookupLookupModuleError(t *testing.T) {
server := test.Registry()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
// this should not be found in the registry
src := "bad/local/path"
mod, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
// Instrument CheckRetry to make sure 404s are not retried
retries := 0
oldCheck := client.client.CheckRetry
client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
if retries > 0 {
t.Fatal("retried after module not found")
}
retries++
return oldCheck(ctx, resp, err)
}
_, err = client.ModuleLocation(context.Background(), mod, "0.2.0")
if err == nil {
t.Fatal("expected error")
}
// check for the exact quoted string to ensure we didn't prepend a hostname.
if !strings.Contains(err.Error(), `"bad/local/path"`) {
t.Fatal("error should not include the hostname. got:", err)
}
}
func TestLookupModuleRetryError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
defer server.Close()
client := NewClient(t.Context(), test.Disco(server), nil)
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected requests to exceed retry", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the error
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
t.Fatal("unexpected error, got:", err)
}
}
func TestLookupModuleNoRetryError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
defer server.Close()
client := NewClient(
t.Context(), test.Disco(server),
// Retries are disabled by the second argument to this function
httpclient.NewForRegistryRequests(t.Context(), 0, 10*time.Second),
)
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected request to fail", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the error
if !strings.Contains(err.Error(), "request failed:") {
t.Fatal("unexpected error, got:", err)
}
}
func TestLookupModuleNetworkError(t *testing.T) {
server := test.RegistryRetryableErrorsServer()
client := NewClient(t.Context(), test.Disco(server), nil)
// Shut down the server to simulate network failure
server.Close()
src := "example.com/test-versions/name/provider"
modsrc, err := regsrc.ParseModuleSource(src)
if err != nil {
t.Fatal(err)
}
resp, err := client.ModuleVersions(context.Background(), modsrc)
if err == nil {
t.Fatal("expected request to fail", err)
}
if resp != nil {
t.Fatal("unexpected response", *resp)
}
// verify maxRetryErrorHandler handler returned the correct error
if !strings.Contains(err.Error(), "request failed after 2 attempts") {
t.Fatal("unexpected error, got:", err)
}
}
func TestModuleLocation_readRegistryResponse(t *testing.T) {
cases := map[string]struct {
src string
handlerFunc func(w http.ResponseWriter, r *http.Request)
registryFlags []uint8
want string
wantErrorStr string
wantToReadFromHeader bool
wantStatusCode int
}{
"shall find the module location in the registry response body": {
src: "exists-in-registry/identifier/provider",
want: "file:///registry/exists",
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
},
},
"shall find the module location in the registry response header": {
src: "exists-in-registry/identifier/provider",
registryFlags: []uint8{test.WithModuleLocationInHeader},
want: "file:///registry/exists",
wantToReadFromHeader: true,
wantStatusCode: http.StatusNoContent,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Terraform-Get", "file:///registry/exists")
w.WriteHeader(http.StatusNoContent)
},
},
"shall read location from the registry response body even if the header with location address is also set": {
src: "exists-in-registry/identifier/provider",
want: "file:///registry/exists",
wantStatusCode: http.StatusOK,
wantToReadFromHeader: false,
registryFlags: []uint8{test.WithModuleLocationInBody, test.WithModuleLocationInHeader},
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Terraform-Get", "file:///registry/exists-header")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(response.ModuleLocationRegistryResp{Location: "file:///registry/exists"})
},
},
"shall fail to find the module": {
src: "not-exist/identifier/provider",
// note that the version is fixed in the mock
// see: /internal/registry/test/mock_registry.go:testMods
wantErrorStr: `module "not-exist/identifier/provider" version "0.2.0" not found`,
wantStatusCode: http.StatusNotFound,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
},
},
"shall fail because of reading response body error": {
src: "foo/bar/baz",
wantErrorStr: "error reading response body from registry",
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Length", "1000") // Set incorrect content length
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{")) // Only write a partial response
// The connection will close after handler returns, but client will expect more data
},
},
"shall fail to deserialize JSON response": {
src: "foo/bar/baz",
wantErrorStr: `module "foo/bar/baz" version "0.2.0" failed to deserialize response body {: unexpected end of JSON input`,
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{"))
},
},
"shall fail because of unexpected protocol change - 422 http status": {
src: "foo/bar/baz",
wantErrorStr: `error getting download location for "foo/bar/baz": 422 Unprocessable Entity resp:bar`,
wantStatusCode: http.StatusUnprocessableEntity,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnprocessableEntity)
_, _ = w.Write([]byte("bar"))
},
},
"shall fail because location is not found in the response": {
src: "foo/bar/baz",
wantErrorStr: `failed to get download URL for "foo/bar/baz": 200 OK resp:{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`,
wantStatusCode: http.StatusOK,
handlerFunc: func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
// note that the response emulates a contract change
_, _ = w.Write([]byte(`{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`))
},
},
}
t.Parallel()
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(tc.handlerFunc))
defer mockServer.Close()
registryServer := test.Registry(tc.registryFlags...)
defer registryServer.Close()
transport := &testTransport{
mockURL: mockServer.URL,
}
httpClient := retryablehttp.NewClient()
httpClient.HTTPClient.Transport = transport
client := NewClient(t.Context(), test.Disco(registryServer), httpClient)
mod, err := regsrc.ParseModuleSource(tc.src)
if err != nil {
t.Fatal(err)
}
got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
// Validate the results
if err != nil && tc.wantErrorStr == "" {
t.Fatalf("unexpected error: %v", err)
}
if err != nil && !strings.Contains(err.Error(), tc.wantErrorStr) {
t.Fatalf("unexpected error content: want=%s, got=%v", tc.wantErrorStr, err)
}
if got != tc.want {
t.Fatalf("unexpected location: want=%s, got=%v", tc.want, got)
}
// Verify status code if we have a successful response
if transport.lastResponse != nil {
gotStatusCode := transport.lastResponse.StatusCode
if tc.wantStatusCode != gotStatusCode {
t.Fatalf("unexpected response status code: want=%d, got=%d", tc.wantStatusCode, gotStatusCode)
}
// Check if we expected to read from header
if tc.wantToReadFromHeader && err == nil {
headerVal := transport.lastResponse.Header.Get("X-Terraform-Get")
if headerVal == "" {
t.Fatalf("expected to read location from header but X-Terraform-Get header was not set")
}
}
}
})
}
}
// testTransport is a custom http.RoundTripper that redirects requests to the mock server
// and captures the response for inspection
type testTransport struct {
mockURL string
// Store the last response received from the mock server
lastResponse *http.Response
}
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Create a new request to the mock server with the same path, method, body, etc.
mockReq := &http.Request{
Method: req.Method,
URL: &url.URL{
Scheme: "http",
Host: strings.TrimPrefix(t.mockURL, "http://"),
Path: req.URL.Path,
},
Header: req.Header,
Body: req.Body,
Host: req.Host,
Proto: req.Proto,
ProtoMajor: req.ProtoMajor,
ProtoMinor: req.ProtoMinor,
}
// Send the request to the mock server
resp, err := http.DefaultTransport.RoundTrip(mockReq)
if err == nil {
t.lastResponse = resp
}
return resp, err
}