mirror of
https://github.com/trevorhobenshield/twitter-api-client.git
synced 2025-12-19 18:12:11 -05:00
170 lines
6.2 KiB
Python
170 lines
6.2 KiB
Python
import asyncio
|
|
import logging.config
|
|
import math
|
|
import platform
|
|
import random
|
|
import re
|
|
import time
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
|
|
import orjson
|
|
from httpx import AsyncClient, Client
|
|
|
|
from .constants import *
|
|
from .login import login
|
|
from .util import get_headers, find_key, build_params
|
|
|
|
reset = '\x1b[0m'
|
|
colors = [f'\x1b[{i}m' for i in range(31, 37)]
|
|
|
|
try:
|
|
import nest_asyncio
|
|
|
|
nest_asyncio.apply()
|
|
except:
|
|
...
|
|
|
|
try:
|
|
import uvloop
|
|
|
|
uvloop.install()
|
|
except:
|
|
...
|
|
|
|
|
|
class Search:
|
|
def __init__(self, email: str = None, username: str = None, password: str = None, session: Client = None, **kwargs):
|
|
self.save = kwargs.get('save', True)
|
|
self.debug = kwargs.get('debug', 0)
|
|
self.logger = self._init_logger(**kwargs)
|
|
self.session = self._validate_session(email, username, password, session, **kwargs)
|
|
|
|
def run(self, queries: list[dict], limit: int = math.inf, out: str = 'data/search_results', **kwargs):
|
|
out = Path(out)
|
|
out.mkdir(parents=True, exist_ok=True)
|
|
return asyncio.run(self.process(queries, limit, out, **kwargs))
|
|
|
|
async def process(self, queries: list[dict], limit: int, out: Path, **kwargs) -> list:
|
|
async with AsyncClient(headers=get_headers(self.session)) as s:
|
|
return await asyncio.gather(*(self.paginate(s, q, limit, out, **kwargs) for q in queries))
|
|
|
|
async def paginate(self, client: AsyncClient, query: dict, limit: int, out: Path, **kwargs) -> list[dict]:
|
|
params = {
|
|
'variables': {
|
|
'count': 20,
|
|
'querySource': 'typed_query',
|
|
'rawQuery': query['query'],
|
|
'product': query['category']
|
|
},
|
|
'features': Operation.default_features,
|
|
'fieldToggles': {'withArticleRichContentState': False},
|
|
}
|
|
|
|
res = []
|
|
cursor = ''
|
|
total = set()
|
|
while True:
|
|
if cursor:
|
|
params['variables']['cursor'] = cursor
|
|
data, entries, cursor = await self.backoff(lambda: self.get(client, params), **kwargs)
|
|
res.extend(entries)
|
|
if len(entries) <= 2 or len(total) >= limit: # just cursors
|
|
self.debug and self.logger.debug(
|
|
f'[{GREEN}success{RESET}] Returned {len(total)} search results for {query["query"]}')
|
|
return res
|
|
total |= set(find_key(entries, 'entryId'))
|
|
self.debug and self.logger.debug(f'{query["query"]}')
|
|
self.save and (out / f'{time.time_ns()}.json').write_bytes(orjson.dumps(entries))
|
|
|
|
async def get(self, client: AsyncClient, params: dict) -> tuple:
|
|
_, qid, name = Operation.SearchTimeline
|
|
r = await client.get(f'https://twitter.com/i/api/graphql/{qid}/{name}', params=build_params(params))
|
|
data = r.json()
|
|
cursor = self.get_cursor(data)
|
|
entries = [y for x in find_key(data, 'entries') for y in x if re.search(r'^(tweet|user)-', y['entryId'])]
|
|
# add on query info
|
|
for e in entries:
|
|
e['query'] = params['variables']['rawQuery']
|
|
return data, entries, cursor
|
|
|
|
def get_cursor(self, data: list[dict]):
|
|
for e in find_key(data, 'content'):
|
|
if e.get('cursorType') == 'Bottom':
|
|
return e['value']
|
|
|
|
async def backoff(self, fn, **kwargs):
|
|
retries = kwargs.get('retries', 3)
|
|
for i in range(retries + 1):
|
|
try:
|
|
data, entries, cursor = await fn()
|
|
if errors := data.get('errors'):
|
|
for e in errors:
|
|
self.logger.warning(f'{YELLOW}{e.get("message")}{RESET}')
|
|
return [], [], ''
|
|
ids = set(find_key(data, 'entryId'))
|
|
if len(ids) >= 2:
|
|
return data, entries, cursor
|
|
except Exception as e:
|
|
if i == retries:
|
|
self.logger.debug(f'Max retries exceeded\n{e}')
|
|
return
|
|
t = 2 ** i + random.random()
|
|
self.logger.debug(f'Retrying in {f"{t:.2f}"} seconds\t\t{e}')
|
|
await asyncio.sleep(t)
|
|
|
|
def _init_logger(self, **kwargs) -> Logger:
|
|
if kwargs.get('debug'):
|
|
cfg = kwargs.get('log_config')
|
|
logging.config.dictConfig(cfg or LOG_CONFIG)
|
|
|
|
# only support one logger
|
|
logger_name = list(LOG_CONFIG['loggers'].keys())[0]
|
|
|
|
# set level of all other loggers to ERROR
|
|
for name in logging.root.manager.loggerDict:
|
|
if name != logger_name:
|
|
logging.getLogger(name).setLevel(logging.ERROR)
|
|
|
|
return logging.getLogger(logger_name)
|
|
|
|
@staticmethod
|
|
def _validate_session(*args, **kwargs):
|
|
email, username, password, session = args
|
|
|
|
# validate credentials
|
|
if all((email, username, password)):
|
|
return login(email, username, password, **kwargs)
|
|
|
|
# invalid credentials, try validating session
|
|
if session and all(session.cookies.get(c) for c in {'ct0', 'auth_token'}):
|
|
return session
|
|
|
|
# invalid credentials and session
|
|
cookies = kwargs.get('cookies')
|
|
|
|
# try validating cookies dict
|
|
if isinstance(cookies, dict) and all(cookies.get(c) for c in {'ct0', 'auth_token'}):
|
|
_session = Client(cookies=cookies, follow_redirects=True)
|
|
_session.headers.update(get_headers(_session))
|
|
return _session
|
|
|
|
# try validating cookies from file
|
|
if isinstance(cookies, str):
|
|
_session = Client(cookies=orjson.loads(Path(cookies).read_bytes()), follow_redirects=True)
|
|
_session.headers.update(get_headers(_session))
|
|
return _session
|
|
|
|
raise Exception('Session not authenticated. '
|
|
'Please use an authenticated session or remove the `session` argument and try again.')
|
|
|
|
@property
|
|
def id(self) -> int:
|
|
""" Get User ID """
|
|
return int(re.findall('"u=(\d+)"', self.session.cookies.get('twid'))[0])
|
|
|
|
def save_cookies(self, fname: str = None):
|
|
""" Save cookies to file """
|
|
cookies = self.session.cookies
|
|
Path(f'{fname or cookies.get("username")}.cookies').write_bytes(orjson.dumps(dict(cookies)))
|