mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-13 07:26:45 +00:00
fix: ensure stop_streaming waits for inflight initial flush before final edit; improve message edit/delete return types and logging
- Update stop_streaming logic to await inflight initial flush task, preventing duplicate message sends on stream stop - Change message edit/delete methods to return Optional[bool] for clearer channel mismatch handling - Refine Feishu logging to include message_id instead of full data object - Suppress allowed_objects warnings in __init__.py - Add test to verify stop_streaming waits for inflight flush before final edit - Update .pylintrc to use 'E' for error enabling
This commit is contained in:
@@ -18,7 +18,7 @@ jobs=0
|
||||
|
||||
# 禁用大部分警告、约定和重构建议,只保留错误和重要警告
|
||||
disable=all
|
||||
enable=error,
|
||||
enable=E,
|
||||
syntax-error,
|
||||
undefined-variable,
|
||||
used-before-assignment,
|
||||
|
||||
@@ -16,6 +16,10 @@ from langchain_core.messages import ( # noqa: F401
|
||||
HumanMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message=".*allowed_objects.*")
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
|
||||
@@ -437,15 +437,23 @@ class StreamingHandler:
|
||||
|
||||
async def _cancel_flush_task(self):
|
||||
"""
|
||||
取消当前的定时刷新任务
|
||||
停止当前的定时刷新任务。
|
||||
|
||||
停止流式输出时,刷新任务可能已经在线程池里发出了首条消息。
|
||||
这里先等待该轮刷新自然完成,确保 message_id 等返回信息能落回本地状态;
|
||||
否则最终刷新会误以为尚未发送过消息,从而再次发送一条新消息。
|
||||
"""
|
||||
if self._flush_task and not self._flush_task.done():
|
||||
self._flush_task.cancel()
|
||||
current_task = asyncio.current_task()
|
||||
if (
|
||||
self._flush_task
|
||||
and not self._flush_task.done()
|
||||
and self._flush_task is not current_task
|
||||
):
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._flush_task = None
|
||||
self._flush_task = None
|
||||
|
||||
async def _flush(self):
|
||||
"""
|
||||
|
||||
@@ -409,7 +409,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
source: str,
|
||||
message_id: str,
|
||||
chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
@@ -418,10 +418,10 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
:param chat_id: 聊天ID(频道ID)
|
||||
:return: 删除是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return None
|
||||
success = False
|
||||
for conf in self.get_configs().values():
|
||||
if channel != self._channel:
|
||||
break
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
@@ -441,7 +441,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑消息
|
||||
:param channel: 消息渠道
|
||||
@@ -454,7 +454,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return False
|
||||
return None
|
||||
for conf in self.get_configs().values():
|
||||
if source != conf.name:
|
||||
continue
|
||||
|
||||
@@ -161,9 +161,9 @@ class FeishuModule(_ModuleBase, _MessageBase[Feishu]):
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
if channel != self._channel:
|
||||
return False
|
||||
return None
|
||||
for conf in self.get_configs().values():
|
||||
if source != conf.name:
|
||||
continue
|
||||
@@ -296,7 +296,7 @@ class FeishuModule(_ModuleBase, _MessageBase[Feishu]):
|
||||
message_id: str,
|
||||
reaction_id: str,
|
||||
source: str,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return False
|
||||
|
||||
@@ -837,7 +837,10 @@ class Feishu:
|
||||
return None
|
||||
|
||||
data = getattr(response, "data", None)
|
||||
logger.info(f"_send_message 飞书回复消息成功:data={data}")
|
||||
logger.info(
|
||||
"_send_message 飞书回复消息成功:message_id=%s",
|
||||
getattr(data, "message_id", None),
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": getattr(data, "message_id", None),
|
||||
@@ -880,7 +883,10 @@ class Feishu:
|
||||
return None
|
||||
|
||||
data = getattr(response, "data", None)
|
||||
logger.info(f"_reply_message 飞书回复消息成功:data={data}")
|
||||
logger.info(
|
||||
"_reply_message 飞书回复消息成功:message_id=%s",
|
||||
getattr(data, "message_id", None),
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": getattr(data, "message_id", None),
|
||||
|
||||
@@ -527,7 +527,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
source: str,
|
||||
message_id: str,
|
||||
chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
@@ -536,10 +536,10 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
:param chat_id: 聊天ID(频道ID)
|
||||
:return: 删除是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return None
|
||||
success = False
|
||||
for conf in self.get_configs().values():
|
||||
if channel != self._channel:
|
||||
break
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
@@ -559,7 +559,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑消息
|
||||
:param channel: 消息渠道
|
||||
@@ -572,7 +572,7 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return False
|
||||
return None
|
||||
for conf in self.get_configs().values():
|
||||
if source != conf.name:
|
||||
continue
|
||||
|
||||
@@ -534,7 +534,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
source: str,
|
||||
message_id: int,
|
||||
chat_id: Optional[int] = None,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
@@ -543,10 +543,10 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
:param chat_id: 聊天ID
|
||||
:return: 删除是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return None
|
||||
success = False
|
||||
for conf in self.get_configs().values():
|
||||
if channel != self._channel:
|
||||
break
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
@@ -566,7 +566,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> bool:
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑消息
|
||||
:param channel: 消息渠道
|
||||
@@ -579,7 +579,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return False
|
||||
return None
|
||||
for conf in self.get_configs().values():
|
||||
if source != conf.name:
|
||||
continue
|
||||
|
||||
@@ -192,6 +192,63 @@ class TestAgentToolStreaming(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(handler._sent_text, "hello world")
|
||||
|
||||
def test_stop_streaming_waits_inflight_initial_flush_before_final_edit(self):
|
||||
async def _run():
|
||||
handler = StreamingHandler()
|
||||
handler._channel = MessageChannel.Feishu.value
|
||||
handler._source = "feishu-main"
|
||||
handler._user_id = "ou_user"
|
||||
handler._streaming_enabled = True
|
||||
handler.emit("hello")
|
||||
|
||||
send_started = asyncio.Event()
|
||||
allow_send_finish = asyncio.Event()
|
||||
calls = []
|
||||
|
||||
async def fake_run_in_threadpool(func, *args, **kwargs):
|
||||
calls.append((func.__name__, args, kwargs))
|
||||
if func.__name__ == "send_direct_message":
|
||||
send_started.set()
|
||||
await allow_send_finish.wait()
|
||||
return MessageResponse(
|
||||
message_id="om_stream",
|
||||
chat_id="oc_stream",
|
||||
channel=MessageChannel.Feishu,
|
||||
source="feishu-main",
|
||||
success=True,
|
||||
)
|
||||
return True
|
||||
|
||||
with patch(
|
||||
"app.agent.callback.run_in_threadpool",
|
||||
new=fake_run_in_threadpool,
|
||||
):
|
||||
# 模拟定时刷新已经开始发送首条消息,但飞书 API 尚未返回。
|
||||
handler._flush_task = asyncio.create_task(handler._flush())
|
||||
await send_started.wait()
|
||||
handler.emit(" world")
|
||||
|
||||
stop_task = asyncio.create_task(handler.stop_streaming())
|
||||
await asyncio.sleep(0)
|
||||
self.assertFalse(stop_task.done())
|
||||
|
||||
allow_send_finish.set()
|
||||
all_sent, final_text = await stop_task
|
||||
|
||||
return all_sent, final_text, calls
|
||||
|
||||
all_sent, final_text, calls = asyncio.run(_run())
|
||||
|
||||
self.assertTrue(all_sent)
|
||||
self.assertEqual(final_text, "hello world")
|
||||
self.assertEqual(
|
||||
[call[0] for call in calls],
|
||||
["send_direct_message", "edit_message", "finalize_message"],
|
||||
)
|
||||
edit_kwargs = calls[1][2]
|
||||
self.assertEqual(edit_kwargs["message_id"], "om_stream")
|
||||
self.assertEqual(edit_kwargs["text"], "hello world")
|
||||
|
||||
def test_stop_streaming_uses_generic_finalize_message(self):
|
||||
handler = StreamingHandler()
|
||||
handler._message_response = MessageResponse(
|
||||
|
||||
Reference in New Issue
Block a user