1
0
mirror of synced 2025-12-25 02:09:19 -05:00

🎉 Source GitHub: Add option to pull commits from user-specified branches (#6223)

* Add option to pull commits from user-specified branches

* Address comments part 1

* Fix Repositories stream error when repo is not part of an org

* Make compatible with old state version and fix request_params to use branch-specific value

Co-authored-by: Chris Wu <chris@faros.ai>
This commit is contained in:
Yevhenii
2021-09-22 11:56:23 +03:00
committed by GitHub
parent 8de72d566b
commit 1644f9016a
8 changed files with 146 additions and 16 deletions

View File

@@ -12,5 +12,5 @@ RUN pip install .
ENV AIRBYTE_ENTRYPOINT "python /airbyte/integration_code/main.py"
ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]
LABEL io.airbyte.version=0.2.0
LABEL io.airbyte.version=0.2.1
LABEL io.airbyte.name=airbyte/source-github

View File

@@ -19,7 +19,7 @@ tests:
cursor_paths:
comments: ["airbytehq/integration-test", "updated_at"]
commit_comments: ["airbytehq/integration-test", "updated_at"]
commits: ["airbytehq/integration-test", "created_at"]
commits: ["airbytehq/integration-test", "master", "created_at"]
events: ["airbytehq/integration-test", "created_at"]
issue_events: ["airbytehq/integration-test", "created_at"]
issue_milestones: ["airbytehq/integration-test", "updated_at"]

View File

@@ -24,7 +24,7 @@
import re
from typing import Any, List, Mapping, Tuple
from typing import Any, Dict, List, Mapping, Tuple
from airbyte_cdk import AirbyteLogger
from airbyte_cdk.models import SyncMode
@@ -88,6 +88,43 @@ class SourceGithub(AbstractSource):
tokens = [t.strip() for t in token.split(TOKEN_SEPARATOR)]
return MultipleTokenAuthenticator(tokens=tokens, auth_method="token")
@staticmethod
def _get_branches_data(selected_branches: str, full_refresh_args: Dict[str, Any] = None) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
selected_branches = set(filter(None, selected_branches.split(" ")))
# Get the default branch for each repository
default_branches = {}
repository_stats_stream = RepositoryStats(**full_refresh_args)
for stream_slice in repository_stats_stream.stream_slices(sync_mode=SyncMode.full_refresh):
default_branches.update(
{
repo_stats["full_name"]: repo_stats["default_branch"]
for repo_stats in repository_stats_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)
}
)
all_branches = []
branches_stream = Branches(**full_refresh_args)
for stream_slice in branches_stream.stream_slices(sync_mode=SyncMode.full_refresh):
for branch in branches_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice):
all_branches.append(f"{branch['repository']}/{branch['name']}")
# Create mapping of repository to list of branches to pull commits for
# If no branches are specified for a repo, use its default branch
branches_to_pull: Dict[str, List[str]] = {}
for repo in full_refresh_args["repositories"]:
repo_branches = []
for branch in selected_branches:
branch_parts = branch.split("/", 2)
if "/".join(branch_parts[:2]) == repo and branch in all_branches:
repo_branches.append(branch_parts[-1])
if not repo_branches:
repo_branches = [default_branches[repo]]
branches_to_pull[repo] = repo_branches
return default_branches, branches_to_pull
def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Any]:
try:
authenticator = self._get_authenticator(config["access_token"])
@@ -110,6 +147,7 @@ class SourceGithub(AbstractSource):
full_refresh_args = {"authenticator": authenticator, "repositories": repositories}
incremental_args = {**full_refresh_args, "start_date": config["start_date"]}
organization_args = {"authenticator": authenticator, "organizations": organizations}
default_branches, branches_to_pull = self._get_branches_data(config.get("branch", ""), full_refresh_args)
return [
Assignees(**full_refresh_args),
@@ -118,7 +156,7 @@ class SourceGithub(AbstractSource):
Comments(**incremental_args),
CommitCommentReactions(**incremental_args),
CommitComments(**incremental_args),
Commits(**incremental_args),
Commits(**incremental_args, branches_to_pull=branches_to_pull, default_branches=default_branches),
Events(**incremental_args),
IssueCommentReactions(**incremental_args),
IssueEvents(**incremental_args),

View File

@@ -23,6 +23,11 @@
"description": "The date from which you'd like to replicate data for GitHub in the format YYYY-MM-DDT00:00:00Z. All data generated after this date will be replicated. Note that it will be used only in the following incremental streams: comments, commits and issues.",
"examples": ["2021-03-01T00:00:00Z"],
"pattern": "^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z$"
},
"branch": {
"type": "string",
"examples": ["airbytehq/airbyte/master"],
"description": "Space-delimited list of GitHub repository branches to pull commits for, e.g. `airbytehq/airbyte/master`. If no branches are specified for a repository, the default branch will be pulled."
}
}
}

View File

