Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def _reply_comment(self, authorperm: str, body: str) -> str:
if self.cache:
self.cache.set(cache_key, "processed")
logger.info(f"Reply comment {authorperm} success body: {body} result {result}")
read_steem_success_count.inc()
post_steem_success_count.inc()
return f"Reply comment {authorperm} success body: {body}"
except Exception as e:
logger.error(f"Reply comment {authorperm} failed e : {str(e)}")
read_steem_failure_count.inc()
post_steem_failure_count.inc()
logger.error(traceback.format_exc())
return f"Reply comment {authorperm} failed e : {str(e)}"

Expand Down
4 changes: 3 additions & 1 deletion packages/sunagent-ext/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "sunagent-ext"
version = "0.0.7b1"
version = "0.0.7b2"
license = {file = "LICENSE-CODE"}
description = "AutoGen extensions library"
readme = "README.md"
Expand All @@ -19,7 +19,9 @@ dependencies = [
"aiohttp",
"onepassword-sdk>=0.3.0",
"pytz",
"nats-py==2.11.0",
"redis",
"prometheus_client",
"requests",
]

Expand Down
149 changes: 149 additions & 0 deletions packages/sunagent-ext/src/sunagent_ext/tweet/tweet_from_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import asyncio
import json
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional

import nats
from nats.aio.msg import Msg

logger = logging.getLogger(__name__)


class TweetFromQueueContext:
# ---------- 新增哨兵 ----------
_SENTINEL = object() # 用于通知 worker 立即退出

def __init__(
self,
*,
batch_size: int,
flush_seconds: float,
callback: Callable[[List[Dict[str, Any]]], Any],
nats_url: str,
subject: str,
user_ids: Optional[List[str]] = None,
):
if batch_size <= 0 or flush_seconds <= 0:
raise ValueError("batch_size / flush_seconds 必须为正")
self.batch_size = batch_size
self.flush_seconds = flush_seconds
self.callback = callback
self.nats_url = nats_url
self.subject = subject
self.user_ids = set(user_ids) if user_ids else None

self._nc: Optional[nats.NATS] = None # type: ignore[name-defined]
self._sub = None
# 队列元素现在是 Dict | sentinel
self._queue: asyncio.Queue[Any] = asyncio.Queue()
self._stop_evt = asyncio.Event()
self._worker_task: Optional[asyncio.Task] = None # type: ignore[type-arg]

# -------------------- 生命周期 --------------------
async def start(self) -> None:
self._nc = await nats.connect(self.nats_url)
self._sub = await self._nc.subscribe(self.subject, cb=self._on_msg) # type: ignore[assignment]
logger.info("Subscribed to <%s>, filter=%s", self.subject, self.user_ids)
self._worker_task = asyncio.create_task(self._worker_loop())

async def stop(self) -> None:
logger.info("Stopping AsyncBatchingQueue...")

# 1. 立即停止接收新消息
if self._sub:
await self._sub.unsubscribe()

# 2. 通知 worker 退出并等待它刷完剩余
self._stop_evt.set()
await self._queue.put(self._SENTINEL) # 让 worker 从 queue.get() 立即返回
if self._worker_task:
await self._worker_task # 内部已 _drain_remaining()

# 3. 关闭 NATS 连接
if self._nc:
await self._nc.close()
logger.info("AsyncBatchingQueue stopped")

async def __aenter__(self): # type: ignore[no-untyped-def]
await self.start()
return self

async def __aexit__(self, exc_type, exc, tb): # type: ignore[no-untyped-def]
await self.stop()

# -------------------- 公共入口 --------------------
async def add(self, item: Dict[str, Any]) -> None:
await self._queue.put(item)

# -------------------- NATS 回调 --------------------
async def _on_msg(self, msg: Msg) -> None:
try:
tweet = json.loads(msg.data.decode())
tweet = self._fix_tweet_dict(tweet)
if self.user_ids and tweet.get("author_id") not in self.user_ids:
return
except Exception as e:
logger.exception("Bad msg: %s", e)
return
await self.add(tweet)

