1
0
mirror of synced 2025-12-19 18:14:56 -05:00
Files
airbyte/airbyte-integrations/connectors/source-google-ads/source_google_ads/components.py

1088 lines
40 KiB
Python

#
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
#
import io
import json
import logging
import re
import threading
from dataclasses import dataclass, field
from itertools import groupby
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, MutableMapping, Optional, Tuple, Union
import anyascii
import requests
from airbyte_cdk import AirbyteTracedException, FailureType, InterpolatedString
from airbyte_cdk.sources.declarative.decoders.composite_raw_decoder import JsonParser
from airbyte_cdk.sources.declarative.decoders.decoder import Decoder
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
from airbyte_cdk.sources.declarative.migrations.state_migration import StateMigration
from airbyte_cdk.sources.declarative.requesters.http_requester import HttpRequester
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.sources.declarative.schema import SchemaLoader
from airbyte_cdk.sources.declarative.schema.inline_schema_loader import InlineSchemaLoader
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
from airbyte_cdk.sources.streams.concurrent.default_stream import DefaultStream
from airbyte_cdk.sources.types import Config, Record, StreamSlice, StreamState
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
from .google_ads import GoogleAds
logger = logging.getLogger("airbyte")
REPORT_MAPPING = {
"account_performance_report": "customer",
"ad_group_ad_legacy": "ad_group_ad",
"ad_group_bidding_strategy": "ad_group",
"ad_listing_group_criterion": "ad_group_criterion",
"campaign_real_time_bidding_settings": "campaign",
"campaign_bidding_strategy": "campaign",
"service_accounts": "customer",
}
DATE_TYPES = ("segments.date", "segments.month", "segments.quarter", "segments.week")
GOOGLE_ADS_DATATYPE_MAPPING = {
"INT64": "integer",
"INT32": "integer",
"DOUBLE": "number",
"STRING": "string",
"BOOLEAN": "boolean",
"DATE": "string",
"MESSAGE": "string",
"ENUM": "string",
"RESOURCE_NAME": "string",
}
class AccessibleAccountsExtractor(RecordExtractor):
"""
Custom extractor for the accessible accounts endpoint.
The response is a list of strings instead of a list of objects.
"""
def extract_records(self, response: requests.Response) -> Iterable[Mapping[str, Any]]:
response_data = response.json().get("resourceNames", [])
if response_data:
for resource_name in response_data:
yield {"accessible_customer_id": resource_name.split("/")[-1]}
@dataclass
class CustomerClientFilter(RecordFilter):
"""
Filter duplicated records based on the "clientCustomer" field.
This can happen when customer client account may be accessible from multiple customer client accounts.
"""
UNIQUE_KEY = "clientCustomer"
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__post_init__(parameters)
self._seen_keys = set()
def filter_records(
self, records: List[Mapping[str, Any]], stream_state: StreamState, stream_slice: Optional[StreamSlice] = None, **kwargs
) -> Iterable[Mapping[str, Any]]:
for record in records:
# Filter out records based on customer_ids if provided in the config
if self.config.get("customer_ids") and (record["id"] not in self.config["customer_ids"]):
continue
# Filter out records based on customer status if provided in the config
if record["status"] not in (
self.config.get("customer_status_filter") or ["UNKNOWN", "ENABLED", "CANCELED", "SUSPENDED", "CLOSED"]
):
continue
key = record[self.UNIQUE_KEY]
if key not in self._seen_keys:
self._seen_keys.add(key)
yield record
@dataclass
class FlattenNestedDictsTransformation(RecordTransformation):
"""
Recursively flattens all nested dictionary fields in a record into top-level keys
using dot-separated paths, skipping dictionaries inside lists, and removes
the original nested dictionary fields.
Example:
{"a": {"b": 1, "c": {"d": 2}, "e": [ {"f": 3} ]}, "g": [{"h": 4}]}
becomes:
{"a.b": 1, "a.c.d": 2, "a.e": [ {"f": 3} ], "g": [{"h": 4}]}
"""
delimiter: str = "."
def transform(
self,
record: Dict[str, Any],
config: Optional[Config] = None,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
) -> None:
def _flatten(prefix: str, obj: Dict[str, Any]):
for key, value in list(obj.items()):
new_key = f"{prefix}{self.delimiter}{key}" if prefix else key
if isinstance(value, dict):
# recurse into nested dict
_flatten(new_key, value)
else:
# set flattened value at top level
record[new_key] = value
# find all top-level dict fields (skip dicts in lists)
for top_key in [k for k, v in list(record.items()) if isinstance(v, dict)]:
nested = record.pop(top_key)
_flatten(top_key, nested)
class DoubleQuotedDictTypeTransformer(TypeTransformer):
"""
Convert arrays of dicts into JSON-formatted string arrays using double quotes only
when the schema defines a (nullable) array of strings. Output strings use no spaces
after commas and a single space after colons for consistent formatting.
Example:
[{'key': 'campaign', 'value': 'gg_nam_dg_search_brand'}]
→ ['{"key": "campaign","value": "gg_nam_dg_search_brand"}']
"""
def __init__(self, *args, **kwargs):
# apply this transformer during schema normalization phase(s)
config = TransformConfig.DefaultSchemaNormalization | TransformConfig.CustomSchemaNormalization
super().__init__(config)
# register our custom transform
self.registerCustomTransform(self.get_transform_function())
@staticmethod
def get_transform_function():
def transform_function(original_value: Any, field_schema: Dict[str, Any]) -> Any:
# Skip null values (schema may include 'null')
if original_value is None:
return original_value
# Only apply if schema type includes 'array'
schema_type = field_schema.get("type")
if isinstance(schema_type, list):
if "array" not in schema_type:
return original_value
elif schema_type != "array":
return original_value
# Only apply if items type includes 'string'
items = field_schema.get("items", {}) or {}
items_type = items.get("type")
if isinstance(items_type, list):
if "string" not in items_type:
return original_value
elif items_type != "string":
return original_value
# Transform only lists where every element is a dict
if isinstance(original_value, list) and all(isinstance(el, dict) for el in original_value):
return [json.dumps(el, separators=(",", ": ")) for el in original_value]
return original_value
return transform_function
class GoogleAdsPerPartitionStateMigration(StateMigration):
"""
Migrates legacy per-partition Google Ads state to low code format
that includes parent_slice for each customer_id partition.
Example input state:
{
"1234567890": {"segments.date": "2120-10-10"},
"0987654321": {"segments.date": "2120-10-11"}
}
Example output state:
{
"states": [
{
"partition": { "customer_id": "1234567890", "parent_slice": {"customer_id": "1234567890_parent", "parent_slice": {}}},
"cursor": {"segments.date": "2120-10-10"}
},
{
"partition": { "customer_id": "0987654321", "parent_slice": {"customer_id": "0987654321_parent", "parent_slice": {}}},
"cursor": {"segments.date": "2120-10-11"}
}
],
"state": {"segments.date": "2120-10-10"}
}
"""
config: Config
customer_client_stream: DefaultStream
cursor_field: str = "segments.date"
def __init__(self, config: Config, customer_client_stream: DefaultStream, cursor_field: str = "segments.date"):
self._config = config
self._parent_stream = customer_client_stream
self._cursor_field = cursor_field
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return stream_state and "state" not in stream_state
def _read_parent_stream(self) -> Iterable[Record]:
for partition in self._parent_stream.generate_partitions():
for record in partition.read():
yield record
def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
if not self.should_migrate(stream_state):
return stream_state
stream_state_values = [
stream_state for stream_state in stream_state.values() if isinstance(stream_state, dict) and self._cursor_field in stream_state
]
if not stream_state_values:
logger.warning("No valid cursor field found in the stream state. Returning empty state.")
return {}
min_state = min(stream_state_values, key=lambda state: state[self._cursor_field])
customer_ids_in_state = list(stream_state.keys())
partitions_state = []
for record in self._read_parent_stream():
customer_id = record.data.get("id")
if customer_id in customer_ids_in_state:
legacy_partition_state = stream_state[customer_id]
partitions_state.append(
{
"partition": {
"customer_id": record["clientCustomer"],
"parent_slice": {"customer_id": record.associated_slice.get("customer_id"), "parent_slice": {}},
},
"cursor": legacy_partition_state,
}
)
if not partitions_state:
logger.warning("No matching customer clients found during state migration.")
return {}
state = {"states": partitions_state, "state": min_state}
return state
@dataclass
class GoogleAdsHttpRequester(HttpRequester):
"""
Custom HTTP requester for Google Ads API that uses the accessible accounts endpoint
to retrieve the list of accessible customer IDs.
"""
schema_loader: InlineSchemaLoader = None
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__post_init__(parameters)
self.stream_response = True
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> MutableMapping[str, Any]:
schema = self.schema_loader.get_json_schema()[self.name]["properties"]
manager = stream_slice.extra_fields.get("manager", False)
resource_name = REPORT_MAPPING.get(self.name, self.name)
fields = [
field
for field in schema.keys()
# exclude metrics.* if this is a manager account
if not (manager and field.startswith("metrics."))
]
if "start_time" in stream_slice and "end_time" in stream_slice:
# For incremental streams
query = f"SELECT {', '.join(fields)} FROM {resource_name} WHERE segments.date BETWEEN '{stream_slice['start_time']}' AND '{stream_slice['end_time']}' ORDER BY segments.date ASC"
else:
# For full refresh streams
query = f"SELECT {', '.join(fields)} FROM {resource_name}"
return {"query": query}
def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return {
"developer-token": self.config["credentials"]["developer_token"],
"login-customer-id": stream_slice["parent_slice"]["customer_id"],
}
@dataclass
class ClickViewHttpRequester(GoogleAdsHttpRequester):
"""
Custom HTTP requester for ClickView stream.
"""
schema_loader: InlineSchemaLoader = None
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> MutableMapping[str, Any]:
schema = self.schema_loader.get_json_schema()["properties"]
fields = [field for field in schema.keys()]
return {"query": f"SELECT {', '.join(fields)} FROM click_view WHERE segments.date = '{stream_slice['start_time']}'"}
@dataclass
class KeysToSnakeCaseGoogleAdsTransformation(RecordTransformation):
"""
Transforms keys in a Google Ads record to snake_case.
The difference with KeysToSnakeCaseTransformation is that this transformation doesn't add underscore before digits.
"""
token_pattern: re.Pattern[str] = re.compile(
r"""
\d*[A-Z]+[a-z]*\d* # uppercase word (with optional leading/trailing digits)
| \d*[a-z]+\d* # lowercase word (with optional leading/trailing digits)
| (?P<NoToken>[^a-zA-Z\d]+) # any non-alphanumeric separators
""",
re.VERBOSE,
)
def transform(
self,
record: Dict[str, Any],
config: Optional[Config] = None,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
) -> None:
transformed_record = self._transform_record(record)
record.clear()
record.update(transformed_record)
def _transform_record(self, record: Dict[str, Any]) -> Dict[str, Any]:
transformed_record = {}
for key, value in record.items():
transformed_key = self.process_key(key)
transformed_value = value
if isinstance(value, dict):
transformed_value = self._transform_record(value)
transformed_record[transformed_key] = transformed_value
return transformed_record
def process_key(self, key: str) -> str:
key = self.normalize_key(key)
tokens = self.tokenize_key(key)
tokens = self.filter_tokens(tokens)
return self.tokens_to_snake_case(tokens)
def normalize_key(self, key: str) -> str:
return str(anyascii.anyascii(key))
def tokenize_key(self, key: str) -> List[str]:
tokens = []
for match in self.token_pattern.finditer(key):
token = match.group(0) if match.group("NoToken") is None else ""
tokens.append(token)
return tokens
def filter_tokens(self, tokens: List[str]) -> List[str]:
if len(tokens) >= 3:
tokens = tokens[:1] + [t for t in tokens[1:-1] if t] + tokens[-1:]
return tokens
def tokens_to_snake_case(self, tokens: List[str]) -> str:
return "_".join(token.lower() for token in tokens)
@dataclass
class ChangeStatusRequester(GoogleAdsHttpRequester):
CURSOR_FIELD: str = "change_status.last_change_date_time"
LIMIT: int = 10000
def __post_init__(self, parameters: Mapping[str, Any]) -> None:
super().__post_init__(parameters)
self._resource_type = (parameters or {}).get("resource_type")
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> MutableMapping[str, Any]:
schema = self.schema_loader.get_json_schema()[self.name]["properties"]
fields = list(schema.keys())
query = (
f"SELECT {', '.join(fields)} FROM {self.name} "
f"WHERE {self.CURSOR_FIELD} BETWEEN '{stream_slice['start_time']}' AND '{stream_slice['end_time']}' "
f"AND change_status.resource_type = {self._resource_type} "
f"ORDER BY {self.CURSOR_FIELD} ASC "
f"LIMIT {self.LIMIT}"
)
return {"query": query}
@dataclass
class CriterionRetriever(SimpleRetriever):
"""
Retrieves Criterion records based on ChangeStatus updates.
For each parent_slice:
1) Emits a “deleted” record for any REMOVED status.
2) Batches the remaining IDs into a single fetch.
3) Attaches the original ChangeStatus timestamp to each returned record.
"""
cursor_field: str = "change_status.last_change_date_time"
def _read_pages(
self,
records_generator_fn: Callable[[Optional[Mapping]], Iterable[Record]],
stream_slice: StreamSlice,
) -> Iterable[Record]:
"""
Emit deletions and then fetch updates grouped by parent_slice:
- GROUP records by parent (zip ids, parents, statuses, times).
- For each group, yield REMOVED records immediately, then perform one fetch for non-removed.
- Attach the original ChangeStatus timestamp to each updated record.
"""
ids = stream_slice.partition[self.primary_key[0]]
parents = stream_slice.partition["parent_slice"]
statuses = stream_slice.extra_fields["change_status.resource_status"]
times = stream_slice.extra_fields["change_status.last_change_date_time"]
# Iterate grouped by parent
for parent, group in groupby(
zip(ids, parents, statuses, times),
key=lambda x: x[1],
):
group_list = list(group)
# Emit deletions first
updated_ids = []
updated_times = []
for _id, _, status, ts in group_list:
if status == "REMOVED":
yield Record(
data={
self.primary_key[0]: _id,
"deleted_at": ts,
},
associated_slice=stream_slice,
stream_name=self.name,
)
else:
updated_ids.append(_id)
updated_times.append(ts)
# Single fetch for non-removed
if not updated_ids:
continue
# build time map for updated records
time_map = dict(zip(updated_ids, updated_times))
new_slice = StreamSlice(
partition={
self.primary_key[0]: updated_ids,
"parent_slice": parent,
},
cursor_slice=stream_slice.cursor_slice,
extra_fields={"change_status.last_change_date_time": updated_times},
)
response = self._fetch_next_page(new_slice)
for rec in records_generator_fn(response):
# attach timestamp from ChangeStatus
rec.data[self.cursor_field] = time_map.get(rec.data.get(self.primary_key[0]))
yield rec
@dataclass
class CriterionIncrementalRequester(GoogleAdsHttpRequester):
CURSOR_FIELD: str = "change_status.last_change_date_time"
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> MutableMapping[str, Any]:
props = self.schema_loader.get_json_schema()[self.name]["properties"]
select_fields = [f for f in props.keys() if f not in (self.CURSOR_FIELD, "deleted_at")]
ids = stream_slice.partition.get(self._parameters["primary_key"][0], [])
in_list = ", ".join(f"'{i}'" for i in ids)
query = (
f"SELECT {', '.join(select_fields)}\n"
f" FROM {self._parameters['resource_name']}\n"
f" WHERE {self._parameters['primary_key'][0]} IN ({in_list})\n"
)
return {"query": query}
def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, any]] = None,
) -> Mapping[str, any]:
return {
"developer-token": self.config["credentials"]["developer_token"],
"login-customer-id": stream_slice["parent_slice"]["parent_slice"]["customer_id"],
}
@dataclass
class CriterionFullRefreshRequester(GoogleAdsHttpRequester):
CURSOR_FIELD: str = "change_status.last_change_date_time"
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> MutableMapping[str, Any]:
props = self.schema_loader.get_json_schema()[self.name]["properties"]
fields = [f for f in props.keys() if f not in (self.CURSOR_FIELD, "deleted_at")]
return {"query": f"SELECT {', '.join(fields)} FROM {self._parameters['resource_name']}"}
class GoogleAdsCriterionParentStateMigration(StateMigration):
"""
Migrates parent state from legacy format to the new format
"""
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return stream_state and not stream_state.get("parent_state")
def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
if not self.should_migrate(stream_state):
return stream_state
return {"parent_state": {"change_status": stream_state}}
class GoogleAdsGlobalStateMigration(StateMigration):
"""
Migrates global state to include use_global_cursor key. Previously legacy GlobalSubstreamCursor was used.
"""
def should_migrate(self, stream_state: Mapping[str, Any]) -> bool:
return stream_state and not stream_state.get("use_global_cursor")
def migrate(self, stream_state: Mapping[str, Any]) -> Mapping[str, Any]:
stream_state["use_global_cursor"] = True
return stream_state
@dataclass(repr=False, eq=False, frozen=True)
class GAQL:
"""
Simple regex parser of Google Ads Query Language
https://developers.google.com/google-ads/api/docs/query/grammar
"""
fields: Tuple[str]
resource_name: str
where: str
order_by: str
limit: Optional[int]
parameters: str
REGEX = re.compile(
r"""\s*
SELECT\s+(?P<FieldNames>\S.*)
\s+
FROM\s+(?P<ResourceNames>[a-z][a-zA-Z_]*(\s*,\s*[a-z][a-zA-Z_]*)*)
\s*
(\s+WHERE\s+(?P<WhereClause>\S.*?))?
(\s+ORDER\s+BY\s+(?P<OrderByClause>\S.*?))?
(\s+LIMIT\s+(?P<LimitClause>[1-9]([0-9])*))?
\s*
(\s+PARAMETERS\s+(?P<ParametersClause>\S.*?))?
$""",
flags=re.I | re.DOTALL | re.VERBOSE,
)
REGEX_FIELD_NAME = re.compile(r"^[a-z][a-z0-9._]*$", re.I)
@classmethod
def parse(cls, query):
m = cls.REGEX.match(query)
if not m:
raise ValueError
fields = [f.strip() for f in m.group("FieldNames").split(",")]
for field in fields:
if not cls.REGEX_FIELD_NAME.match(field):
raise ValueError
resource_names = re.split(r"\s*,\s*", m.group("ResourceNames"))
if len(resource_names) > 1:
raise ValueError
resource_name = resource_names[0]
where = cls._normalize(m.group("WhereClause") or "")
order_by = cls._normalize(m.group("OrderByClause") or "")
limit = m.group("LimitClause")
if limit:
limit = int(limit)
parameters = cls._normalize(m.group("ParametersClause") or "")
return cls(tuple(fields), resource_name, where, order_by, limit, parameters)
def __str__(self):
fields = ", ".join(self.fields)
query = f"SELECT {fields} FROM {self.resource_name}"
if self.where:
query += " WHERE " + self.where
if self.order_by:
query += " ORDER BY " + self.order_by
if self.limit is not None:
query += " LIMIT " + str(self.limit)
if self.parameters:
query += " PARAMETERS " + self.parameters
return query
def __repr__(self):
return self.__str__()
@staticmethod
def _normalize(s):
s = s.strip()
return re.sub(r"\s+", " ", s)
def set_where(self, value: str):
return self.__class__(self.fields, self.resource_name, value, self.order_by, self.limit, self.parameters)
def set_limit(self, value: int):
return self.__class__(self.fields, self.resource_name, self.where, self.order_by, value, self.parameters)
def append_field(self, value):
fields = list(self.fields)
fields.append(value)
return self.__class__(tuple(fields), self.resource_name, self.where, self.order_by, self.limit, self.parameters)
class CustomGAQueryHttpRequester(HttpRequester):
"""
Custom HTTP requester for custom query streams.
"""
parameters: Mapping[str, Any]
def __post_init__(self, parameters: Mapping[str, Any]):
super().__post_init__(parameters=parameters)
self.query = GAQL.parse(parameters.get("query"))
self.stream_response = True
@staticmethod
def is_metrics_in_custom_query(query: GAQL) -> bool:
for field in query.fields:
if field.split(".")[0] == "metrics":
return True
return False
@staticmethod
def is_custom_query_incremental(query: GAQL) -> bool:
time_segment_in_select, time_segment_in_where = ["segments.date" in clause for clause in [query.fields, query.where]]
return time_segment_in_select and not time_segment_in_where
@staticmethod
def _insert_segments_date_expr(query: GAQL, start_date: str, end_date: str) -> GAQL:
if "segments.date" not in query.fields:
query = query.append_field("segments.date")
condition = f"segments.date BETWEEN '{start_date}' AND '{end_date}'"
if query.where:
return query.set_where(query.where + " AND " + condition)
return query.set_where(condition)
def get_request_body_json(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> MutableMapping[str, Any]:
query = self._build_query(stream_slice)
return {"query": query}
def get_request_headers(
self,
*,
stream_state: Optional[StreamState] = None,
stream_slice: Optional[StreamSlice] = None,
next_page_token: Optional[Mapping[str, Any]] = None,
) -> Mapping[str, Any]:
return {
"developer-token": self.config["credentials"]["developer_token"],
"login-customer-id": stream_slice["parent_slice"]["customer_id"],
}
def _build_query(self, stream_slice: StreamSlice) -> str:
is_incremental = self.is_custom_query_incremental(self.query)
if is_incremental:
start_date = stream_slice["start_time"]
end_date = stream_slice["end_time"]
return str(self._insert_segments_date_expr(self.query, start_date, end_date))
else:
return str(self.query)
def _get_resource_name(self) -> str:
"""
Extract the resource name from the `FROM` clause of the query.
e.g. Parses a query "SELECT field1, field2, field3 FROM table" and returns "table".
"""
query_upper = self.query.upper()
from_index = query_upper.find("FROM")
return self.query[from_index + 4 :].strip()
class CustomGAQueryClickViewHttpRequester(CustomGAQueryHttpRequester):
@staticmethod
def _insert_segments_date_expr(query: GAQL, start_date: str, end_date: str) -> GAQL:
if "segments.date" not in query.fields:
query = query.append_field("segments.date")
condition = f"segments.date ='{start_date}'"
if query.where:
return query.set_where(query.where + " AND " + condition)
return query.set_where(condition)
@dataclass()
class CustomGAQuerySchemaLoader(SchemaLoader):
"""
Custom schema loader for custom query streams. Parses the user-provided query to extract the fields and then queries the Google Ads API for each field to retreive field metadata.
"""
config: Config
query: str = ""
cursor_field: Union[str, InterpolatedString] = ""
_google_ads_client: Optional[GoogleAds] = None
_client_lock = threading.Lock()
def __post_init__(self):
self._cursor_field = InterpolatedString.create(self.cursor_field, parameters={}) if self.cursor_field else None
self._validate_query(self.query)
@classmethod
def google_ads_client(cls, config: Config) -> GoogleAds:
"""
Lazily creates a single GoogleAds client shared by all instances of this class.
First call wins (its config is used).
"""
if cls._google_ads_client is None:
with cls._client_lock:
if cls._google_ads_client is None:
cls._google_ads_client = GoogleAds(credentials=cls.get_credentials(config))
return cls._google_ads_client
@staticmethod
def get_credentials(config: Mapping[str, Any]) -> MutableMapping[str, Any]:
credentials = config["credentials"]
# use_proto_plus is set to True, because setting to False returned wrong value types, which breaks the backward compatibility.
# For more info read the related PR's description: https://github.com/airbytehq/airbyte/pull/9996
credentials.update(use_proto_plus=True)
return credentials
def get_json_schema(self) -> Dict[str, Any]:
local_json_schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {},
"additionalProperties": True,
}
fields = self._get_list_of_fields()
fields_metadata = self.google_ads_client(self.config).get_fields_metadata(fields)
for field, field_metadata in fields_metadata.items():
field_value = self._build_field_value(field, field_metadata)
local_json_schema["properties"][field] = field_value
return local_json_schema
def _build_field_value(self, field: str, field_metadata) -> Any:
# Data type return in enum format: "GoogleAdsFieldDataType.<data_type>"
google_data_type = field_metadata.data_type.name
field_value = {"type": [GOOGLE_ADS_DATATYPE_MAPPING.get(google_data_type, "string"), "null"]}
# Google Ads doesn't differentiate between DATE and DATETIME, so we need to manually check for fields with known type
if google_data_type == "DATE" and field in DATE_TYPES:
field_value["format"] = "date"
if google_data_type == "ENUM":
field_value = {"type": "string", "enum": list(field_metadata.enum_values)}
if field_metadata.is_repeated:
field_value = {"type": ["null", "array"], "items": field_value}
return field_value
def _get_list_of_fields(self) -> List[str]:
"""
Extract field names from the `SELECT` clause of the query.
e.g. Parses a query "SELECT field1, field2, field3 FROM table" and returns ["field1", "field2", "field3"].
"""
query_upper = self.query.upper()
select_index = query_upper.find("SELECT")
match = re.search(r"\bFROM\b", query_upper)
if not match:
raise ValueError("Could not find a valid FROM clause in query")
from_index = match.start()
fields_portion = self.query[select_index + 6 : from_index].strip()
fields = [field.strip() for field in fields_portion.split(",")]
cursor_field = self._cursor_field.eval(self.config) if self._cursor_field else None
if cursor_field:
fields.append(cursor_field)
return fields
def _validate_query(self, query: str):
try:
GAQL.parse(query)
except ValueError:
raise AirbyteTracedException(
failure_type=FailureType.config_error,
internal_message=f"The provided query is invalid: {query}. Please refer to the Google Ads API documentation for the correct syntax: https://developers.google.com/google-ads/api/fields/v20/overview and test validate your query using the Google Ads Query Builder: https://developers.google.com/google-ads/api/fields/v20/query_validator",
message=f"The provided query is invalid: {query}. Please refer to the Google Ads API documentation for the correct syntax: https://developers.google.com/google-ads/api/fields/v20/overview and test validate your query using the Google Ads Query Builder: https://developers.google.com/google-ads/api/fields/v20/query_validator",
)
@dataclass
class StringParseState:
inside_string: bool = False
escape_next_character: bool = False
collected_string_chars: List[str] = field(default_factory=list)
last_parsed_key: Optional[str] = None
@dataclass
class TopLevelObjectState:
depth: int = 0
@dataclass
class ResultsArrayState:
inside_results_array: bool = False
array_nesting_depth: int = 0
expecting_results_array_start: bool = False
@dataclass
class RecordParseState:
inside_record: bool = False
record_text_buffer: List[str] = field(default_factory=list)
record_nesting_depth: int = 0
@dataclass
class GoogleAdsStreamingDecoder(Decoder):
"""
JSON streaming decoder optimized for Google Ads API responses.
Uses a fast JSON parse when the full payload fits within max_direct_decode_bytes;
otherwise streams records incrementally from the `results` array.
Ensures truncated or structurally invalid JSON is detected and reported.
"""
chunk_size: int = 5 * 1024 * 1024 # 5 MB
# Fast-path threshold: if whole body < 20 MB, decode with json.loads
max_direct_decode_bytes: int = 20 * 1024 * 1024 # 20 MB
def __post_init__(self):
self.parser = JsonParser()
def is_stream_response(self) -> bool:
return True
def decode(self, response: requests.Response) -> Generator[MutableMapping[str, Any], None, None]:
data, complete = self._buffer_up_to_limit(response)
if complete:
yield from self.parser.parse(io.BytesIO(data))
return
records_batch: List[Dict[str, Any]] = []
for record in self._parse_records_from_stream(data):
records_batch.append(record)
if len(records_batch) >= 100:
yield {"results": records_batch}
records_batch = []
if records_batch:
yield {"results": records_batch}
def _buffer_up_to_limit(self, response: requests.Response) -> Tuple[Union[bytes, Iterable[bytes]], bool]:
buf = bytearray()
response_stream = response.iter_content(chunk_size=self.chunk_size)
while chunk := next(response_stream, None):
buf.extend(chunk)
if len(buf) >= self.max_direct_decode_bytes:
return (self._chain_prefix_and_stream(bytes(buf), response_stream), False)
return (bytes(buf), True)
@staticmethod
def _chain_prefix_and_stream(prefix: bytes, rest_stream: Iterable[bytes]) -> Iterable[bytes]:
yield prefix
yield from rest_stream
def _parse_records_from_stream(self, byte_iter: Iterable[bytes], encoding: str = "utf-8") -> Generator[Dict[str, Any], None, None]:
string_state = StringParseState()
results_state = ResultsArrayState()
record_state = RecordParseState()
top_level_state = TopLevelObjectState()
for chunk in byte_iter:
for char in chunk.decode(encoding, errors="replace"):
self._append_to_current_record_if_any(char, record_state)
if self._update_string_state(char, string_state):
continue
# Track outer braces only outside results array
if not results_state.inside_results_array:
if char == "{":
top_level_state.depth += 1
elif char == "}":
top_level_state.depth = max(0, top_level_state.depth - 1)
if not results_state.inside_results_array:
self._detect_results_array(char, string_state, results_state)
continue
record = self._parse_record_structure(char, results_state, record_state)
if record is not None:
yield record
# EOF validation
if (
string_state.inside_string
or record_state.inside_record
or record_state.record_nesting_depth != 0
or results_state.inside_results_array
or results_state.array_nesting_depth != 0
or top_level_state.depth != 0
):
raise AirbyteTracedException(
message="Response JSON stream ended prematurely and is incomplete.",
internal_message=(
"Detected truncated JSON stream: one or more structural elements were not fully closed before the response ended."
),
failure_type=FailureType.system_error,
)
def _update_string_state(self, char: str, state: StringParseState) -> bool:
"""Return True if char was handled as part of string parsing."""
if state.inside_string:
if state.escape_next_character:
state.escape_next_character = False
return True
if char == "\\":
state.escape_next_character = True
return True
if char == '"':
state.inside_string = False
state.last_parsed_key = "".join(state.collected_string_chars)
state.collected_string_chars.clear()
return True
state.collected_string_chars.append(char)
return True
if char == '"':
state.inside_string = True
state.collected_string_chars.clear()
return True
return False
def _detect_results_array(self, char: str, string_state: StringParseState, results_state: ResultsArrayState) -> None:
if char == ":" and string_state.last_parsed_key == "results":
results_state.expecting_results_array_start = True
elif char == "[" and results_state.expecting_results_array_start:
results_state.inside_results_array = True
results_state.array_nesting_depth = 1
results_state.expecting_results_array_start = False
def _parse_record_structure(
self, char: str, results_state: ResultsArrayState, record_state: RecordParseState
) -> Optional[Dict[str, Any]]:
if char == "{":
if record_state.inside_record:
record_state.record_nesting_depth += 1
else:
self._start_record(record_state)
return None
if char == "}":
if record_state.inside_record:
record_state.record_nesting_depth -= 1
if record_state.record_nesting_depth == 0:
return self._finish_record(record_state)
return None
if char == "[":
if record_state.inside_record:
record_state.record_nesting_depth += 1
else:
results_state.array_nesting_depth += 1
return None
if char == "]":
if record_state.inside_record:
record_state.record_nesting_depth -= 1
else:
results_state.array_nesting_depth -= 1
if results_state.array_nesting_depth == 0:
results_state.inside_results_array = False
return None
@staticmethod
def _append_to_current_record_if_any(char: str, record_state: RecordParseState):
if record_state.inside_record:
record_state.record_text_buffer.append(char)
@staticmethod
def _start_record(record_state: RecordParseState):
record_state.inside_record = True
record_state.record_text_buffer = ["{"]
record_state.record_nesting_depth = 1
@staticmethod
def _finish_record(record_state: RecordParseState) -> Optional[Dict[str, Any]]:
text = "".join(record_state.record_text_buffer).strip()
record_state.inside_record = False
record_state.record_text_buffer.clear()
record_state.record_nesting_depth = 0
return json.loads(text) if text else None