@@ -137,8 +137,12 @@ class GithubStream(HttpStream, ABC):
)
elif e.response.status_code == requests.codes.NOT_FOUND and "/teams?" in error_msg:
# For private repositories `Teams` stream is not available and we get "404 Client Error: Not Found for
# url: https://api.github.com/orgs/sherifnada/teams?per_page=100" error.
# url: https://api.github.com/orgs/<org_name>/teams?per_page=100" error.
error_msg = f"Syncing `Team` stream isn't available for organization `{stream_slice['organization']}`."
elif e.response.status_code == requests.codes.NOT_FOUND and "/repos?" in error_msg:
# `Repositories` stream is not available for repositories not in an organization.
# Handle "404 Client Error: Not Found for url: https://api.github.com/orgs/<org_name>/repos?per_page=100" error.
error_msg = f"Syncing `Repositories` stream isn't available for organization `{stream_slice['organization']}`."
elif e.response.status_code == requests.codes.GONE and "/projects?" in error_msg:
# Some repos don't have projects enabled and we we get "410 Client Error: Gone for
# url: https://api.github.com/repos/xyz/projects?per_page=100" error.
@@ -618,23 +622,105 @@ class Comments(IncrementalGithubStream):
class Commits(IncrementalGithubStream):
"""
API docs: https://docs.github.com/en/rest/reference/issues#list-issue-comments-for-a-repository
API docs: https://docs.github.com/en/rest/reference/repos#list-commits
Pull commits from each branch of each repository, tracking state for each branch
"""
primary_key = "sha"
cursor_field = "created_at"
def transform(self, record: MutableMapping[str, Any], repository: str = None, **kwargs) -> MutableMapping[str, Any]:
def __init__(self, branches_to_pull: Mapping[str, List[str]], default_branches: Mapping[str, str], **kwargs):
super().__init__(**kwargs)
self.branches_to_pull = branches_to_pull
self.default_branches = default_branches
def request_params(self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, **kwargs) -> MutableMapping[str, Any]:
params = super(IncrementalGithubStream, self).request_params(stream_state=stream_state, stream_slice=stream_slice, **kwargs)
params["since"] = self.get_starting_point(
stream_state=stream_state, repository=stream_slice["repository"], branch=stream_slice["branch"]
)
params["sha"] = stream_slice["branch"]
return params
def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]:
for stream_slice in super().stream_slices(**kwargs):
repository = stream_slice["repository"]
for branch in self.branches_to_pull.get(repository, []):
yield {"branch": branch, "repository": repository}
def parse_response(
self,
response: requests.Response,
stream_state: Mapping[str, Any],
stream_slice: Mapping[str, Any] = None,
next_page_token: Mapping[str, Any] = None,
) -> Iterable[Mapping]:
for record in response.json(): # GitHub puts records in an array.
yield self.transform(record=record, repository=stream_slice["repository"], branch=stream_slice["branch"])
def transform(self, record: MutableMapping[str, Any], repository: str = None, branch: str = None, **kwargs) -> MutableMapping[str, Any]:
record = super().transform(record=record, repository=repository)
# Record of the `commits` stream doesn't have an updated_at/created_at field at the top level (so we could
# just write `record["updated_at"]` or `record["created_at"]`). Instead each record has such value in
# `commit.author.date`. So the easiest way is to just enrich the record returned from API with top level
# field `created_at` and use it as cursor_field.
# Include the branch in the record
record["created_at"] = record["commit"]["author"]["date"]
record["branch"] = branch
return record
def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]):
state_value = latest_cursor_value = latest_record.get(self.cursor_field)
current_repository = latest_record["repository"]
current_branch = latest_record["branch"]
if current_stream_state.get(current_repository):
repository_commits_state = current_stream_state[current_repository]
if repository_commits_state.get(self.cursor_field):
# transfer state from old source version to per-branch version
if current_branch == self.default_branches[current_repository]:
state_value = max(latest_cursor_value, repository_commits_state[self.cursor_field])
del repository_commits_state[self.cursor_field]
elif repository_commits_state.get(current_branch, {}).get(self.cursor_field):
state_value = max(latest_cursor_value, repository_commits_state[current_branch][self.cursor_field])
if current_repository not in current_stream_state:
current_stream_state[current_repository] = {}
current_stream_state[current_repository][current_branch] = {self.cursor_field: state_value}
return current_stream_state
def get_starting_point(self, stream_state: Mapping[str, Any], repository: str, branch: str) -> str:
start_point = self._start_date
if stream_state and stream_state.get(repository, {}).get(branch, {}).get(self.cursor_field):
return max(start_point, stream_state[repository][branch][self.cursor_field])
if branch == self.default_branches[repository]:
return super().get_starting_point(stream_state=stream_state, repository=repository)
return start_point
def read_records(
self,
sync_mode: SyncMode,
cursor_field: List[str] = None,
stream_slice: Mapping[str, Any] = None,
stream_state: Mapping[str, Any] = None,
) -> Iterable[Mapping[str, Any]]:
repository = stream_slice["repository"]
start_point_map = {
branch: self.get_starting_point(stream_state=stream_state, repository=repository, branch=branch)
for branch in self.branches_to_pull.get(repository, [])
}
for record in super(SemiIncrementalGithubStream, self).read_records(
sync_mode=sync_mode, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state
):
if record.get(self.cursor_field) > start_point_map[stream_slice["branch"]]:
yield record
elif self.is_sorted_descending and record.get(self.cursor_field) < start_point_map[stream_slice["branch"]]:
break
class Issues(IncrementalGithubStream):
"""