Files
opentf/internal/encryption/method/external/command.go
2025-03-10 13:16:27 -04:00

189 lines
4.4 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package external
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"os/exec"
"strings"
"time"
"github.com/opentofu/opentofu/internal/encryption/keyprovider"
"github.com/opentofu/opentofu/internal/encryption/method"
)
type command struct {
keys *keyprovider.Output
encryptCommand []string
decryptCommand []string
}
func (c command) Encrypt(data []byte) ([]byte, error) {
var key []byte
if c.keys != nil {
key = c.keys.EncryptionKey
}
input := InputV1{
Key: key,
Payload: data,
}
result, err := c.run(c.encryptCommand, input)
if err != nil {
return nil, &method.ErrEncryptionFailed{
Cause: err,
}
}
return result, nil
}
func (c command) Decrypt(data []byte) ([]byte, error) {
var key []byte
if c.keys != nil {
key = c.keys.DecryptionKey
if len(c.keys.EncryptionKey) > 0 && len(key) == 0 {
return nil, &method.ErrDecryptionKeyUnavailable{}
}
}
if len(data) == 0 {
return nil, &method.ErrDecryptionFailed{Cause: &method.ErrCryptoFailure{
Message: "Cannot decrypt empty data.",
}}
}
input := InputV1{
Key: key,
Payload: data,
}
result, err := c.run(c.decryptCommand, input)
if err != nil {
return nil, &method.ErrDecryptionFailed{
Cause: err,
}
}
return result, nil
}
func (c command) run(command []string, input any) ([]byte, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
inputData, err := json.Marshal(input)
if err != nil {
return nil, &method.ErrCryptoFailure{
Message: "failed to marshal input",
Cause: err,
}
}
stderr := &bytes.Buffer{}
cmd := exec.CommandContext(ctx, command[0], command[1:]...)
handler := &ioHandler{
false,
bytes.NewBuffer(inputData),
[]byte{},
cancel,
nil,
}
cmd.Stdin = handler
cmd.Stdout = handler
cmd.Stderr = stderr
if err := cmd.Run(); err != nil {
if handler.err != nil {
return nil, &method.ErrCryptoFailure{
Message: "external encryption method failure",
Cause: handler.err,
SupplementalData: fmt.Sprintf("Stderr:\n-------\n%s\n", stderr.String()),
}
}
var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
if exitErr.ExitCode() != 0 {
return nil, &method.ErrCryptoFailure{
Message: "external encryption method exited with non-zero exit code",
Cause: err,
SupplementalData: fmt.Sprintf("Stderr:\n-------\n%s\n", stderr.String()),
}
}
}
return nil, &method.ErrCryptoFailure{
Message: "external encryption method exited with an error",
Cause: err,
SupplementalData: fmt.Sprintf("Stderr:\n-------\n%s\n", stderr.String()),
}
}
var result *OutputV1
decoder := json.NewDecoder(bytes.NewReader(handler.output))
decoder.DisallowUnknownFields()
if err := decoder.Decode(&result); err != nil {
return nil, &method.ErrCryptoFailure{
Message: "external encryption method returned an invalid JSON",
Cause: err,
SupplementalData: fmt.Sprintf("Stderr:\n-------\n%s\n", stderr.String()),
}
}
return result.Payload, nil
}
type ioHandler struct {
headerFinished bool
input *bytes.Buffer
output []byte
cancel func()
err error
}
func (i *ioHandler) Write(p []byte) (int, error) {
i.output = append(i.output, p...)
n := len(p)
if i.headerFinished {
// Header is finished, just collect the output.
return n, nil
}
// Check if the full header is present.
parts := strings.SplitN(string(i.output), "\n", 2)
if len(parts) == 1 {
return n, nil
}
var header Header
// Note: this is intentionally not using strict decoding. Later protocol versions may introduce additional header
// fields.
if jsonErr := json.Unmarshal([]byte(parts[0]), &header); jsonErr != nil {
err := fmt.Errorf("failed to unmarshal header from external method (%w)", jsonErr)
i.err = err
i.cancel()
return n, err
}
if header.Magic != Magic {
err := fmt.Errorf("invalid magic received from external method: %s", header.Magic)
i.err = err
i.cancel()
return n, err
}
if header.Version != 1 {
err := fmt.Errorf("invalid version number received from external method: %d", header.Version)
i.err = err
i.cancel()
return n, err
}
i.headerFinished = true
i.output = []byte(parts[1])
return n, nil
}
func (i *ioHandler) Read(p []byte) (int, error) {
return i.input.Read(p)
}