Filter provider schemas for performance

Performance testing identified that the decoding of provider GetSchema
responses took a non trivial quantity of time and memory.

This removes the problematic global provider schema cache and instead
manages the primary cache as a field passed to the GRPC providers.

This, combined with better handling of the GetProviderSchemaOptional=false
case, dramatically reduces the overhead of running providers.

The primary downside to this approach as it currently stands is that a
new contextPlugins is created for each backend operation (validate,
plan, apply).  This will start up and then subsequently quickly shut
down an already cached provider.  In practice, this is still a
performance win and should be addressed in a subsequent PR that
overhauls the provider and provisioner management interfaces in
a more comprehensive fashion.

Signed-off-by: Christian Mesh <christianmesh1@gmail.com>
This commit is contained in:
Christian Mesh
2025-07-10 17:48:07 -04:00
parent 868dc2f01b
commit 8f6a372d5e
26 changed files with 322 additions and 229 deletions

View File

@@ -0,0 +1,38 @@
package addrs
type ProviderResourceRequirments map[ResourceMode]map[string]struct{}
type ProviderSchemaRequirements map[Provider]ProviderResourceRequirments
func (p ProviderSchemaRequirements) AddResource(provider Provider, mode ResourceMode, typ string) {
pm, ok := p[provider]
if !ok {
pm = make(ProviderResourceRequirments)
p[provider] = pm
}
mm, ok := pm[mode]
if !ok {
mm = make(map[string]struct{})
pm[mode] = mm
}
mm[typ] = struct{}{}
}
func (s ProviderResourceRequirments) HasResource(mode ResourceMode, typ string) bool {
if s == nil {
// Legacy path
return true
}
_, ok := s[mode][typ]
return ok
}
func (p ProviderSchemaRequirements) Merge(other ProviderSchemaRequirements) {
for provider, pm := range other {
for mode, mm := range pm {
for typ := range mm {
p.AddResource(provider, mode, typ)
}
}
}
}

View File

