mirror of
https://github.com/turbot/steampipe.git
synced 2026-02-22 23:00:16 -05:00
339 lines
10 KiB
Go
339 lines
10 KiB
Go
package db_local
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"log"
|
|
"math/big"
|
|
"os"
|
|
"time"
|
|
|
|
"github.com/turbot/go-kit/helpers"
|
|
"github.com/turbot/steampipe/pkg/utils"
|
|
)
|
|
|
|
const CertIssuer = "steampipe.io"
|
|
|
|
var (
|
|
CertExpiryTolerance = 180 * (24 * time.Hour) // 180 days
|
|
RootCertValidityPeriod = 5 * 365 * (24 * time.Hour) // 5 years
|
|
ServerCertValidityPeriod = 365 * (24 * time.Hour) // 1 year
|
|
)
|
|
|
|
// CertificatesExist checks if the root and server certificate and key files exist
|
|
func CertificatesExist() bool {
|
|
return helpers.FileExists(getRootCertLocation()) && helpers.FileExists(getServerCertLocation())
|
|
}
|
|
|
|
// RemoveServerCertificate removes the server certificate certificates so it will be regenerated
|
|
func RemoveServerCertificate() error {
|
|
utils.LogTime("db_local.RemoveServerCertificate start")
|
|
defer utils.LogTime("db_local.RemoveServerCertificate end")
|
|
|
|
if err := os.Remove(getServerCertLocation()); err != nil {
|
|
return err
|
|
}
|
|
return os.Remove(getServerCertKeyLocation())
|
|
}
|
|
|
|
// RemoveAllCertificates removes root and server certificates so that they can be regenerated
|
|
func RemoveAllCertificates() error {
|
|
utils.LogTime("db_local.RemoveAllCertificates start")
|
|
defer utils.LogTime("db_local.RemoveAllCertificates end")
|
|
|
|
// remove the root cert (but not key)
|
|
if err := os.Remove(getRootCertLocation()); err != nil {
|
|
return err
|
|
}
|
|
// remove the server cert and key
|
|
return RemoveServerCertificate()
|
|
}
|
|
|
|
// ValidateRootCertificate checks the root certificate exists, is not expired and has correct Subject
|
|
func ValidateRootCertificate() bool {
|
|
utils.LogTime("db_local.ValidateRootCertificates start")
|
|
defer utils.LogTime("db_local.ValidateRootCertificates end")
|
|
|
|
rootCertificate, err := parseCertificateInLocation(getRootCertLocation())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return (rootCertificate.Subject.CommonName == CertIssuer) && isCerticateExpiring(rootCertificate)
|
|
}
|
|
|
|
// ValidateServerCertificate checks the server certificate exists, is not expired and has correct issuer
|
|
func ValidateServerCertificate() bool {
|
|
utils.LogTime("db_local.ValidateServerCertificates start")
|
|
defer utils.LogTime("db_local.ValidateServerCertificates end")
|
|
|
|
serverCertificate, err := parseCertificateInLocation(getServerCertLocation())
|
|
if err != nil {
|
|
return false
|
|
}
|
|
|
|
return (serverCertificate.Issuer.CommonName == CertIssuer) && isCerticateExpiring(serverCertificate)
|
|
}
|
|
|
|
// if certificate or private key files do not exist, generate them
|
|
func ensureSelfSignedCertificate() (err error) {
|
|
if serverCertificateAndKeyExist() && rootCertificateAndKeyExists() {
|
|
return nil
|
|
}
|
|
|
|
// so one or both of the root and server certificate need creating
|
|
var rootPrivateKey *rsa.PrivateKey
|
|
var rootCertificate *x509.Certificate
|
|
if rootCertificateAndKeyExists() {
|
|
// if the root cert and key exist, load them
|
|
rootPrivateKey, err = loadRootPrivateKey()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rootCertificate, err = parseCertificateInLocation(getRootCertLocation())
|
|
} else {
|
|
// otherwise generate them
|
|
rootCertificate, rootPrivateKey, err = generateRootCertificate()
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// now generate new server cert
|
|
return generateServerCertificates(rootCertificate, rootPrivateKey)
|
|
|
|
}
|
|
|
|
// rootCertificateAndKeyExists checks if the root certificate ands private key files exist
|
|
func rootCertificateAndKeyExists() bool {
|
|
return helpers.FileExists(getRootCertLocation()) && helpers.FileExists(getRootCertKeyLocation())
|
|
}
|
|
|
|
// serverCertificateAndKeyExist checks if the server certificate ands private key files exist
|
|
func serverCertificateAndKeyExist() bool {
|
|
return helpers.FileExists(getServerCertLocation()) && helpers.FileExists(getServerCertKeyLocation())
|
|
}
|
|
|
|
// isCerticateExpiring checks whether the certificate expires within a predefined CertExpiryTolerance period (defined above)
|
|
func isCerticateExpiring(certificate *x509.Certificate) bool {
|
|
return certificate.NotAfter.Add(-CertExpiryTolerance).After(time.Now())
|
|
}
|
|
|
|
func parseCertificateInLocation(location string) (*x509.Certificate, error) {
|
|
utils.LogTime("db_local.parseCertificateInLocation start")
|
|
defer utils.LogTime("db_local.parseCertificateInLocation end")
|
|
|
|
rootCertRaw, err := os.ReadFile(getRootCertLocation())
|
|
if err != nil {
|
|
// if we can't read the certificate, then there's a problem with permissions
|
|
return nil, err
|
|
}
|
|
// decode the pem blocks
|
|
rootPemBlock, _ := pem.Decode(rootCertRaw)
|
|
if rootPemBlock == nil {
|
|
return nil, fmt.Errorf("could not decode PEM blocks from certificate at %s", location)
|
|
}
|
|
// parse the PEM Blocks to Certificates
|
|
return x509.ParseCertificate(rootPemBlock.Bytes)
|
|
}
|
|
|
|
func generateRootCertificate() (*x509.Certificate, *rsa.PrivateKey, error) {
|
|
utils.LogTime("db_local.generateServiceCertificates start")
|
|
defer utils.LogTime("db_local.generateServiceCertificates end")
|
|
|
|
// Load or create our own certificate authority
|
|
caPrivateKey, err := ensureRootPrivateKey()
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
now := time.Now()
|
|
|
|
// Certificate authority input
|
|
caCertificateData := &x509.Certificate{
|
|
SerialNumber: big.NewInt(2020),
|
|
NotBefore: now,
|
|
NotAfter: now.Add(RootCertValidityPeriod),
|
|
Subject: pkix.Name{CommonName: CertIssuer},
|
|
IsCA: true,
|
|
BasicConstraintsValid: true,
|
|
}
|
|
|
|
caCertificate, err := x509.CreateCertificate(rand.Reader, caCertificateData, caCertificateData, &caPrivateKey.PublicKey, caPrivateKey)
|
|
if err != nil {
|
|
log.Println("[WARN] failed to create certificate")
|
|
return nil, nil, err
|
|
}
|
|
|
|
caCertificatePem := &bytes.Buffer{}
|
|
err = pem.Encode(caCertificatePem, &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: caCertificate,
|
|
})
|
|
if err != nil {
|
|
log.Println("[WARN] failed to encode to PEM")
|
|
return nil, nil, err
|
|
}
|
|
|
|
if err := writeCertFile(getRootCertLocation(), caCertificatePem.String()); err != nil {
|
|
log.Println("[WARN] failed to save the certificate")
|
|
return nil, nil, err
|
|
}
|
|
|
|
return caCertificateData, caPrivateKey, nil
|
|
}
|
|
|
|
func generateServerCertificates(caCertificateData *x509.Certificate, caPrivateKey *rsa.PrivateKey) error {
|
|
utils.LogTime("db_local.generateServerCertificates start")
|
|
defer utils.LogTime("db_local.generateServerCertificates end")
|
|
|
|
now := time.Now()
|
|
|
|
// set up for server certificate
|
|
serverCertificateData := &x509.Certificate{
|
|
SerialNumber: big.NewInt(2019),
|
|
Subject: caCertificateData.Subject,
|
|
Issuer: caCertificateData.Subject,
|
|
NotBefore: now,
|
|
NotAfter: now.Add(ServerCertValidityPeriod),
|
|
}
|
|
|
|
// Generate the server private key
|
|
serverPrivKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
serverCertBytes, err := x509.CreateCertificate(rand.Reader, serverCertificateData, caCertificateData, &serverPrivKey.PublicKey, caPrivateKey)
|
|
|
|
if err != nil {
|
|
log.Println("[INFO] Failed to create server certificate")
|
|
return err
|
|
}
|
|
|
|
// Encode and save the server certificate
|
|
serverCertPem := new(bytes.Buffer)
|
|
err = pem.Encode(serverCertPem, &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: serverCertBytes,
|
|
})
|
|
if err != nil {
|
|
log.Println("[INFO] Failed to encode to PEM for server certificate")
|
|
return err
|
|
}
|
|
if err := writeCertFile(getServerCertLocation(), serverCertPem.String()); err != nil {
|
|
log.Println("[INFO] Failed to save server certificate")
|
|
return err
|
|
}
|
|
|
|
// Encode and save the server private key
|
|
serverPrivKeyPem := new(bytes.Buffer)
|
|
err = pem.Encode(serverPrivKeyPem, &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(serverPrivKey),
|
|
})
|
|
if err != nil {
|
|
log.Println("[INFO] Failed to encode to PEM for server private key")
|
|
return err
|
|
}
|
|
if err := writeCertFile(getServerCertKeyLocation(), serverPrivKeyPem.String()); err != nil {
|
|
log.Println("[INFO] Failed to save server private key")
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// derive ssl status from out ssl mode
|
|
func sslStatus() string {
|
|
status := sslMode()
|
|
if status == "require" {
|
|
return "on"
|
|
}
|
|
return "off"
|
|
}
|
|
|
|
// derive ssl mode from the prsesnce of the server certificate and key file
|
|
func sslMode() string {
|
|
if serverCertificateAndKeyExist() {
|
|
return "require"
|
|
}
|
|
return "disable"
|
|
}
|
|
|
|
func ensureRootPrivateKey() (*rsa.PrivateKey, error) {
|
|
// first try to load the key
|
|
// if any errors are encountered this will just return nil
|
|
caPrivateKey, _ := loadRootPrivateKey()
|
|
if caPrivateKey != nil {
|
|
// we loaded one
|
|
return caPrivateKey, nil
|
|
}
|
|
// so we failed to load the key - generate instead
|
|
var err error
|
|
caPrivateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
log.Println("[WARN] private key creation failed for ca failed")
|
|
return nil, err
|
|
}
|
|
caPrivateKeyPem := new(bytes.Buffer)
|
|
err = pem.Encode(caPrivateKeyPem, &pem.Block{
|
|
Type: "RSA PRIVATE KEY",
|
|
Bytes: x509.MarshalPKCS1PrivateKey(caPrivateKey),
|
|
})
|
|
if err != nil {
|
|
log.Println("[WARN] failed to encode to PEM for ca private key")
|
|
return nil, err
|
|
}
|
|
if err := writeCertFile(getRootCertKeyLocation(), caPrivateKeyPem.String()); err != nil {
|
|
log.Println("[WARN] failed to save root private key")
|
|
return nil, err
|
|
}
|
|
return caPrivateKey, nil
|
|
}
|
|
|
|
func loadRootPrivateKey() (*rsa.PrivateKey, error) {
|
|
location := getRootCertKeyLocation()
|
|
|
|
priv, err := os.ReadFile(location)
|
|
if err != nil {
|
|
log.Printf("[TRACE] loadRootPrivateKey - failed to load key from %s: %s", location, err.Error())
|
|
return nil, err
|
|
}
|
|
|
|
privPem, _ := pem.Decode(priv)
|
|
if privPem.Type != "RSA PRIVATE KEY" {
|
|
log.Printf("[TRACE] RSA private key is of the wrong type: %v", privPem.Type)
|
|
return nil, fmt.Errorf("RSA private key is of the wrong type: %v", privPem.Type)
|
|
}
|
|
|
|
privPemBytes := privPem.Bytes
|
|
|
|
var parsedKey interface{}
|
|
if parsedKey, err = x509.ParsePKCS1PrivateKey(privPemBytes); err != nil {
|
|
if parsedKey, err = x509.ParsePKCS8PrivateKey(privPemBytes); err != nil {
|
|
// note this returns type `interface{}`
|
|
log.Printf("[TRACE] failed to parse RSA private key: %s", err.Error())
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
var privateKey *rsa.PrivateKey
|
|
var ok bool
|
|
privateKey, ok = parsedKey.(*rsa.PrivateKey)
|
|
if !ok {
|
|
log.Printf("[TRACE] failed to parse RSA private key")
|
|
return nil, fmt.Errorf("failed to parse RSA private key")
|
|
}
|
|
return privateKey, nil
|
|
}
|
|
|
|
func writeCertFile(filePath string, cert string) error {
|
|
return os.WriteFile(filePath, []byte(cert), 0600)
|
|
}
|