Files
opentf/internal/tofu/context_functions_test.go
Christian Mesh a3fe39ff33 Remove global schema cache and clean up tofu schema/contextPlugins (#3589)
Signed-off-by: Christian Mesh <christianmesh1@gmail.com>
Co-authored-by: Martin Atkins <mart@degeneration.co.uk>
2025-12-17 09:49:39 -05:00

823 lines
20 KiB
Go

// Copyright (c) The OpenTofu Authors
// SPDX-License-Identifier: MPL-2.0
// Copyright (c) 2023 HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package tofu
import (
"context"
"strings"
"testing"
"github.com/hashicorp/hcl/v2"
"github.com/hashicorp/hcl/v2/hclsyntax"
"github.com/opentofu/opentofu/internal/addrs"
"github.com/opentofu/opentofu/internal/configs/configschema"
"github.com/opentofu/opentofu/internal/lang/marks"
"github.com/opentofu/opentofu/internal/providers"
"github.com/opentofu/opentofu/internal/tfdiags"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/function"
)
func TestFunctions(t *testing.T) {
mockProvider := &MockProvider{
GetProviderSchemaResponse: &providers.GetProviderSchemaResponse{
Provider: providers.Schema{},
Functions: map[string]providers.FunctionSpec{
"echo": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{providers.FunctionParameterSpec{
Name: "input",
Type: cty.String,
AllowNullValue: false,
AllowUnknownValues: false,
}},
Return: cty.String,
},
"concat": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{providers.FunctionParameterSpec{
Name: "input",
Type: cty.String,
AllowNullValue: false,
AllowUnknownValues: false,
}},
VariadicParameter: &providers.FunctionParameterSpec{
Name: "vary",
Type: cty.String,
AllowNullValue: false,
},
Return: cty.String,
},
"coalesce": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{providers.FunctionParameterSpec{
Name: "input1",
Type: cty.String,
AllowNullValue: true,
AllowUnknownValues: false,
}, providers.FunctionParameterSpec{
Name: "input2",
Type: cty.String,
AllowNullValue: false,
AllowUnknownValues: false,
}},
Return: cty.String,
},
"unknown_param": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{providers.FunctionParameterSpec{
Name: "input",
Type: cty.String,
AllowNullValue: false,
AllowUnknownValues: true,
}},
Return: cty.String,
},
"error_param": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{providers.FunctionParameterSpec{
Name: "input",
Type: cty.String,
AllowNullValue: false,
AllowUnknownValues: false,
}},
Return: cty.String,
},
},
},
}
mockProvider.CallFunctionFn = func(req providers.CallFunctionRequest) (resp providers.CallFunctionResponse) {
switch req.Name {
case "echo":
resp.Result = req.Arguments[0]
case "concat":
str := ""
for _, arg := range req.Arguments {
str += arg.AsString()
}
resp.Result = cty.StringVal(str)
case "coalesce":
resp.Result = req.Arguments[0]
if resp.Result.IsNull() {
resp.Result = req.Arguments[1]
}
case "unknown_param":
resp.Result = cty.StringVal("knownvalue")
case "error_param":
resp.Error = &providers.CallFunctionArgumentError{
Text: "my error text",
FunctionArgument: 0,
}
default:
panic("Invalid function")
}
return resp
}
mockProvider.GetFunctionsFn = func() (resp providers.GetFunctionsResponse) {
resp.Functions = mockProvider.GetProviderSchemaResponse.Functions
return resp
}
rng := tfdiags.SourceRange{}
providerFunc := func(fn string) addrs.ProviderFunction {
pf, _ := addrs.ParseFunction(fn).AsProviderFunction()
return pf
}
// Function missing (validate)
mockProvider.GetFunctionsCalled = false
_, diags := evalContextProviderFunction(t.Context(), mockProvider, walkValidate, providerFunc("provider::mockname::missing"), rng)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if mockProvider.GetFunctionsCalled {
t.Fatal("expected GetFunctions NOT to be called since it's not initialized")
}
// Function missing (Non-validate)
mockProvider.GetFunctionsCalled = false
_, diags = evalContextProviderFunction(t.Context(), mockProvider, walkPlan, providerFunc("provider::mockname::missing"), rng)
if !diags.HasErrors() {
t.Fatal("expected unknown function")
}
if diags.Err().Error() != `Function not found in provider: Function "provider::mockname::missing" was not registered by provider` {
t.Fatal(diags.Err())
}
if !mockProvider.GetFunctionsCalled {
t.Fatal("expected GetFunctions to be called")
}
ctx := &hcl.EvalContext{
Functions: map[string]function.Function{},
Variables: map[string]cty.Value{
"unknown_value": cty.UnknownVal(cty.String),
"sensitive_value": cty.StringVal("sensitive!").Mark(marks.Sensitive),
},
}
// Load functions into ctx
for _, fn := range []string{"echo", "concat", "coalesce", "unknown_param", "error_param"} {
pf := providerFunc("provider::mockname::" + fn)
impl, diags := evalContextProviderFunction(t.Context(), mockProvider, walkPlan, pf, rng)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
ctx.Functions[pf.String()] = *impl
}
evaluate := func(exprStr string) (cty.Value, hcl.Diagnostics) {
expr, diags := hclsyntax.ParseExpression([]byte(exprStr), "exprtest", hcl.InitialPos)
if diags.HasErrors() {
t.Fatal(diags)
}
return expr.Value(ctx)
}
t.Run("echo function", func(t *testing.T) {
// These are all assumptions that the provider implementation should not have to worry about:
t.Log("Checking not enough arguments")
_, diags := evaluate("provider::mockname::echo()")
if !strings.Contains(diags.Error(), `Not enough function arguments; Function "provider::mockname::echo" expects 1 argument(s). Missing value for "input"`) {
t.Error(diags.Error())
}
t.Log("Checking too many arguments")
_, diags = evaluate(`provider::mockname::echo("1", "2", "3")`)
if !strings.Contains(diags.Error(), `Too many function arguments; Function "provider::mockname::echo" expects only 1 argument(s)`) {
t.Error(diags.Error())
}
t.Log("Checking null argument")
_, diags = evaluate(`provider::mockname::echo(null)`)
if !strings.Contains(diags.Error(), `Invalid function argument; Invalid value for "input" parameter: argument must not be null`) {
t.Error(diags.Error())
}
t.Log("Checking unknown argument")
val, diags := evaluate(`provider::mockname::echo(unknown_value)`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.UnknownVal(cty.String)) {
t.Error(val.AsString())
}
// Actually test the function implementation
t.Log("Checking valid argument")
val, diags = evaluate(`provider::mockname::echo("hello functions!")`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.StringVal("hello functions!")) {
t.Error(val.AsString())
}
t.Log("Checking sensitive argument")
val, diags = evaluate(`provider::mockname::echo(sensitive_value)`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.StringVal("sensitive!").Mark(marks.Sensitive)) {
t.Error(val.AsString())
}
})
t.Run("concat function", func(t *testing.T) {
// Make sure varargs are handled properly
// Single
val, diags := evaluate(`provider::mockname::concat("foo")`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.StringVal("foo")) {
t.Error(val.AsString())
}
// Multi
val, diags = evaluate(`provider::mockname::concat("foo", "bar", "baz")`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.StringVal("foobarbaz")) {
t.Error(val.AsString())
}
})
t.Run("coalesce function", func(t *testing.T) {
val, diags := evaluate(`provider::mockname::coalesce("first", "second")`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.StringVal("first")) {
t.Error(val.AsString())
}
val, diags = evaluate(`provider::mockname::coalesce(null, "second")`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.StringVal("second")) {
t.Error(val.AsString())
}
})
t.Run("unknown_param function", func(t *testing.T) {
val, diags := evaluate(`provider::mockname::unknown_param(unknown_value)`)
if diags.HasErrors() {
t.Error(diags.Error())
}
if !val.RawEquals(cty.StringVal("knownvalue")) {
t.Error(val.AsString())
}
})
t.Run("error_param function", func(t *testing.T) {
_, diags := evaluate(`provider::mockname::error_param("foo")`)
if !strings.Contains(diags.Error(), `Invalid function argument; Invalid value for "input" parameter: my error text.`) {
t.Error(diags.Error())
}
})
}
// Standard scenario using root provider explicitly passed
func TestContext2Functions_providerFunctions(t *testing.T) {
p := testProvider("aws")
p.GetProviderSchemaResponse = &providers.GetProviderSchemaResponse{
Provider: providers.Schema{
Block: &configschema.Block{
Attributes: map[string]*configschema.Attribute{
"region": &configschema.Attribute{
Type: cty.String,
},
},
},
},
Functions: map[string]providers.FunctionSpec{
"arn_parse": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{{
Name: "arn",
Type: cty.String,
}},
Return: cty.Bool,
},
},
}
p.CallFunctionResponse = &providers.CallFunctionResponse{
Result: cty.True,
}
m := testModuleInline(t, map[string]string{
"main.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
provider "aws" {
region="us-east-1"
}
module "mod" {
source = "./mod"
providers = {
aws = aws
}
}
`,
"mod/mod.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
variable "obfmod" {
type = object({
arns = optional(list(string))
})
description = "Configuration for xxx."
validation {
condition = alltrue([
for arn in var.obfmod.arns: can(provider::aws::arn_parse(arn))
])
error_message = "All arns MUST BE a valid AWS ARN format."
}
default = {
arns = [
"arn:partition:service:region:account-id:resource-id",
]
}
}
`,
})
ctx := testContext2(t, &ContextOpts{
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
diags := ctx.Validate(context.Background(), m)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if !p.CallFunctionCalled {
t.Fatalf("Expected function call")
}
}
// Explicitly passed provider with custom function
func TestContext2Functions_providerFunctionsCustom(t *testing.T) {
p := testProvider("aws")
p.GetFunctionsResponse = &providers.GetFunctionsResponse{
Functions: map[string]providers.FunctionSpec{
"arn_parse_custom": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{{
Name: "arn",
Type: cty.String,
}},
Return: cty.Bool,
},
},
}
p.CallFunctionResponse = &providers.CallFunctionResponse{
Result: cty.True,
}
m := testModuleInline(t, map[string]string{
"main.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
provider "aws" {
region="us-east-1"
alias = "primary"
}
module "mod" {
source = "./mod"
providers = {
aws = aws.primary
}
}
`,
"mod/mod.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
variable "obfmod" {
type = object({
arns = optional(list(string))
})
description = "Configuration for xxx."
validation {
condition = alltrue([
for arn in var.obfmod.arns: can(provider::aws::arn_parse_custom(arn))
])
error_message = "All arns MUST BE a valid AWS ARN format."
}
default = {
arns = [
"arn:partition:service:region:account-id:resource-id",
]
}
}
`,
})
ctx := testContext2(t, &ContextOpts{
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
diags := ctx.Validate(context.Background(), m)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if p.GetFunctionsCalled {
t.Fatalf("Unexpected function call")
}
if p.CallFunctionCalled {
t.Fatalf("Unexpected function call")
}
p.GetFunctionsCalled = false
p.CallFunctionCalled = false
_, diags = ctx.Plan(context.Background(), m, nil, nil)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if !p.GetFunctionsCalled {
t.Fatalf("Expected function call")
}
if !p.CallFunctionCalled {
t.Fatalf("Expected function call")
}
}
// Defaulted stub provider with non-custom function
func TestContext2Functions_providerFunctionsStub(t *testing.T) {
p := testProvider("aws")
p.GetProviderSchemaResponse = &providers.GetProviderSchemaResponse{
Functions: map[string]providers.FunctionSpec{
"arn_parse": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{{
Name: "arn",
Type: cty.String,
}},
Return: cty.Bool,
},
},
}
p.CallFunctionResponse = &providers.CallFunctionResponse{
Result: cty.True,
}
m := testModuleInline(t, map[string]string{
"main.tf": `
module "mod" {
source = "./mod"
}
`,
"mod/mod.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
variable "obfmod" {
type = object({
arns = optional(list(string))
})
description = "Configuration for xxx."
validation {
condition = alltrue([
for arn in var.obfmod.arns: can(provider::aws::arn_parse(arn))
])
error_message = "All arns MUST BE a valid AWS ARN format."
}
default = {
arns = [
"arn:partition:service:region:account-id:resource-id",
]
}
}
`,
})
ctx := testContext2(t, &ContextOpts{
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
diags := ctx.Validate(context.Background(), m)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if !p.GetProviderSchemaCalled {
t.Fatalf("Unexpected function call")
}
if p.GetFunctionsCalled {
t.Fatalf("Unexpected function call")
}
if !p.CallFunctionCalled {
t.Fatalf("Unexpected function call")
}
p.GetProviderSchemaCalled = false
p.GetFunctionsCalled = false
p.CallFunctionCalled = false
_, diags = ctx.Plan(context.Background(), m, nil, nil)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if !p.GetProviderSchemaCalled {
t.Fatalf("Unexpected function call")
}
if p.GetFunctionsCalled {
t.Fatalf("Expected function call")
}
if !p.CallFunctionCalled {
t.Fatalf("Expected function call")
}
}
// Defaulted stub provider with custom function (no allowed)
func TestContext2Functions_providerFunctionsStubCustom(t *testing.T) {
p := testProvider("aws")
p.GetProviderSchemaResponse = &providers.GetProviderSchemaResponse{
Functions: map[string]providers.FunctionSpec{
"arn_parse": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{{
Name: "arn",
Type: cty.String,
}},
Return: cty.Bool,
},
},
}
p.CallFunctionResponse = &providers.CallFunctionResponse{
Result: cty.True,
}
m := testModuleInline(t, map[string]string{
"main.tf": `
module "mod" {
source = "./mod"
}
`,
"mod/mod.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
variable "obfmod" {
type = object({
arns = optional(list(string))
})
description = "Configuration for xxx."
validation {
condition = alltrue([
for arn in var.obfmod.arns: can(provider::aws::arn_parse_custom(arn))
])
error_message = "All arns MUST BE a valid AWS ARN format."
}
default = {
arns = [
"arn:partition:service:region:account-id:resource-id",
]
}
}
`,
})
ctx := testContext2(t, &ContextOpts{
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
_, diags := ctx.Plan(context.Background(), m, nil, nil)
if !diags.HasErrors() {
t.Fatal("Expected error!")
}
expected := `Function not found in provider: Function "provider::aws::arn_parse_custom" was not registered by provider`
if expected != diags.Err().Error() {
t.Fatalf("Expected error %q, got %q", expected, diags.Err().Error())
}
if !p.GetFunctionsCalled {
t.Fatalf("Expected function call")
}
if p.CallFunctionCalled {
t.Fatalf("Unexpected function call")
}
}
// Defaulted stub provider
func TestContext2Functions_providerFunctionsForEachCount(t *testing.T) {
p := testProvider("aws")
p.GetProviderSchemaResponse = &providers.GetProviderSchemaResponse{
Functions: map[string]providers.FunctionSpec{
"arn_parse": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{{
Name: "arn",
Type: cty.String,
}},
Return: cty.Bool,
},
},
}
p.CallFunctionResponse = &providers.CallFunctionResponse{
Result: cty.True,
}
m := testModuleInline(t, map[string]string{
"main.tf": `
provider "aws" {
for_each = {"a": 1, "b": 2}
alias = "iter"
}
module "mod" {
source = "./mod"
for_each = {"a": 1, "b": 2}
providers = {
aws = aws.iter[each.key]
}
}
`,
"mod/mod.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
variable "obfmod" {
type = object({
arns = optional(list(string))
})
description = "Configuration for xxx."
validation {
condition = alltrue([
for arn in var.obfmod.arns: can(provider::aws::arn_parse(arn))
])
error_message = "All arns MUST BE a valid AWS ARN format."
}
default = {
arns = [
"arn:partition:service:region:account-id:resource-id",
]
}
}
`,
})
ctx := testContext2(t, &ContextOpts{
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
diags := ctx.Validate(context.Background(), m)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if !p.GetProviderSchemaCalled {
t.Fatalf("Unexpected function call")
}
if p.GetFunctionsCalled {
t.Fatalf("Unexpected function call")
}
if !p.CallFunctionCalled {
t.Fatalf("Unexpected function call")
}
p.GetProviderSchemaCalled = false
p.GetFunctionsCalled = false
p.CallFunctionCalled = false
_, diags = ctx.Plan(context.Background(), m, nil, nil)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if !p.GetProviderSchemaCalled {
t.Fatalf("Unexpected function call")
}
if p.GetFunctionsCalled {
t.Fatalf("Expected function call")
}
if !p.CallFunctionCalled {
t.Fatalf("Expected function call")
}
}
// Functions used as variable values are evaluated correctly
func TestContext2Functions_providerFunctionsVariableCustom(t *testing.T) {
p := testProvider("aws")
p.GetFunctionsResponse = &providers.GetFunctionsResponse{
Functions: map[string]providers.FunctionSpec{
"arn_parse_custom": providers.FunctionSpec{
Parameters: []providers.FunctionParameterSpec{{
Name: "arn",
Type: cty.String,
}},
Return: cty.Bool,
},
},
}
p.CallFunctionResponse = &providers.CallFunctionResponse{
Result: cty.True,
}
m := testModuleInline(t, map[string]string{
"main.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
provider "aws" {
region="us-east-1"
alias = "primary"
}
module "mod" {
source = "./mod"
providers = {
aws = aws.primary
}
}
`,
"mod/mod.tf": `
terraform {
required_providers {
aws = ">=5.70.0"
}
}
module "mod2" {
source = "./mod2"
value = provider::aws::arn_parse_custom("foo")
}
`,
"mod/mod2/mod.tf": `
variable "value" { }
`,
})
ctx := testContext2(t, &ContextOpts{
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
diags := ctx.Validate(context.Background(), m)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if p.GetFunctionsCalled {
t.Fatalf("Unexpected function call")
}
if p.CallFunctionCalled {
t.Fatalf("Unexpected function call")
}
p.GetFunctionsCalled = false
p.CallFunctionCalled = false
_, diags = ctx.Plan(context.Background(), m, nil, nil)
if diags.HasErrors() {
t.Fatal(diags.Err())
}
if !p.GetFunctionsCalled {
t.Fatalf("Expected function call")
}
if !p.CallFunctionCalled {
t.Fatalf("Expected function call")
}
}