🎉 Source GitHub: Add option to pull commits from user-specified branches (#6223)
* Add option to pull commits from user-specified branches * Address comments part 1 * Fix Repositories stream error when repo is not part of an org * Make compatible with old state version and fix request_params to use branch-specific value Co-authored-by: Chris Wu <chris@faros.ai>
This commit is contained in:
@@ -12,5 +12,5 @@ RUN pip install .
|
||||
ENV AIRBYTE_ENTRYPOINT "python /airbyte/integration_code/main.py"
|
||||
ENTRYPOINT ["python", "/airbyte/integration_code/main.py"]
|
||||
|
||||
LABEL io.airbyte.version=0.2.0
|
||||
LABEL io.airbyte.version=0.2.1
|
||||
LABEL io.airbyte.name=airbyte/source-github
|
||||
|
||||
@@ -19,7 +19,7 @@ tests:
|
||||
cursor_paths:
|
||||
comments: ["airbytehq/integration-test", "updated_at"]
|
||||
commit_comments: ["airbytehq/integration-test", "updated_at"]
|
||||
commits: ["airbytehq/integration-test", "created_at"]
|
||||
commits: ["airbytehq/integration-test", "master", "created_at"]
|
||||
events: ["airbytehq/integration-test", "created_at"]
|
||||
issue_events: ["airbytehq/integration-test", "created_at"]
|
||||
issue_milestones: ["airbytehq/integration-test", "updated_at"]
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
|
||||
|
||||
import re
|
||||
from typing import Any, List, Mapping, Tuple
|
||||
from typing import Any, Dict, List, Mapping, Tuple
|
||||
|
||||
from airbyte_cdk import AirbyteLogger
|
||||
from airbyte_cdk.models import SyncMode
|
||||
@@ -88,6 +88,43 @@ class SourceGithub(AbstractSource):
|
||||
tokens = [t.strip() for t in token.split(TOKEN_SEPARATOR)]
|
||||
return MultipleTokenAuthenticator(tokens=tokens, auth_method="token")
|
||||
|
||||
@staticmethod
|
||||
def _get_branches_data(selected_branches: str, full_refresh_args: Dict[str, Any] = None) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
|
||||
selected_branches = set(filter(None, selected_branches.split(" ")))
|
||||
|
||||
# Get the default branch for each repository
|
||||
default_branches = {}
|
||||
repository_stats_stream = RepositoryStats(**full_refresh_args)
|
||||
for stream_slice in repository_stats_stream.stream_slices(sync_mode=SyncMode.full_refresh):
|
||||
default_branches.update(
|
||||
{
|
||||
repo_stats["full_name"]: repo_stats["default_branch"]
|
||||
for repo_stats in repository_stats_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)
|
||||
}
|
||||
)
|
||||
|
||||
all_branches = []
|
||||
branches_stream = Branches(**full_refresh_args)
|
||||
for stream_slice in branches_stream.stream_slices(sync_mode=SyncMode.full_refresh):
|
||||
for branch in branches_stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice):
|
||||
all_branches.append(f"{branch['repository']}/{branch['name']}")
|
||||
|
||||
# Create mapping of repository to list of branches to pull commits for
|
||||
# If no branches are specified for a repo, use its default branch
|
||||
branches_to_pull: Dict[str, List[str]] = {}
|
||||
for repo in full_refresh_args["repositories"]:
|
||||
repo_branches = []
|
||||
for branch in selected_branches:
|
||||
branch_parts = branch.split("/", 2)
|
||||
if "/".join(branch_parts[:2]) == repo and branch in all_branches:
|
||||
repo_branches.append(branch_parts[-1])
|
||||
if not repo_branches:
|
||||
repo_branches = [default_branches[repo]]
|
||||
|
||||
branches_to_pull[repo] = repo_branches
|
||||
|
||||
return default_branches, branches_to_pull
|
||||
|
||||
def check_connection(self, logger: AirbyteLogger, config: Mapping[str, Any]) -> Tuple[bool, Any]:
|
||||
try:
|
||||
authenticator = self._get_authenticator(config["access_token"])
|
||||
@@ -110,6 +147,7 @@ class SourceGithub(AbstractSource):
|
||||
full_refresh_args = {"authenticator": authenticator, "repositories": repositories}
|
||||
incremental_args = {**full_refresh_args, "start_date": config["start_date"]}
|
||||
organization_args = {"authenticator": authenticator, "organizations": organizations}
|
||||
default_branches, branches_to_pull = self._get_branches_data(config.get("branch", ""), full_refresh_args)
|
||||
|
||||
return [
|
||||
Assignees(**full_refresh_args),
|
||||
@@ -118,7 +156,7 @@ class SourceGithub(AbstractSource):
|
||||
Comments(**incremental_args),
|
||||
CommitCommentReactions(**incremental_args),
|
||||
CommitComments(**incremental_args),
|
||||
Commits(**incremental_args),
|
||||
Commits(**incremental_args, branches_to_pull=branches_to_pull, default_branches=default_branches),
|
||||
Events(**incremental_args),
|
||||
IssueCommentReactions(**incremental_args),
|
||||
IssueEvents(**incremental_args),
|
||||
|
||||
@@ -23,6 +23,11 @@
|
||||
"description": "The date from which you'd like to replicate data for GitHub in the format YYYY-MM-DDT00:00:00Z. All data generated after this date will be replicated. Note that it will be used only in the following incremental streams: comments, commits and issues.",
|
||||
"examples": ["2021-03-01T00:00:00Z"],
|
||||
"pattern": "^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}Z$"
|
||||
},
|
||||
"branch": {
|
||||
"type": "string",
|
||||
"examples": ["airbytehq/airbyte/master"],
|
||||
"description": "Space-delimited list of GitHub repository branches to pull commits for, e.g. `airbytehq/airbyte/master`. If no branches are specified for a repository, the default branch will be pulled."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -137,8 +137,12 @@ class GithubStream(HttpStream, ABC):
|
||||
)
|
||||
elif e.response.status_code == requests.codes.NOT_FOUND and "/teams?" in error_msg:
|
||||
# For private repositories `Teams` stream is not available and we get "404 Client Error: Not Found for
|
||||
# url: https://api.github.com/orgs/sherifnada/teams?per_page=100" error.
|
||||
# url: https://api.github.com/orgs/<org_name>/teams?per_page=100" error.
|
||||
error_msg = f"Syncing `Team` stream isn't available for organization `{stream_slice['organization']}`."
|
||||
elif e.response.status_code == requests.codes.NOT_FOUND and "/repos?" in error_msg:
|
||||
# `Repositories` stream is not available for repositories not in an organization.
|
||||
# Handle "404 Client Error: Not Found for url: https://api.github.com/orgs/<org_name>/repos?per_page=100" error.
|
||||
error_msg = f"Syncing `Repositories` stream isn't available for organization `{stream_slice['organization']}`."
|
||||
elif e.response.status_code == requests.codes.GONE and "/projects?" in error_msg:
|
||||
# Some repos don't have projects enabled and we we get "410 Client Error: Gone for
|
||||
# url: https://api.github.com/repos/xyz/projects?per_page=100" error.
|
||||
@@ -618,23 +622,105 @@ class Comments(IncrementalGithubStream):
|
||||
|
||||
class Commits(IncrementalGithubStream):
|
||||
"""
|
||||
API docs: https://docs.github.com/en/rest/reference/issues#list-issue-comments-for-a-repository
|
||||
API docs: https://docs.github.com/en/rest/reference/repos#list-commits
|
||||
|
||||
Pull commits from each branch of each repository, tracking state for each branch
|
||||
"""
|
||||
|
||||
primary_key = "sha"
|
||||
cursor_field = "created_at"
|
||||
|
||||
def transform(self, record: MutableMapping[str, Any], repository: str = None, **kwargs) -> MutableMapping[str, Any]:
|
||||
def __init__(self, branches_to_pull: Mapping[str, List[str]], default_branches: Mapping[str, str], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.branches_to_pull = branches_to_pull
|
||||
self.default_branches = default_branches
|
||||
|
||||
def request_params(self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, Any] = None, **kwargs) -> MutableMapping[str, Any]:
|
||||
params = super(IncrementalGithubStream, self).request_params(stream_state=stream_state, stream_slice=stream_slice, **kwargs)
|
||||
params["since"] = self.get_starting_point(
|
||||
stream_state=stream_state, repository=stream_slice["repository"], branch=stream_slice["branch"]
|
||||
)
|
||||
params["sha"] = stream_slice["branch"]
|
||||
return params
|
||||
|
||||
def stream_slices(self, **kwargs) -> Iterable[Optional[Mapping[str, Any]]]:
|
||||
for stream_slice in super().stream_slices(**kwargs):
|
||||
repository = stream_slice["repository"]
|
||||
for branch in self.branches_to_pull.get(repository, []):
|
||||
yield {"branch": branch, "repository": repository}
|
||||
|
||||
def parse_response(
|
||||
self,
|
||||
response: requests.Response,
|
||||
stream_state: Mapping[str, Any],
|
||||
stream_slice: Mapping[str, Any] = None,
|
||||
next_page_token: Mapping[str, Any] = None,
|
||||
) -> Iterable[Mapping]:
|
||||
for record in response.json(): # GitHub puts records in an array.
|
||||
yield self.transform(record=record, repository=stream_slice["repository"], branch=stream_slice["branch"])
|
||||
|
||||
def transform(self, record: MutableMapping[str, Any], repository: str = None, branch: str = None, **kwargs) -> MutableMapping[str, Any]:
|
||||
record = super().transform(record=record, repository=repository)
|
||||
|
||||
# Record of the `commits` stream doesn't have an updated_at/created_at field at the top level (so we could
|
||||
# just write `record["updated_at"]` or `record["created_at"]`). Instead each record has such value in
|
||||
# `commit.author.date`. So the easiest way is to just enrich the record returned from API with top level
|
||||
# field `created_at` and use it as cursor_field.
|
||||
# Include the branch in the record
|
||||
record["created_at"] = record["commit"]["author"]["date"]
|
||||
record["branch"] = branch
|
||||
|
||||
return record
|
||||
|
||||
def get_updated_state(self, current_stream_state: MutableMapping[str, Any], latest_record: Mapping[str, Any]):
|
||||
state_value = latest_cursor_value = latest_record.get(self.cursor_field)
|
||||
current_repository = latest_record["repository"]
|
||||
current_branch = latest_record["branch"]
|
||||
|
||||
if current_stream_state.get(current_repository):
|
||||
repository_commits_state = current_stream_state[current_repository]
|
||||
if repository_commits_state.get(self.cursor_field):
|
||||
# transfer state from old source version to per-branch version
|
||||
if current_branch == self.default_branches[current_repository]:
|
||||
state_value = max(latest_cursor_value, repository_commits_state[self.cursor_field])
|
||||
del repository_commits_state[self.cursor_field]
|
||||
elif repository_commits_state.get(current_branch, {}).get(self.cursor_field):
|
||||
state_value = max(latest_cursor_value, repository_commits_state[current_branch][self.cursor_field])
|
||||
|
||||
if current_repository not in current_stream_state:
|
||||
current_stream_state[current_repository] = {}
|
||||
|
||||
current_stream_state[current_repository][current_branch] = {self.cursor_field: state_value}
|
||||
return current_stream_state
|
||||
|
||||
def get_starting_point(self, stream_state: Mapping[str, Any], repository: str, branch: str) -> str:
|
||||
start_point = self._start_date
|
||||
if stream_state and stream_state.get(repository, {}).get(branch, {}).get(self.cursor_field):
|
||||
return max(start_point, stream_state[repository][branch][self.cursor_field])
|
||||
if branch == self.default_branches[repository]:
|
||||
return super().get_starting_point(stream_state=stream_state, repository=repository)
|
||||
return start_point
|
||||
|
||||
def read_records(
|
||||
self,
|
||||
sync_mode: SyncMode,
|
||||
cursor_field: List[str] = None,
|
||||
stream_slice: Mapping[str, Any] = None,
|
||||
stream_state: Mapping[str, Any] = None,
|
||||
) -> Iterable[Mapping[str, Any]]:
|
||||
repository = stream_slice["repository"]
|
||||
start_point_map = {
|
||||
branch: self.get_starting_point(stream_state=stream_state, repository=repository, branch=branch)
|
||||
for branch in self.branches_to_pull.get(repository, [])
|
||||
}
|
||||
for record in super(SemiIncrementalGithubStream, self).read_records(
|
||||
sync_mode=sync_mode, cursor_field=cursor_field, stream_slice=stream_slice, stream_state=stream_state
|
||||
):
|
||||
if record.get(self.cursor_field) > start_point_map[stream_slice["branch"]]:
|
||||
yield record
|
||||
elif self.is_sorted_descending and record.get(self.cursor_field) < start_point_map[stream_slice["branch"]]:
|
||||
break
|
||||
|
||||
|
||||
class Issues(IncrementalGithubStream):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user