@@ -227,7 +227,12 @@ func (b *Local) localRunDirect(ctx context.Context, op *backend.Operation, run *
}
run.InputState = state
tfCtx, moreDiags := tofu.NewContext(coreOpts)
reqs := config.ProviderSchemaRequirements()
if state != nil {
reqs.Merge(state.ProviderSchemaRequirements())
}
tfCtx, moreDiags := tofu.NewContext(coreOpts, reqs)
diags = diags.Append(moreDiags)
if moreDiags.HasErrors() {
return nil, nil, diags
@@ -388,7 +393,12 @@ func (b *Local) localRunForPlanFile(ctx context.Context, op *backend.Operation,
// refreshing we did while building the plan.
run.InputState = priorStateFile.State
tfCtx, moreDiags := tofu.NewContext(coreOpts)
reqs := config.ProviderSchemaRequirements()
if run.InputState != nil {
reqs.Merge(run.InputState.ProviderSchemaRequirements())
}
tfCtx, moreDiags := tofu.NewContext(coreOpts, reqs)
diags = diags.Append(moreDiags)
if moreDiags.HasErrors() {
return nil, nil, diags

View File

@@ -150,7 +150,7 @@ func (b *Remote) LocalRun(ctx context.Context, op *backend.Operation) (*backend.
}
}
tfCtx, ctxDiags := tofu.NewContext(&opts)
tfCtx, ctxDiags := tofu.NewContext(&opts, nil)
diags = diags.Append(ctxDiags)
ret.Core = tfCtx

View File

@@ -151,7 +151,7 @@ func (b *Cloud) LocalRun(ctx context.Context, op *backend.Operation) (*backend.L
}
}
tfCtx, ctxDiags := tofu.NewContext(&opts)
tfCtx, ctxDiags := tofu.NewContext(&opts, nil)
diags = diags.Append(ctxDiags)
ret.Core = tfCtx

View File

@@ -586,9 +586,7 @@ func (m *Meta) contextOpts(ctx context.Context) (*tofu.ContextOpts, error) {
opts.Providers = m.testingOverrides.Providers
opts.Provisioners = m.testingOverrides.Provisioners
} else {
var providerFactories map[addrs.Provider]providers.Factory
providerFactories, err = m.providerFactories()
opts.Providers = providerFactories
opts.ProvidersFn = m.providerFactories
opts.Provisioners = m.provisionerFactories()
}
@@ -949,7 +947,7 @@ func (c *Meta) MaybeGetSchemas(ctx context.Context, state *states.State, config
diags = diags.Append(err)
return nil, diags
}
tfCtx, ctxDiags := tofu.NewContext(opts)
tfCtx, ctxDiags := tofu.NewContext(opts, nil)
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
return nil, diags

View File

@@ -233,7 +233,7 @@ func (m *Meta) providerDevOverrideRuntimeWarnings() tfdiags.Diagnostics {
// package have been modified outside of the installer. If it returns an error,
// the returned map may be incomplete or invalid, but will be as complete
// as possible given the cause of the error.
func (m *Meta) providerFactories() (map[addrs.Provider]providers.Factory, error) {
func (m *Meta) providerFactories(reqs addrs.ProviderSchemaRequirements) (map[addrs.Provider]providers.Factory, error) {
locks, diags := m.lockedDependencies()
if diags.HasErrors() {
return nil, fmt.Errorf("failed to read dependency lock file: %w", diags.Err())
@@ -276,6 +276,20 @@ func (m *Meta) providerFactories() (map[addrs.Provider]providers.Factory, error)
devOverrideProviders := m.ProviderDevOverrides
unmanagedProviders := m.UnmanagedProviders
providerReqs := func(provider addrs.Provider) addrs.ProviderResourceRequirments {
// Legacy path that does not support filtering
if reqs == nil {
// Unfiltered
return nil
}
if req, ok := reqs[provider]; ok {
return req
}
// Filtered (empty)
return addrs.ProviderResourceRequirments{}
}
factories := make(map[addrs.Provider]providers.Factory, len(providerLocks)+len(internalFactories)+len(unmanagedProviders))
for name, factory := range internalFactories {
factories[addrs.NewBuiltInProvider(name)] = factory
@@ -324,13 +338,13 @@ func (m *Meta) providerFactories() (map[addrs.Provider]providers.Factory, error)
continue
}
}
factories[provider] = providerFactory(cached)
factories[provider] = providerFactory(cached, providerReqs(provider))
}
for provider, localDir := range devOverrideProviders {
factories[provider] = devOverrideProviderFactory(provider, localDir)
factories[provider] = devOverrideProviderFactory(provider, localDir, providerReqs(provider))
}
for provider, reattach := range unmanagedProviders {
factories[provider] = unmanagedProviderFactory(provider, reattach)
factories[provider] = unmanagedProviderFactory(provider, reattach, providerReqs(provider))
}
var err error
@@ -351,7 +365,11 @@ func (m *Meta) internalProviders() map[string]providers.Factory {
// providerFactory produces a provider factory that runs up the executable
// file in the given cache package and uses go-plugin to implement
// providers.Interface against it.
func providerFactory(meta *providercache.CachedProvider) providers.Factory {
func providerFactory(meta *providercache.CachedProvider, reqs addrs.ProviderResourceRequirments) providers.Factory {
// Same schema for each instance
schema := &providers.CachedSchema{
Filter: reqs,
}
return func() (providers.Interface, error) {
execFile, err := meta.ExecutableFile()
if err != nil {
@@ -382,7 +400,7 @@ func providerFactory(meta *providercache.CachedProvider) providers.Factory {
}
protoVer := client.NegotiatedVersion()
p, err := initializeProviderInstance(raw, protoVer, client, meta.Provider)
p, err := initializeProviderInstance(raw, protoVer, client, meta.Provider, schema)
if errors.Is(err, errUnsupportedProtocolVersion) {
panic(err)
}
@@ -393,25 +411,27 @@ func providerFactory(meta *providercache.CachedProvider) providers.Factory {
// initializeProviderInstance uses the plugin dispensed by the RPC client, and initializes a plugin instance
// per the protocol version
func initializeProviderInstance(plugin interface{}, protoVer int, pluginClient *plugin.Client, pluginAddr addrs.Provider) (providers.Interface, error) {
func initializeProviderInstance(plugin interface{}, protoVer int, pluginClient *plugin.Client, pluginAddr addrs.Provider, schema *providers.CachedSchema) (providers.Interface, error) {
// store the client so that the plugin can kill the child process
switch protoVer {
case 5:
p := plugin.(*tfplugin.GRPCProvider)
p.PluginClient = pluginClient
p.Addr = pluginAddr
p.Schema = schema
return p, nil
case 6:
p := plugin.(*tfplugin6.GRPCProvider)
p.PluginClient = pluginClient
p.Addr = pluginAddr
p.Schema = schema
return p, nil
default:
return nil, errUnsupportedProtocolVersion
}
}
func devOverrideProviderFactory(provider addrs.Provider, localDir getproviders.PackageLocalDir) providers.Factory {
func devOverrideProviderFactory(provider addrs.Provider, localDir getproviders.PackageLocalDir, reqs addrs.ProviderResourceRequirments) providers.Factory {
// A dev override is essentially a synthetic cache entry for our purposes
// here, so that's how we'll construct it. The providerFactory function
// doesn't actually care about the version, so we can leave it
@@ -421,13 +441,17 @@ func devOverrideProviderFactory(provider addrs.Provider, localDir getproviders.P
Provider: provider,
Version: getproviders.UnspecifiedVersion,
PackageDir: string(localDir),
})
}, reqs)
}
// unmanagedProviderFactory produces a provider factory that uses the passed
// reattach information to connect to go-plugin processes that are already
// running, and implements providers.Interface against it.
func unmanagedProviderFactory(provider addrs.Provider, reattach *plugin.ReattachConfig) providers.Factory {
func unmanagedProviderFactory(provider addrs.Provider, reattach *plugin.ReattachConfig, reqs addrs.ProviderResourceRequirments) providers.Factory {
// Same schema for each instance
schema := &providers.CachedSchema{
Filter: reqs,
}
return func() (providers.Interface, error) {
config := &plugin.ClientConfig{
HandshakeConfig: tfplugin.Handshake,
@@ -477,7 +501,7 @@ func unmanagedProviderFactory(provider addrs.Provider, reattach *plugin.Reattach
protoVer = 5
}
return initializeProviderInstance(raw, protoVer, client, provider)
return initializeProviderInstance(raw, protoVer, client, provider, schema)
}
}

View File

@@ -701,7 +701,7 @@ func (runner *TestFileRunner) validate(ctx context.Context, config *configs.Conf
var diags tfdiags.Diagnostics
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts)
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts, config.ProviderSchemaRequirements())
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
return diags
@@ -756,7 +756,7 @@ func (runner *TestFileRunner) destroy(ctx context.Context, config *configs.Confi
SetVariables: variables,
}
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts)
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts, config.ProviderSchemaRequirements())
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
return state, diags
@@ -833,7 +833,7 @@ func (runner *TestFileRunner) plan(ctx context.Context, config *configs.Config,
ExternalReferences: references,
}
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts)
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts, config.ProviderSchemaRequirements())
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
return nil, nil, diags
@@ -888,7 +888,7 @@ func (runner *TestFileRunner) apply(ctx context.Context, plan *plans.Plan, state
created = append(created, change)
}
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts)
tfCtx, ctxDiags := tofu.NewContext(runner.Suite.Opts, config.ProviderSchemaRequirements())
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
return nil, state, diags

View File

@@ -113,7 +113,7 @@ func (c *ValidateCommand) validate(ctx context.Context, dir, testDir string, noT
return diags
}
tfCtx, ctxDiags := tofu.NewContext(opts)
tfCtx, ctxDiags := tofu.NewContext(opts, cfg.ProviderSchemaRequirements())
diags = diags.Append(ctxDiags)
if ctxDiags.HasErrors() {
return diags

View File

@@ -817,6 +817,35 @@ func (c *Config) ProviderTypes() []addrs.Provider {
return ret
}
func (c *Config) ProviderSchemaRequirements() addrs.ProviderSchemaRequirements {
result := make(addrs.ProviderSchemaRequirements)
for _, r := range c.Module.ManagedResources {
result.AddResource(r.Provider, r.Mode, r.Type)
}
for _, d := range c.Module.DataResources {
result.AddResource(d.Provider, d.Mode, d.Type)
}
for _, c := range c.Module.Checks {
d := c.DataResource
if d != nil {
result.AddResource(d.Provider, d.Mode, d.Type)
}
}
/* TODO
Import []*Import
Removed []*Removed
*/
for _, ch := range c.Children {
result.Merge(ch.ProviderSchemaRequirements())
}
return result
}
// ResolveAbsProviderAddr returns the AbsProviderConfig represented by the given
// ProviderConfig address, which must not be nil or this method will panic.
//

View File

@@ -9,7 +9,6 @@ import (
"context"
"errors"
"fmt"
"sync"
plugin "github.com/hashicorp/go-plugin"
"github.com/zclconf/go-cty/cty"
@@ -87,42 +86,25 @@ type GRPCProvider struct {
// to use as the parent context for gRPC API calls.
ctx context.Context
mu sync.Mutex
// schema stores the schema for this provider. This is used to properly
// serialize the requests for schemas.
schema providers.GetProviderSchemaResponse
Schema *providers.CachedSchema
hasFetchedSchemas bool
}
var _ providers.Interface = new(GRPCProvider)
func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.GetProviderSchemaResponse) {
logger.Trace("GRPCProvider: GetProviderSchema")
p.mu.Lock()
defer p.mu.Unlock()
// First, we check the global cache.
// The cache could contain this schema if an instance of this provider has previously been started.
if !p.Addr.IsZero() {
// Even if the schema is cached, GetProviderSchemaOptional could be false. This would indicate that once instantiated,
// this provider requires the get schema call to be made at least once, as it handles part of the provider's setup.
// At this point, we don't know if this is the first call to a provider instance or not, so we don't use the result in that case.
if schemaCached, ok := providers.SchemaCache.Get(p.Addr); ok && schemaCached.ServerCapabilities.GetProviderSchemaOptional {
logger.Trace("GRPCProvider: GetProviderSchema: serving from global schema cache", "address", p.Addr)
return schemaCached
}
p.Schema.Lock()
defer p.Schema.Unlock()
// Check to see if the schema cache has been populated AND we are allowed to use it. Some providers require GetProviderSchema to be called on startup.
if p.Schema.Value != nil && (p.Schema.Value.ServerCapabilities.GetProviderSchemaOptional || p.hasFetchedSchemas) {
logger.Trace("GRPCProvider: GetProviderSchema: serving from schema cache", "address", p.Addr)
return *p.Schema.Value
}
// If the local cache is non-zero, we know this instance has called
// GetProviderSchema at least once, so has satisfied the possible requirement of `GetProviderSchemaOptional=false`.
// This means that we can return early now using the locally cached schema, without making this call again.
if p.schema.Provider.Block != nil {
return p.schema
}
resp.ResourceTypes = make(map[string]providers.Schema)
resp.DataSources = make(map[string]providers.Schema)
resp.Functions = make(map[string]providers.FunctionSpec)
// Some providers may generate quite large schemas, and the internal default
// grpc response size limit is 4MB. 64MB should cover most any use case, and
// if we get providers nearing that we may want to consider a finer-grained
@@ -138,6 +120,18 @@ func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.Ge
return resp
}
// Mark that this instance of the provider has fetched schemas
p.hasFetchedSchemas = true
// Check to see if the schema cache has been populated AND we are allowed to use it. We can skip decode if it's already been done by another instance.
if p.Schema.Value != nil {
logger.Trace("GRPCProvider: GetProviderSchema: serving from schema cache", "address", p.Addr)
return *p.Schema.Value
}
resp.ResourceTypes = make(map[string]providers.Schema)
resp.DataSources = make(map[string]providers.Schema)
resp.Functions = make(map[string]providers.FunctionSpec)
resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(protoResp.Diagnostics))
if resp.Diagnostics.HasErrors() {
@@ -157,11 +151,15 @@ func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.Ge
}
for name, res := range protoResp.ResourceSchemas {
resp.ResourceTypes[name] = convert.ProtoToProviderSchema(res)
if p.Schema.Filter.HasResource(addrs.ManagedResourceMode, name) {
resp.ResourceTypes[name] = convert.ProtoToProviderSchema(res)
}
}
for name, data := range protoResp.DataSourceSchemas {
resp.DataSources[name] = convert.ProtoToProviderSchema(data)
if p.Schema.Filter.HasResource(addrs.DataResourceMode, name) {
resp.DataSources[name] = convert.ProtoToProviderSchema(data)
}
}
for name, fn := range protoResp.Functions {
@@ -173,22 +171,7 @@ func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.Ge
resp.ServerCapabilities.GetProviderSchemaOptional = protoResp.ServerCapabilities.GetProviderSchemaOptional
}
// Set the global provider cache so that future calls to this provider can use the cached value.
// Crucially, this doesn't look at GetProviderSchemaOptional, because the layers above could use this cache
// *without* creating an instance of this provider. And if there is no instance,
// then we don't need to set up anything (cause there is nothing to set up), so we need no call
// to the providers GetSchema rpc.
if !p.Addr.IsZero() {
providers.SchemaCache.Set(p.Addr, resp)
}
// Always store this here in the client for providers that are not able to use GetProviderSchemaOptional.
// Crucially, this indicates that we've made at least one call to GetProviderSchema to this instance of the provider,
// which means in the future we'll be able to return using this cache
// (because the possible setup contained in the GetProviderSchema call has happened).
// If GetProviderSchemaOptional is true then this cache won't actually ever be used, because the calls to this method
// will be satisfied by the global provider cache.
p.schema = resp
p.Schema.Value = &resp
return resp
}

View File

@@ -121,6 +121,7 @@ func providerProtoSchema() *proto.GetProviderSchema_Response {
func TestGRPCProvider_GetSchema(t *testing.T) {
p := &GRPCProvider{
client: mockProviderClient(t),
Schema: &providers.CachedSchema{},
}
resp := p.GetProviderSchema(t.Context())
@@ -141,6 +142,7 @@ func TestGRPCProvider_GetSchema_GRPCError(t *testing.T) {
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
resp := p.GetProviderSchema(t.Context())
@@ -151,8 +153,8 @@ func TestGRPCProvider_GetSchema_GRPCError(t *testing.T) {
func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
ctrl := gomock.NewController(t)
client := mockproto.NewMockProviderClient(ctrl)
// The SchemaCache is global and is saved between test runs
providers.SchemaCache = providers.NewMockSchemaCache()
schemaCache := &providers.CachedSchema{}
providerAddr := addrs.Provider{
Namespace: "namespace",
@@ -175,6 +177,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
p := &GRPCProvider{
client: client,
Addr: providerAddr,
Schema: schemaCache,
}
resp := p.GetProviderSchema(t.Context())
@@ -186,6 +189,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
p = &GRPCProvider{
client: client,
Addr: providerAddr,
Schema: schemaCache,
}
resp = p.GetProviderSchema(t.Context())
@@ -198,8 +202,6 @@ func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
func TestGRPCProvider_GetSchema_GlobalCacheDisabled(t *testing.T) {
ctrl := gomock.NewController(t)
client := mockproto.NewMockProviderClient(ctrl)
// The SchemaCache is global and is saved between test runs
providers.SchemaCache = providers.NewMockSchemaCache()
providerAddr := addrs.Provider{
Namespace: "namespace",
@@ -222,6 +224,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheDisabled(t *testing.T) {
p := &GRPCProvider{
client: client,
Addr: providerAddr,
Schema: &providers.CachedSchema{},
}
resp := p.GetProviderSchema(t.Context())
@@ -233,6 +236,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheDisabled(t *testing.T) {
p = &GRPCProvider{
client: client,
Addr: providerAddr,
Schema: &providers.CachedSchema{},
}
resp = p.GetProviderSchema(t.Context())
@@ -266,6 +270,7 @@ func TestGRPCProvider_GetSchema_ResponseErrorDiagnostic(t *testing.T) {
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
resp := p.GetProviderSchema(t.Context())
@@ -277,6 +282,7 @@ func TestGRPCProvider_PrepareProviderConfig(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().PrepareProviderConfig(
@@ -293,6 +299,7 @@ func TestGRPCProvider_ValidateResourceConfig(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ValidateResourceTypeConfig(
@@ -312,6 +319,7 @@ func TestGRPCProvider_ValidateDataSourceConfig(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ValidateDataSourceConfig(
@@ -331,6 +339,7 @@ func TestGRPCProvider_UpgradeResourceState(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().UpgradeResourceState(
@@ -362,6 +371,7 @@ func TestGRPCProvider_UpgradeResourceStateJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().UpgradeResourceState(
@@ -393,6 +403,7 @@ func TestGRPCProvider_MoveResourceState(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().MoveResourceState(
@@ -429,6 +440,7 @@ func TestGRPCProvider_Configure(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().Configure(
@@ -449,6 +461,7 @@ func TestGRPCProvider_Stop(t *testing.T) {
client := mockproto.NewMockProviderClient(ctrl)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().Stop(
@@ -466,6 +479,7 @@ func TestGRPCProvider_ReadResource(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadResource(
@@ -499,6 +513,7 @@ func TestGRPCProvider_ReadResourceJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadResource(
@@ -532,6 +547,7 @@ func TestGRPCProvider_ReadEmptyJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadResource(
@@ -564,6 +580,7 @@ func TestGRPCProvider_PlanResourceChange(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -627,6 +644,7 @@ func TestGRPCProvider_PlanResourceChange_deferred(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().PlanResourceChange(
@@ -673,6 +691,7 @@ func TestGRPCProvider_PlanResourceChangeJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -736,6 +755,7 @@ func TestGRPCProvider_ApplyResourceChange(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -782,6 +802,7 @@ func TestGRPCProvider_ApplyResourceChangeJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -829,6 +850,7 @@ func TestGRPCProvider_ImportResourceState(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -872,6 +894,7 @@ func TestGRPCProvider_ImportResourceStateJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -916,6 +939,7 @@ func TestGRPCProvider_ReadDataSource(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadDataSource(
@@ -949,6 +973,7 @@ func TestGRPCProvider_ReadDataSourceJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadDataSource(
@@ -982,6 +1007,7 @@ func TestGRPCProvider_CallFunction(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().CallFunction(

View File

@@ -9,7 +9,6 @@ import (
"context"
"errors"
"fmt"
"sync"
plugin "github.com/hashicorp/go-plugin"
"github.com/zclconf/go-cty/cty"
@@ -87,42 +86,24 @@ type GRPCProvider struct {
// to use as the parent context for gRPC API calls.
ctx context.Context
mu sync.Mutex
// schema stores the schema for this provider. This is used to properly
// serialize the requests for schemas.
schema providers.GetProviderSchemaResponse
Schema *providers.CachedSchema
hasFetchedSchemas bool
}
var _ providers.Interface = new(GRPCProvider)
func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.GetProviderSchemaResponse) {
logger.Trace("GRPCProvider.v6: GetProviderSchema")
p.mu.Lock()
defer p.mu.Unlock()
p.Schema.Lock()
defer p.Schema.Unlock()
// First, we check the global cache.
// The cache could contain this schema if an instance of this provider has previously been started.
if !p.Addr.IsZero() {
// Even if the schema is cached, GetProviderSchemaOptional could be false. This would indicate that once instantiated,
// this provider requires the get schema call to be made at least once, as it handles part of the provider's setup.
// At this point, we don't know if this is the first call to a provider instance or not, so we don't use the result in that case.
if schemaCached, ok := providers.SchemaCache.Get(p.Addr); ok && schemaCached.ServerCapabilities.GetProviderSchemaOptional {
logger.Trace("GRPCProvider: GetProviderSchema: serving from global schema cache", "address", p.Addr)
return schemaCached
}
// Check to see if the schema cache has been populated AND we are allowed to use it. Some providers require GetProviderSchema to be called on startup.
if p.Schema.Value != nil && (p.Schema.Value.ServerCapabilities.GetProviderSchemaOptional || p.hasFetchedSchemas) {
logger.Trace("GRPCProvider: GetProviderSchema: serving from schema cache", "address", p.Addr)
return *p.Schema.Value
}
// If the local cache is non-zero, we know this instance has called
// GetProviderSchema at least once, so has satisfied the possible requirement of `GetProviderSchemaOptional=false`.
// This means that we can return early now using the locally cached schema, without making this call again.
if p.schema.Provider.Block != nil {
return p.schema
}
resp.ResourceTypes = make(map[string]providers.Schema)
resp.DataSources = make(map[string]providers.Schema)
resp.Functions = make(map[string]providers.FunctionSpec)
// Some providers may generate quite large schemas, and the internal default
// grpc response size limit is 4MB. 64MB should cover most any use case, and
// if we get providers nearing that we may want to consider a finer-grained
@@ -138,6 +119,18 @@ func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.Ge
return resp
}
// Mark that this instance of the provider has fetched schemas
p.hasFetchedSchemas = true
// Check to see if the schema cache has been populated AND we are allowed to use it. We can skip decode if it's already been done by another instance.
if p.Schema.Value != nil {
logger.Trace("GRPCProvider: GetProviderSchema: serving from schema cache", "address", p.Addr)
return *p.Schema.Value
}
resp.ResourceTypes = make(map[string]providers.Schema)
resp.DataSources = make(map[string]providers.Schema)
resp.Functions = make(map[string]providers.FunctionSpec)
resp.Diagnostics = resp.Diagnostics.Append(convert.ProtoToDiagnostics(protoResp.Diagnostics))
if resp.Diagnostics.HasErrors() {
@@ -157,11 +150,15 @@ func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.Ge
}
for name, res := range protoResp.ResourceSchemas {
resp.ResourceTypes[name] = convert.ProtoToProviderSchema(res)
if p.Schema.Filter.HasResource(addrs.ManagedResourceMode, name) {
resp.ResourceTypes[name] = convert.ProtoToProviderSchema(res)
}
}
for name, data := range protoResp.DataSourceSchemas {
resp.DataSources[name] = convert.ProtoToProviderSchema(data)
if p.Schema.Filter.HasResource(addrs.DataResourceMode, name) {
resp.DataSources[name] = convert.ProtoToProviderSchema(data)
}
}
for name, fn := range protoResp.Functions {
@@ -173,22 +170,7 @@ func (p *GRPCProvider) GetProviderSchema(ctx context.Context) (resp providers.Ge
resp.ServerCapabilities.GetProviderSchemaOptional = protoResp.ServerCapabilities.GetProviderSchemaOptional
}
// Set the global provider cache so that future calls to this provider can use the cached value.
// Crucially, this doesn't look at GetProviderSchemaOptional, because the layers above could use this cache
// *without* creating an instance of this provider. And if there is no instance,
// then we don't need to set up anything (cause there is nothing to set up), so we need no call
// to the providers GetSchema rpc.
if !p.Addr.IsZero() {
providers.SchemaCache.Set(p.Addr, resp)
}
// Always store this here in the client for providers that are not able to use GetProviderSchemaOptional.
// Crucially, this indicates that we've made at least one call to GetProviderSchema to this instance of the provider,
// which means in the future we'll be able to return using this cache
// (because the possible setup contained in the GetProviderSchema call has happened).
// If GetProviderSchemaOptional is true then this cache won't actually ever be used, because the calls to this method
// will be satisfied by the global provider cache.
p.schema = resp
p.Schema.Value = &resp
return resp
}

View File

@@ -128,6 +128,7 @@ func providerProtoSchema() *proto.GetProviderSchema_Response {
func TestGRPCProvider_GetSchema(t *testing.T) {
p := &GRPCProvider{
client: mockProviderClient(t),
Schema: &providers.CachedSchema{},
}
resp := p.GetProviderSchema(t.Context())
@@ -148,6 +149,7 @@ func TestGRPCProvider_GetSchema_GRPCError(t *testing.T) {
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
resp := p.GetProviderSchema(t.Context())
@@ -179,6 +181,7 @@ func TestGRPCProvider_GetSchema_ResponseErrorDiagnostic(t *testing.T) {
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
resp := p.GetProviderSchema(t.Context())
@@ -190,7 +193,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
ctrl := gomock.NewController(t)
client := mockproto.NewMockProviderClient(ctrl)
// The SchemaCache is global and is saved between test runs
providers.SchemaCache = providers.NewMockSchemaCache()
schemaCache := &providers.CachedSchema{}
providerAddr := addrs.Provider{
Namespace: "namespace",
@@ -212,6 +215,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
// Re-initialize the provider before each run to avoid usage of the local cache
p := &GRPCProvider{
client: client,
Schema: schemaCache,
Addr: providerAddr,
}
resp := p.GetProviderSchema(t.Context())
@@ -223,6 +227,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
p = &GRPCProvider{
client: client,
Schema: schemaCache,
Addr: providerAddr,
}
resp = p.GetProviderSchema(t.Context())
@@ -236,8 +241,6 @@ func TestGRPCProvider_GetSchema_GlobalCacheEnabled(t *testing.T) {
func TestGRPCProvider_GetSchema_GlobalCacheDisabled(t *testing.T) {
ctrl := gomock.NewController(t)
client := mockproto.NewMockProviderClient(ctrl)
// The SchemaCache is global and is saved between test runs
providers.SchemaCache = providers.NewMockSchemaCache()
providerAddr := addrs.Provider{
Namespace: "namespace",
@@ -259,6 +262,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheDisabled(t *testing.T) {
// Re-initialize the provider before each run to avoid usage of the local cache
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
Addr: providerAddr,
}
resp := p.GetProviderSchema(t.Context())
@@ -270,6 +274,7 @@ func TestGRPCProvider_GetSchema_GlobalCacheDisabled(t *testing.T) {
p = &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
Addr: providerAddr,
}
resp = p.GetProviderSchema(t.Context())
@@ -284,6 +289,7 @@ func TestGRPCProvider_PrepareProviderConfig(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ValidateProviderConfig(
@@ -300,6 +306,7 @@ func TestGRPCProvider_ValidateResourceConfig(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ValidateResourceConfig(
@@ -319,6 +326,7 @@ func TestGRPCProvider_ValidateDataResourceConfig(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ValidateDataResourceConfig(
@@ -338,6 +346,7 @@ func TestGRPCProvider_UpgradeResourceState(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().UpgradeResourceState(
@@ -369,6 +378,7 @@ func TestGRPCProvider_UpgradeResourceStateJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().UpgradeResourceState(
@@ -400,6 +410,7 @@ func TestGRPCProvider_MoveResourceState(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().MoveResourceState(
@@ -435,6 +446,7 @@ func TestGRPCProvider_Configure(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ConfigureProvider(
@@ -455,6 +467,7 @@ func TestGRPCProvider_Stop(t *testing.T) {
client := mockproto.NewMockProviderClient(ctrl)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().StopProvider(
@@ -472,6 +485,7 @@ func TestGRPCProvider_ReadResource(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadResource(
@@ -505,6 +519,7 @@ func TestGRPCProvider_ReadResourceJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadResource(
@@ -538,6 +553,7 @@ func TestGRPCProvider_ReadEmptyJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadResource(
@@ -570,6 +586,7 @@ func TestGRPCProvider_PlanResourceChange(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -633,6 +650,7 @@ func TestGRPCProvider_PlanResourceChange_deferred(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().PlanResourceChange(
@@ -679,6 +697,7 @@ func TestGRPCProvider_PlanResourceChangeJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -742,6 +761,7 @@ func TestGRPCProvider_ApplyResourceChange(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -788,6 +808,7 @@ func TestGRPCProvider_ApplyResourceChangeJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -835,6 +856,7 @@ func TestGRPCProvider_ImportResourceState(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -878,6 +900,7 @@ func TestGRPCProvider_ImportResourceStateJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
expectedPrivate := []byte(`{"meta": "data"}`)
@@ -922,6 +945,7 @@ func TestGRPCProvider_ReadDataSource(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadDataSource(
@@ -955,6 +979,7 @@ func TestGRPCProvider_ReadDataSourceJSON(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().ReadDataSource(
@@ -988,6 +1013,7 @@ func TestGRPCProvider_CallFunction(t *testing.T) {
client := mockProviderClient(t)
p := &GRPCProvider{
client: client,
Schema: &providers.CachedSchema{},
}
client.EXPECT().CallFunction(

View File

@@ -1,9 +1 @@
package providers
import "github.com/opentofu/opentofu/internal/addrs"
func NewMockSchemaCache() *schemaCache {
return &schemaCache{
m: make(map[addrs.Provider]ProviderSchema),
}
}

View File

@@ -11,39 +11,8 @@ import (
"github.com/opentofu/opentofu/internal/addrs"
)
// SchemaCache is a global cache of Schemas.
// This will be accessed by both core and the provider clients to ensure that
// large schemas are stored in a single location.
var SchemaCache = &schemaCache{
m: make(map[addrs.Provider]ProviderSchema),
}
// Global cache for provider schemas
// Cache the entire response to ensure we capture any new fields, like
// ServerCapabilities. This also serves to capture errors so that multiple
// concurrent calls resulting in an error can be handled in the same manner.
type schemaCache struct {
mu sync.Mutex
m map[addrs.Provider]ProviderSchema
}
func (c *schemaCache) Set(p addrs.Provider, s ProviderSchema) {
c.mu.Lock()
defer c.mu.Unlock()
c.m[p] = s
}
func (c *schemaCache) Get(p addrs.Provider) (ProviderSchema, bool) {
c.mu.Lock()
defer c.mu.Unlock()
s, ok := c.m[p]
return s, ok
}
func (c *schemaCache) Remove(p addrs.Provider) {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.m, p)
type CachedSchema struct {
sync.Mutex
Filter addrs.ProviderResourceRequirments
Value *ProviderSchema
}

View File

@@ -289,7 +289,7 @@ func testSession(t *testing.T, test testSessionTest) {
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("test"): providers.FactoryFixed(p),
},
})
}, nil)
if diags.HasErrors() {
t.Fatalf("failed to create context: %s", diags.Err())
}

View File

@@ -368,6 +368,20 @@ func (s *State) ProviderRequirements() getproviders.Requirements {
return ret
}
func (s *State) ProviderSchemaRequirements() addrs.ProviderSchemaRequirements {
if s == nil {
return nil
}
m := make(addrs.ProviderSchemaRequirements)
for _, ms := range s.Modules {
for _, rc := range ms.Resources {
m.AddResource(rc.ProviderConfig.Provider, addrs.ManagedResourceMode, rc.Addr.Resource.Type)
}
}
return m
}
// PruneResourceHusks is a specialized method that will remove any Resource
// objects that do not contain any instances, even if they have an EachMode.
//

View File

@@ -44,6 +44,7 @@ type ContextOpts struct {
Hooks []Hook
Parallelism int
Providers map[addrs.Provider]providers.Factory
ProvidersFn func(addrs.ProviderSchemaRequirements) (map[addrs.Provider]providers.Factory, error)
Provisioners map[string]provisioners.Factory
Encryption encryption.Encryption
@@ -105,7 +106,7 @@ type Context struct {
//
// If the returned diagnostics contains errors then the resulting context is
// invalid and must not be used.
func NewContext(opts *ContextOpts) (*Context, tfdiags.Diagnostics) {
func NewContext(opts *ContextOpts, reqs addrs.ProviderSchemaRequirements) (*Context, tfdiags.Diagnostics) {
var diags tfdiags.Diagnostics
log.Printf("[TRACE] tofu.NewContext: starting")
@@ -135,6 +136,12 @@ func NewContext(opts *ContextOpts) (*Context, tfdiags.Diagnostics) {
par = 10
}
if opts.ProvidersFn != nil && opts.Providers == nil {
var err error
opts.Providers, err = opts.ProvidersFn(reqs)
diags = diags.Append(err)
}
plugins := newContextPlugins(opts.Providers, opts.Provisioners)
log.Printf("[TRACE] tofu.NewContext: complete")

View File

@@ -4292,7 +4292,7 @@ func TestContext2Apply_excludedWithTaintedInState(t *testing.T) {
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}

View File

@@ -2571,7 +2571,7 @@ func TestContext2Apply_provisionerInterpCount(t *testing.T) {
}
ctxOpts.Providers = Providers
ctxOpts.Provisioners = provisioners
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("failed to create context for plan: %s", diags.Err())
}
@@ -5977,7 +5977,7 @@ func TestContext2Apply_destroyModuleWithAttrsReferencingResource(t *testing.T) {
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -6048,7 +6048,7 @@ func TestContext2Apply_destroyWithModuleVariableAndCount(t *testing.T) {
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -6192,7 +6192,7 @@ func TestContext2Apply_destroyWithModuleVariableAndCountNested(t *testing.T) {
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -8194,7 +8194,7 @@ func TestContext2Apply_issue7824(t *testing.T) {
addrs.NewDefaultProvider("template"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -8268,7 +8268,7 @@ func TestContext2Apply_issue5254(t *testing.T) {
addrs.NewDefaultProvider("template"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -8346,7 +8346,7 @@ func TestContext2Apply_targetedWithTaintedInState(t *testing.T) {
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -8613,7 +8613,7 @@ func TestContext2Apply_destroyNestedModuleWithAttrsReferencingResource(t *testin
addrs.NewDefaultProvider("null"): testProviderFuncFixed(p),
}
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -9199,7 +9199,7 @@ func TestContext2Apply_plannedInterpolatedCount(t *testing.T) {
}
ctxOpts.Providers = Providers
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -9260,7 +9260,7 @@ func TestContext2Apply_plannedDestroyInterpolatedCount(t *testing.T) {
}
ctxOpts.Providers = providers
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("err: %s", diags.Err())
}
@@ -9788,7 +9788,7 @@ func TestContext2Apply_destroyDataCycle(t *testing.T) {
t.Fatal(err)
}
ctxOpts.Providers = Providers
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("failed to create context for plan: %s", diags.Err())
}
@@ -10145,7 +10145,7 @@ func TestContext2Apply_cbdCycle(t *testing.T) {
t.Fatal(err)
}
ctxOpts.Providers = Providers
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("failed to create context for plan: %s", diags.Err())
}
@@ -11665,7 +11665,7 @@ func TestContext2Apply_destroyProviderReference(t *testing.T) {
t.Fatal(err)
}
ctxOpts.Providers = providers
ctx, diags = NewContext(ctxOpts)
ctx, diags = NewContext(ctxOpts, nil)
if diags.HasErrors() {
t.Fatalf("failed to create context for plan: %s", diags.Err())

View File

@@ -471,11 +471,6 @@ variable "obfmod" {
// Defaulted stub provider with non-custom function
func TestContext2Functions_providerFunctionsStub(t *testing.T) {
p := testProvider("aws")
addr := addrs.ImpliedProviderForUnqualifiedType("aws")
// Explicitly non-parallel
t.Setenv("foo", "bar")
defer providers.SchemaCache.Remove(addr)
p.GetProviderSchemaResponse = &providers.GetProviderSchemaResponse{
Functions: map[string]providers.FunctionSpec{
@@ -492,9 +487,6 @@ func TestContext2Functions_providerFunctionsStub(t *testing.T) {
Result: cty.True,
}
// SchemaCache is initialzed earlier on in the command package
providers.SchemaCache.Set(addr, *p.GetProviderSchemaResponse)
m := testModuleInline(t, map[string]string{
"main.tf": `
module "mod" {
@@ -571,11 +563,6 @@ variable "obfmod" {
// Defaulted stub provider with custom function (no allowed)
func TestContext2Functions_providerFunctionsStubCustom(t *testing.T) {
p := testProvider("aws")
addr := addrs.ImpliedProviderForUnqualifiedType("aws")
// Explicitly non-parallel
t.Setenv("foo", "bar")
defer providers.SchemaCache.Remove(addr)
p.GetProviderSchemaResponse = &providers.GetProviderSchemaResponse{
Functions: map[string]providers.FunctionSpec{
@@ -592,9 +579,6 @@ func TestContext2Functions_providerFunctionsStubCustom(t *testing.T) {
Result: cty.True,
}
// SchemaCache is initialzed earlier on in the command package
providers.SchemaCache.Set(addr, *p.GetProviderSchemaResponse)
m := testModuleInline(t, map[string]string{
"main.tf": `
module "mod" {
@@ -655,11 +639,6 @@ variable "obfmod" {
// Defaulted stub provider
func TestContext2Functions_providerFunctionsForEachCount(t *testing.T) {
p := testProvider("aws")
addr := addrs.ImpliedProviderForUnqualifiedType("aws")
// Explicitly non-parallel
t.Setenv("foo", "bar")
defer providers.SchemaCache.Remove(addr)
p.GetProviderSchemaResponse = &providers.GetProviderSchemaResponse{
Functions: map[string]providers.FunctionSpec{
@@ -676,9 +655,6 @@ func TestContext2Functions_providerFunctionsForEachCount(t *testing.T) {
Result: cty.True,
}
// SchemaCache is initialzed earlier on in the command package
providers.SchemaCache.Set(addr, *p.GetProviderSchemaResponse)
m := testModuleInline(t, map[string]string{
"main.tf": `
provider "aws" {

View File

@@ -9,6 +9,7 @@ import (
"context"
"fmt"
"log"
"sync"
"github.com/opentofu/opentofu/internal/addrs"
"github.com/opentofu/opentofu/internal/configs/configschema"
@@ -16,6 +17,11 @@ import (
"github.com/opentofu/opentofu/internal/provisioners"
)
type lockableProviderSchema struct {
sync.Mutex
schema *providers.ProviderSchema
}
// contextPlugins represents a library of available plugins (providers and
// provisioners) which we assume will all be used with the same
// tofu.Context, and thus it'll be safe to cache certain information
@@ -23,12 +29,16 @@ import (
type contextPlugins struct {
providerFactories map[addrs.Provider]providers.Factory
provisionerFactories map[string]provisioners.Factory
providerSchemasLock sync.Mutex
providerSchemas map[addrs.Provider]*lockableProviderSchema
}
func newContextPlugins(providerFactories map[addrs.Provider]providers.Factory, provisionerFactories map[string]provisioners.Factory) *contextPlugins {
return &contextPlugins{
providerFactories: providerFactories,
provisionerFactories: provisionerFactories,
providerSchemas: map[addrs.Provider]*lockableProviderSchema{},
}
}
@@ -68,26 +78,32 @@ func (cp *contextPlugins) NewProvisionerInstance(typ string) (provisioners.Inter
// to repeatedly call this method with the same address if various different
// parts of OpenTofu all need the same schema information.
func (cp *contextPlugins) ProviderSchema(ctx context.Context, addr addrs.Provider) (providers.ProviderSchema, error) {
// Check the global schema cache first.
// This cache is only written by the provider client, and transparently
// used by GetProviderSchema, but we check it here because at this point we
// may be able to avoid spinning up the provider instance at all.
//
// It's worth noting that ServerCapabilities.GetProviderSchemaOptional is ignored here.
// That is because we're checking *prior* to the provider's instantiation.
// GetProviderSchemaOptional only says that *if we instantiate a provider*,
// then we need to run the get schema call at least once.
// BUG This SHORT CIRCUITS the logic below and is not the only code which inserts provider schemas into the cache!!
schemas, ok := providers.SchemaCache.Get(addr)
if ok {
log.Printf("[TRACE] tofu.contextPlugins: Serving provider %q schema from global schema cache", addr)
return schemas, nil
// TODO this cache is probably not the same between validate, plan, apply...
// Hold the coarse lock
cp.providerSchemasLock.Lock()
// Locate the fine lock
lockSchema, ok := cp.providerSchemas[addr]
if !ok {
lockSchema = &lockableProviderSchema{}
cp.providerSchemas[addr] = lockSchema
}
// Release the coarse lock
cp.providerSchemasLock.Unlock()
// Hold the fine lock
lockSchema.Lock()
defer lockSchema.Unlock()
if lockSchema.schema != nil {
return *lockSchema.schema, nil
}
log.Printf("[TRACE] tofu.contextPlugins: Initializing provider %q to read its schema", addr)
provider, err := cp.NewProviderInstance(addr)
if err != nil {
return schemas, fmt.Errorf("failed to instantiate provider %q to obtain schema: %w", addr, err)
return providers.ProviderSchema{}, fmt.Errorf("failed to instantiate provider %q to obtain schema: %w", addr, err)
}
defer provider.Close(ctx)
@@ -122,6 +138,8 @@ func (cp *contextPlugins) ProviderSchema(ctx context.Context, addr addrs.Provide
}
}
lockSchema.schema = &resp
return resp, nil
}

View File

@@ -42,6 +42,7 @@ func simpleMockPluginLibrary() *contextPlugins {
return provisioner, nil
},
},
providerSchemas: map[addrs.Provider]*lockableProviderSchema{},
}
return ret
}

View File

@@ -1319,7 +1319,7 @@ func TestContext2Refresh_unknownProvider(t *testing.T) {
c, diags := NewContext(&ContextOpts{
Providers: map[addrs.Provider]providers.Factory{},
})
}, nil)
assertNoDiagnostics(t, diags)
_, diags = c.Refresh(context.Background(), m, states.NewState(), &PlanOpts{Mode: plans.NormalMode})

View File

@@ -105,7 +105,7 @@ func TestNewContextRequiredVersion(t *testing.T) {
Required: constraint,
})
}
c, diags := NewContext(&ContextOpts{})
c, diags := NewContext(&ContextOpts{}, nil)
if diags.HasErrors() {
t.Fatalf("unexpected NewContext errors: %s", diags.Err())
}
@@ -164,7 +164,7 @@ terraform {}
Required: constraint,
})
}
c, diags := NewContext(&ContextOpts{})
c, diags := NewContext(&ContextOpts{}, nil)
if diags.HasErrors() {
t.Fatalf("unexpected NewContext errors: %s", diags.Err())
}
@@ -178,7 +178,7 @@ terraform {}
}
func TestContext_missingPlugins(t *testing.T) {
ctx, diags := NewContext(&ContextOpts{})
ctx, diags := NewContext(&ContextOpts{}, nil)
assertNoDiagnostics(t, diags)
configSrc := `
@@ -325,7 +325,7 @@ func TestContext_contextValuesPropagation(t *testing.T) {
func testContext2(t testing.TB, opts *ContextOpts) *Context {
t.Helper()
ctx, diags := NewContext(opts)
ctx, diags := NewContext(opts, nil)
if diags.HasErrors() {
t.Fatalf("failed to create test context\n\n%s\n", diags.Err())
}

View File

@@ -96,7 +96,7 @@ func TestContext2Validate_badVar(t *testing.T) {
func TestContext2Validate_varNoDefaultExplicitType(t *testing.T) {
m := testModule(t, "validate-var-no-default-explicit-type")
c, diags := NewContext(&ContextOpts{})
c, diags := NewContext(&ContextOpts{}, nil)
if diags.HasErrors() {
t.Fatalf("unexpected NewContext errors: %s", diags.Err())
}
@@ -317,7 +317,7 @@ func TestContext2Validate_countVariableNoDefault(t *testing.T) {
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
}, nil)
assertNoDiagnostics(t, diags)
_, diags = c.Plan(context.Background(), m, nil, &PlanOpts{})
@@ -866,7 +866,7 @@ func TestContext2Validate_requiredVar(t *testing.T) {
Providers: map[addrs.Provider]providers.Factory{
addrs.NewDefaultProvider("aws"): testProviderFuncFixed(p),
},
})
}, nil)
assertNoDiagnostics(t, diags)
// NOTE: This test has grown idiosyncratic because originally Terraform