From 2dbf8ec0adf61d2247c491d89d1c4a57dd1d9364 Mon Sep 17 00:00:00 2001 From: trevorhobenshield Date: Mon, 5 Jun 2023 20:23:35 -0700 Subject: [PATCH] refactor logger --- setup.py | 2 +- twitter/account.py | 7 ++++--- twitter/constants.py | 5 +---- twitter/scraper.py | 44 ++++++++++++++++++++++++++--------------- twitter/search.py | 47 ++++++++++++++++++++++---------------------- 5 files changed, 57 insertions(+), 48 deletions(-) diff --git a/setup.py b/setup.py index 3e934cc..92f2f19 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ install_requires = [ setup( name="twitter-api-client", - version="0.8.7", + version="0.8.8", python_requires=">=3.10.10", description="Twitter API", long_description=dedent(''' diff --git a/twitter/account.py b/twitter/account.py index f664f73..82808a4 100644 --- a/twitter/account.py +++ b/twitter/account.py @@ -101,7 +101,7 @@ class Account: variables['message']['text'] = {'text': text} res = self.gql('POST', Operation.useSendMessageMutation, variables) if find_key(res, 'dm_validation_failure_type'): - logger.debug(f"{RED}Failed to send DM(s) to {receivers}{RESET}") + self.logger.debug(f"{RED}Failed to send DM(s) to {receivers}{RESET}") return res def tweet(self, text: str, *, media: any = None, **kwargs) -> dict: @@ -560,8 +560,9 @@ class Account: def init_logger(cfg: dict) -> Logger: if cfg: logging.config.dictConfig(cfg) - return logging.getLogger(__name__) - return logger + else: + logging.config.dictConfig(LOG_CONFIG) + return logging.getLogger(__name__) @staticmethod def validate_session(*args, **kwargs): diff --git a/twitter/constants.py b/twitter/constants.py index 38afa8c..12a2d13 100644 --- a/twitter/constants.py +++ b/twitter/constants.py @@ -35,7 +35,7 @@ DISABLE_LOG_PROPAGATION = [ 'uvloop', ] -LOGGING_CONFIG = { +LOG_CONFIG = { 'version': 1, 'disable_existing_loggers': False, 'formatters': { @@ -73,9 +73,6 @@ LOGGING_CONFIG = { }, } -logging.config.dictConfig(LOGGING_CONFIG) -logger = logging.getLogger(__name__) - @dataclass class SpaceCategory: diff --git a/twitter/scraper.py b/twitter/scraper.py index 40160b3..9911534 100644 --- a/twitter/scraper.py +++ b/twitter/scraper.py @@ -35,8 +35,9 @@ class Scraper: self.guest = False self.logger = self.init_logger(kwargs.get('log_config', False)) self.session = self.validate_session(email, username, password, session, **kwargs) - self.save = kwargs.get('save', True) self.debug = kwargs.get('debug', 0) + self.save = kwargs.get('save', True) + self.pbar = kwargs.get('pbar', True) self.out_path = Path('data') def users(self, screen_names: list[str], **kwargs) -> list[dict]: @@ -120,8 +121,10 @@ class Scraper: async def process(): async with AsyncClient(headers=self.session.headers, cookies=self.session.cookies) as client: - return await tqdm_asyncio.gather(*(download(client, x, y) for x, y in urls), - desc='downloading media') + tasks = (download(client, x, y) for x, y in urls) + if self.pbar: + return await tqdm_asyncio.gather(*tasks, desc='Downloading media') + return await asyncio.gather(*tasks) async def download(client: AsyncClient, post_url: str, cdn_url: str) -> None: name = urlsplit(post_url).path.replace('/', '_')[1:] @@ -154,10 +157,10 @@ class Scraper: "-0200", "-0100", "+0000", "+0100", "+0200", "+0300", "+0400", "+0500", "+0600", "+0700", "+0800", "+0900", "+1000", "+1100", "+1200", "+1300", "+1400"] async with AsyncClient(headers=get_headers(self.session)) as client: - return await tqdm_asyncio.gather( - *(get_trends(client, o, url) for o in offsets), - desc='downloading media' - ) + tasks = (get_trends(client, o, url) for o in offsets) + if self.pbar: + return await tqdm_asyncio.gather(*tasks, desc='Getting trends') + return await asyncio.gather(*tasks) trends = asyncio.run(process()) out = self.out_path / 'raw' / 'trends' @@ -294,7 +297,10 @@ class Scraper: headers = self.session.headers if self.guest else get_headers(self.session) cookies = self.session.cookies async with AsyncClient(limits=limits, headers=headers, cookies=cookies, timeout=20) as c: - return await tqdm_asyncio.gather(*(get(c, key) for key in keys), desc='downloading chat') + tasks = (get(c, key) for key in keys) + if self.pbar: + return await tqdm_asyncio.gather(*tasks, desc='Downloading chat data') + return await asyncio.gather(*tasks) return asyncio.run(process()) @@ -311,7 +317,9 @@ class Scraper: tasks = [] for d in data: tasks.extend([get(c, chunk, d['rest_id']) for chunk in d['chunks']]) - return await tqdm_asyncio.gather(*tasks, desc='downloading audio') + if self.pbar: + return await tqdm_asyncio.gather(*tasks, desc='Downloading audio') + return await asyncio.gather(*tasks) chunks = asyncio.run(process(data)) streams = {} @@ -381,10 +389,10 @@ class Scraper: headers = self.session.headers if self.guest else get_headers(self.session) cookies = self.session.cookies async with AsyncClient(limits=limits, headers=headers, cookies=cookies, timeout=20) as c: - return await tqdm_asyncio.gather( - *(self._paginate(c, operation, **q, **kwargs) for q in queries), - desc=operation[-1], - ) + tasks = (self._paginate(c, operation, **q, **kwargs) for q in queries) + if self.pbar: + return await tqdm_asyncio.gather(*tasks, desc=operation[-1]) + return await asyncio.gather(*tasks) async def _paginate(self, client: AsyncClient, operation: tuple, **kwargs): limit = kwargs.pop('limit', math.inf) @@ -507,7 +515,10 @@ class Scraper: limits = Limits(max_connections=100) async with AsyncClient(headers=client.headers, limits=limits, timeout=30) as c: - return await tqdm_asyncio.gather(*(get(c, _id) for _id in spaces), desc='getting live transcripts') + tasks = (get(c, _id) for _id in spaces) + if self.pbar: + return await tqdm_asyncio.gather(*tasks, desc='Getting live transcripts') + return await asyncio.gather(*tasks) def space_live_transcript(self, room: str, frequency: int = 1): async def get(spaces: list[dict]): @@ -592,8 +603,9 @@ class Scraper: def init_logger(cfg: dict) -> Logger: if cfg: logging.config.dictConfig(cfg) - return logging.getLogger(__name__) - return logger + else: + logging.config.dictConfig(LOG_CONFIG) + return logging.getLogger(__name__) def validate_session(self, *args, **kwargs): email, username, password, session = args diff --git a/twitter/search.py b/twitter/search.py index 1313191..68ac2e3 100644 --- a/twitter/search.py +++ b/twitter/search.py @@ -41,26 +41,6 @@ class Search: self.api = 'https://api.twitter.com/2/search/adaptive.json?' self.save = kwargs.get('save', True) self.debug = kwargs.get('debug', 0) - self.logger = self.init_logger(kwargs.get('log_config', False)) - - @staticmethod - def init_logger(cfg: dict) -> Logger: - if cfg: - logging.config.dictConfig(cfg) - return logging.getLogger(__name__) - return logger - - @staticmethod - def validate_session(*args, **kwargs): - email, username, password, session = args - if session and all(session.cookies.get(c) for c in {'ct0', 'auth_token'}): - # authenticated session provided - return session - if not session: - # no session provided, login to authenticate - return login(email, username, password, **kwargs) - raise Exception('Session not authenticated. ' - 'Please use an authenticated session or remove the `session` argument and try again.') def run(self, *args, out: str = 'data', **kwargs): out_path = self.make_output_dirs(out) @@ -72,8 +52,7 @@ class Search: async with AsyncClient(headers=get_headers(self.session)) as s: return await asyncio.gather(*(self.paginate(q, s, config, out, **kwargs) for q in queries)) - async def paginate(self, query: str, session: AsyncClient, config: dict, out: Path, **kwargs) -> list[ - dict]: + async def paginate(self, query: str, session: AsyncClient, config: dict, out: Path, **kwargs) -> list[dict]: config['q'] = query data, next_cursor = await self.backoff(lambda: self.get(session, config), query, **kwargs) all_data = [data] @@ -84,7 +63,7 @@ class Search: if len(ids) >= kwargs.get('limit', math.inf): if self.debug: self.logger.debug( - f'[{GREEN}success{RESET}] returned {len(ids)} search results for {c}{query}{reset}') + f'[{GREEN}success{RESET}] Returned {len(ids)} search results for {c}{query}{reset}') return all_data if self.debug: self.logger.debug(f'{c}{query}{reset}') @@ -120,7 +99,7 @@ class Search: t = 2 ** i + random.random() if self.debug: self.logger.debug( - f'No data for: \u001b[1m{info}\u001b[0m | retrying in {f"{t:.2f}"} seconds\t\t{e}') + f'No data for: {BOLD}{info}{RESET}, retrying in {f"{t:.2f}"} seconds\t\t{e}') time.sleep(t) async def get(self, session: AsyncClient, params: dict) -> tuple: @@ -154,3 +133,23 @@ class Search: (p / 'processed').mkdir(parents=True, exist_ok=True) (p / 'final').mkdir(parents=True, exist_ok=True) return p + + @staticmethod + def init_logger(cfg: dict) -> Logger: + if cfg: + logging.config.dictConfig(cfg) + else: + logging.config.dictConfig(LOG_CONFIG) + return logging.getLogger(__name__) + + @staticmethod + def validate_session(*args, **kwargs): + email, username, password, session = args + if session and all(session.cookies.get(c) for c in {'ct0', 'auth_token'}): + # authenticated session provided + return session + if not session: + # no session provided, log-in to authenticate + return login(email, username, password, **kwargs) + raise Exception('Session not authenticated. ' + 'Please use an authenticated session or remove the `session` argument and try again.')