mirror of
https://github.com/opentffoundation/opentf.git
synced 2025-12-19 17:59:05 -05:00
156 lines
4.0 KiB
Go
156 lines
4.0 KiB
Go
// Copyright (c) The OpenTofu Authors
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
// Copyright (c) 2023 HashiCorp, Inc.
|
|
// SPDX-License-Identifier: MPL-2.0
|
|
|
|
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
|
|
"github.com/opentofu/opentofu/internal/httpclient"
|
|
"github.com/opentofu/opentofu/internal/tfdiags"
|
|
"golang.org/x/crypto/pkcs12"
|
|
)
|
|
|
|
type ClientCertificateAuthConfig struct {
|
|
ClientCertificate string
|
|
ClientCertificatePassword string
|
|
ClientCertificatePath string
|
|
}
|
|
|
|
type clientCertAuth struct{}
|
|
|
|
var _ AuthMethod = &clientCertAuth{}
|
|
|
|
func (cred *clientCertAuth) Name() string {
|
|
return "Client Certificate Auth"
|
|
}
|
|
|
|
func (cred *clientCertAuth) Construct(ctx context.Context, config *Config) (azcore.TokenCredential, error) {
|
|
client := httpclient.New(ctx)
|
|
|
|
clientCertificate, err := consolidateCertificate(config.ClientCertificate, config.ClientCertificatePath)
|
|
if err != nil {
|
|
// This should never happen; this is checked in the Validate function
|
|
return nil, err
|
|
}
|
|
|
|
privateKey, certificate, err := pkcs12.Decode(
|
|
clientCertificate,
|
|
config.ClientCertificatePassword,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
clientId, err := consolidateClientId(config)
|
|
if err != nil {
|
|
// This should never happen; this is checked in the Validate function
|
|
return nil, err
|
|
}
|
|
|
|
return azidentity.NewClientCertificateCredential(
|
|
config.TenantID,
|
|
clientId,
|
|
[]*x509.Certificate{certificate},
|
|
privateKey,
|
|
&azidentity.ClientCertificateCredentialOptions{
|
|
ClientOptions: clientOptions(client, config.CloudConfig),
|
|
},
|
|
)
|
|
}
|
|
|
|
func (cred *clientCertAuth) Validate(_ context.Context, config *Config) tfdiags.Diagnostics {
|
|
var diags tfdiags.Diagnostics
|
|
if config.TenantID == "" {
|
|
diags = diags.Append(tfdiags.Sourceless(
|
|
tfdiags.Error,
|
|
"Invalid Azure Client Certificate Auth",
|
|
"Tenant ID is missing.",
|
|
))
|
|
}
|
|
|
|
_, err := consolidateClientId(config)
|
|
if err != nil {
|
|
diags = diags.Append(tfdiags.Sourceless(
|
|
tfdiags.Error,
|
|
"Invalid Azure Client Certificate Auth",
|
|
fmt.Sprintf("The Client ID is misconfigured: %s.", tfdiags.FormatError(err)),
|
|
))
|
|
}
|
|
clientCertificate, err := consolidateCertificate(config.ClientCertificate, config.ClientCertificatePath)
|
|
if err != nil {
|
|
diags = diags.Append(tfdiags.Sourceless(
|
|
tfdiags.Error,
|
|
"Invalid Azure Client Certificate Auth",
|
|
fmt.Sprintf("The Client Certificate is misconfigured: %s.", tfdiags.FormatError(err)),
|
|
))
|
|
} else {
|
|
_, _, err := pkcs12.Decode(
|
|
clientCertificate,
|
|
config.ClientCertificatePassword,
|
|
)
|
|
if err != nil {
|
|
diags = diags.Append(tfdiags.Sourceless(
|
|
tfdiags.Error,
|
|
"Invalid Azure Client Certificate Auth",
|
|
fmt.Sprintf("The Client Certificate is invalid: %s.", tfdiags.FormatError(err)),
|
|
))
|
|
}
|
|
}
|
|
return diags
|
|
}
|
|
|
|
func (cred *clientCertAuth) AugmentConfig(_ context.Context, config *Config) error {
|
|
return checkNamesForAccessKeyCredentials(config.StorageAddresses)
|
|
}
|
|
|
|
func consolidateCertificate(base64EncodedCertificate, certificateFilename string) ([]byte, error) {
|
|
var certBytes []byte
|
|
var fileBytes []byte
|
|
|
|
if len(base64EncodedCertificate) > 0 {
|
|
var err error
|
|
certBytes, err = base64.StdEncoding.DecodeString(base64EncodedCertificate)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error decoding client certificate: %w", err)
|
|
}
|
|
}
|
|
if len(certificateFilename) > 0 {
|
|
var err error
|
|
fileBytes, err = os.ReadFile(certificateFilename)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error reading client certificate file: %w", err)
|
|
}
|
|
}
|
|
|
|
hasCert := len(certBytes) > 0
|
|
hasFile := len(fileBytes) > 0
|
|
|
|
if !hasCert && !hasFile {
|
|
return nil, errors.New("missing certificate, client certificate is required")
|
|
}
|
|
|
|
if !hasCert {
|
|
return fileBytes, nil
|
|
}
|
|
|
|
if !hasFile {
|
|
return certBytes, nil
|
|
}
|
|
|
|
if !bytes.Equal(certBytes, fileBytes) {
|
|
return nil, errors.New("client certificate provided directly and through file do not match; either make them the same value or only provide one")
|
|
}
|
|
return fileBytes, nil
|
|
}
|