Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def _reply_comment(self, authorperm: str, body: str) -> str:
if self.cache:
self.cache.set(cache_key, "processed")
logger.info(f"Reply comment {authorperm} success body: {body} result {result}")
read_steem_success_count.inc()
post_steem_success_count.inc()
return f"Reply comment {authorperm} success body: {body}"
except Exception as e:
logger.error(f"Reply comment {authorperm} failed e : {str(e)}")
read_steem_failure_count.inc()
post_steem_failure_count.inc()
logger.error(traceback.format_exc())
return f"Reply comment {authorperm} failed e : {str(e)}"

Expand Down
4 changes: 3 additions & 1 deletion packages/sunagent-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "sunagent-ext"
version = "0.0.7b1"
version = "0.0.7b2"
license = {file = "LICENSE-CODE"}
description = "AutoGen extensions library"
readme = "README.md"
Expand All @@ -19,7 +19,9 @@ dependencies = [
"aiohttp",
"onepassword-sdk>=0.3.0",
"pytz",
"nats-py==2.11.0",
"redis",
"prometheus_client",
"requests",
]

Expand Down
126 changes: 126 additions & 0 deletions packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# tweet_from_queue.py
import asyncio
import json
import logging
from datetime import datetime
from typing import Any, Callable, List

import nats
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from nats.aio.msg import Msg
from nats.aio.subscription import Subscription

logger = logging.getLogger(__name__)


class TweetFromQueueContext:
"""
1. 订阅 NATS subject,累积推文
2. APScheduler 每 10 秒强制 flush 一次(不管有没有消息)
3. 支持优雅关闭
"""

def __init__(
self,
size: int,
user_ids: List[str],
nats_url: str,
callback: Callable[[List[dict[str, Any]]], Any],
flush_seconds: int = 10,
):
self.size = size
self.user_ids = user_ids
self.nats_url = nats_url
self.flush_seconds = flush_seconds

self._nc = None
self._buffer: List[dict[str, Any]] = []
self._callback: Callable[[List[dict[str, Any]]], Any] = callback
self._sub: Subscription | None = None
self._scheduler: AsyncIOScheduler | None = None # type: ignore[no-any-unimported]
self._lock = asyncio.Lock()

# -------------------- 生命周期 --------------------
async def start(self, subject: str) -> None:
# 1. 连接 NATS
self._nc = await nats.connect(self.nats_url) # type: ignore[assignment]
self._sub = await self._nc.subscribe(subject, cb=self._nc_consume) # type: ignore[attr-defined]
logger.info("Subscribed to <%s>, agent=%s", subject, self.user_ids)

# 2. 启动 APScheduler 定时 flush
self._scheduler = AsyncIOScheduler()
self._scheduler.add_job(
self._flush_wrap, # 定时执行的协程
trigger="interval",
seconds=self.flush_seconds,
max_instances=1,
next_run_time=datetime.now(), # 立即执行一次
)
self._scheduler.start()

async def close(self) -> None:
"""优雅关闭:停订阅 -> 停调度器 -> 刷剩余数据 -> 断 NATS"""
# 1. 停止订阅(不再接收新消息)
if self._sub:
await self._sub.unsubscribe()
Comment thread
boboliu-1010 marked this conversation as resolved.
self._sub = None

# 2. 停止调度器(不再定时 flush)
if self._scheduler:
self._scheduler.shutdown(wait=False)

# 3. 刷剩余数据
async with self._lock:
await self._flush()

# 4. 断开 NATS
if self._nc:
await self._nc.close()
logger.info("NATS connection closed")

# -------------------- NATS 回调 --------------------
async def _nc_consume(self, msg: Msg) -> None:
try:
tweet = json.loads(msg.data.decode())
tweet = self._fix_tweet_dict(tweet)
if tweet["author_id"] not in self.user_ids:
return
except Exception as e:
logger.exception("Bad message: %s", e)
return

async with self._lock:
self._buffer.append(tweet)
# 如果满了也立即刷
if len(self._buffer) >= self.size:
await self._flush()

# -------------------- 定时刷新包装 --------------------
async def _flush_wrap(self) -> None:
"""APScheduler 只支持调用普通函数,这里包一层协程"""
async with self._lock:
Comment thread
boboliu-1010 marked this conversation as resolved.
Outdated
await self._flush()

# -------------------- 真正刷新 --------------------
async def _flush(self) -> None:
Comment thread
boboliu-1010 marked this conversation as resolved.
if not self._buffer:
return
# 1. 锁内只做“拷贝 + 清空”
async with self._lock:
to_send = self._buffer.copy()
self._buffer.clear()
# 2. 锁外执行回调,避免阻塞后续入队 / 定时 flush
logger.info("Flush %d tweets", len(to_send))
try:
await self._callback(to_send)
except Exception as e:
logger.exception("Callback error: %s", e)

# -------------------- 工具 --------------------
@staticmethod
def _fix_tweet_dict(msg: dict[str, Any]) -> dict[str, Any]:
fixed = msg.copy()
for key in ("created_at", "updated_at"):
if key in fixed and isinstance(fixed[key], str):
fixed[key] = datetime.fromisoformat(fixed[key])
return fixed
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import asyncio
import logging
import traceback
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Dict, List, Optional, cast

import tweepy
from prometheus_client import Counter, Gauge
from tweepy import Media, NotFound, TwitterServerError, User # 保持原类型注解
from tweepy import Media, NotFound, Tweet, TwitterServerError, User # 保持原类型注解
from tweepy import Response as TwitterResponse