# -------------------- 核心 worker --------------------
async def _worker_loop(self) -> None:
"""单协程:永久阻塞等第一条 -> 设 deadline -> 超时/满批刷 -> 收到哨兵退出"""
while not self._stop_evt.is_set():
batch: List[Dict[str, Any]] = []
# 1. 永久阻塞等第一条(CPU 不再空转)
first = await self._queue.get()
if first is self._SENTINEL: # 收到哨兵直接退出
break
batch.append(first)
deadline = asyncio.get_event_loop().time() + self.flush_seconds

# 2. 收集剩余
while len(batch) < self.batch_size and not self._stop_evt.is_set():
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
break
try:
item = await asyncio.wait_for(self._queue.get(), timeout=remaining)
if item is self._SENTINEL: # 哨兵提前结束收集
break
batch.append(item)
except asyncio.TimeoutError:
break

# 3. 处理
if batch:
try:
await self.callback(batch)
except Exception as e:
logger.exception("Callback error: %s", e)

# 4. 退出时刷剩余
await self._drain_remaining()

async def _flush(self) -> None:
"""同步刷剩余(给 _drain_remaining 用)"""
batch = []
while not self._queue.empty():
item = self._queue.get_nowait()
if item is self._SENTINEL: # 忽略哨兵
continue
batch.append(item)
if batch:
try:
await self.callback(batch)
except Exception as e:
logger.exception("Final callback error: %s", e)

async def _drain_remaining(self) -> None:
await self._flush()

@staticmethod
def _fix_tweet_dict(msg: Dict[str, Any]) -> Dict[str, Any]:
fixed = msg.copy()
for key in ("created_at", "updated_at"):
if key in fixed and isinstance(fixed[key], str):
fixed[key] = datetime.fromisoformat(fixed[key])
return fixed
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@

import asyncio
import logging
import traceback
from datetime import datetime, timedelta, timezone
from typing import Any, Callable, Dict, List, Optional, cast

import tweepy
from prometheus_client import Counter, Gauge
from tweepy import Media, NotFound, TwitterServerError, User # 保持原类型注解
from tweepy import Media, NotFound, Tweet, TwitterServerError, User # 保持原类型注解
from tweepy import Response as TwitterResponse

