From ae12627cfb6d451d08a513831c4ad32c3666e25e Mon Sep 17 00:00:00 2001 From: bobo Date: Wed, 10 Sep 2025 18:44:27 +0800 Subject: [PATCH 01/16] feat: update tweet get --- packages/sunagent-ext/pyproject.toml | 1 + .../sunagent_ext/tweet/twitter_get_context.py | 107 +++++++++++++++++- 2 files changed, 105 insertions(+), 3 deletions(-) diff --git a/packages/sunagent-ext/pyproject.toml b/packages/sunagent-ext/pyproject.toml index b3f9945..01b4d53 100644 --- a/packages/sunagent-ext/pyproject.toml +++ b/packages/sunagent-ext/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "onepassword-sdk>=0.3.0", "pytz", "redis", + "prometheus_client", "requests", ] diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py index 8958366..9723128 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py @@ -5,12 +5,14 @@ import asyncio import logging +import traceback from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, List, Optional, cast from prometheus_client import Counter, Gauge from tweepy import Media, NotFound, TwitterServerError, User # 保持原类型注解 from tweepy import Response as TwitterResponse +import tweepy from sunagent_ext.tweet.twitter_client_pool import TwitterClientPool @@ -223,7 +225,7 @@ async def _fetch_timeline( read_tweet_success_count.labels(client_key=client_key).inc(len(resp.data or [])) # 交给中间层处理 - tweet_list, next_token = await self.on_twitter_response(agent_id, resp, filter_func) + tweet_list, next_token = await self.on_twitter_response(agent_id, me_id, resp, filter_func) all_raw.extend(tweet_list) if not next_token: break @@ -254,6 +256,7 @@ async def _fetch_timeline( async def on_twitter_response( # type: ignore[no-any-unimported] self, agent_id: str, + me_id: str, response: TwitterResponse, filter_func: Callable[[Dict[str, Any]], bool], ) -> tuple[list[Dict[str, Any]], Optional[str]]: @@ -267,20 +270,23 @@ async def on_twitter_response( # type: ignore[no-any-unimported] out: list[Dict[str, Any]] = [] for tweet in all_tweets: - if not await self._should_keep(agent_id, tweet, filter_func): + if not await self._should_keep(agent_id, me_id, tweet, filter_func): continue norm = await self._normalize_tweet(tweet) out.append(norm) return out, next_token async def _should_keep( - self, agent_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool] + self, agent_id: str, me_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool] ) -> bool: is_processed = await self._check_tweet_process(tweet["id"], agent_id) if is_processed: logger.info("already processed %s", tweet["id"]) return False author_id = str(tweet["author_id"]) + if me_id == author_id: + logger.info("skip my tweet %s", tweet["id"]) + return False if author_id in self.block_uids: logger.info("blocked user %s", author_id) return False @@ -366,6 +372,101 @@ async def _recursive_fetch(self, tweet: Dict[str, Any], chain: list[Dict[str, An await self._recursive_fetch(parent, chain, depth + 1) chain.append(tweet) + async def fetch_new_tweets_manual_( self, + ids: List[str], + last_seen_id: str | None = None, + ): + """ + 1. 取所有 ALIVE KOL 的 twitter_id + 2. 将 id 列表拆分成多条不超长 query + 3. 逐条交给 fetch_new_tweets_manual_tweets 翻页 + 4. 返回全部结果以及 **所有结果中最大的 tweet_id** + """ + BASE_EXTRA = " -is:retweet" + max_len = 512 - len(BASE_EXTRA) - 10 + queries: List[str] = [] + + buf, first = [], True + for uid in ids: + clause = f"from:{uid}" + if len(" OR ".join(buf + [clause])) > max_len: + queries.append(" OR ".join(buf) + BASE_EXTRA) + buf, first = [clause], True + else: + buf.append(clause) + first = False + if buf: + queries.append(" OR ".join(buf) + BASE_EXTRA) + + # 3) 逐条调用内层并合并 + all_tweets: List[tweepy.Tweet] = [] + for q in queries: + tweets = await self.fetch_new_tweets_manual_tweets( + query=q, + last_seen_id=last_seen_id + ) + all_tweets.extend(tweets) + await asyncio.sleep(30) + + # 4) 取所有结果中最大的 id 作为 last_seen_id + last_id = max((tw.id for tw in all_tweets), default=None) + return all_tweets, last_id + + async def get_kol_tweet(self, kol_ids: List[str]): + cache_key = "kol_last_seen_id" + last_seen_id = await self.cache.get(cache_key) + tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=kol_ids, last_seen_id=last_seen_id) + await self.cache.set(cache_key, last_seen_id) + return tweets + + async def fetch_new_tweets_manual_tweets( + self, + query: str, + last_seen_id: str | None = None, + max_per_page: int = 100, + hours: int = 24 + ): + tweets = [] + next_token = None + + since = datetime.now(timezone.utc) - timedelta(hours=hours) + start_time = None if last_seen_id else since.isoformat(timespec="seconds") + logger.info(f"query: {query}") + while True: + cli = None + try: + cli, key = await self.pool.acquire() + resp = cli.search_recent_tweets( + query=query, + start_time=start_time, + since_id=last_seen_id, + max_results=max_per_page, + tweet_fields=TWEET_FIELDS, + next_token=next_token, + user_auth=True + ) + page_data = resp.data or [] + logger.info(f"page_data: {len(page_data)}") + for tw in page_data: + # 1. 已读过的直接停 + tweets.append(tw) + read_tweet_success_count.inc() + next_token = resp.meta.get("next_token") + if not next_token: + break + except tweepy.TooManyRequests as e: + logger.error(traceback.format_exc()) + read_tweet_failure_count.inc() + if cli: + await self.pool.report_failure(cli) + return tweets + except tweepy.TweepyException as e: + if cli: + await self.pool.report_failure(cli) + logger.error(traceback.format_exc()) + return tweets + return tweets + async def _get_tweet_with_retry(self, tweet_id: str) -> Optional[Dict[str, Any]]: for attempt in range(3): cli, client_key = await self.pool.acquire() From 6d8418f3a4e100abdd160df9cd138a3a31f21fe4 Mon Sep 17 00:00:00 2001 From: bobo Date: Thu, 11 Sep 2025 13:56:36 +0800 Subject: [PATCH 02/16] feat: update tweet get --- .../sunagent_ext/tweet/twitter_get_context.py | 52 ++++++++----------- 1 file changed, 22 insertions(+), 30 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py index 9723128..90abb7a 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py @@ -9,10 +9,10 @@ from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, List, Optional, cast +import tweepy from prometheus_client import Counter, Gauge -from tweepy import Media, NotFound, TwitterServerError, User # 保持原类型注解 +from tweepy import Media, NotFound, Tweet, TwitterServerError, User # 保持原类型注解 from tweepy import Response as TwitterResponse -import tweepy from sunagent_ext.tweet.twitter_client_pool import TwitterClientPool @@ -277,7 +277,7 @@ async def on_twitter_response( # type: ignore[no-any-unimported] return out, next_token async def _should_keep( - self, agent_id: str, me_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool] + self, agent_id: str, me_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool] ) -> bool: is_processed = await self._check_tweet_process(tweet["id"], agent_id) if is_processed: @@ -372,10 +372,11 @@ async def _recursive_fetch(self, tweet: Dict[str, Any], chain: list[Dict[str, An await self._recursive_fetch(parent, chain, depth + 1) chain.append(tweet) - async def fetch_new_tweets_manual_( self, - ids: List[str], - last_seen_id: str | None = None, - ): + async def fetch_new_tweets_manual_( # type: ignore[no-any-unimported] + self, + ids: List[str], + last_seen_id: str | None = None, + ) -> tuple[List[Tweet], str | None]: """ 1. 取所有 ALIVE KOL 的 twitter_id 2. 将 id 列表拆分成多条不超长 query @@ -386,25 +387,20 @@ async def fetch_new_tweets_manual_( self, max_len = 512 - len(BASE_EXTRA) - 10 queries: List[str] = [] - buf, first = [], True + buf: list[str] = [] for uid in ids: clause = f"from:{uid}" if len(" OR ".join(buf + [clause])) > max_len: queries.append(" OR ".join(buf) + BASE_EXTRA) - buf, first = [clause], True + buf = [clause] else: buf.append(clause) - first = False if buf: queries.append(" OR ".join(buf) + BASE_EXTRA) - # 3) 逐条调用内层并合并 - all_tweets: List[tweepy.Tweet] = [] + all_tweets: List[tweepy.Tweet] = [] # type: ignore[no-any-unimported] for q in queries: - tweets = await self.fetch_new_tweets_manual_tweets( - query=q, - last_seen_id=last_seen_id - ) + tweets = await self.fetch_new_tweets_manual_tweets(query=q, last_seen_id=last_seen_id) all_tweets.extend(tweets) await asyncio.sleep(30) @@ -412,20 +408,16 @@ async def fetch_new_tweets_manual_( self, last_id = max((tw.id for tw in all_tweets), default=None) return all_tweets, last_id - async def get_kol_tweet(self, kol_ids: List[str]): + async def get_kol_tweet(self, kol_ids: List[str]) -> List[Tweet]: # type: ignore[no-any-unimported] cache_key = "kol_last_seen_id" - last_seen_id = await self.cache.get(cache_key) + last_seen_id = self.cache.get(cache_key) tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=kol_ids, last_seen_id=last_seen_id) await self.cache.set(cache_key, last_seen_id) return tweets - async def fetch_new_tweets_manual_tweets( - self, - query: str, - last_seen_id: str | None = None, - max_per_page: int = 100, - hours: int = 24 - ): + async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported] + self, query: str, last_seen_id: str | None = None, max_per_page: int = 100, hours: int = 24 + ) -> List[Tweet]: tweets = [] next_token = None @@ -443,24 +435,24 @@ async def fetch_new_tweets_manual_tweets( max_results=max_per_page, tweet_fields=TWEET_FIELDS, next_token=next_token, - user_auth=True + user_auth=True, ) page_data = resp.data or [] logger.info(f"page_data: {len(page_data)}") for tw in page_data: # 1. 已读过的直接停 tweets.append(tw) - read_tweet_success_count.inc() + read_tweet_success_count.labels(client_key=key).inc(len(resp.data or [])) next_token = resp.meta.get("next_token") if not next_token: break - except tweepy.TooManyRequests as e: + except tweepy.TooManyRequests: logger.error(traceback.format_exc()) - read_tweet_failure_count.inc() + read_tweet_failure_count.labels(client_key=key).inc() if cli: await self.pool.report_failure(cli) return tweets - except tweepy.TweepyException as e: + except tweepy.TweepyException: if cli: await self.pool.report_failure(cli) logger.error(traceback.format_exc()) From 6adc2d62daf18c02de2c0fa86fd757afdfc62b9e Mon Sep 17 00:00:00 2001 From: bobo Date: Thu, 11 Sep 2025 16:17:37 +0800 Subject: [PATCH 03/16] feat: add tweet hub client --- .../agents/_steemit_context_builder_agent.py | 4 +- packages/sunagent-ext/pyproject.toml | 2 +- .../sunagent_ext/tweet/tweet_hub_client.py | 79 +++++++++++++++++++ .../sunagent_ext/tweet/twitter_get_context.py | 6 +- 4 files changed, 85 insertions(+), 6 deletions(-) create mode 100644 packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py diff --git a/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py b/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py index 91f9956..f4acddf 100644 --- a/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py +++ b/packages/sunagent-app/src/sunagent_app/agents/_steemit_context_builder_agent.py @@ -78,11 +78,11 @@ def _reply_comment(self, authorperm: str, body: str) -> str: if self.cache: self.cache.set(cache_key, "processed") logger.info(f"Reply comment {authorperm} success body: {body} result {result}") - read_steem_success_count.inc() + post_steem_success_count.inc() return f"Reply comment {authorperm} success body: {body}" except Exception as e: logger.error(f"Reply comment {authorperm} failed e : {str(e)}") - read_steem_failure_count.inc() + post_steem_failure_count.inc() logger.error(traceback.format_exc()) return f"Reply comment {authorperm} failed e : {str(e)}" diff --git a/packages/sunagent-ext/pyproject.toml b/packages/sunagent-ext/pyproject.toml index 01b4d53..2d80af7 100644 --- a/packages/sunagent-ext/pyproject.toml +++ b/packages/sunagent-ext/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "sunagent-ext" -version = "0.0.7b1" +version = "0.0.7b2" license = {file = "LICENSE-CODE"} description = "AutoGen extensions library" readme = "README.md" diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py new file mode 100644 index 0000000..7a2ed4f --- /dev/null +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py @@ -0,0 +1,79 @@ +# kol_sdk.py +import json +import logging +from datetime import datetime +from typing import Any, Dict, List + +import requests +from requests.adapters import HTTPAdapter +from urllib3.util.retry import Retry + +logger = logging.getLogger("sunagent-ext") + +DEFAULT_TIMEOUT = 10 # 秒 + + +class TweetHubClient: + """ + 轻量级 requests 封装,支持: + 1. 创建 KOL 列表 + 2. 删除 KOL 列表 + 3. 自动重试 + 超时 + 4. 中文友好(ensure_ascii=False) + """ + + def __init__(self, base_url: str, agent_id: str = "hub", timeout: int = DEFAULT_TIMEOUT): + """ + :param base_url: 不含尾巴 / ,例:http://127.0.0.1:8084/api/sun + :param agent_id: 默认 agent_id + :param timeout: 单次请求超时 + """ + self.base_url = base_url.rstrip("/") + self.agent_id = agent_id + self.timeout = timeout + self._session = requests.Session() + + # 重试策略:3 次、backoff、状态码 5xx/429 + retry = Retry( + total=3, + backoff_factor=0.3, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods={"POST", "DELETE"}, + ) + self._session.mount("http://", HTTPAdapter(max_retries=retry)) + self._session.mount("https://", HTTPAdapter(max_retries=retry)) + + # ---------- 内部方法 ---------- + def _request(self, method: str, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]: + url = f"{self.base_url}{endpoint}" + # 自动注入 agent_id + payload.setdefault("agent_id", self.agent_id) + try: + resp = self._session.request( + method=method, + url=url, + json=payload, # requests 会自动用 utf-8 编码 + timeout=self.timeout, + ) + resp.raise_for_status() + return resp.json() # type: ignore[no-any-return] + except requests.exceptions.RequestException as e: + logger.error(f"[KolClient] {method} {url} error: {e}") + raise + + # ---------- 业务接口 ---------- + def create_kol(self, kol_list: List[str], agent_id: str | None = None) -> Dict[str, Any]: + """ + 批量创建/绑定 KOL + :param kol_list: twitter_id 列表 + :param agent_id: 可选,不传使用实例默认值 + """ + payload = {"kol_list": kol_list, "agent_id": agent_id or self.agent_id} + return self._request("POST", "/kol", payload) + + def delete_kol(self, kol_list: List[str], agent_id: str | None = None) -> Dict[str, Any]: + """ + 批量删除/解绑 KOL + """ + payload = {"kol_list": kol_list, "agent_id": agent_id or self.agent_id} + return self._request("DELETE", "/kol", payload) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py index 90abb7a..b969988 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py @@ -402,7 +402,6 @@ async def fetch_new_tweets_manual_( # type: ignore[no-any-unimported] for q in queries: tweets = await self.fetch_new_tweets_manual_tweets(query=q, last_seen_id=last_seen_id) all_tweets.extend(tweets) - await asyncio.sleep(30) # 4) 取所有结果中最大的 id 作为 last_seen_id last_id = max((tw.id for tw in all_tweets), default=None) @@ -412,7 +411,9 @@ async def get_kol_tweet(self, kol_ids: List[str]) -> List[Tweet]: # type: ignor cache_key = "kol_last_seen_id" last_seen_id = self.cache.get(cache_key) tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=kol_ids, last_seen_id=last_seen_id) - await self.cache.set(cache_key, last_seen_id) + logger.info(f"get_kol_tweet tweets: {len(tweets)} last_seen_id: {last_seen_id}") + if last_seen_id: + self.cache.set(cache_key, last_seen_id) return tweets async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported] @@ -420,7 +421,6 @@ async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported] ) -> List[Tweet]: tweets = [] next_token = None - since = datetime.now(timezone.utc) - timedelta(hours=hours) start_time = None if last_seen_id else since.isoformat(timespec="seconds") logger.info(f"query: {query}") From 9d04d15a4eb7435e002a833839bee92f4576a6a8 Mon Sep 17 00:00:00 2001 From: bobo Date: Thu, 11 Sep 2025 18:15:42 +0800 Subject: [PATCH 04/16] feat: add tweet from queue --- packages/sunagent-ext/pyproject.toml | 1 + .../sunagent_ext/tweet/tweet_from_queue.py | 113 ++++++++++++++++++ .../sunagent_ext/tweet/tweet_hub_client.py | 79 ------------ 3 files changed, 114 insertions(+), 79 deletions(-) create mode 100644 packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py delete mode 100644 packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py diff --git a/packages/sunagent-ext/pyproject.toml b/packages/sunagent-ext/pyproject.toml index 2d80af7..caa0597 100644 --- a/packages/sunagent-ext/pyproject.toml +++ b/packages/sunagent-ext/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "aiohttp", "onepassword-sdk>=0.3.0", "pytz", + "nats-py==2.11.0", "redis", "prometheus_client", "requests", diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py new file mode 100644 index 0000000..ec66025 --- /dev/null +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -0,0 +1,113 @@ +# tweet_from_queue.py +import asyncio +import json +import logging +from datetime import datetime +from typing import Any, Callable, List + +import nats +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from nats.aio.msg import Msg + +logger = logging.getLogger(__name__) + + +class TweetFromQueueContext: + """ + 1. 订阅 NATS subject,累积推文 + 2. APScheduler 每 10 秒强制 flush 一次(不管有没有消息) + 3. 支持优雅关闭 + """ + + def __init__( + self, + size: int, + agent_id: str, + nats_url: str, + callback: Callable[[List[dict[str, Any]]], Any], + flush_seconds: int = 10, + ): + self.size = size + self.agent_id = agent_id + self.nats_url = nats_url + self.flush_seconds = flush_seconds + + self._nc = None + self._buffer: List[dict[str, Any]] = [] + self._callback: Callable[[List[dict[str, Any]]], Any] = callback + self._scheduler: AsyncIOScheduler | None = None # type: ignore[no-any-unimported] + self._lock = asyncio.Lock() + + # -------------------- 生命周期 -------------------- + async def start(self, subject: str) -> None: + # 1. 连接 NATS + self._nc = await nats.connect(self.nats_url) # type: ignore[assignment] + await self._nc.subscribe(subject, cb=self._nc_consume) # type: ignore[attr-defined] + logger.info("Subscribed to <%s>, agent=%s", subject, self.agent_id) + + # 2. 启动 APScheduler 定时 flush + self._scheduler = AsyncIOScheduler() + self._scheduler.add_job( + self._flush_wrap, # 定时执行的协程 + trigger="interval", + seconds=self.flush_seconds, + max_instances=1, + next_run_time=datetime.now(), # 立即执行一次 + ) + self._scheduler.start() + + async def close(self) -> None: + """优雅关闭:停调度器 -> 刷尾数据 -> 断 NATS""" + if self._scheduler: + self._scheduler.shutdown(wait=False) + + async with self._lock: + await self._flush() + + if self._nc: + await self._nc.close() + logger.info("NATS connection closed") + + # -------------------- NATS 回调 -------------------- + async def _nc_consume(self, msg: Msg) -> None: + try: + tweet = json.loads(msg.data.decode()) + tweet = self._fix_tweet_dict(tweet) + except Exception as e: + logger.exception("Bad message: %s", e) + return + + if self.agent_id not in tweet.get("agent_ids", []): + return + + async with self._lock: + self._buffer.append(tweet) + # 如果满了也立即刷 + if len(self._buffer) >= self.size: + await self._flush() + + # -------------------- 定时刷新包装 -------------------- + async def _flush_wrap(self) -> None: + """APScheduler 只支持调用普通函数,这里包一层协程""" + async with self._lock: + await self._flush() + + # -------------------- 真正刷新 -------------------- + async def _flush(self) -> None: + if not self._buffer: + return + logger.info("Flush %d tweets", len(self._buffer)) + try: + await self._callback(self._buffer.copy()) + except Exception as e: + logger.exception("Callback error: %s", e) + self._buffer.clear() + + # -------------------- 工具 -------------------- + @staticmethod + def _fix_tweet_dict(msg: dict[str, Any]) -> dict[str, Any]: + fixed = msg.copy() + for key in ("created_at", "updated_at"): + if key in fixed and isinstance(fixed[key], str): + fixed[key] = datetime.fromisoformat(fixed[key]) + return fixed diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py deleted file mode 100644 index 7a2ed4f..0000000 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_hub_client.py +++ /dev/null @@ -1,79 +0,0 @@ -# kol_sdk.py -import json -import logging -from datetime import datetime -from typing import Any, Dict, List - -import requests -from requests.adapters import HTTPAdapter -from urllib3.util.retry import Retry - -logger = logging.getLogger("sunagent-ext") - -DEFAULT_TIMEOUT = 10 # 秒 - - -class TweetHubClient: - """ - 轻量级 requests 封装,支持: - 1. 创建 KOL 列表 - 2. 删除 KOL 列表 - 3. 自动重试 + 超时 - 4. 中文友好(ensure_ascii=False) - """ - - def __init__(self, base_url: str, agent_id: str = "hub", timeout: int = DEFAULT_TIMEOUT): - """ - :param base_url: 不含尾巴 / ,例:http://127.0.0.1:8084/api/sun - :param agent_id: 默认 agent_id - :param timeout: 单次请求超时 - """ - self.base_url = base_url.rstrip("/") - self.agent_id = agent_id - self.timeout = timeout - self._session = requests.Session() - - # 重试策略:3 次、backoff、状态码 5xx/429 - retry = Retry( - total=3, - backoff_factor=0.3, - status_forcelist=[429, 500, 502, 503, 504], - allowed_methods={"POST", "DELETE"}, - ) - self._session.mount("http://", HTTPAdapter(max_retries=retry)) - self._session.mount("https://", HTTPAdapter(max_retries=retry)) - - # ---------- 内部方法 ---------- - def _request(self, method: str, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]: - url = f"{self.base_url}{endpoint}" - # 自动注入 agent_id - payload.setdefault("agent_id", self.agent_id) - try: - resp = self._session.request( - method=method, - url=url, - json=payload, # requests 会自动用 utf-8 编码 - timeout=self.timeout, - ) - resp.raise_for_status() - return resp.json() # type: ignore[no-any-return] - except requests.exceptions.RequestException as e: - logger.error(f"[KolClient] {method} {url} error: {e}") - raise - - # ---------- 业务接口 ---------- - def create_kol(self, kol_list: List[str], agent_id: str | None = None) -> Dict[str, Any]: - """ - 批量创建/绑定 KOL - :param kol_list: twitter_id 列表 - :param agent_id: 可选,不传使用实例默认值 - """ - payload = {"kol_list": kol_list, "agent_id": agent_id or self.agent_id} - return self._request("POST", "/kol", payload) - - def delete_kol(self, kol_list: List[str], agent_id: str | None = None) -> Dict[str, Any]: - """ - 批量删除/解绑 KOL - """ - payload = {"kol_list": kol_list, "agent_id": agent_id or self.agent_id} - return self._request("DELETE", "/kol", payload) From 08e7b34658b54eeffc127221a092675322f21b99 Mon Sep 17 00:00:00 2001 From: bobo Date: Thu, 11 Sep 2025 18:23:17 +0800 Subject: [PATCH 05/16] feat: format kol --- .../sunagent_ext/tweet/twitter_get_context.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py index b969988..162200a 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/twitter_get_context.py @@ -378,7 +378,7 @@ async def fetch_new_tweets_manual_( # type: ignore[no-any-unimported] last_seen_id: str | None = None, ) -> tuple[List[Tweet], str | None]: """ - 1. 取所有 ALIVE KOL 的 twitter_id + 1. 取所有 ALIVE user 的 twitter_id 2. 将 id 列表拆分成多条不超长 query 3. 逐条交给 fetch_new_tweets_manual_tweets 翻页 4. 返回全部结果以及 **所有结果中最大的 tweet_id** @@ -407,11 +407,11 @@ async def fetch_new_tweets_manual_( # type: ignore[no-any-unimported] last_id = max((tw.id for tw in all_tweets), default=None) return all_tweets, last_id - async def get_kol_tweet(self, kol_ids: List[str]) -> List[Tweet]: # type: ignore[no-any-unimported] - cache_key = "kol_last_seen_id" + async def get_user_tweet(self, user_ids: List[str]) -> List[Tweet]: # type: ignore[no-any-unimported] + cache_key = "user_last_seen_id" last_seen_id = self.cache.get(cache_key) - tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=kol_ids, last_seen_id=last_seen_id) - logger.info(f"get_kol_tweet tweets: {len(tweets)} last_seen_id: {last_seen_id}") + tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=user_ids, last_seen_id=last_seen_id) + logger.info(f"get_user_tweet tweets: {len(tweets)} last_seen_id: {last_seen_id}") if last_seen_id: self.cache.set(cache_key, last_seen_id) return tweets @@ -419,13 +419,13 @@ async def get_kol_tweet(self, kol_ids: List[str]) -> List[Tweet]: # type: ignor async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported] self, query: str, last_seen_id: str | None = None, max_per_page: int = 100, hours: int = 24 ) -> List[Tweet]: - tweets = [] + tweets: List[Any] = [] next_token = None since = datetime.now(timezone.utc) - timedelta(hours=hours) start_time = None if last_seen_id else since.isoformat(timespec="seconds") logger.info(f"query: {query}") while True: - cli = None + cli, key = None, "" try: cli, key = await self.pool.acquire() resp = cli.search_recent_tweets( @@ -439,9 +439,7 @@ async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported] ) page_data = resp.data or [] logger.info(f"page_data: {len(page_data)}") - for tw in page_data: - # 1. 已读过的直接停 - tweets.append(tw) + tweets.extend(page_data) read_tweet_success_count.labels(client_key=key).inc(len(resp.data or [])) next_token = resp.meta.get("next_token") if not next_token: From c8b06abf5c09f105057a5f0a77b7e0f12f3b351a Mon Sep 17 00:00:00 2001 From: bobo Date: Thu, 11 Sep 2025 20:46:47 +0800 Subject: [PATCH 06/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sunagent_ext/tweet/tweet_from_queue.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index ec66025..b2ee962 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -22,13 +22,13 @@ class TweetFromQueueContext: def __init__( self, size: int, - agent_id: str, + user_ids: List[str], nats_url: str, callback: Callable[[List[dict[str, Any]]], Any], flush_seconds: int = 10, ): self.size = size - self.agent_id = agent_id + self.user_ids = user_ids self.nats_url = nats_url self.flush_seconds = flush_seconds @@ -43,7 +43,7 @@ async def start(self, subject: str) -> None: # 1. 连接 NATS self._nc = await nats.connect(self.nats_url) # type: ignore[assignment] await self._nc.subscribe(subject, cb=self._nc_consume) # type: ignore[attr-defined] - logger.info("Subscribed to <%s>, agent=%s", subject, self.agent_id) + logger.info("Subscribed to <%s>, agent=%s", subject, self.user_ids) # 2. 启动 APScheduler 定时 flush self._scheduler = AsyncIOScheduler() @@ -73,13 +73,12 @@ async def _nc_consume(self, msg: Msg) -> None: try: tweet = json.loads(msg.data.decode()) tweet = self._fix_tweet_dict(tweet) + if tweet["author_id"] not in self.user_ids: + return except Exception as e: logger.exception("Bad message: %s", e) return - if self.agent_id not in tweet.get("agent_ids", []): - return - async with self._lock: self._buffer.append(tweet) # 如果满了也立即刷 @@ -96,12 +95,16 @@ async def _flush_wrap(self) -> None: async def _flush(self) -> None: if not self._buffer: return - logger.info("Flush %d tweets", len(self._buffer)) + # 1. 锁内只做“拷贝 + 清空” + async with self._lock: + to_send = self._buffer.copy() + self._buffer.clear() + # 2. 锁外执行回调,避免阻塞后续入队 / 定时 flush + logger.info("Flush %d tweets", len(to_send)) try: - await self._callback(self._buffer.copy()) + await self._callback(to_send) except Exception as e: logger.exception("Callback error: %s", e) - self._buffer.clear() # -------------------- 工具 -------------------- @staticmethod From 69d730a85f981bcb5f1d6110b8f3c60724eafe0c Mon Sep 17 00:00:00 2001 From: bobo Date: Fri, 12 Sep 2025 11:04:30 +0800 Subject: [PATCH 07/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/sunagent_ext/tweet/tweet_from_queue.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index b2ee962..7ee2a47 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -8,6 +8,7 @@ import nats from apscheduler.schedulers.asyncio import AsyncIOScheduler from nats.aio.msg import Msg +from nats.aio.subscription import Subscription logger = logging.getLogger(__name__) @@ -35,6 +36,7 @@ def __init__( self._nc = None self._buffer: List[dict[str, Any]] = [] self._callback: Callable[[List[dict[str, Any]]], Any] = callback + self._sub: Subscription | None = None self._scheduler: AsyncIOScheduler | None = None # type: ignore[no-any-unimported] self._lock = asyncio.Lock() @@ -42,7 +44,7 @@ def __init__( async def start(self, subject: str) -> None: # 1. 连接 NATS self._nc = await nats.connect(self.nats_url) # type: ignore[assignment] - await self._nc.subscribe(subject, cb=self._nc_consume) # type: ignore[attr-defined] + self._sub = await self._nc.subscribe(subject, cb=self._nc_consume) # type: ignore[attr-defined] logger.info("Subscribed to <%s>, agent=%s", subject, self.user_ids) # 2. 启动 APScheduler 定时 flush @@ -57,13 +59,21 @@ async def start(self, subject: str) -> None: self._scheduler.start() async def close(self) -> None: - """优雅关闭:停调度器 -> 刷尾数据 -> 断 NATS""" + """优雅关闭:停订阅 -> 停调度器 -> 刷剩余数据 -> 断 NATS""" + # 1. 停止订阅(不再接收新消息) + if self._sub: + await self._sub.unsubscribe() + self._sub = None + + # 2. 停止调度器(不再定时 flush) if self._scheduler: self._scheduler.shutdown(wait=False) + # 3. 刷剩余数据 async with self._lock: await self._flush() + # 4. 断开 NATS if self._nc: await self._nc.close() logger.info("NATS connection closed") From d09de68a0488705cc6ee988d58cb647ba67d5526 Mon Sep 17 00:00:00 2001 From: bobo Date: Fri, 12 Sep 2025 14:24:12 +0800 Subject: [PATCH 08/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/sunagent_ext/tweet/tweet_from_queue.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index 7ee2a47..2180dac 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -50,7 +50,7 @@ async def start(self, subject: str) -> None: # 2. 启动 APScheduler 定时 flush self._scheduler = AsyncIOScheduler() self._scheduler.add_job( - self._flush_wrap, # 定时执行的协程 + self._flush, # 定时执行的协程 trigger="interval", seconds=self.flush_seconds, max_instances=1, @@ -95,21 +95,18 @@ async def _nc_consume(self, msg: Msg) -> None: if len(self._buffer) >= self.size: await self._flush() - # -------------------- 定时刷新包装 -------------------- - async def _flush_wrap(self) -> None: - """APScheduler 只支持调用普通函数,这里包一层协程""" - async with self._lock: - await self._flush() - # -------------------- 真正刷新 -------------------- async def _flush(self) -> None: + # 1. 无数据直接返回(避免空跑) if not self._buffer: return - # 1. 锁内只做“拷贝 + 清空” + + # 2. 锁内:拷贝 + 清空 async with self._lock: to_send = self._buffer.copy() self._buffer.clear() - # 2. 锁外执行回调,避免阻塞后续入队 / 定时 flush + + # 3. 锁外:执行回调 logger.info("Flush %d tweets", len(to_send)) try: await self._callback(to_send) From 4a7835bb6bab4923b106b3af0860d13a5dd09c49 Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 11:33:25 +0800 Subject: [PATCH 09/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/sunagent_ext/tweet/tweet_from_queue.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index 2180dac..e849a52 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -89,24 +89,24 @@ async def _nc_consume(self, msg: Msg) -> None: logger.exception("Bad message: %s", e) return + need_flush = False async with self._lock: self._buffer.append(tweet) - # 如果满了也立即刷 if len(self._buffer) >= self.size: - await self._flush() + need_flush = True - # -------------------- 真正刷新 -------------------- - async def _flush(self) -> None: - # 1. 无数据直接返回(避免空跑) - if not self._buffer: - return + # 锁外调用 flush,避免死锁 + if need_flush: + await self._flush() - # 2. 锁内:拷贝 + 清空 + async def _flush(self) -> None: async with self._lock: + if not self._buffer: + return to_send = self._buffer.copy() self._buffer.clear() - # 3. 锁外:执行回调 + # 锁外执行回调,避免阻塞消息处理 logger.info("Flush %d tweets", len(to_send)) try: await self._callback(to_send) From 096ad10db537e5e481b485ccc17240e47a02c458 Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 13:56:18 +0800 Subject: [PATCH 10/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sunagent_ext/tweet/tweet_from_queue.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index e849a52..3c44a96 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -6,7 +6,9 @@ from typing import Any, Callable, List import nats +from apscheduler.executors.pool import ThreadPoolExecutor from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.date import DateTrigger from nats.aio.msg import Msg from nats.aio.subscription import Subscription @@ -37,8 +39,9 @@ def __init__( self._buffer: List[dict[str, Any]] = [] self._callback: Callable[[List[dict[str, Any]]], Any] = callback self._sub: Subscription | None = None - self._scheduler: AsyncIOScheduler | None = None # type: ignore[no-any-unimported] + self._scheduler: AsyncIOScheduler = AsyncIOScheduler() # type: ignore[no-any-unimported] self._lock = asyncio.Lock() + self.flash_job_id = "flash" # -------------------- 生命周期 -------------------- async def start(self, subject: str) -> None: @@ -48,10 +51,10 @@ async def start(self, subject: str) -> None: logger.info("Subscribed to <%s>, agent=%s", subject, self.user_ids) # 2. 启动 APScheduler 定时 flush - self._scheduler = AsyncIOScheduler() self._scheduler.add_job( self._flush, # 定时执行的协程 trigger="interval", + id=self.flash_job_id, seconds=self.flush_seconds, max_instances=1, next_run_time=datetime.now(), # 立即执行一次 @@ -97,7 +100,18 @@ async def _nc_consume(self, msg: Msg) -> None: # 锁外调用 flush,避免死锁 if need_flush: - await self._flush() + await self._trigger_immediate_flush() + + async def _trigger_immediate_flush(self) -> None: + """插一个立即任务,若已有未执行的则跳过/覆盖""" + self._scheduler.add_job( + self._flush, + trigger=DateTrigger(run_date=datetime.now()), + kwargs={"force": True}, # 标记为立即触发 + id=self.flash_job_id, # 与定时器同一 ID + replace_existing=True, # 覆盖未执行的旧任务 + max_instances=1, + ) async def _flush(self) -> None: async with self._lock: @@ -105,8 +119,6 @@ async def _flush(self) -> None: return to_send = self._buffer.copy() self._buffer.clear() - - # 锁外执行回调,避免阻塞消息处理 logger.info("Flush %d tweets", len(to_send)) try: await self._callback(to_send) From 4943650e12f69f0eaf3f899113d479b60a2e79f1 Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 17:25:32 +0800 Subject: [PATCH 11/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sunagent_ext/tweet/tweet_from_queue.py | 190 +++++++++--------- 1 file changed, 96 insertions(+), 94 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index 3c44a96..2d385cf 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -1,133 +1,135 @@ -# tweet_from_queue.py import asyncio import json import logging from datetime import datetime -from typing import Any, Callable, List +from typing import Any, Callable, Dict, List, Optional import nats -from apscheduler.executors.pool import ThreadPoolExecutor -from apscheduler.schedulers.asyncio import AsyncIOScheduler -from apscheduler.triggers.date import DateTrigger from nats.aio.msg import Msg -from nats.aio.subscription import Subscription logger = logging.getLogger(__name__) class TweetFromQueueContext: - """ - 1. 订阅 NATS subject,累积推文 - 2. APScheduler 每 10 秒强制 flush 一次(不管有没有消息) - 3. 支持优雅关闭 - """ - def __init__( self, - size: int, - user_ids: List[str], + *, + batch_size: int, + flush_seconds: float, + callback: Callable[[List[Dict[str, Any]]], Any], nats_url: str, - callback: Callable[[List[dict[str, Any]]], Any], - flush_seconds: int = 10, + subject: str, + user_ids: Optional[List[str]] = None, ): - self.size = size - self.user_ids = user_ids - self.nats_url = nats_url + if batch_size <= 0 or flush_seconds <= 0: + raise ValueError("batch_size / flush_seconds 必须为正") + self.batch_size = batch_size self.flush_seconds = flush_seconds + self.callback = callback + self.nats_url = nats_url + self.subject = subject + self.user_ids = set(user_ids) if user_ids else None - self._nc = None - self._buffer: List[dict[str, Any]] = [] - self._callback: Callable[[List[dict[str, Any]]], Any] = callback - self._sub: Subscription | None = None - self._scheduler: AsyncIOScheduler = AsyncIOScheduler() # type: ignore[no-any-unimported] - self._lock = asyncio.Lock() - self.flash_job_id = "flash" + self._nc: Optional[nats.NATS] = None # type: ignore[name-defined] + self._sub = None + self._queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + self._stop_evt = asyncio.Event() + self._worker_task: Optional[asyncio.Task] = None # type: ignore[type-arg] # -------------------- 生命周期 -------------------- - async def start(self, subject: str) -> None: - # 1. 连接 NATS - self._nc = await nats.connect(self.nats_url) # type: ignore[assignment] - self._sub = await self._nc.subscribe(subject, cb=self._nc_consume) # type: ignore[attr-defined] - logger.info("Subscribed to <%s>, agent=%s", subject, self.user_ids) - - # 2. 启动 APScheduler 定时 flush - self._scheduler.add_job( - self._flush, # 定时执行的协程 - trigger="interval", - id=self.flash_job_id, - seconds=self.flush_seconds, - max_instances=1, - next_run_time=datetime.now(), # 立即执行一次 - ) - self._scheduler.start() - - async def close(self) -> None: - """优雅关闭:停订阅 -> 停调度器 -> 刷剩余数据 -> 断 NATS""" - # 1. 停止订阅(不再接收新消息) + async def start(self) -> None: + self._nc = await nats.connect(self.nats_url) + self._sub = await self._nc.subscribe(self.subject, cb=self._on_msg) # type: ignore[assignment] + logger.info("Subscribed to <%s>, filter=%s", self.subject, self.user_ids) + # 启动单协程 worker + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def stop(self) -> None: + logger.info("Stopping AsyncBatchingQueue...") + self._stop_evt.set() + if self._worker_task: + await self._worker_task if self._sub: await self._sub.unsubscribe() - self._sub = None + await self._flush() # 刷剩余 + if self._nc: + await self._nc.close() + logger.info("AsyncBatchingQueue stopped") - # 2. 停止调度器(不再定时 flush) - if self._scheduler: - self._scheduler.shutdown(wait=False) + async def __aenter__(self): # type: ignore[no-untyped-def] + await self.start() + return self - # 3. 刷剩余数据 - async with self._lock: - await self._flush() + async def __aexit__(self, exc_type, exc, tb): # type: ignore[no-untyped-def] + await self.stop() - # 4. 断开 NATS - if self._nc: - await self._nc.close() - logger.info("NATS connection closed") + # -------------------- 公共入口 -------------------- + async def add(self, item: Dict[str, Any]) -> None: + await self._queue.put(item) # -------------------- NATS 回调 -------------------- - async def _nc_consume(self, msg: Msg) -> None: + async def _on_msg(self, msg: Msg) -> None: try: tweet = json.loads(msg.data.decode()) tweet = self._fix_tweet_dict(tweet) - if tweet["author_id"] not in self.user_ids: - return + # if self.user_ids and tweet.get("author_id") not in self.user_ids: + # return except Exception as e: - logger.exception("Bad message: %s", e) + logger.exception("Bad msg: %s", e) return - - need_flush = False - async with self._lock: - self._buffer.append(tweet) - if len(self._buffer) >= self.size: - need_flush = True - - # 锁外调用 flush,避免死锁 - if need_flush: - await self._trigger_immediate_flush() - - async def _trigger_immediate_flush(self) -> None: - """插一个立即任务,若已有未执行的则跳过/覆盖""" - self._scheduler.add_job( - self._flush, - trigger=DateTrigger(run_date=datetime.now()), - kwargs={"force": True}, # 标记为立即触发 - id=self.flash_job_id, # 与定时器同一 ID - replace_existing=True, # 覆盖未执行的旧任务 - max_instances=1, - ) + await self.add(tweet) + + # -------------------- 核心 worker:完全模仿 _worker_loop -------------------- + async def _worker_loop(self) -> None: + """单协程:等第一条 -> 设 deadline -> 继续 get(timeout=剩余时间) -> 满批/超时刷""" + while not self._stop_evt.is_set(): + batch: List[Dict[str, Any]] = [] + deadline = None + + # 1. 阻塞等第一条 + try: + first = await asyncio.wait_for(self._queue.get(), timeout=0.1) + batch.append(first) + deadline = asyncio.get_event_loop().time() + self.flush_seconds + except asyncio.TimeoutError: + continue + + # 2. 收集剩余 + while len(batch) < self.batch_size and not self._stop_evt.is_set(): + remaining = deadline - asyncio.get_event_loop().time() + if remaining <= 0: + break + try: + item = await asyncio.wait_for(self._queue.get(), timeout=remaining) + batch.append(item) + except asyncio.TimeoutError: + break + + # 3. 处理 + if batch: + try: + await self.callback(batch) + except Exception as e: + logger.exception("Callback error: %s", e) + # 4. 退出时刷剩余 + await self._drain_remaining() async def _flush(self) -> None: - async with self._lock: - if not self._buffer: - return - to_send = self._buffer.copy() - self._buffer.clear() - logger.info("Flush %d tweets", len(to_send)) - try: - await self._callback(to_send) - except Exception as e: - logger.exception("Callback error: %s", e) + """同步刷剩余(给 stop 用)""" + batch = [] + while not self._queue.empty(): + batch.append(self._queue.get_nowait()) + if batch: + try: + await self.callback(batch) + except Exception as e: + logger.exception("Final callback error: %s", e) + + async def _drain_remaining(self) -> None: + await self._flush() - # -------------------- 工具 -------------------- @staticmethod - def _fix_tweet_dict(msg: dict[str, Any]) -> dict[str, Any]: + def _fix_tweet_dict(msg: Dict[str, Any]) -> Dict[str, Any]: fixed = msg.copy() for key in ("created_at", "updated_at"): if key in fixed and isinstance(fixed[key], str): From a6d593d477e9d94ccd848f45c47afa68363f91d9 Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 18:10:01 +0800 Subject: [PATCH 12/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../sunagent_ext/tweet/tweet_from_queue.py | 44 +++++++++++-------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index 2d385cf..d8e3b3d 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -11,6 +11,9 @@ class TweetFromQueueContext: + # ---------- 新增哨兵 ---------- + _SENTINEL = object() # 用于通知 worker 立即退出 + def __init__( self, *, @@ -32,7 +35,8 @@ def __init__( self._nc: Optional[nats.NATS] = None # type: ignore[name-defined] self._sub = None - self._queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + # 队列元素现在是 Dict | sentinel + self._queue: asyncio.Queue[Any] = asyncio.Queue() self._stop_evt = asyncio.Event() self._worker_task: Optional[asyncio.Task] = None # type: ignore[type-arg] @@ -41,17 +45,18 @@ async def start(self) -> None: self._nc = await nats.connect(self.nats_url) self._sub = await self._nc.subscribe(self.subject, cb=self._on_msg) # type: ignore[assignment] logger.info("Subscribed to <%s>, filter=%s", self.subject, self.user_ids) - # 启动单协程 worker self._worker_task = asyncio.create_task(self._worker_loop()) async def stop(self) -> None: logger.info("Stopping AsyncBatchingQueue...") self._stop_evt.set() + # 1. 先塞哨兵,让 worker 从 queue.get() 立即返回 + await self._queue.put(self._SENTINEL) if self._worker_task: await self._worker_task if self._sub: await self._sub.unsubscribe() - await self._flush() # 刷剩余 + # 2. worker loop 退出前已经 _drain_remaining(),这里不再 _flush() if self._nc: await self._nc.close() logger.info("AsyncBatchingQueue stopped") @@ -72,27 +77,24 @@ async def _on_msg(self, msg: Msg) -> None: try: tweet = json.loads(msg.data.decode()) tweet = self._fix_tweet_dict(tweet) - # if self.user_ids and tweet.get("author_id") not in self.user_ids: - # return + if self.user_ids and tweet.get("author_id") not in self.user_ids: + return except Exception as e: logger.exception("Bad msg: %s", e) return await self.add(tweet) - # -------------------- 核心 worker:完全模仿 _worker_loop -------------------- + # -------------------- 核心 worker -------------------- async def _worker_loop(self) -> None: - """单协程:等第一条 -> 设 deadline -> 继续 get(timeout=剩余时间) -> 满批/超时刷""" + """单协程:永久阻塞等第一条 -> 设 deadline -> 超时/满批刷 -> 收到哨兵退出""" while not self._stop_evt.is_set(): batch: List[Dict[str, Any]] = [] - deadline = None - - # 1. 阻塞等第一条 - try: - first = await asyncio.wait_for(self._queue.get(), timeout=0.1) - batch.append(first) - deadline = asyncio.get_event_loop().time() + self.flush_seconds - except asyncio.TimeoutError: - continue + # 1. 永久阻塞等第一条(CPU 不再空转) + first = await self._queue.get() + if first is self._SENTINEL: # 收到哨兵直接退出 + break + batch.append(first) + deadline = asyncio.get_event_loop().time() + self.flush_seconds # 2. 收集剩余 while len(batch) < self.batch_size and not self._stop_evt.is_set(): @@ -101,6 +103,8 @@ async def _worker_loop(self) -> None: break try: item = await asyncio.wait_for(self._queue.get(), timeout=remaining) + if item is self._SENTINEL: # 哨兵提前结束收集 + break batch.append(item) except asyncio.TimeoutError: break @@ -111,14 +115,18 @@ async def _worker_loop(self) -> None: await self.callback(batch) except Exception as e: logger.exception("Callback error: %s", e) + # 4. 退出时刷剩余 await self._drain_remaining() async def _flush(self) -> None: - """同步刷剩余(给 stop 用)""" + """同步刷剩余(给 _drain_remaining 用)""" batch = [] while not self._queue.empty(): - batch.append(self._queue.get_nowait()) + item = self._queue.get_nowait() + if item is self._SENTINEL: # 忽略哨兵 + continue + batch.append(item) if batch: try: await self.callback(batch) From 9e44c5cbab88f2f080081b1b6ddb673e7f8b8160 Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 18:21:19 +0800 Subject: [PATCH 13/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/sunagent_ext/tweet/tweet_from_queue.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py index d8e3b3d..d4247ca 100644 --- a/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py +++ b/packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py @@ -49,14 +49,18 @@ async def start(self) -> None: async def stop(self) -> None: logger.info("Stopping AsyncBatchingQueue...") - self._stop_evt.set() - # 1. 先塞哨兵,让 worker 从 queue.get() 立即返回 - await self._queue.put(self._SENTINEL) - if self._worker_task: - await self._worker_task + + # 1. 立即停止接收新消息 if self._sub: await self._sub.unsubscribe() - # 2. worker loop 退出前已经 _drain_remaining(),这里不再 _flush() + + # 2. 通知 worker 退出并等待它刷完剩余 + self._stop_evt.set() + await self._queue.put(self._SENTINEL) # 让 worker 从 queue.get() 立即返回 + if self._worker_task: + await self._worker_task # 内部已 _drain_remaining() + + # 3. 关闭 NATS 连接 if self._nc: await self._nc.close() logger.info("AsyncBatchingQueue stopped") From 45b28f9f3aa211a8382553875dbb6fcf056403eb Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 18:32:15 +0800 Subject: [PATCH 14/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py index 8d2a57f..4d06455 100644 --- a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py +++ b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py @@ -79,7 +79,7 @@ _FileSpec: TypeAlias = _FileContent | _FileSpecTuple2 | _FileSpecTuple3 | _FileSpecTuple4 _FilesType: TypeAlias = Mapping[str, _FileSpec] | Iterable[tuple[str, _FileSpec]] | None _AuthType: TypeAlias = tuple[str, str] | AuthBase | Callable[[PreparedRequest], PreparedRequest] | None -_TimeoutType: TypeAlias = float | tuple[float, float] | tuple[float, None] | None +_TimeoutType: TypeAlias = float | tuple[float, float] | tuple[float, None] | tuple[None, float] | None _Hook: TypeAlias = Callable[[Response], Any] _HooksType: TypeAlias = Mapping[str, Iterable[_Hook] | _Hook] | None _CertType: TypeAlias = str | tuple[str, str] | None From c6486b55a8c8e5f2150c08bb68acb8b7e3cf0f1a Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 18:34:43 +0800 Subject: [PATCH 15/16] =?UTF-8?q?feat=EF=BC=9A=20update=20consume=20&=20fl?= =?UTF-8?q?ash?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py index 4d06455..8d2a57f 100644 --- a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py +++ b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py @@ -79,7 +79,7 @@ _FileSpec: TypeAlias = _FileContent | _FileSpecTuple2 | _FileSpecTuple3 | _FileSpecTuple4 _FilesType: TypeAlias = Mapping[str, _FileSpec] | Iterable[tuple[str, _FileSpec]] | None _AuthType: TypeAlias = tuple[str, str] | AuthBase | Callable[[PreparedRequest], PreparedRequest] | None -_TimeoutType: TypeAlias = float | tuple[float, float] | tuple[float, None] | tuple[None, float] | None +_TimeoutType: TypeAlias = float | tuple[float, float] | tuple[float, None] | None _Hook: TypeAlias = Callable[[Response], Any] _HooksType: TypeAlias = Mapping[str, Iterable[_Hook] | _Hook] | None _CertType: TypeAlias = str | tuple[str, str] | None From c16316719034e3b127c0813abe86c818027768b5 Mon Sep 17 00:00:00 2001 From: bobo Date: Mon, 15 Sep 2025 18:37:42 +0800 Subject: [PATCH 16/16] =?UTF-8?q?feat=EF=BC=9A=20mypy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py index 8d2a57f..d831c29 100644 --- a/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py +++ b/packages/sunagent-ext/src/sunagent_ext/utils/timeout_session.py @@ -100,7 +100,7 @@ def request( cookies: _CookiesType = None, files: _FilesType = None, auth: _AuthType = None, - timeout: _TimeoutType = None, + timeout: _TimeoutType = None, # type: ignore[override] allow_redirects: bool = True, proxies: _TextMapping | None = None, hooks: _HooksType = None,