from sunagent_ext.tweet.twitter_client_pool import TwitterClientPool
Expand Down Expand Up @@ -223,7 +225,7 @@ async def _fetch_timeline(
read_tweet_success_count.labels(client_key=client_key).inc(len(resp.data or []))

# 交给中间层处理
tweet_list, next_token = await self.on_twitter_response(agent_id, resp, filter_func)
tweet_list, next_token = await self.on_twitter_response(agent_id, me_id, resp, filter_func)
all_raw.extend(tweet_list)
if not next_token:
break
Expand Down Expand Up @@ -254,6 +256,7 @@ async def _fetch_timeline(
async def on_twitter_response( # type: ignore[no-any-unimported]
self,
agent_id: str,
me_id: str,
response: TwitterResponse,
filter_func: Callable[[Dict[str, Any]], bool],
) -> tuple[list[Dict[str, Any]], Optional[str]]:
Expand All @@ -267,20 +270,23 @@ async def on_twitter_response( # type: ignore[no-any-unimported]
out: list[Dict[str, Any]] = []

for tweet in all_tweets:
if not await self._should_keep(agent_id, tweet, filter_func):
if not await self._should_keep(agent_id, me_id, tweet, filter_func):
continue
norm = await self._normalize_tweet(tweet)
out.append(norm)
return out, next_token

async def _should_keep(
self, agent_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool]
self, agent_id: str, me_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool]
) -> bool:
is_processed = await self._check_tweet_process(tweet["id"], agent_id)
if is_processed:
logger.info("already processed %s", tweet["id"])
return False
author_id = str(tweet["author_id"])
if me_id == author_id:
logger.info("skip my tweet %s", tweet["id"])
return False
if author_id in self.block_uids:
logger.info("blocked user %s", author_id)
return False
Expand Down Expand Up @@ -366,6 +372,91 @@ async def _recursive_fetch(self, tweet: Dict[str, Any], chain: list[Dict[str, An
await self._recursive_fetch(parent, chain, depth + 1)
chain.append(tweet)

async def fetch_new_tweets_manual_( # type: ignore[no-any-unimported]
self,
ids: List[str],
last_seen_id: str | None = None,
) -> tuple[List[Tweet], str | None]:
"""
1. 取所有 ALIVE user 的 twitter_id
2. 将 id 列表拆分成多条不超长 query
3. 逐条交给 fetch_new_tweets_manual_tweets 翻页
4. 返回全部结果以及 **所有结果中最大的 tweet_id**
"""
BASE_EXTRA = " -is:retweet"
max_len = 512 - len(BASE_EXTRA) - 10
queries: List[str] = []

buf: list[str] = []
for uid in ids:
clause = f"from:{uid}"
if len(" OR ".join(buf + [clause])) > max_len:
queries.append(" OR ".join(buf) + BASE_EXTRA)
buf = [clause]
else:
buf.append(clause)
if buf:
queries.append(" OR ".join(buf) + BASE_EXTRA)
# 3) 逐条调用内层并合并
all_tweets: List[tweepy.Tweet] = [] # type: ignore[no-any-unimported]
for q in queries:
tweets = await self.fetch_new_tweets_manual_tweets(query=q, last_seen_id=last_seen_id)
all_tweets.extend(tweets)

# 4) 取所有结果中最大的 id 作为 last_seen_id
last_id = max((tw.id for tw in all_tweets), default=None)
return all_tweets, last_id

async def get_user_tweet(self, user_ids: List[str]) -> List[Tweet]: # type: ignore[no-any-unimported]
cache_key = "user_last_seen_id"
last_seen_id = self.cache.get(cache_key)
tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=user_ids, last_seen_id=last_seen_id)
logger.info(f"get_user_tweet tweets: {len(tweets)} last_seen_id: {last_seen_id}")
if last_seen_id:
self.cache.set(cache_key, last_seen_id)
return tweets

async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported]
self, query: str, last_seen_id: str | None = None, max_per_page: int = 100, hours: int = 24
) -> List[Tweet]:
tweets: List[Any] = []
next_token = None
since = datetime.now(timezone.utc) - timedelta(hours=hours)
start_time = None if last_seen_id else since.isoformat(timespec="seconds")
logger.info(f"query: {query}")
while True:
cli, key = None, ""
try:
cli, key = await self.pool.acquire()
resp = cli.search_recent_tweets(
query=query,
start_time=start_time,
since_id=last_seen_id,
max_results=max_per_page,
tweet_fields=TWEET_FIELDS,
next_token=next_token,
user_auth=True,
)
page_data = resp.data or []
logger.info(f"page_data: {len(page_data)}")
tweets.extend(page_data)
read_tweet_success_count.labels(client_key=key).inc(len(resp.data or []))
next_token = resp.meta.get("next_token")
if not next_token:
break
except tweepy.TooManyRequests:
logger.error(traceback.format_exc())
read_tweet_failure_count.labels(client_key=key).inc()
if cli:
await self.pool.report_failure(cli)
return tweets
except tweepy.TweepyException:
if cli:
await self.pool.report_failure(cli)
logger.error(traceback.format_exc())
return tweets
return tweets

async def _get_tweet_with_retry(self, tweet_id: str) -> Optional[Dict[str, Any]]:
for attempt in range(3):
cli, client_key = await self.pool.acquire()
Expand Down
Loading