Files
steampipe/pkg/db/db_client/db_client_execute_test.go
2026-02-13 17:56:39 +05:30

160 lines
6.0 KiB
Go

package db_client
import (
"os"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestTimestamptzTextFormatImplemented verifies that the timestamptz wire protocol fix is in place.
// Reference: https://github.com/turbot/steampipe/issues/4450
//
// This test verifies that startQuery uses QueryResultFormatsByOID to request text format
// for timestamptz columns, ensuring PostgreSQL formats values using the session timezone.
//
// Without this fix, pgx uses binary protocol which loses session timezone info, causing
// timestamptz values to display in the local machine timezone instead of the session timezone.
func TestTimestamptzTextFormatImplemented(t *testing.T) {
// Read the db_client_execute.go file to verify the fix is present
content, err := os.ReadFile("db_client_execute.go")
require.NoError(t, err, "should be able to read db_client_execute.go")
sourceCode := string(content)
// Verify QueryResultFormatsByOID is used
assert.Contains(t, sourceCode, "pgx.QueryResultFormatsByOID",
"QueryResultFormatsByOID must be used to specify format for specific column types")
// Verify TimestamptzOID is referenced
assert.Contains(t, sourceCode, "pgtype.TimestamptzOID",
"TimestamptzOID must be specified to request text format for timestamptz columns")
// Verify TextFormatCode is used
assert.Contains(t, sourceCode, "pgx.TextFormatCode",
"TextFormatCode must be used to request text format")
// Verify the fix is in startQuery function
funcStart := strings.Index(sourceCode, "func (c *DbClient) startQuery")
assert.NotEqual(t, -1, funcStart, "startQuery function must exist")
// Extract just the startQuery function for more precise checking
funcEnd := strings.Index(sourceCode[funcStart:], "\nfunc ")
if funcEnd == -1 {
funcEnd = len(sourceCode)
} else {
funcEnd += funcStart
}
startQueryFunc := sourceCode[funcStart:funcEnd]
// Verify all three components are in startQuery
assert.Contains(t, startQueryFunc, "QueryResultFormatsByOID",
"QueryResultFormatsByOID must be in startQuery function")
assert.Contains(t, startQueryFunc, "TimestamptzOID",
"TimestamptzOID must be in startQuery function")
assert.Contains(t, startQueryFunc, "TextFormatCode",
"TextFormatCode must be in startQuery function")
// Verify there's a comment explaining the fix
hasComment := strings.Contains(startQueryFunc, "session timezone") ||
strings.Contains(startQueryFunc, "text format for timestamptz") ||
strings.Contains(startQueryFunc, "Request text format")
assert.True(t, hasComment,
"Comment should explain why text format is needed for timestamptz")
// Verify queryArgs are constructed and used
assert.Contains(t, startQueryFunc, "queryArgs",
"queryArgs variable must be used to prepend format specification")
assert.Contains(t, startQueryFunc, "conn.Query(ctx, query, queryArgs...)",
"conn.Query must use queryArgs instead of args directly")
}
// TestTimestamptzFormatCorrectness verifies the format specification structure
func TestTimestamptzFormatCorrectness(t *testing.T) {
content, err := os.ReadFile("db_client_execute.go")
require.NoError(t, err, "should be able to read db_client_execute.go")
sourceCode := string(content)
// Verify the QueryResultFormatsByOID is constructed as the first element
// This is critical - it must be the first argument before actual query parameters
assert.Contains(t, sourceCode, "queryArgs := make([]any, 0, len(args)+1)",
"queryArgs must be allocated with capacity for format spec + args")
// Verify format spec is appended first
lines := strings.Split(sourceCode, "\n")
var foundMake, foundAppendFormat, foundAppendArgs bool
var makeIdx, appendFormatIdx, appendArgsIdx int
for i, line := range lines {
if strings.Contains(line, "queryArgs := make([]any, 0, len(args)+1)") {
foundMake = true
makeIdx = i
}
if strings.Contains(line, "queryArgs = append(queryArgs, pgx.QueryResultFormatsByOID{") {
foundAppendFormat = true
appendFormatIdx = i
}
if strings.Contains(line, "queryArgs = append(queryArgs, args...)") {
foundAppendArgs = true
appendArgsIdx = i
}
}
assert.True(t, foundMake, "queryArgs must be allocated")
assert.True(t, foundAppendFormat, "format spec must be appended to queryArgs")
assert.True(t, foundAppendArgs, "original args must be appended to queryArgs")
// Verify correct order: make -> append format spec -> append args
if foundMake && foundAppendFormat && foundAppendArgs {
assert.Less(t, makeIdx, appendFormatIdx,
"queryArgs must be allocated before appending format spec")
assert.Less(t, appendFormatIdx, appendArgsIdx,
"format spec must be appended before original args")
}
}
// TestTimestamptzFormatDoesNotAffectOtherTypes verifies only timestamptz format is changed
func TestTimestamptzFormatDoesNotAffectOtherTypes(t *testing.T) {
content, err := os.ReadFile("db_client_execute.go")
require.NoError(t, err, "should be able to read db_client_execute.go")
sourceCode := string(content)
// Find the QueryResultFormatsByOID map construction
funcStart := strings.Index(sourceCode, "func (c *DbClient) startQuery")
require.NotEqual(t, -1, funcStart, "startQuery function must exist")
funcEnd := strings.Index(sourceCode[funcStart:], "\nfunc ")
if funcEnd == -1 {
funcEnd = len(sourceCode)
} else {
funcEnd += funcStart
}
startQueryFunc := sourceCode[funcStart:funcEnd]
// Verify ONLY TimestamptzOID is in the map (no other OIDs)
// This ensures we don't accidentally change format for other types
otherOIDs := []string{
"DateOID",
"TimestampOID",
"TimeOID",
"IntervalOID",
"JSONOID",
"JSONBOID",
}
for _, oid := range otherOIDs {
assert.NotContains(t, startQueryFunc, "pgtype."+oid,
"Should not change format for "+oid+" - only timestamptz needs text format")
}
// Verify there's only one entry in QueryResultFormatsByOID
// Count how many times we see "OID:" in the map definition
oidCount := strings.Count(startQueryFunc, "OID:")
assert.Equal(t, 1, oidCount,
"QueryResultFormatsByOID should have exactly one entry (TimestamptzOID)")
}