from sunagent_ext.tweet.twitter_client_pool import TwitterClientPool
Expand Down Expand Up @@ -223,7 +225,7 @@ async def _fetch_timeline(
read_tweet_success_count.labels(client_key=client_key).inc(len(resp.data or []))

# 交给中间层处理
tweet_list, next_token = await self.on_twitter_response(agent_id, resp, filter_func)
tweet_list, next_token = await self.on_twitter_response(agent_id, me_id, resp, filter_func)
all_raw.extend(tweet_list)
if not next_token:
break
Expand Down Expand Up @@ -254,6 +256,7 @@ async def _fetch_timeline(
async def on_twitter_response( # type: ignore[no-any-unimported]
self,
agent_id: str,
me_id: str,
response: TwitterResponse,
filter_func: Callable[[Dict[str, Any]], bool],
) -> tuple[list[Dict[str, Any]], Optional[str]]:
Expand All @@ -267,20 +270,23 @@ async def on_twitter_response( # type: ignore[no-any-unimported]
out: list[Dict[str, Any]] = []

for tweet in all_tweets:
if not await self._should_keep(agent_id, tweet, filter_func):
if not await self._should_keep(agent_id, me_id, tweet, filter_func):
continue
norm = await self._normalize_tweet(tweet)
out.append(norm)
return out, next_token

async def _should_keep(
self, agent_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool]
self, agent_id: str, me_id: str, tweet: Dict[str, Any], filter_func: Callable[[Dict[str, Any]], bool]
) -> bool:
is_processed = await self._check_tweet_process(tweet["id"], agent_id)
if is_processed:
logger.info("already processed %s", tweet["id"])
return False
author_id = str(tweet["author_id"])
if me_id == author_id:
logger.info("skip my tweet %s", tweet["id"])
return False
if author_id in self.block_uids:
logger.info("blocked user %s", author_id)
return False
Expand Down Expand Up @@ -366,6 +372,91 @@ async def _recursive_fetch(self, tweet: Dict[str, Any], chain: list[Dict[str, An
await self._recursive_fetch(parent, chain, depth + 1)
chain.append(tweet)

async def fetch_new_tweets_manual_( # type: ignore[no-any-unimported]
self,
ids: List[str],
last_seen_id: str | None = None,
) -> tuple[List[Tweet], str | None]:
"""
1. 取所有 ALIVE user 的 twitter_id
2. 将 id 列表拆分成多条不超长 query
3. 逐条交给 fetch_new_tweets_manual_tweets 翻页
4. 返回全部结果以及 **所有结果中最大的 tweet_id**
"""
BASE_EXTRA = " -is:retweet"
max_len = 512 - len(BASE_EXTRA) - 10
queries: List[str] = []

buf: list[str] = []
for uid in ids:
clause = f"from:{uid}"
if len(" OR ".join(buf + [clause])) > max_len:
queries.append(" OR ".join(buf) + BASE_EXTRA)
buf = [clause]
else:
buf.append(clause)
if buf:
queries.append(" OR ".join(buf) + BASE_EXTRA)
# 3) 逐条调用内层并合并
all_tweets: List[tweepy.Tweet] = [] # type: ignore[no-any-unimported]
for q in queries:
tweets = await self.fetch_new_tweets_manual_tweets(query=q, last_seen_id=last_seen_id)
all_tweets.extend(tweets)

# 4) 取所有结果中最大的 id 作为 last_seen_id
last_id = max((tw.id for tw in all_tweets), default=None)
return all_tweets, last_id

async def get_user_tweet(self, user_ids: List[str]) -> List[Tweet]: # type: ignore[no-any-unimported]
cache_key = "user_last_seen_id"
last_seen_id = self.cache.get(cache_key)
tweets, last_seen_id = await self.fetch_new_tweets_manual_(ids=user_ids, last_seen_id=last_seen_id)
logger.info(f"get_user_tweet tweets: {len(tweets)} last_seen_id: {last_seen_id}")
if last_seen_id:
self.cache.set(cache_key, last_seen_id)
return tweets

async def fetch_new_tweets_manual_tweets( # type: ignore[no-any-unimported]
self, query: str, last_seen_id: str | None = None, max_per_page: int = 100, hours: int = 24
) -> List[Tweet]:
tweets: List[Any] = []
next_token = None
since = datetime.now(timezone.utc) - timedelta(hours=hours)
start_time = None if last_seen_id else since.isoformat(timespec="seconds")
logger.info(f"query: {query}")
while True:
cli, key = None, ""
try:
cli, key = await self.pool.acquire()
resp = cli.search_recent_tweets(
query=query,
start_time=start_time,
since_id=last_seen_id,
max_results=max_per_page,
tweet_fields=TWEET_FIELDS,
next_token=next_token,
user_auth=True,
)
page_data = resp.data or []
logger.info(f"page_data: {len(page_data)}")
tweets.extend(page_data)
read_tweet_success_count.labels(client_key=key).inc(len(resp.data or []))
next_token = resp.meta.get("next_token")
if not next_token:
break
except tweepy.TooManyRequests:
logger.error(traceback.format_exc())
read_tweet_failure_count.labels(client_key=key).inc()
if cli:
await self.pool.report_failure(cli)
return tweets
except tweepy.TweepyException:
if cli:
await self.pool.report_failure(cli)
logger.error(traceback.format_exc())
return tweets
return tweets

async def _get_tweet_with_retry(self, tweet_id: str) -> Optional[Dict[str, Any]]:
for attempt in range(3):
cli, client_key = await self.pool.acquire()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def request(
cookies: _CookiesType = None,
files: _FilesType = None,
auth: _AuthType = None,
timeout: _TimeoutType = None,
timeout: _TimeoutType = None, # type: ignore[override]
allow_redirects: bool = True,
proxies: _TextMapping | None = None,
hooks: _HooksType = None,
Expand Down
Loading