diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index 21a132c7..5bc8e803 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Any, AsyncGenerator, Mapping, Sequence from autogen_core import CancellationToken, ComponentBase -from pydantic import BaseModel +from pydantic import BaseModel, SerializeAsAny from ..messages import BaseAgentEvent, BaseChatMessage from ._task import TaskRunner @@ -13,10 +13,10 @@ from ._task import TaskRunner class Response: """A response from calling :meth:`ChatAgent.on_messages`.""" - chat_message: BaseChatMessage + chat_message: SerializeAsAny[BaseChatMessage] """A chat message produced by the agent as the response.""" - inner_messages: Sequence[BaseAgentEvent | BaseChatMessage] | None = None + inner_messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]] | None = None """Inner messages produced by the agent, they can be :class:`BaseAgentEvent` or :class:`BaseChatMessage`.""" diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py index d585be1d..b858b4a4 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py @@ -1,7 +1,7 @@ from typing import AsyncGenerator, Protocol, Sequence from autogen_core import CancellationToken -from pydantic import BaseModel +from pydantic import BaseModel, SerializeAsAny from ..messages import BaseAgentEvent, BaseChatMessage @@ -9,7 +9,7 @@ from ..messages import BaseAgentEvent, BaseChatMessage class TaskResult(BaseModel): """Result of running a task.""" - messages: Sequence[BaseAgentEvent | BaseChatMessage] + messages: Sequence[SerializeAsAny[BaseAgentEvent | BaseChatMessage]] """Messages produced by the task.""" stop_reason: str | None = None diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py index a954dd6e..a149e586 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_events.py @@ -1,7 +1,7 @@ import traceback from typing import List -from pydantic import BaseModel +from pydantic import BaseModel, SerializeAsAny from ...base import Response, TaskResult from ...messages import BaseAgentEvent, BaseChatMessage, StopMessage @@ -38,7 +38,7 @@ class SerializableException(BaseModel): class GroupChatStart(BaseModel): """A request to start a group chat.""" - messages: List[BaseChatMessage] | None = None + messages: List[SerializeAsAny[BaseChatMessage]] | None = None """An optional list of messages to start the group chat.""" output_task_messages: bool = True @@ -48,7 +48,7 @@ class GroupChatStart(BaseModel): class GroupChatAgentResponse(BaseModel): """A response published to a group chat.""" - response: Response + response: SerializeAsAny[Response] """The response from an agent.""" name: str @@ -58,7 +58,7 @@ class GroupChatAgentResponse(BaseModel): class GroupChatTeamResponse(BaseModel): """A response published to a group chat from a team.""" - result: TaskResult + result: SerializeAsAny[TaskResult] """The result from a team.""" name: str @@ -74,7 +74,7 @@ class GroupChatRequestPublish(BaseModel): class GroupChatMessage(BaseModel): """A message from a group chat.""" - message: BaseAgentEvent | BaseChatMessage + message: SerializeAsAny[BaseAgentEvent | BaseChatMessage] """The message that was published.""" diff --git a/python/packages/autogen-agentchat/tests/test_events.py b/python/packages/autogen-agentchat/tests/test_events.py new file mode 100644 index 00000000..b29f05ca --- /dev/null +++ b/python/packages/autogen-agentchat/tests/test_events.py @@ -0,0 +1,85 @@ +import json + +from autogen_agentchat.base import Response, TaskResult +from autogen_agentchat.messages import TextMessage +from autogen_agentchat.teams._group_chat._events import ( + GroupChatAgentResponse, + GroupChatMessage, + GroupChatStart, + GroupChatTeamResponse, +) + + +def test_group_chat_message_preserves_subclass_data() -> None: + """Test that GroupChatMessage preserves TextMessage subclass fields.""" + # Create a TextMessage with subclass-specific fields + text_msg = TextMessage( + content="Hello, world!", + source="TestAgent", + ) + + # Wrap in GroupChatMessage + group_msg = GroupChatMessage(message=text_msg) + + # Serialize and verify subclass fields are preserved + json_data = group_msg.model_dump_json() + parsed = json.loads(json_data) + + # The critical test: subclass fields should be preserved + assert "content" in parsed["message"], "TextMessage content field should be preserved" + assert "type" in parsed["message"], "TextMessage type field should be preserved" + assert parsed["message"]["content"] == "Hello, world!" + assert parsed["message"]["type"] == "TextMessage" + + +def test_group_chat_start_preserves_message_list_data() -> None: + """Test that GroupChatStart preserves subclass data in message lists.""" + text_msg1 = TextMessage(content="First message", source="Agent1") + text_msg2 = TextMessage(content="Second message", source="Agent2") + + group_start = GroupChatStart(messages=[text_msg1, text_msg2]) + + json_data = group_start.model_dump_json() + parsed = json.loads(json_data) + + # Check both messages preserve subclass data + assert "content" in parsed["messages"][0] + assert "content" in parsed["messages"][1] + assert parsed["messages"][0]["content"] == "First message" + assert parsed["messages"][1]["content"] == "Second message" + + +def test_group_chat_agent_response_preserves_dataclass_fields() -> None: + """Test that GroupChatAgentResponse preserves data in Response dataclass fields.""" + text_msg = TextMessage(content="Response message", source="ResponseAgent") + inner_text_msg = TextMessage(content="Inner message", source="InnerAgent") + response = Response(chat_message=text_msg, inner_messages=[inner_text_msg]) + + group_response = GroupChatAgentResponse(response=response, name="TestAgent") + + json_data = group_response.model_dump_json() + parsed = json.loads(json_data) + + # Verify dataclass field preserves subclass data + assert "content" in parsed["response"]["chat_message"] + assert "type" in parsed["response"]["chat_message"] + assert parsed["response"]["chat_message"]["content"] == "Response message" + inner_msgs = parsed["response"]["inner_messages"] + assert len(inner_msgs) == 1 + assert "content" in inner_msgs[0] + assert inner_msgs[0]["content"] == "Inner message" + + +def test_group_chat_team_response_preserves_nested_data() -> None: + """Test that GroupChatTeamResponse preserves deeply nested subclass data.""" + text_msg = TextMessage(content="Nested message", source="NestedAgent") + task_result = TaskResult(messages=[text_msg]) + + team_response = GroupChatTeamResponse(result=task_result, name="TestTeam") + + json_data = team_response.model_dump_json() + parsed = json.loads(json_data) + + # Verify deeply nested subclass data is preserved + assert "content" in parsed["result"]["messages"][0] + assert parsed["result"]["messages"][0]["content"] == "Nested message"