194 lines
8.9 KiB
Python
194 lines
8.9 KiB
Python
#
|
|
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
|
|
#
|
|
|
|
|
|
import os
|
|
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union
|
|
|
|
import pendulum
|
|
from airbyte_cdk.config_observation import emit_configuration_as_airbyte_control_message
|
|
from airbyte_cdk.models import SyncMode
|
|
from airbyte_cdk.sources import AbstractSource
|
|
from airbyte_cdk.sources.streams import Stream
|
|
from airbyte_cdk.sources.streams.http.requests_native_auth.oauth import SingleUseRefreshTokenOauth2Authenticator
|
|
from airbyte_cdk.sources.streams.http.requests_native_auth.token import TokenAuthenticator
|
|
from requests.auth import AuthBase
|
|
|
|
from .streams import (
|
|
Branches,
|
|
Commits,
|
|
EpicIssues,
|
|
Epics,
|
|
GitlabStream,
|
|
GroupIssueBoards,
|
|
GroupLabels,
|
|
GroupMembers,
|
|
GroupMilestones,
|
|
GroupProjects,
|
|
Groups,
|
|
GroupsList,
|
|
IncludeDescendantGroups,
|
|
Issues,
|
|
Jobs,
|
|
MergeRequestCommits,
|
|
MergeRequests,
|
|
Pipelines,
|
|
PipelinesExtended,
|
|
ProjectLabels,
|
|
ProjectMembers,
|
|
ProjectMilestones,
|
|
Projects,
|
|
Releases,
|
|
Tags,
|
|
Users,
|
|
)
|
|
from .utils import parse_url
|
|
|
|
|
|
class SingleUseRefreshTokenGitlabOAuth2Authenticator(SingleUseRefreshTokenOauth2Authenticator):
|
|
def __init__(self, *args, created_at_name: str = "created_at", **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._created_at_name = created_at_name
|
|
|
|
def get_created_at_name(self) -> str:
|
|
return self._created_at_name
|
|
|
|
def get_access_token(self) -> str:
|
|
if self.token_has_expired():
|
|
new_access_token, access_token_expires_in, access_token_created_at, new_refresh_token = self.refresh_access_token()
|
|
new_token_expiry_date = self.get_new_token_expiry_date(access_token_expires_in, access_token_created_at)
|
|
self.access_token = new_access_token
|
|
self.set_refresh_token(new_refresh_token)
|
|
self.set_token_expiry_date(new_token_expiry_date)
|
|
emit_configuration_as_airbyte_control_message(self._connector_config)
|
|
return self.access_token
|
|
|
|
@staticmethod
|
|
def get_new_token_expiry_date(access_token_expires_in: int, access_token_created_at: int) -> pendulum.DateTime:
|
|
return pendulum.from_timestamp(access_token_created_at + access_token_expires_in)
|
|
|
|
def refresh_access_token(self) -> Tuple[str, int, int, str]:
|
|
response_json = self._get_refresh_access_token_response()
|
|
return (
|
|
response_json[self.get_access_token_name()],
|
|
response_json[self.get_expires_in_name()],
|
|
response_json[self.get_created_at_name()],
|
|
response_json[self.get_refresh_token_name()],
|
|
)
|
|
|
|
|
|
def get_authenticator(config: MutableMapping) -> AuthBase:
|
|
if config["credentials"]["auth_type"] == "access_token":
|
|
return TokenAuthenticator(token=config["credentials"]["access_token"])
|
|
return SingleUseRefreshTokenGitlabOAuth2Authenticator(config, token_refresh_endpoint=f"https://{config['api_url']}/oauth/token")
|
|
|
|
|
|
class SourceGitlab(AbstractSource):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.__auth_params: Mapping[str, Any] = {}
|
|
self.__groups_stream: Optional[GitlabStream] = None
|
|
self.__projects_stream: Optional[GitlabStream] = None
|
|
|
|
@staticmethod
|
|
def _ensure_default_values(config: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
|
|
config["api_url"] = config.get("api_url") or "gitlab.com"
|
|
return config
|
|
|
|
def _groups_stream(self, config: MutableMapping[str, Any]) -> Groups:
|
|
if not self.__groups_stream:
|
|
auth_params = self._auth_params(config)
|
|
group_ids = list(map(lambda x: x["id"], self._get_group_list(config)))
|
|
self.__groups_stream = Groups(group_ids=group_ids, **auth_params)
|
|
return self.__groups_stream
|
|
|
|
def _projects_stream(self, config: MutableMapping[str, Any]) -> Union[Projects, GroupProjects]:
|
|
if not self.__projects_stream:
|
|
auth_params = self._auth_params(config)
|
|
project_ids = list(filter(None, config.get("projects", "").split(" ")))
|
|
groups_stream = self._groups_stream(config)
|
|
if groups_stream.group_ids:
|
|
self.__projects_stream = GroupProjects(project_ids=project_ids, parent_stream=groups_stream, **auth_params)
|
|
return self.__projects_stream
|
|
self.__projects_stream = Projects(project_ids=project_ids, **auth_params)
|
|
return self.__projects_stream
|
|
|
|
def _auth_params(self, config: MutableMapping[str, Any]) -> Mapping[str, Any]:
|
|
if not self.__auth_params:
|
|
auth = get_authenticator(config)
|
|
self.__auth_params = dict(authenticator=auth, api_url=config["api_url"])
|
|
return self.__auth_params
|
|
|
|
def _get_group_list(self, config: MutableMapping[str, Any]) -> List[str]:
|
|
group_ids = list(filter(None, config.get("groups", "").split(" ")))
|
|
# Gitlab exposes different APIs to get a list of groups.
|
|
# We use https://docs.gitlab.com/ee/api/groups.html#list-groups in case there's no group IDs in the input config.
|
|
# This API provides full information about all available groups, including subgroups.
|
|
#
|
|
# In case there is a definitive list of groups IDs in the input config, the above API can not be used since
|
|
# it does not support filtering by group ID, so we use
|
|
# https://docs.gitlab.com/ee/api/groups.html#details-of-a-group and
|
|
# https: //docs.gitlab.com/ee/api/groups.html#list-a-groups-descendant-groups for each group ID. The latter one does not
|
|
# provide full group info so can only be used to retrieve alist of group IDs and pass it further to init a corresponding stream.
|
|
auth_params = self._auth_params(config)
|
|
stream = GroupsList(**auth_params) if not group_ids else IncludeDescendantGroups(group_ids=group_ids, **auth_params)
|
|
for stream_slice in stream.stream_slices(sync_mode=SyncMode.full_refresh):
|
|
yield from stream.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice)
|
|
|
|
@staticmethod
|
|
def _is_http_allowed() -> bool:
|
|
return os.environ.get("DEPLOYMENT_MODE", "").upper() != "CLOUD"
|
|
|
|
def check_connection(self, logger, config) -> Tuple[bool, any]:
|
|
config = self._ensure_default_values(config)
|
|
is_valid, scheme, _ = parse_url(config["api_url"])
|
|
if not is_valid:
|
|
return False, "Invalid API resource locator."
|
|
if scheme == "http" and not self._is_http_allowed():
|
|
return False, "Http scheme is not allowed in this environment. Please use `https` instead."
|
|
try:
|
|
projects = self._projects_stream(config)
|
|
for stream_slice in projects.stream_slices(sync_mode=SyncMode.full_refresh):
|
|
next(projects.read_records(sync_mode=SyncMode.full_refresh, stream_slice=stream_slice))
|
|
return True, None
|
|
return True, None # in case there's no projects
|
|
except Exception as error:
|
|
return False, f"Unable to connect to Gitlab API with the provided credentials - {repr(error)}"
|
|
|
|
def streams(self, config: MutableMapping[str, Any]) -> List[Stream]:
|
|
config = self._ensure_default_values(config)
|
|
auth_params = self._auth_params(config)
|
|
|
|
groups, projects = self._groups_stream(config), self._projects_stream(config)
|
|
pipelines = Pipelines(parent_stream=projects, start_date=config["start_date"], **auth_params)
|
|
merge_requests = MergeRequests(parent_stream=projects, start_date=config["start_date"], **auth_params)
|
|
epics = Epics(parent_stream=groups, **auth_params)
|
|
|
|
streams = [
|
|
groups,
|
|
projects,
|
|
Branches(parent_stream=projects, repository_part=True, **auth_params),
|
|
Commits(parent_stream=projects, repository_part=True, start_date=config["start_date"], **auth_params),
|
|
epics,
|
|
EpicIssues(parent_stream=epics, **auth_params),
|
|
GroupIssueBoards(parent_stream=groups, **auth_params),
|
|
Issues(parent_stream=projects, start_date=config["start_date"], **auth_params),
|
|
Jobs(parent_stream=pipelines, **auth_params),
|
|
ProjectMilestones(parent_stream=projects, **auth_params),
|
|
GroupMilestones(parent_stream=groups, **auth_params),
|
|
ProjectMembers(parent_stream=projects, **auth_params),
|
|
GroupMembers(parent_stream=groups, **auth_params),
|
|
ProjectLabels(parent_stream=projects, **auth_params),
|
|
GroupLabels(parent_stream=groups, **auth_params),
|
|
merge_requests,
|
|
MergeRequestCommits(parent_stream=merge_requests, **auth_params),
|
|
Releases(parent_stream=projects, **auth_params),
|
|
Tags(parent_stream=projects, repository_part=True, **auth_params),
|
|
pipelines,
|
|
PipelinesExtended(parent_stream=pipelines, **auth_params),
|
|
Users(parent_stream=projects, **auth_params),
|
|
]
|
|
|
|
return streams
|