Support Arbitrary Catalog IDs on Athena Data Source (#7059)

Co-authored-by: SeongTae Jeong <seongtaejg@gmail.com>
This commit is contained in:
Daisuke Taniwaki
2024-07-24 15:57:27 +09:00
committed by GitHub
parent 80f7ba1b91
commit c244e75352
2 changed files with 111 additions and 5 deletions

View File

@@ -76,6 +76,10 @@ class Athena(BaseQueryRunner):
"default": "default",
},
"glue": {"type": "boolean", "title": "Use Glue Data Catalog"},
"catalog_ids": {
"type": "string",
"title": "Enter Glue Data Catalog IDs, separated by commas (leave blank for default catalog)",
},
"work_group": {
"type": "string",
"title": "Athena Work Group",
@@ -88,7 +92,7 @@ class Athena(BaseQueryRunner):
},
},
"required": ["region", "s3_staging_dir"],
"extra_options": ["glue", "cost_per_tb"],
"extra_options": ["glue", "catalog_ids", "cost_per_tb"],
"order": [
"region",
"s3_staging_dir",
@@ -172,16 +176,23 @@ class Athena(BaseQueryRunner):
"region_name": self.configuration["region"],
}
def __get_schema_from_glue(self):
def __get_schema_from_glue(self, catalog_id=""):
client = boto3.client("glue", **self._get_iam_credentials())
schema = {}
database_paginator = client.get_paginator("get_databases")
table_paginator = client.get_paginator("get_tables")
for databases in database_paginator.paginate():
databases_iterator = database_paginator.paginate(
**({"CatalogId": catalog_id} if catalog_id != "" else {}),
)
for databases in databases_iterator:
for database in databases["DatabaseList"]:
iterator = table_paginator.paginate(DatabaseName=database["Name"])
iterator = table_paginator.paginate(
DatabaseName=database["Name"],
**({"CatalogId": catalog_id} if catalog_id != "" else {}),
)
for table in iterator.search("TableList[]"):
table_name = "%s.%s" % (database["Name"], table["Name"])
if "StorageDescriptor" not in table:
@@ -196,7 +207,8 @@ class Athena(BaseQueryRunner):
def get_schema(self, get_stats=False):
if self.configuration.get("glue", False):
return self.__get_schema_from_glue()
catalog_ids = [id.strip() for id in self.configuration.get("catalog_ids", "").split(",")]
return sum([self.__get_schema_from_glue(catalog_id) for catalog_id in catalog_ids], [])
schema = {}
query = """

View File

@@ -221,3 +221,97 @@ class TestGlueSchema(TestCase):
)
with self.stubber:
assert query_runner.get_schema() == []
def test_multi_catalog_tables(self):
"""Tables of multi-catalogs"""
query_runner = Athena({"glue": True, "region": "mars-east-1", "catalog_ids": "foo,bar"})
self.stubber.add_response("get_databases", {"DatabaseList": [{"Name": "test1"}]}, {"CatalogId": "foo"})
self.stubber.add_response(
"get_tables",
{
"TableList": [
{
"Name": "jdbc_table",
"StorageDescriptor": {
"Columns": [{"Name": "row_id", "Type": "int"}],
"Location": "Database.Schema.Table",
"Compressed": False,
"NumberOfBuckets": -1,
"SerdeInfo": {"Parameters": {}},
"BucketColumns": [],
"SortColumns": [],
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
"StoredAsSubDirectories": False,
},
"PartitionKeys": [],
"TableType": "EXTERNAL_TABLE",
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
}
]
},
{"CatalogId": "foo", "DatabaseName": "test1"},
)
self.stubber.add_response("get_databases", {"DatabaseList": [{"Name": "test2"}]}, {"CatalogId": "bar"})
self.stubber.add_response(
"get_tables",
{
"TableList": [
{
"Name": "jdbc_table",
"StorageDescriptor": {
"Columns": [{"Name": "row_id", "Type": "int"}],
"Location": "Database.Schema.Table",
"Compressed": False,
"NumberOfBuckets": -1,
"SerdeInfo": {"Parameters": {}},
"BucketColumns": [],
"SortColumns": [],
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
"StoredAsSubDirectories": False,
},
"PartitionKeys": [],
"TableType": "EXTERNAL_TABLE",
"Parameters": {
"CrawlerSchemaDeserializerVersion": "1.0",
"CrawlerSchemaSerializerVersion": "1.0",
"UPDATED_BY_CRAWLER": "jdbc",
"classification": "sqlserver",
"compressionType": "none",
"connectionName": "jdbctest",
"typeOfData": "view",
},
}
]
},
{"CatalogId": "bar", "DatabaseName": "test2"},
)
with self.stubber:
assert query_runner.get_schema() == [
{"columns": ["row_id"], "name": "test1.jdbc_table"},
{"columns": ["row_id"], "name": "test2.jdbc_table"},
]