diff --git a/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py b/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py index 91f9956..f4acddf 100644 --- a/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py +++ b/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py @@ -78,11 +78,11 @@ def _reply_comment(self, authorperm: str, body: str) -> str: if self.cache: self.cache.set(cache_key, "processed") logger.info(f"Reply comment {authorperm} success body: {body} result {result}") - read_steem_success_count.inc() + post_steem_success_count.inc() return f"Reply comment {authorperm} success body: {body}" except Exception as e: logger.error(f"Reply comment {authorperm} failed e : {str(e)}") - read_steem_failure_count.inc() + post_steem_failure_count.inc() logger.error(traceback.format_exc()) return f"Reply comment {authorperm} failed e : {str(e)}" diff --git a/packages/sunagent-ext/pyproject.toml b/packages/sunagent-ext/pyproject.toml index b3f9945..caa0597 100644 --- a/packages/sunagent-ext/pyproject.toml +++ b/packages/sunagent-ext/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sunagent-ext" -version = "0.0.7b1" +version = "0.0.7b2" license = {file = "LICENSE-CODE"} description = "AutoGen extensions library" readme = "README.md" @@ -19,7 +19,9 @@ dependencies = [ "aiohttp", "onepassword-sdk>=0.3.0", "pytz", + "nats-py==2.11.0", "redis", + "prometheus_client", "requests", ] diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py new file mode 100644 index 0000000..d4247ca --- /dev/null +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -0,0 +1,149 @@ +import asyncio +import json +import logging +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional + +import nats +from nats.aio.msg import Msg + +logger = logging.getLogger(__name__) + + +class TweetFromQueueContext: + # ---------- 新增哨兵 ---------- + _SENTINEL = object() # 用于通知 worker 立即退出 + + def __init__( + self, + *, + batch_size: int, + flush_seconds: float, + callback: Callable[[List[Dict[str, Any]]], Any], + nats_url: str, + subject: str, + user_ids: Optional[List[str]] = None, + ): + if batch_size <= 0 or flush_seconds <= 0: + raise ValueError("batch_size / flush_seconds 必须为正") + self.batch_size = batch_size + self.flush_seconds = flush_seconds + self.callback = callback + self.nats_url = nats_url + self.subject = subject + self.user_ids = set(user_ids) if user_ids else None + + self._nc: Optional[nats.NATS] = None # type: ignore[name-defined] + self._sub = None + # 队列元素现在是 Dict | sentinel + self._queue: asyncio.Queue[Any] = asyncio.Queue() + self._stop_evt = asyncio.Event() + self._worker_task: Optional[asyncio.Task] = None # type: ignore[type-arg] + + # -------------------- 生命周期 -------------------- + async def start(self) -> None: + self._nc = await nats.connect(self.nats_url) + self._sub = await self._nc.subscribe(self.subject, cb=self._on_msg) # type: ignore[assignment] + logger.info("Subscribed to <%s>, filter=%s", self.subject, self.user_ids) + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def stop(self) -> None: + logger.info("Stopping AsyncBatchingQueue...") + + # 1. 立即停止接收新消息 + if self._sub: + await self._sub.unsubscribe() + + # 2. 通知 worker 退出并等待它刷完剩余 + self._stop_evt.set() + await self._queue.put(self._SENTINEL) # 让 worker 从 queue.get() 立即返回 + if self._worker_task: + await self._worker_task # 内部已 _drain_remaining() + + # 3. 关闭 NATS 连接 + if self._nc: + await self._nc.close() + logger.info("AsyncBatchingQueue stopped") + + async def __aenter__(self): # type: ignore[no-untyped-def] + await self.start() + return self + + async def __aexit__(self, exc_type, exc, tb): # type: ignore[no-untyped-def] + await self.stop() + + # -------------------- 公共入口 -------------------- + async def add(self, item: Dict[str, Any]) -> None: + await self._queue.put(item) + + # -------------------- NATS 回调 -------------------- + async def _on_msg(self, msg: Msg) -> None: + try: + tweet = json.loads(msg.data.decode()) + tweet = self._fix_tweet_dict(tweet) + if self.user_ids and tweet.get("author_id") not in self.user_ids: + return + except Exception as e: + logger.exception("Bad msg: %s", e) + return + await self.add(tweet) + + # -------------------- 核心 worker -------------------- + async def _worker_loop(self) -> None: + """单协程:永久阻塞等第一条 -> 设 deadline -> 超时/满批刷 -> 收到哨兵退出""" + while not self._stop_evt.is_set(): + batch: List[Dict[str, Any]] = [] + # 1. 永久阻塞等第一条(CPU 不再空转) + first = await self._queue.get() + if first is self._SENTINEL: # 收到哨兵直接退出 + break + batch.append(first) + deadline = asyncio.get_event_loop().time() + self.flush_seconds + + # 2. 收集剩余 + while len(batch) < self.batch_size and not self._stop_evt.is_set(): + remaining = deadline - asyncio.get_event_loop().time() + if remaining <= 0: + break + try: + item = await asyncio.wait_for(self._queue.get(), timeout=remaining) + if item is self._SENTINEL: # 哨兵提前结束收集 + break + batch.append(item) + except asyncio.TimeoutError: + break + + # 3. 处理 + if batch: + try: + await self.callback(batch) + except Exception as e: + logger.exception("Callback error: %s", e) + + # 4. 退出时刷剩余 + await self._drain_remaining() + + async def _flush(self) -> None: + """同步刷剩余(给 _drain_remaining 用)""" + batch = [] + while not self._queue.empty(): + item = self._queue.get_nowait() + if item is self._SENTINEL: # 忽略哨兵 + continue + batch.append(item) + if batch: + try: + await self.callback(batch) + except Exception as e: + logger.exception("Final callback error: %s", e) + + async def _drain_remaining(self) -> None: + await self._flush() + + @staticmethod + def _fix_tweet_dict(msg: Dict[str, Any]) -> Dict[str, Any]: + fixed = msg.copy() + for key in ("created_at", "updated_at"): + if key in fixed and isinstance(fixed[key], str): + fixed[key] = datetime.fromisoformat(fixed[key]) + return fixed diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py index 8958366..162200a 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py @@ -5,11 +5,13 @@ import asyncio import logging +import traceback from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, List, Optional, cast +import tweepy from prometheus_client import Counter, Gauge -from tweepy import Media, NotFound, TwitterServerError, User # 保持原类型注解 +from tweepy import Media, NotFound, Tweet, TwitterServerError, User # 保持原类型注解 from tweepy import Response as TwitterResponse from sunagent_ext.tweet.twitter_client_pool import TwitterClientPool @@ -223,7 +225,7 @@ async def _fetch_timeline( read_tweet_success_count.labels(client_key=client_key).inc(len(resp.data or [])) # 交给中间层处理 - tweet_list, next_token = await self.on_twitter_response(agent_id, resp, filter_func) + tweet_list, next_token = await self.on_twitter_response(agent_id, me_id, resp, filter_func) all_raw.extend(tweet_list) if not next_token: break @@ -254,6 +256,7 @@ async def _fetch_timeline( async def on_twitter_response( # type: ignore[no-any-unimported] self, agent_id: str, + me_id: str, response: TwitterResponse, filter_func: Callable[[Dict[str, Any]], bool], ) -> tuple[list[Dict[str, Any]], Optional[str]]: @@ -267,20 +270,23 @@ async def on_twitter_response( # type: ignore[no-any-unimported] out: list[Dict[str, Any]] = [] for tweet in all_tweets: - if not await self._should_keep(agent_id, tweet, filter_func): + if not await self._should_keep(agent_id, me_id, tweet, filter_func): continue norm = await self._normalize_tweet(tweet) out.append(norm) return out, next_token async def _should_keep( - self, agent_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool] + self, agent_id: str, me_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool] ) -> bool: is_processed = await self._check_tweet_process(tweet["id"], agent_id) if is_processed: logger.info("already processed %s", tweet["id"]) return False author_id = str(tweet["author_id"]) + if me_id == author_id: + logger.info("skip my tweet %s", tweet["id"]) + return False if author_id in self.block_uids: logger.info("blocked user %s", author_id) return False @@ -366,6 +372,91 @@ async def _recursive_fetch(self, tweet: Dict[str, Any], chain: list[Dict[str, An await self._recursive_fetch(parent, chain, depth + 1) chain.append(tweet) + async def fetch_new_tweets_manual_( # type: ignore[no-any-unimported] + self, + ids: List[str], + last_seen_id: str | None = None, + ) -> tuple[List[Tweet], str | None]: + """ + 1. 取所有 ALIVE user 的 twitter_id + 2. 将 id 列表拆分成多条不超长 query + 3. 逐条交给 fetch_new_tweets_manual_tweets 翻页 + 4. 返回全部结果以及 **所有结果中最大的 tweet_id** + """ + BASE_EXTRA = " -is:retweet" + max_len = 512 - len(BASE_EXTRA) - 10 + queries: List[str] = [] + + buf: list[str] = [] + for uid in ids: + clause = f"from:{uid}" + if len(" OR ".join(buf + [clause])) > max_len: + queries.append(" OR ".join(buf) + BASE_EXTRA) + buf = [clause] + else: + buf.append(clause) + if buf: + queries.append(" OR ".join(buf) + BASE_EXTRA) + # 3) 逐条调用内层并合并 + all_tweets: List[tweepy.Tweet] = [] # type: ignore[no-any-unimported] + for q in queries: + tweets = await self.fetch_new_tweets_manual_tweets(query=q, last_seen_id=last_seen_id) + all_tweets.extend(tweets) + + # 4) 取所有结果中最大的 id 作为 last_seen_id + last_id = max((tw.id for tw in all_tweets), default=None) + return all_tweets, last_id + + async def get_user_tweet(self, user_ids: List[str]) -> List[Tweet]: # type: ignore[no-any-unimported] + cache_key = "user_last_seen_id" + last_seen_id = self.cache.get(cache_key) + tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=user_ids, last_seen_id=last_seen_id) + logger.info(f"get_user_tweet tweets: {len(tweets)} last_seen_id: {last_seen_id}") + if last_seen_id: + self.cache.set(cache_key, last_seen_id) + return tweets + + async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported] + self, query: str, last_seen_id: str | None = None, max_per_page: int = 100, hours: int = 24 + ) -> List[Tweet]: + tweets: List[Any] = [] + next_token = None + since = datetime.now(timezone.utc) - timedelta(hours=hours) + start_time = None if last_seen_id else since.isoformat(timespec="seconds") + logger.info(f"query: {query}") + while True: + cli, key = None, "" + try: + cli, key = await self.pool.acquire() + resp = cli.search_recent_tweets( + query=query, + start_time=start_time, + since_id=last_seen_id, + max_results=max_per_page, + tweet_fields=TWEET_FIELDS, + next_token=next_token, + user_auth=True, + ) + page_data = resp.data or [] + logger.info(f"page_data: {len(page_data)}") + tweets.extend(page_data) + read_tweet_success_count.labels(client_key=key).inc(len(resp.data or [])) + next_token = resp.meta.get("next_token") + if not next_token: + break + except tweepy.TooManyRequests: + logger.error(traceback.format_exc()) + read_tweet_failure_count.labels(client_key=key).inc() + if cli: + await self.pool.report_failure(cli) + return tweets + except tweepy.TweepyException: + if cli: + await self.pool.report_failure(cli) + logger.error(traceback.format_exc()) + return tweets + return tweets + async def _get_tweet_with_retry(self, tweet_id: str) -> Optional[Dict[str, Any]]: for attempt in range(3): cli, client_key = await self.pool.acquire() diff --git a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py index 8d2a57f..d831c29 100644 --- a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py +++ b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py @@ -100,7 +100,7 @@ def request( cookies: _CookiesType = None, files: _FilesType = None, auth: _AuthType = None, - timeout: _TimeoutType = None, + timeout: _TimeoutType = None, # type: ignore[override] allow_redirects: bool = True, proxies: _TextMapping | None = None, hooks: _HooksType = None,