Files
opentf/internal/backend/remote-state/azure/auth/client_cert_auth.go
2025-09-15 19:22:17 +01:00

156 lines
4.0 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package auth
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"errors"
"fmt"
"os"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/opentofu/opentofu/internal/httpclient"
"github.com/opentofu/opentofu/internal/tfdiags"
"golang.org/x/crypto/pkcs12"
)
type ClientCertificateAuthConfig struct {
ClientCertificate string
ClientCertificatePassword string
ClientCertificatePath string
}
type clientCertAuth struct{}
var _ AuthMethod = &clientCertAuth{}
func (cred *clientCertAuth) Name() string {
return "Client Certificate Auth"
}
func (cred *clientCertAuth) Construct(ctx context.Context, config *Config) (azcore.TokenCredential, error) {
client := httpclient.New(ctx)
clientCertificate, err := consolidateCertificate(config.ClientCertificate, config.ClientCertificatePath)
if err != nil {
// This should never happen; this is checked in the Validate function
return nil, err
}
privateKey, certificate, err := pkcs12.Decode(
clientCertificate,
config.ClientCertificatePassword,
)
if err != nil {
return nil, err
}
clientId, err := consolidateClientId(config)
if err != nil {
// This should never happen; this is checked in the Validate function
return nil, err
}
return azidentity.NewClientCertificateCredential(
config.TenantID,
clientId,
[]*x509.Certificate{certificate},
privateKey,
&azidentity.ClientCertificateCredentialOptions{
ClientOptions: clientOptions(client, config.CloudConfig),
},
)
}
func (cred *clientCertAuth) Validate(_ context.Context, config *Config) tfdiags.Diagnostics {
var diags tfdiags.Diagnostics
if config.TenantID == "" {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Invalid Azure Client Certificate Auth",
"Tenant ID is missing.",
))
}
_, err := consolidateClientId(config)
if err != nil {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Invalid Azure Client Certificate Auth",
fmt.Sprintf("The Client ID is misconfigured: %s.", tfdiags.FormatError(err)),
))
}
clientCertificate, err := consolidateCertificate(config.ClientCertificate, config.ClientCertificatePath)
if err != nil {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Invalid Azure Client Certificate Auth",
fmt.Sprintf("The Client Certificate is misconfigured: %s.", tfdiags.FormatError(err)),
))
} else {
_, _, err := pkcs12.Decode(
clientCertificate,
config.ClientCertificatePassword,
)
if err != nil {
diags = diags.Append(tfdiags.Sourceless(
tfdiags.Error,
"Invalid Azure Client Certificate Auth",
fmt.Sprintf("The Client Certificate is invalid: %s.", tfdiags.FormatError(err)),
))
}
}
return diags
}
func (cred *clientCertAuth) AugmentConfig(_ context.Context, config *Config) error {
return checkNamesForAccessKeyCredentials(config.StorageAddresses)
}
func consolidateCertificate(base64EncodedCertificate, certificateFilename string) ([]byte, error) {
var certBytes []byte
var fileBytes []byte
if len(base64EncodedCertificate) > 0 {
var err error
certBytes, err = base64.StdEncoding.DecodeString(base64EncodedCertificate)
if err != nil {
return nil, fmt.Errorf("error decoding client certificate: %w", err)
}
}
if len(certificateFilename) > 0 {
var err error
fileBytes, err = os.ReadFile(certificateFilename)
if err != nil {
return nil, fmt.Errorf("error reading client certificate file: %w", err)
}
}
hasCert := len(certBytes) > 0
hasFile := len(fileBytes) > 0
if !hasCert && !hasFile {
return nil, errors.New("missing certificate, client certificate is required")
}
if !hasCert {
return fileBytes, nil
}
if !hasFile {
return certBytes, nil
}
if !bytes.Equal(certBytes, fileBytes) {
return nil, errors.New("client certificate provided directly and through file do not match; either make them the same value or only provide one")
}
return fileBytes, nil
}