From d8f2eb68d4264707676424da39f5bff54705daf2 Mon Sep 17 00:00:00 2001 From: HyunjunJeon Date: Fri, 2 Jan 2026 12:34:02 +0900 Subject: [PATCH] feat(workflow): implement workflow graph system with node types and agent vertex - Added `mod.rs` for the workflow graph system, outlining the structure and usage. - Introduced `node.rs` defining various node types (Agent, Tool, Router, etc.) and their configurations. - Created `vertices/agent.rs` implementing the `AgentVertex` for LLM-based processing with tool calling capabilities. - Added `vertices/mod.rs` to organize vertex implementations. - Implemented serialization and deserialization for node configurations using Serde. - Included tests for node configurations and agent vertex functionality. --- .gitignore | 4 - AGENTS.md | 43 + CLAUDE.md | 288 + docs/plans/2025-12-31-rig-deepagents.md | 4704 ----------------- docs/plans/2026-01-01-rig-deepagents-fixes.md | 804 --- .../2026-01-01-rig-deepagents-phase5-6.md | 1526 ------ .../2026-01-01-rig-deepagents-phase7-9.md | 1307 ----- .../crates/rig-deepagents/Cargo.toml | 12 + .../crates/rig-deepagents/src/lib.rs | 2 + .../src/pregel/checkpoint/file.rs | 438 ++ .../src/pregel/checkpoint/mod.rs | 551 ++ .../rig-deepagents/src/pregel/config.rs | 273 + .../crates/rig-deepagents/src/pregel/error.rs | 243 + .../rig-deepagents/src/pregel/message.rs | 230 + .../crates/rig-deepagents/src/pregel/mod.rs | 45 + .../rig-deepagents/src/pregel/runtime.rs | 871 +++ .../crates/rig-deepagents/src/pregel/state.rs | 297 ++ .../rig-deepagents/src/pregel/vertex.rs | 562 ++ .../crates/rig-deepagents/src/workflow/mod.rs | 57 + .../rig-deepagents/src/workflow/node.rs | 513 ++ .../src/workflow/vertices/agent.rs | 354 ++ .../src/workflow/vertices/mod.rs | 11 + 22 files changed, 4790 insertions(+), 8345 deletions(-) create mode 100644 AGENTS.md create mode 100644 CLAUDE.md delete mode 100644 docs/plans/2025-12-31-rig-deepagents.md delete mode 100644 docs/plans/2026-01-01-rig-deepagents-fixes.md delete mode 100644 docs/plans/2026-01-01-rig-deepagents-phase5-6.md delete mode 100644 docs/plans/2026-01-01-rig-deepagents-phase7-9.md create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/file.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/mod.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/config.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/error.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/message.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/mod.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/runtime.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/state.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/pregel/vertex.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/workflow/mod.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/workflow/node.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/workflow/vertices/agent.rs create mode 100644 rust-research-agent/crates/rig-deepagents/src/workflow/vertices/mod.rs diff --git a/.gitignore b/.gitignore index 192e22f..c2c5b44 100644 --- a/.gitignore +++ b/.gitignore @@ -18,10 +18,6 @@ wheels/ *_api/ # AI -AGENTS.md -CLAUDE.md -GEMINI.md -QWEN.md .serena/ # Others diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..ab08d30 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,43 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- `research_agent/` contains the core Python agents, prompts, tools, and subagent utilities. +- `skills/` holds project-level skills as `SKILL.md` files (YAML frontmatter + instructions). +- `research_workspace/` is the agent’s working filesystem for generated outputs; keep it clean or example-only. +- `deep-agents-ui/` is the Next.js/React UI with source under `deep-agents-ui/src/`. +- `deepagents_sourcecode/` vendors upstream library sources for reference and comparison. +- `rust-research-agent/` is a standalone Rust tutorial agent with its own build/test flow. +- `langgraph.json` defines the LangGraph deployment entrypoint for the research agent. + +## Build, Test, and Development Commands +Use the UI commands from `deep-agents-ui/` when working on the frontend: +```bash +cd deep-agents-ui && yarn install # install deps +cd deep-agents-ui && yarn dev # run local UI +cd deep-agents-ui && yarn build # production build +cd deep-agents-ui && yarn lint # eslint checks +cd deep-agents-ui && yarn format # prettier format +``` +Python tooling is configured in `pyproject.toml` (ruff + mypy): +```bash +uv run ruff format . +uv run ruff check . +uv run mypy . +``` + +## Coding Style & Naming Conventions +- Python: follow ruff defaults and Google-style docstrings (see `pyproject.toml`); prefer `snake_case` modules and functions. +- TypeScript/React: keep `PascalCase` for components, `camelCase` for hooks/utilities; rely on ESLint + Prettier (Tailwind plugin). +- Skill definitions: keep one skill per directory with a `SKILL.md` entrypoint and clear, task-focused naming. + +## Testing Guidelines +- There are no repository-wide tests for `research_agent/` yet; add `pytest` tests when introducing new logic. +- Subprojects have their own suites: see `deepagents_sourcecode/libs/*/Makefile` and `rust-research-agent/README.md` for `make test` or `cargo test`. + +## Commit & Pull Request Guidelines +- Git history uses short, descriptive messages in English or Korean with no enforced prefix; keep summaries concise and imperative. +- For PRs, include: a brief summary, testing notes (or “not run”), linked issues, and UI screenshots for frontend changes. + +## Configuration & Secrets +- Copy `env.example` to `.env` for API keys; never commit secrets. +- UI-only keys can be set via `NEXT_PUBLIC_LANGSMITH_API_KEY` in `deep-agents-ui/`. diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..153ba18 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,288 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +A multi-agent research system demonstrating **FileSystem-based Context Engineering** using LangChain's DeepAgents framework. The system includes: +- **Python DeepAgents**: LangChain-based multi-agent orchestration with web research capabilities +- **Rust `rig-deepagents`**: A port/reimagining using the Rig framework with Pregel-inspired graph execution + +The system enables agents to conduct web research, delegate tasks to sub-agents, and generate comprehensive reports with persistent filesystem state. + +## Development Commands + +### Python Backend + +```bash +# Install dependencies (uses uv package manager) +uv sync + +# Start LangGraph development server (API on localhost:2024) +langgraph dev + +# Linting and formatting +ruff check research_agent/ +ruff format research_agent/ + +# Type checking +mypy research_agent/ +``` + +### Frontend UI (deep-agents-ui/) + +```bash +cd deep-agents-ui +yarn install +yarn dev # Dev server on localhost:3000 +yarn build # Production build +yarn lint # ESLint +yarn format # Prettier +``` + +### Interactive Notebook Development + +```bash +# Open Jupyter for interactive agent testing +jupyter notebook DeepAgent_research.ipynb +``` + +The `research_agent/utils.py` module provides Rich-formatted display helpers for notebooks: +- `format_messages(messages)` - Renders messages with colored panels (Human=blue, AI=green, Tool=yellow) +- `show_prompt(text, title)` - Displays prompts with XML/header syntax highlighting + +### Rust `rig-deepagents` Crate + +```bash +cd rust-research-agent/crates/rig-deepagents + +# Run all tests (159 tests) +cargo test + +# Run tests for a specific module +cargo test pregel:: # Pregel runtime tests +cargo test workflow:: # Workflow node tests +cargo test middleware:: # Middleware tests + +# Linting (strict, treats warnings as errors) +cargo clippy -- -D warnings + +# Build with optional features +cargo build --features checkpointer-sqlite +cargo build --features checkpointer-redis +cargo build --features checkpointer-postgres +``` + +### Running the Full Stack + +1. Start backend: `langgraph dev` (port 2024) +2. Start frontend: `cd deep-agents-ui && yarn dev` (port 3000) +3. Configure UI with Deployment URL (`http://127.0.0.1:2024`) and Assistant ID (`research`) + +## Required Environment Variables + +Copy `env.example` to `.env`: +- `OPENAI_API_KEY` - For gpt-4.1 model +- `TAVILY_API_KEY` - For web search functionality +- `LANGSMITH_API_KEY` - Optional, format `lsv2_pt_...` for tracing +- `LANGSMITH_TRACING` / `LANGSMITH_PROJECT` - Optional tracing config + +## Architecture + +### Multi-SubAgent System + +The system uses a three-tier agent hierarchy with two distinct SubAgent types: + +``` +Main Orchestrator Agent (agent.py) + │ + ├── FilesystemBackend (../research_workspace) + │ └── Persistent state via virtual filesystem + │ + └── SubAgents + ├── researcher (CompiledSubAgent) ─ Autonomous DeepAgent + │ └── "Breadth-first, then depth" research pattern + ├── explorer (Simple SubAgent) ─ Fast read-only exploration + └── synthesizer (Simple SubAgent) ─ Research result integration +``` + +**CompiledSubAgent vs Simple SubAgent:** + +| Type | Definition | Execution | Use Case | +|------|------------|-----------|----------| +| CompiledSubAgent | `{"runnable": CompiledStateGraph}` | Multi-turn autonomous | Complex research with self-planning | +| Simple SubAgent | `{"system_prompt": str}` | Single response | Quick tasks, file ops | + +### Core Components + +**`research_agent/agent.py`** - Orchestrator configuration: +- LLM: `ChatOpenAI(model="gpt-4.1", temperature=0.0)` +- Creates researcher via `get_researcher_subagent()` (CompiledSubAgent) +- Defines `explorer_agent`, `synthesizer_agent` (Simple SubAgents) +- Assembles `ALL_SUBAGENTS = [researcher_subagent, *SIMPLE_SUBAGENTS]` + +**`research_agent/researcher/`** - Autonomous researcher module: +- `agent.py`: `create_researcher_agent()` factory and `get_researcher_subagent()` wrapper +- `prompts.py`: `AUTONOMOUS_RESEARCHER_INSTRUCTIONS` with three-phase workflow (Exploratory → Directed → Synthesis) + +**Backend Factory Pattern** - The `backend_factory(rt: ToolRuntime)` function demonstrates the recommended pattern: +```python +CompositeBackend( + default=StateBackend(rt), # In-memory state (temporary files) + routes={"/": fs_backend} # Route "/" paths to FilesystemBackend +) +``` +This enables routing: paths starting with "/" go to persistent local filesystem (`research_workspace/`), others use ephemeral state. + +**`research_agent/prompts.py`** - Prompt templates: +- `RESEARCH_WORKFLOW_INSTRUCTIONS` - Main workflow (plan → save → delegate → synthesize → write → verify) +- `SUBAGENT_DELEGATION_INSTRUCTIONS` - When to parallelize (comparisons) vs single agent (overviews) +- `EXPLORER_INSTRUCTIONS` - Fast read-only exploration with filesystem tools +- `SYNTHESIZER_INSTRUCTIONS` - Multi-source integration with confidence levels + +**`research_agent/tools.py`** - Research tools: +- `tavily_search(query, max_results, topic)` - Searches web, fetches full page content, converts to markdown +- `think_tool(reflection)` - Explicit reflection step for deliberate research + +**`langgraph.json`** - Deployment config pointing to `./research_agent/agent.py:agent` + +### Context Engineering Pattern + +The filesystem acts as long-term memory: +1. Agent reads/writes files in virtual `research_workspace/` +2. Structured outputs: reports, TODOs, request files +3. Middleware auto-injects filesystem and sub-agent tools +4. Automatic context summarization for token efficiency + +### DeepAgents Auto-Injected Tools + +The `create_deep_agent()` function automatically adds these tools via middleware: +- **TodoListMiddleware**: `write_todos` - Task planning and progress tracking +- **FilesystemMiddleware**: `ls`, `read_file`, `write_file`, `edit_file`, `glob`, `grep` - File operations +- **SubAgentMiddleware**: `task` - Delegate work to sub-agents +- **SkillsMiddleware**: Progressive skill disclosure via `skills/` directory + +Custom tools (`tavily_search`, `think_tool`) are added explicitly in `agent.py`. + +### Skills System + +Project-level skills are located in `PROJECT_ROOT/skills/`: +- `academic-search/` - arXiv paper search with structured output +- `data-synthesis/` - Multi-source data integration and analysis +- `report-writing/` - Structured report generation +- `skill-creator/` - Meta-skill for creating new skills + +Each skill has a `SKILL.md` file with YAML frontmatter (name, description) and detailed instructions. The SkillsMiddleware uses Progressive Disclosure: only skill metadata is injected into the system prompt at session start; full skill content is read on-demand when needed. + +### Research Workflow + +**Orchestrator workflow:** +``` +Plan → Save Request → Delegate to Sub-agents → Synthesize → Write Report → Verify +``` + +**Autonomous Researcher workflow (breadth-first, then depth):** +``` +Phase 1: Exploratory Search (1-2 searches) → Identify directions +Phase 2: Directed Research (1-2 searches per direction) → Deep dive +Phase 3: Synthesis → Combine findings with source agreement analysis +``` + +Sub-agents operate with token budgets (5-6 max searches) and explicit reflection loops (Search → think_tool → Decide → Repeat). + +## Rust `rig-deepagents` Architecture + +The Rust crate provides a Pregel-inspired graph execution runtime for agent workflows. + +### Module Structure + +``` +rust-research-agent/crates/rig-deepagents/src/ +├── lib.rs # Library entry point and re-exports +├── pregel/ # Pregel Runtime (graph execution engine) +│ ├── runtime.rs # Superstep orchestration, workflow timeout, retry policies +│ ├── vertex.rs # Vertex trait and compute context +│ ├── message.rs # Inter-vertex message passing +│ ├── config.rs # PregelConfig, RetryPolicy +│ ├── checkpoint/ # Fault tolerance via checkpointing +│ │ ├── mod.rs # Checkpointer trait and factory +│ │ └── file.rs # FileCheckpointer implementation +│ └── state.rs # WorkflowState trait, UnitState +├── workflow/ # Workflow Builder DSL +│ ├── node.rs # NodeKind (Agent, Tool, Router, SubAgent, FanOut/FanIn) +│ └── mod.rs # WorkflowGraph builder API +├── middleware/ # AgentMiddleware trait and MiddlewareStack +├── backends/ # Backend trait (Memory, Filesystem, Composite) +├── llm/ # LLMProvider abstraction (OpenAI, Anthropic) +└── tools/ # Tool implementations (read_file, write_file, grep, etc.) +``` + +### Pregel Execution Model + +The runtime executes workflows using synchronized supersteps: + +``` +┌─────────────────────────────────────────────────────────────┐ +│ PregelRuntime │ +│ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ +│ │Superstep│→ │Superstep│→ │Superstep│→ ... │ +│ │ 0 │ │ 1 │ │ 2 │ │ +│ └─────────┘ └─────────┘ └─────────┘ │ +│ │ │ │ │ +│ ▼ ▼ ▼ │ +│ Per-Superstep: Deliver → Compute → Collect → Route │ +└─────────────────────────────────────────────────────────────┘ +``` + +- **Vertex**: Computation unit with `compute()` method (Agent, Tool, Router) +- **Message**: Communication between vertices across supersteps +- **Checkpointing**: Fault tolerance via periodic state snapshots +- **Retry Policy**: Exponential backoff with configurable max retries + +### Key Types + +| Type | Purpose | +|------|---------| +| `PregelRuntime` | Executes workflow graph with state S and message M | +| `Vertex` | Trait for computation nodes | +| `WorkflowState` | Trait for workflow state (must be serializable) | +| `PregelConfig` | Runtime configuration (max supersteps, parallelism, timeout) | +| `Checkpointer` | Trait for state persistence (Memory, File, SQLite, Redis, Postgres) | + +### Design Documents + +- `docs/plans/2026-01-02-rig-deepagents-pregel-design.md` - Comprehensive Pregel runtime design +- `docs/plans/2026-01-02-rig-deepagents-implementation-tasks.md` - Implementation task breakdown + +## Key Files for Understanding the System + +**Python DeepAgents:** +1. `research_agent/agent.py` - Orchestrator creation and SubAgent assembly +2. `research_agent/researcher/agent.py` - Autonomous researcher factory (CompiledSubAgent pattern) +3. `research_agent/researcher/prompts.py` - Three-phase autonomous workflow +4. `research_agent/prompts.py` - Orchestrator and Simple SubAgent prompts +5. `research_agent/tools.py` - Tool implementations +6. `research_agent/skills/middleware.py` - SkillsMiddleware with progressive disclosure + +**Rust rig-deepagents:** +7. `rust-research-agent/crates/rig-deepagents/src/pregel/runtime.rs` - Pregel execution engine +8. `rust-research-agent/crates/rig-deepagents/src/pregel/vertex.rs` - Vertex abstraction +9. `rust-research-agent/crates/rig-deepagents/src/workflow/node.rs` - Node type definitions +10. `rust-research-agent/crates/rig-deepagents/src/llm/provider.rs` - LLMProvider trait + +**Documentation:** +11. `DeepAgents_Technical_Guide.md` - Python DeepAgents reference (Korean) +12. `docs/plans/2026-01-02-rig-deepagents-pregel-design.md` - Rust Pregel design + +## Tech Stack + +- **Python 3.13**: deepagents, langchain-openai, langgraph-cli, tavily-python +- **Rust**: rig-core 0.27, tokio, serde, async-trait, thiserror +- **Frontend**: Next.js 16, React, TailwindCSS, Radix UI +- **Package managers**: uv (Python), Yarn (Node.js), Cargo (Rust) + +## External Resources + +- [LangChain DeepAgent Docs](https://docs.langchain.com/oss/python/deepagents/overview) +- [LangGraph CLI Docs](https://docs.langchain.com/langsmith/cli#configuration-file) +- [DeepAgent UI](https://github.com/langchain-ai/deep-agents-ui) diff --git a/docs/plans/2025-12-31-rig-deepagents.md b/docs/plans/2025-12-31-rig-deepagents.md deleted file mode 100644 index a493c5c..0000000 --- a/docs/plans/2025-12-31-rig-deepagents.md +++ /dev/null @@ -1,4704 +0,0 @@ -# Rig DeepAgents Implementation Plan (Enhanced v2) - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** LangChain DeepAgents의 **전체 기능** (`create_deep_agent` 패리티)을 Rust/Rig 프레임워크로 구현하여 Python 대비 속도 이점을 **E2E 벤치마크**로 입증한다. - -**Architecture:** -- LangChain의 AgentMiddleware 패턴을 Rust 트레이트 시스템으로 포팅 -- **Rig의 Tool 트레이트 직접 통합** (Arc 사용 안 함) -- **Full Parity**: TodoList, Filesystem, SubAgent, Summarization, PatchToolCalls 미들웨어 -- **Agent Execution Loop**: LLM 호출 및 도구 실행 루프 구현 -- **실제 OpenAI API E2E 테스트**로 Python vs Rust 레이턴시/처리량 비교 - -**Tech Stack:** Rust 1.75+, rig-core 0.27, rig-openai 0.27, tokio, serde, async-trait, criterion (벤치마크) - -**Python Reference:** `deepagents_sourcecode/libs/deepagents/deepagents/` (LangChain 구현) - -**이전 검증 피드백 반영:** -- ✅ `files_update` 필드를 WriteResult/EditResult에 추가 -- ✅ `grep`는 리터럴 검색 (정규식 아님) - **프롬프트도 수정** -- ✅ Rig Tool 트레이트 직접 통합 (`DynTool = Any` 제거) -- ✅ `tokio::sync::RwLock` 사용 (async 안전성) -- ✅ SubAgentMiddleware, SummarizationMiddleware, PatchToolCallsMiddleware 추가 -- ✅ criterion 기반 통계적 벤치마크 -- ✅ **HashMap import 추가** -- ✅ **FileData 중복 정의 해결** -- ✅ **Agent Execution Loop 추가** -- ✅ **FilesystemBackend, CompositeBackend 추가** -- ✅ **ToolRuntime 개념 추가** -- ✅ **각 미들웨어에 실제 도구 구현** - ---- - -## Phase 1: 프로젝트 초기화 - -### Task 1.1: Cargo 프로젝트 생성 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/Cargo.toml` -- Create: `rust-research-agent/crates/rig-deepagents/src/lib.rs` - -**Step 1: 디렉토리 구조 생성** - -```bash -mkdir -p rust-research-agent/crates/rig-deepagents/src -``` - -**Step 2: Cargo.toml 작성** - -```toml -[package] -name = "rig-deepagents" -version = "0.1.0" -edition = "2021" -description = "DeepAgents-style middleware system for Rig framework" -license = "MIT" - -[dependencies] -rig-core = { version = "0.27", features = ["derive"] } -tokio = { version = "1", features = ["full", "sync"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" -async-trait = "0.1" -thiserror = "2" -anyhow = "1" -tracing = "0.1" -chrono = { version = "0.4", features = ["serde"] } -glob = "0.3" -uuid = { version = "1", features = ["v4"] } - -[dev-dependencies] -rig-openai = "0.27" -tokio-test = "0.4" -dotenv = "0.15" -criterion = { version = "0.5", features = ["async_tokio"] } - -[[bench]] -name = "middleware_benchmark" -harness = false -``` - -**Step 3: lib.rs 기본 구조 작성** - -```rust -//! rig-deepagents: DeepAgents-style middleware system for Rig -//! -//! LangChain DeepAgents의 핵심 패턴을 Rust로 구현합니다. -//! - AgentMiddleware 트레이트: 도구 자동 주입, 프롬프트 수정 -//! - Backend 트레이트: 파일시스템 추상화 -//! - MiddlewareStack: 미들웨어 조합 및 실행 -//! - AgentExecutor: LLM 호출 및 도구 실행 루프 - -pub mod error; -pub mod state; -pub mod backends; -pub mod middleware; -pub mod runtime; -pub mod executor; -pub mod tools; - -pub use error::{BackendError, MiddlewareError, DeepAgentError}; -pub use state::{AgentState, Message, Role, Todo, TodoStatus, FileData}; -pub use backends::{Backend, FileInfo, GrepMatch, MemoryBackend}; -pub use middleware::{AgentMiddleware, MiddlewareStack, StateUpdate}; -pub use runtime::ToolRuntime; -pub use executor::AgentExecutor; -``` - -**Step 4: 빌드 확인** - -Run: `cd rust-research-agent/crates/rig-deepagents && cargo check` -Expected: Compiling rig-deepagents... - -**Step 5: Commit** - -```bash -git add rust-research-agent/crates/rig-deepagents/ -git commit -m "feat: initialize rig-deepagents crate structure" -``` - ---- - -## Phase 2: 에러 타입 및 상태 정의 - -### Task 2.1: 에러 타입 정의 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/error.rs` - -**Python Reference:** `deepagents/backends/protocol.py` - `FileOperationError`, `WriteResult`, `EditResult` - -**Step 1: 에러 타입 구현** (FileData는 state.rs에서 정의) - -```rust -// src/error.rs -//! 에러 타입 정의 -//! -//! Python Reference: deepagents/backends/protocol.py의 FileOperationError - -use std::collections::HashMap; -use thiserror::Error; -use crate::state::FileData; - -/// 백엔드 작업 에러 -/// Python: FileOperationError literal type -#[derive(Error, Debug, Clone)] -pub enum BackendError { - #[error("File not found: {0}")] - FileNotFound(String), - - #[error("Permission denied: {0}")] - PermissionDenied(String), - - #[error("Is a directory: {0}")] - IsDirectory(String), - - #[error("Invalid path: {0}")] - InvalidPath(String), - - #[error("Path traversal not allowed: {0}")] - PathTraversal(String), - - #[error("File already exists: {0}")] - FileExists(String), - - #[error("IO error: {0}")] - Io(String), - - #[error("Pattern error: {0}")] - Pattern(String), -} - -/// 미들웨어 에러 -#[derive(Error, Debug)] -pub enum MiddlewareError { - #[error("Backend error: {0}")] - Backend(#[from] BackendError), - - #[error("Tool execution error: {0}")] - ToolExecution(String), - - #[error("State update error: {0}")] - StateUpdate(String), - - #[error("Serialization error: {0}")] - Serialization(#[from] serde_json::Error), - - #[error("SubAgent error: {0}")] - SubAgent(String), -} - -/// DeepAgent 최상위 에러 -#[derive(Error, Debug)] -pub enum DeepAgentError { - #[error("Middleware error: {0}")] - Middleware(#[from] MiddlewareError), - - #[error("Agent execution error: {0}")] - AgentExecution(String), - - #[error("Configuration error: {0}")] - Config(String), - - #[error("LLM error: {0}")] - LlmError(String), - - #[error("Tool not found: {0}")] - ToolNotFound(String), -} - -/// 쓰기 작업 결과 -/// Python: WriteResult dataclass -/// -/// **Codex 피드백 반영:** `files_update` 필드 추가 -/// - 체크포인트 백엔드: {file_path: FileData} 형태로 상태 업데이트 -/// - 외부 백엔드 (디스크/S3): None (이미 영구 저장됨) -#[derive(Debug, Clone)] -pub struct WriteResult { - pub error: Option, - pub path: Option, - /// 체크포인트 백엔드를 위한 상태 업데이트 - /// Python: files_update: dict[str, Any] | None - pub files_update: Option>, -} - -impl WriteResult { - /// 체크포인트 백엔드용 성공 결과 - pub fn success_with_update(path: &str, file_data: FileData) -> Self { - let mut files = HashMap::new(); - files.insert(path.to_string(), file_data); - Self { error: None, path: Some(path.to_string()), files_update: Some(files) } - } - - /// 외부 백엔드용 성공 결과 (files_update = None) - pub fn success_external(path: &str) -> Self { - Self { error: None, path: Some(path.to_string()), files_update: None } - } - - pub fn error(msg: &str) -> Self { - Self { error: Some(msg.to_string()), path: None, files_update: None } - } - - pub fn is_ok(&self) -> bool { - self.error.is_none() - } -} - -/// 편집 작업 결과 -/// Python: EditResult dataclass -#[derive(Debug, Clone)] -pub struct EditResult { - pub error: Option, - pub path: Option, - /// 체크포인트 백엔드를 위한 상태 업데이트 - pub files_update: Option>, - pub occurrences: Option, -} - -impl EditResult { - /// 체크포인트 백엔드용 성공 결과 - pub fn success_with_update(path: &str, file_data: FileData, occurrences: usize) -> Self { - let mut files = HashMap::new(); - files.insert(path.to_string(), file_data); - Self { - error: None, - path: Some(path.to_string()), - files_update: Some(files), - occurrences: Some(occurrences), - } - } - - /// 외부 백엔드용 성공 결과 - pub fn success_external(path: &str, occurrences: usize) -> Self { - Self { - error: None, - path: Some(path.to_string()), - files_update: None, - occurrences: Some(occurrences), - } - } - - pub fn error(msg: &str) -> Self { - Self { error: Some(msg.to_string()), path: None, files_update: None, occurrences: None } - } - - pub fn is_ok(&self) -> bool { - self.error.is_none() - } -} -``` - -**Step 2: 테스트 추가** (src/error.rs 하단) - -```rust -#[cfg(test)] -mod tests { - use super::*; - use crate::state::FileData; - - #[test] - fn test_backend_error_display() { - let err = BackendError::FileNotFound("/test.txt".to_string()); - assert!(err.to_string().contains("/test.txt")); - } - - #[test] - fn test_middleware_error_from_backend() { - let backend_err = BackendError::FileNotFound("/test.txt".to_string()); - let middleware_err: MiddlewareError = backend_err.into(); - assert!(matches!(middleware_err, MiddlewareError::Backend(_))); - } - - #[test] - fn test_write_result_success() { - let file_data = FileData::new("hello"); - let result = WriteResult::success_with_update("/test.txt", file_data); - assert!(result.is_ok()); - assert!(result.files_update.is_some()); - } - - #[test] - fn test_write_result_external() { - let result = WriteResult::success_external("/test.txt"); - assert!(result.is_ok()); - assert!(result.files_update.is_none()); - } -} -``` - -**Step 3: 테스트 실행** - -Run: `cargo test error` -Expected: PASS - -**Step 4: Commit** - -```bash -git add -A && git commit -m "feat: add error types with WriteResult and EditResult" -``` - ---- - -### Task 2.2: AgentState 정의 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/state.rs` - -**Python Reference:** `langchain/agents/middleware/types.py` - `AgentState(TypedDict)` - -**Step 1: state.rs 구현** - -```rust -// src/state.rs -//! 에이전트 상태 정의 -//! -//! Python Reference: langchain/agents/middleware/types.py의 AgentState - -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; -use std::any::Any; -use chrono::Utc; - -/// Todo 상태 -/// Python: Literal["pending", "in_progress", "completed"] -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "snake_case")] -pub enum TodoStatus { - Pending, - InProgress, - Completed, -} - -/// Todo 아이템 -/// Python: Todo(TypedDict) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Todo { - pub content: String, - pub status: TodoStatus, -} - -impl Todo { - pub fn new(content: &str) -> Self { - Self { - content: content.to_string(), - status: TodoStatus::Pending, - } - } - - pub fn with_status(content: &str, status: TodoStatus) -> Self { - Self { - content: content.to_string(), - status, - } - } -} - -/// 파일 데이터 -/// Python: FileData(TypedDict) in filesystem.py -/// -/// **Note:** 이 타입은 error.rs의 WriteResult/EditResult에서도 사용됨 -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FileData { - pub content: Vec, - pub created_at: String, - pub modified_at: String, -} - -impl FileData { - pub fn new(content: &str) -> Self { - let now = Utc::now().to_rfc3339(); - Self { - content: content.lines().map(String::from).collect(), - created_at: now.clone(), - modified_at: now, - } - } - - pub fn as_string(&self) -> String { - self.content.join("\n") - } - - pub fn update(&mut self, new_content: &str) { - self.content = new_content.lines().map(String::from).collect(); - self.modified_at = Utc::now().to_rfc3339(); - } - - pub fn line_count(&self) -> usize { - self.content.len() - } -} - -/// 메시지 역할 -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, - Tool, -} - -/// 도구 호출 정보 -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCall { - pub id: String, - pub name: String, - pub arguments: serde_json::Value, -} - -/// 메시지 -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Message { - pub role: Role, - pub content: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_call_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tool_calls: Option>, -} - -impl Message { - pub fn user(content: &str) -> Self { - Self { role: Role::User, content: content.to_string(), tool_call_id: None, tool_calls: None } - } - - pub fn assistant(content: &str) -> Self { - Self { role: Role::Assistant, content: content.to_string(), tool_call_id: None, tool_calls: None } - } - - pub fn assistant_with_tool_calls(content: &str, tool_calls: Vec) -> Self { - Self { - role: Role::Assistant, - content: content.to_string(), - tool_call_id: None, - tool_calls: Some(tool_calls) - } - } - - pub fn system(content: &str) -> Self { - Self { role: Role::System, content: content.to_string(), tool_call_id: None, tool_calls: None } - } - - pub fn tool(content: &str, tool_call_id: &str) -> Self { - Self { - role: Role::Tool, - content: content.to_string(), - tool_call_id: Some(tool_call_id.to_string()), - tool_calls: None, - } - } - - /// 이 메시지에 dangling tool call이 있는지 확인 - pub fn has_tool_calls(&self) -> bool { - self.tool_calls.as_ref().map_or(false, |tc| !tc.is_empty()) - } -} - -/// 에이전트 상태 -/// Python: AgentState(TypedDict) + FilesystemState + PlanningState -#[derive(Debug, Clone, Default)] -pub struct AgentState { - /// 메시지 히스토리 - pub messages: Vec, - - /// Todo 리스트 (TodoListMiddleware) - pub todos: Vec, - - /// 가상 파일 시스템 (FilesystemMiddleware) - pub files: HashMap, - - /// 구조화된 응답 - pub structured_response: Option, - - /// 확장 데이터 (미들웨어별 커스텀 상태) - extensions: HashMap>, -} - -impl AgentState { - pub fn new() -> Self { - Self::default() - } - - /// 초기 메시지로 상태 생성 - pub fn with_messages(messages: Vec) -> Self { - Self { - messages, - ..Default::default() - } - } - - /// 확장 데이터 설정 - pub fn set_extension(&mut self, key: &str, value: T) { - self.extensions.insert(key.to_string(), Box::new(value)); - } - - /// 확장 데이터 조회 - pub fn get_extension(&self, key: &str) -> Option<&T> { - self.extensions.get(key).and_then(|v| v.downcast_ref::()) - } - - /// 마지막 사용자 메시지 가져오기 - pub fn last_user_message(&self) -> Option<&Message> { - self.messages.iter().rev().find(|m| m.role == Role::User) - } - - /// 마지막 어시스턴트 메시지 가져오기 - pub fn last_assistant_message(&self) -> Option<&Message> { - self.messages.iter().rev().find(|m| m.role == Role::Assistant) - } - - /// 메시지 추가 - pub fn add_message(&mut self, message: Message) { - self.messages.push(message); - } - - /// 메시지 수 반환 - pub fn message_count(&self) -> usize { - self.messages.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_todo_status_serialization() { - let status = TodoStatus::InProgress; - let json = serde_json::to_string(&status).unwrap(); - assert_eq!(json, "\"in_progress\""); - } - - #[test] - fn test_agent_state_default() { - let state = AgentState::new(); - assert!(state.messages.is_empty()); - assert!(state.todos.is_empty()); - assert!(state.files.is_empty()); - } - - #[test] - fn test_file_data_creation() { - let file = FileData::new("hello\nworld"); - assert_eq!(file.content, vec!["hello", "world"]); - assert!(!file.created_at.is_empty()); - assert_eq!(file.line_count(), 2); - } - - #[test] - fn test_message_with_tool_calls() { - let tool_call = ToolCall { - id: "call_123".to_string(), - name: "read_file".to_string(), - arguments: serde_json::json!({"path": "/test.txt"}), - }; - let msg = Message::assistant_with_tool_calls("", vec![tool_call]); - assert!(msg.has_tool_calls()); - } - - #[test] - fn test_agent_state_with_messages() { - let state = AgentState::with_messages(vec![Message::user("Hello")]); - assert_eq!(state.message_count(), 1); - assert!(state.last_user_message().is_some()); - } -} -``` - -**Step 2: 테스트 실행** - -Run: `cargo test state` -Expected: PASS - -**Step 3: Commit** - -```bash -git add -A && git commit -m "feat: add AgentState with Todo, FileData, and Message types" -``` - ---- - -## Phase 3: Backend 트레이트 - -### Task 3.1: Backend 프로토콜 정의 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/backends/mod.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/backends/protocol.rs` - -**Python Reference:** `deepagents/backends/protocol.py` - `BackendProtocol(ABC)` - -**Step 1: 디렉토리 생성** - -```bash -mkdir -p rust-research-agent/crates/rig-deepagents/src/backends -``` - -**Step 2: protocol.rs 구현** - -```rust -// src/backends/protocol.rs -//! Backend 프로토콜 정의 -//! -//! Python Reference: deepagents/backends/protocol.py의 BackendProtocol - -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use crate::error::{BackendError, WriteResult, EditResult}; - -/// 파일 정보 -/// Python: FileInfo(TypedDict) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct FileInfo { - pub path: String, - pub is_dir: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub size: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub modified_at: Option, -} - -impl FileInfo { - pub fn file(path: &str, size: u64) -> Self { - Self { path: path.to_string(), is_dir: false, size: Some(size), modified_at: None } - } - - pub fn file_with_time(path: &str, size: u64, modified_at: &str) -> Self { - Self { - path: path.to_string(), - is_dir: false, - size: Some(size), - modified_at: Some(modified_at.to_string()), - } - } - - pub fn dir(path: &str) -> Self { - Self { path: path.to_string(), is_dir: true, size: None, modified_at: None } - } -} - -/// Grep 검색 결과 -/// Python: GrepMatch(TypedDict) -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct GrepMatch { - pub path: String, - pub line: usize, - pub text: String, -} - -impl GrepMatch { - pub fn new(path: &str, line: usize, text: &str) -> Self { - Self { path: path.to_string(), line, text: text.to_string() } - } -} - -/// Backend 프로토콜 -/// Python: BackendProtocol(ABC) -/// -/// 모든 백엔드 구현체가 준수해야 하는 인터페이스입니다. -/// 파일시스템 추상화를 제공하여 인메모리, 디스크, 클라우드 등 다양한 저장소 지원. -#[async_trait] -pub trait Backend: Send + Sync { - /// 디렉토리 내용 나열 - /// Python: ls_info(path: str) -> list[FileInfo] - async fn ls(&self, path: &str) -> Result, BackendError>; - - /// 파일 읽기 (페이지네이션 지원) - /// Python: read(file_path: str, offset: int, limit: int) -> str - /// - /// Returns: 라인 번호 포함된 포맷 (cat -n 스타일) - async fn read(&self, path: &str, offset: usize, limit: usize) -> Result; - - /// 파일 쓰기 (새 파일 생성) - /// Python: write(file_path: str, content: str) -> WriteResult - async fn write(&self, path: &str, content: &str) -> Result; - - /// 파일 편집 (문자열 교체) - /// Python: edit(file_path: str, old_string: str, new_string: str, replace_all: bool) -> EditResult - async fn edit( - &self, - path: &str, - old_string: &str, - new_string: &str, - replace_all: bool - ) -> Result; - - /// Glob 패턴 검색 - /// Python: glob_info(pattern: str, path: str) -> list[FileInfo] - async fn glob(&self, pattern: &str, path: &str) -> Result, BackendError>; - - /// 텍스트 검색 (리터럴 문자열) - /// Python: grep_raw(pattern: str, path: str | None, glob: str | None) -> list[GrepMatch] - /// - /// **Important:** pattern은 리터럴 문자열입니다 (정규식 아님!) - async fn grep( - &self, - pattern: &str, - path: Option<&str>, - glob_filter: Option<&str>, - ) -> Result, BackendError>; - - /// 파일 존재 여부 확인 - async fn exists(&self, path: &str) -> Result; - - /// 파일 삭제 - async fn delete(&self, path: &str) -> Result<(), BackendError>; -} -``` - -**Step 3: mod.rs 생성** - -```rust -// src/backends/mod.rs -//! 백엔드 모듈 -//! -//! 파일시스템 추상화를 제공합니다. - -pub mod protocol; -pub mod memory; -pub mod filesystem; -pub mod composite; - -pub use protocol::{Backend, FileInfo, GrepMatch}; -pub use memory::MemoryBackend; -pub use filesystem::FilesystemBackend; -pub use composite::CompositeBackend; -``` - -**Step 4: Commit** - -```bash -git add -A && git commit -m "feat: add Backend trait protocol" -``` - ---- - -### Task 3.2: MemoryBackend 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/backends/memory.rs` - -**Python Reference:** `deepagents/backends/state.py` - `StateBackend` - -**Step 1: memory.rs 구현** - -```rust -// src/backends/memory.rs -//! 인메모리 백엔드 구현 -//! -//! Python Reference: deepagents/backends/state.py의 StateBackend -//! -//! **Codex 피드백 반영:** -//! - `tokio::sync::RwLock` 사용 (async 안전성) -//! - `grep`는 리터럴 검색 (정규식 아님) - -use async_trait::async_trait; -use std::collections::{HashMap, HashSet}; -use tokio::sync::RwLock; -use glob::Pattern; - -use super::protocol::{Backend, FileInfo, GrepMatch}; -use crate::error::{BackendError, WriteResult, EditResult}; -use crate::state::FileData; - -/// 인메모리 백엔드 -/// Python: StateBackend - 상태에 파일 저장 -/// -/// **Note:** tokio::sync::RwLock을 사용하여 async 컨텍스트에서 안전하게 동작 -pub struct MemoryBackend { - files: RwLock>, -} - -impl MemoryBackend { - pub fn new() -> Self { - Self { - files: RwLock::new(HashMap::new()), - } - } - - /// 기존 파일로 초기화 - pub fn with_files(files: HashMap) -> Self { - Self { - files: RwLock::new(files), - } - } - - /// 경로 정규화 및 검증 - fn validate_path(path: &str) -> Result { - if path.contains("..") || path.starts_with("~") { - return Err(BackendError::PathTraversal(path.to_string())); - } - - let normalized = if path.starts_with('/') { - path.to_string() - } else { - format!("/{}", path) - }; - - Ok(normalized) - } - - /// 라인 번호 포맷팅 - fn format_with_line_numbers(content: &str, offset: usize) -> String { - content - .lines() - .enumerate() - .map(|(i, line)| format!("{}\t{}", offset + i + 1, line)) - .collect::>() - .join("\n") - } - - /// 상위 디렉토리 생성 (가상) - fn ensure_parent_dirs(_path: &str) { - // 인메모리 백엔드에서는 디렉토리 자동 생성 불필요 - } -} - -impl Default for MemoryBackend { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl Backend for MemoryBackend { - async fn ls(&self, path: &str) -> Result, BackendError> { - let path = Self::validate_path(path)?; - let files = self.files.read().await; - - let prefix = if path == "/" { "" } else { &path }; - let mut results = Vec::new(); - let mut dirs_seen = HashSet::new(); - - for (file_path, data) in files.iter() { - if file_path.starts_with(prefix) || prefix.is_empty() { - let relative = file_path.strip_prefix(prefix).unwrap_or(file_path); - let relative = relative.trim_start_matches('/'); - - if let Some(slash_pos) = relative.find('/') { - // 서브디렉토리 - let dir_name = &relative[..slash_pos]; - let dir_path = format!("{}/{}", path.trim_end_matches('/'), dir_name); - if dirs_seen.insert(dir_path.clone()) { - results.push(FileInfo::dir(&format!("{}/", dir_path))); - } - } else if !relative.is_empty() { - // 파일 - let size = data.content.iter().map(|s| s.len()).sum::() as u64; - results.push(FileInfo::file_with_time( - file_path, - size, - &data.modified_at, - )); - } - } - } - - results.sort_by(|a, b| a.path.cmp(&b.path)); - Ok(results) - } - - async fn read(&self, path: &str, offset: usize, limit: usize) -> Result { - let path = Self::validate_path(path)?; - let files = self.files.read().await; - - let file = files.get(&path).ok_or_else(|| BackendError::FileNotFound(path.clone()))?; - - let lines: Vec<_> = file.content.iter() - .skip(offset) - .take(limit) - .cloned() - .collect(); - - let content = lines.join("\n"); - Ok(Self::format_with_line_numbers(&content, offset)) - } - - async fn write(&self, path: &str, content: &str) -> Result { - let path = Self::validate_path(path)?; - let mut files = self.files.write().await; - - // 이미 존재하면 에러 - if files.contains_key(&path) { - return Ok(WriteResult::error(&format!( - "Cannot write to {} because it already exists. Read and then make an edit, or write to a new path.", - path - ))); - } - - let file_data = FileData::new(content); - files.insert(path.clone(), file_data.clone()); - - // 체크포인트 백엔드이므로 files_update 포함 - Ok(WriteResult::success_with_update(&path, file_data)) - } - - async fn edit( - &self, - path: &str, - old_string: &str, - new_string: &str, - replace_all: bool - ) -> Result { - let path = Self::validate_path(path)?; - let mut files = self.files.write().await; - - let file = files.get_mut(&path).ok_or_else(|| BackendError::FileNotFound(path.clone()))?; - - let content = file.as_string(); - let occurrences = content.matches(old_string).count(); - - if occurrences == 0 { - return Ok(EditResult::error(&format!("String '{}' not found in file", old_string))); - } - - if !replace_all && occurrences > 1 { - return Ok(EditResult::error(&format!( - "String '{}' found {} times. Use replace_all=true or provide more context.", - old_string, occurrences - ))); - } - - let new_content = if replace_all { - content.replace(old_string, new_string) - } else { - content.replacen(old_string, new_string, 1) - }; - - file.update(&new_content); - let updated_file = file.clone(); - let actual_occurrences = if replace_all { occurrences } else { 1 }; - - // 체크포인트 백엔드이므로 files_update 포함 - Ok(EditResult::success_with_update(&path, updated_file, actual_occurrences)) - } - - async fn glob(&self, pattern: &str, base_path: &str) -> Result, BackendError> { - let _base = Self::validate_path(base_path)?; - let files = self.files.read().await; - - let glob_pattern = Pattern::new(pattern) - .map_err(|e| BackendError::Pattern(e.to_string()))?; - - let mut results = Vec::new(); - for (file_path, data) in files.iter() { - let match_path = file_path.trim_start_matches('/'); - if glob_pattern.matches(match_path) { - let size = data.content.iter().map(|s| s.len()).sum::() as u64; - results.push(FileInfo::file_with_time( - file_path, - size, - &data.modified_at, - )); - } - } - - results.sort_by(|a, b| a.path.cmp(&b.path)); - Ok(results) - } - - /// 리터럴 텍스트 검색 - /// - /// **Codex 피드백 반영:** 정규식이 아닌 리터럴 문자열 검색 - /// Python: grep_raw의 docstring - "검색할 리터럴 문자열 (정규식 아님)" - async fn grep( - &self, - pattern: &str, - path: Option<&str>, - glob_filter: Option<&str>, - ) -> Result, BackendError> { - let files = self.files.read().await; - - let glob_pattern = glob_filter.map(|g| Pattern::new(g)).transpose() - .map_err(|e| BackendError::Pattern(e.to_string()))?; - - let mut results = Vec::new(); - - for (file_path, data) in files.iter() { - // Path filter - if let Some(p) = path { - if !file_path.starts_with(p) { - continue; - } - } - - // Glob filter - if let Some(ref gp) = glob_pattern { - let match_path = file_path.trim_start_matches('/'); - if !gp.matches(match_path) { - continue; - } - } - - // 리터럴 검색 (정규식 아님) - for (line_num, line) in data.content.iter().enumerate() { - if line.contains(pattern) { - results.push(GrepMatch::new(file_path, line_num + 1, line)); - } - } - } - - Ok(results) - } - - async fn exists(&self, path: &str) -> Result { - let path = Self::validate_path(path)?; - let files = self.files.read().await; - Ok(files.contains_key(&path)) - } - - async fn delete(&self, path: &str) -> Result<(), BackendError> { - let path = Self::validate_path(path)?; - let mut files = self.files.write().await; - - if files.remove(&path).is_none() { - return Err(BackendError::FileNotFound(path)); - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_memory_backend_write_and_read() { - let backend = MemoryBackend::new(); - - // Write - let result = backend.write("/test.txt", "Hello, World!").await.unwrap(); - assert!(result.is_ok()); - assert!(result.files_update.is_some()); - - // Read - let content = backend.read("/test.txt", 0, 100).await.unwrap(); - assert!(content.contains("Hello, World!")); - } - - #[tokio::test] - async fn test_memory_backend_write_existing_file() { - let backend = MemoryBackend::new(); - backend.write("/test.txt", "content").await.unwrap(); - - // 두 번째 쓰기는 에러 - let result = backend.write("/test.txt", "new content").await.unwrap(); - assert!(!result.is_ok()); - assert!(result.error.unwrap().contains("already exists")); - } - - #[tokio::test] - async fn test_memory_backend_edit() { - let backend = MemoryBackend::new(); - backend.write("/test.txt", "foo bar foo").await.unwrap(); - - // Edit single - let result = backend.edit("/test.txt", "foo", "baz", false).await.unwrap(); - assert!(!result.is_ok()); // 2번 발견되어 에러 - - // Edit all - let result = backend.edit("/test.txt", "foo", "baz", true).await.unwrap(); - assert!(result.is_ok()); - assert_eq!(result.occurrences, Some(2)); - - let content = backend.read("/test.txt", 0, 100).await.unwrap(); - assert!(content.contains("baz bar baz")); - } - - #[tokio::test] - async fn test_memory_backend_ls() { - let backend = MemoryBackend::new(); - backend.write("/dir/file1.txt", "content1").await.unwrap(); - backend.write("/dir/file2.txt", "content2").await.unwrap(); - - let files = backend.ls("/dir").await.unwrap(); - assert_eq!(files.len(), 2); - } - - #[tokio::test] - async fn test_memory_backend_glob() { - let backend = MemoryBackend::new(); - backend.write("/src/main.rs", "fn main()").await.unwrap(); - backend.write("/src/lib.rs", "pub mod").await.unwrap(); - backend.write("/test.txt", "test").await.unwrap(); - - let files = backend.glob("**/*.rs", "/").await.unwrap(); - assert_eq!(files.len(), 2); - } - - #[tokio::test] - async fn test_memory_backend_grep_literal() { - let backend = MemoryBackend::new(); - backend.write("/test.rs", "fn main() {\n println!(\"hello\");\n}").await.unwrap(); - - // 리터럴 검색 - 정규식 메타문자가 리터럴로 처리됨 - let matches = backend.grep("()", None, None).await.unwrap(); - assert!(!matches.is_empty()); // "()" 를 리터럴로 찾음 - } - - #[tokio::test] - async fn test_memory_backend_delete() { - let backend = MemoryBackend::new(); - backend.write("/test.txt", "content").await.unwrap(); - - assert!(backend.exists("/test.txt").await.unwrap()); - backend.delete("/test.txt").await.unwrap(); - assert!(!backend.exists("/test.txt").await.unwrap()); - } -} -``` - -**Step 2: 테스트 실행** - -Run: `cargo test memory` -Expected: PASS - -**Step 3: Commit** - -```bash -git add -A && git commit -m "feat: implement MemoryBackend with tokio RwLock" -``` - ---- - -### Task 3.3: FilesystemBackend 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/backends/filesystem.rs` - -**Python Reference:** `deepagents/backends/filesystem.py` - -**Step 1: filesystem.rs 구현** - -```rust -// src/backends/filesystem.rs -//! 실제 파일시스템 백엔드 구현 -//! -//! Python Reference: deepagents/backends/filesystem.py - -use async_trait::async_trait; -use std::path::{Path, PathBuf}; -use tokio::fs; -use glob::Pattern; -use chrono::{DateTime, Utc}; - -use super::protocol::{Backend, FileInfo, GrepMatch}; -use crate::error::{BackendError, WriteResult, EditResult}; - -/// 파일시스템 백엔드 -/// Python: FilesystemBackend -/// -/// 실제 파일시스템에서 직접 파일을 읽고 씁니다. -pub struct FilesystemBackend { - /// 루트 디렉토리 - root: PathBuf, - /// 가상 모드 - 모든 경로를 루트 내부로 제한 - virtual_mode: bool, -} - -impl FilesystemBackend { - pub fn new(root: impl AsRef) -> Self { - Self { - root: root.as_ref().to_path_buf(), - virtual_mode: true, - } - } - - pub fn with_virtual_mode(root: impl AsRef, virtual_mode: bool) -> Self { - Self { - root: root.as_ref().to_path_buf(), - virtual_mode, - } - } - - /// 경로 검증 및 해결 - fn resolve_path(&self, path: &str) -> Result { - if self.virtual_mode { - // 경로 탐색 방지 - if path.contains("..") || path.starts_with("~") { - return Err(BackendError::PathTraversal(path.to_string())); - } - - let clean_path = path.trim_start_matches('/'); - let resolved = self.root.join(clean_path).canonicalize() - .unwrap_or_else(|_| self.root.join(clean_path)); - - // 루트 외부 접근 방지 - if !resolved.starts_with(&self.root) { - return Err(BackendError::PathTraversal(path.to_string())); - } - - Ok(resolved) - } else { - Ok(PathBuf::from(path)) - } - } - - /// 가상 경로로 변환 - fn to_virtual_path(&self, path: &Path) -> String { - if self.virtual_mode { - path.strip_prefix(&self.root) - .map(|p| format!("/{}", p.display())) - .unwrap_or_else(|_| path.display().to_string()) - } else { - path.display().to_string() - } - } - - fn format_with_line_numbers(content: &str, offset: usize) -> String { - content - .lines() - .enumerate() - .map(|(i, line)| format!("{}\t{}", offset + i + 1, line)) - .collect::>() - .join("\n") - } -} - -#[async_trait] -impl Backend for FilesystemBackend { - async fn ls(&self, path: &str) -> Result, BackendError> { - let resolved = self.resolve_path(path)?; - - if !resolved.exists() || !resolved.is_dir() { - return Ok(vec![]); - } - - let mut results = Vec::new(); - let mut entries = fs::read_dir(&resolved).await - .map_err(|e| BackendError::Io(e.to_string()))?; - - while let Some(entry) = entries.next_entry().await - .map_err(|e| BackendError::Io(e.to_string()))? - { - let path = entry.path(); - let metadata = entry.metadata().await - .map_err(|e| BackendError::Io(e.to_string()))?; - - let virt_path = self.to_virtual_path(&path); - - if metadata.is_dir() { - results.push(FileInfo::dir(&format!("{}/", virt_path))); - } else { - let modified = metadata.modified() - .ok() - .map(|t| DateTime::::from(t).to_rfc3339()); - - results.push(FileInfo { - path: virt_path, - is_dir: false, - size: Some(metadata.len()), - modified_at: modified, - }); - } - } - - results.sort_by(|a, b| a.path.cmp(&b.path)); - Ok(results) - } - - async fn read(&self, path: &str, offset: usize, limit: usize) -> Result { - let resolved = self.resolve_path(path)?; - - if !resolved.exists() || !resolved.is_file() { - return Err(BackendError::FileNotFound(path.to_string())); - } - - let content = fs::read_to_string(&resolved).await - .map_err(|e| BackendError::Io(e.to_string()))?; - - let lines: Vec<&str> = content.lines().collect(); - let start = offset.min(lines.len()); - let end = (offset + limit).min(lines.len()); - - let selected = lines[start..end].join("\n"); - Ok(Self::format_with_line_numbers(&selected, offset)) - } - - async fn write(&self, path: &str, content: &str) -> Result { - let resolved = self.resolve_path(path)?; - - if resolved.exists() { - return Ok(WriteResult::error(&format!( - "Cannot write to {} because it already exists. Read and then make an edit.", - path - ))); - } - - // 상위 디렉토리 생성 - if let Some(parent) = resolved.parent() { - fs::create_dir_all(parent).await - .map_err(|e| BackendError::Io(e.to_string()))?; - } - - fs::write(&resolved, content).await - .map_err(|e| BackendError::Io(e.to_string()))?; - - // 외부 백엔드이므로 files_update = None - Ok(WriteResult::success_external(path)) - } - - async fn edit( - &self, - path: &str, - old_string: &str, - new_string: &str, - replace_all: bool - ) -> Result { - let resolved = self.resolve_path(path)?; - - if !resolved.exists() || !resolved.is_file() { - return Err(BackendError::FileNotFound(path.to_string())); - } - - let content = fs::read_to_string(&resolved).await - .map_err(|e| BackendError::Io(e.to_string()))?; - - let occurrences = content.matches(old_string).count(); - - if occurrences == 0 { - return Ok(EditResult::error(&format!("String '{}' not found in file", old_string))); - } - - if !replace_all && occurrences > 1 { - return Ok(EditResult::error(&format!( - "String '{}' found {} times. Use replace_all=true or provide more context.", - old_string, occurrences - ))); - } - - let new_content = if replace_all { - content.replace(old_string, new_string) - } else { - content.replacen(old_string, new_string, 1) - }; - - fs::write(&resolved, &new_content).await - .map_err(|e| BackendError::Io(e.to_string()))?; - - let actual = if replace_all { occurrences } else { 1 }; - Ok(EditResult::success_external(path, actual)) - } - - async fn glob(&self, pattern: &str, base_path: &str) -> Result, BackendError> { - let resolved = self.resolve_path(base_path)?; - - if !resolved.exists() || !resolved.is_dir() { - return Ok(vec![]); - } - - let glob_pattern = Pattern::new(pattern) - .map_err(|e| BackendError::Pattern(e.to_string()))?; - - let mut results = Vec::new(); - - // 재귀적으로 파일 검색 - let walker = walkdir::WalkDir::new(&resolved); - for entry in walker.into_iter().filter_map(|e| e.ok()) { - if !entry.file_type().is_file() { - continue; - } - - let rel_path = entry.path().strip_prefix(&resolved) - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_default(); - - if glob_pattern.matches(&rel_path) { - let virt_path = self.to_virtual_path(entry.path()); - let metadata = entry.metadata() - .map_err(|e| BackendError::Io(e.to_string()))?; - - results.push(FileInfo { - path: virt_path, - is_dir: false, - size: Some(metadata.len()), - modified_at: metadata.modified() - .ok() - .map(|t| DateTime::::from(t).to_rfc3339()), - }); - } - } - - results.sort_by(|a, b| a.path.cmp(&b.path)); - Ok(results) - } - - async fn grep( - &self, - pattern: &str, - path: Option<&str>, - glob_filter: Option<&str>, - ) -> Result, BackendError> { - let search_path = path.unwrap_or("/"); - let resolved = self.resolve_path(search_path)?; - - if !resolved.exists() { - return Ok(vec![]); - } - - let glob_pattern = glob_filter.map(|g| Pattern::new(g)).transpose() - .map_err(|e| BackendError::Pattern(e.to_string()))?; - - let mut results = Vec::new(); - let walker = walkdir::WalkDir::new(&resolved); - - for entry in walker.into_iter().filter_map(|e| e.ok()) { - if !entry.file_type().is_file() { - continue; - } - - // Glob filter - if let Some(ref gp) = glob_pattern { - let name = entry.file_name().to_string_lossy(); - if !gp.matches(&name) { - continue; - } - } - - // 파일 읽기 - let content = match std::fs::read_to_string(entry.path()) { - Ok(c) => c, - Err(_) => continue, - }; - - let virt_path = self.to_virtual_path(entry.path()); - - // 리터럴 검색 - for (line_num, line) in content.lines().enumerate() { - if line.contains(pattern) { - results.push(GrepMatch::new(&virt_path, line_num + 1, line)); - } - } - } - - Ok(results) - } - - async fn exists(&self, path: &str) -> Result { - let resolved = self.resolve_path(path)?; - Ok(resolved.exists()) - } - - async fn delete(&self, path: &str) -> Result<(), BackendError> { - let resolved = self.resolve_path(path)?; - - if !resolved.exists() { - return Err(BackendError::FileNotFound(path.to_string())); - } - - fs::remove_file(&resolved).await - .map_err(|e| BackendError::Io(e.to_string()))?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tempfile::TempDir; - - #[tokio::test] - async fn test_filesystem_backend_write_and_read() { - let temp = TempDir::new().unwrap(); - let backend = FilesystemBackend::new(temp.path()); - - let result = backend.write("/test.txt", "Hello").await.unwrap(); - assert!(result.is_ok()); - assert!(result.files_update.is_none()); // 외부 백엔드 - - let content = backend.read("/test.txt", 0, 100).await.unwrap(); - assert!(content.contains("Hello")); - } - - #[tokio::test] - async fn test_filesystem_backend_path_traversal() { - let temp = TempDir::new().unwrap(); - let backend = FilesystemBackend::new(temp.path()); - - let result = backend.read("/../etc/passwd", 0, 100).await; - assert!(result.is_err()); - } -} -``` - -**Step 2: Cargo.toml에 walkdir 추가** - -```toml -# Cargo.toml [dependencies]에 추가 -walkdir = "2" -tempfile = "3" -``` - -**Step 3: 테스트 실행** - -Run: `cargo test filesystem` -Expected: PASS - -**Step 4: Commit** - -```bash -git add -A && git commit -m "feat: implement FilesystemBackend with virtual mode" -``` - ---- - -### Task 3.4: CompositeBackend 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/backends/composite.rs` - -**Python Reference:** `deepagents/backends/composite.py` - -**Step 1: composite.rs 구현** - -```rust -// src/backends/composite.rs -//! 복합 백엔드 - 경로 기반 라우팅 -//! -//! Python Reference: deepagents/backends/composite.py - -use async_trait::async_trait; -use std::sync::Arc; - -use super::protocol::{Backend, FileInfo, GrepMatch}; -use crate::error::{BackendError, WriteResult, EditResult}; - -/// 라우트 설정 -pub struct Route { - pub prefix: String, - pub backend: Arc, -} - -/// 복합 백엔드 -/// Python: CompositeBackend -/// -/// 경로 접두사를 기반으로 요청을 다른 백엔드로 라우팅합니다. -pub struct CompositeBackend { - default: Arc, - routes: Vec, -} - -impl CompositeBackend { - pub fn new(default: Arc) -> Self { - Self { - default, - routes: Vec::new(), - } - } - - /// 라우트 추가 (빌더 패턴) - pub fn with_route(mut self, prefix: &str, backend: Arc) -> Self { - // 길이 순으로 정렬 (가장 긴 것 먼저) - let route = Route { - prefix: prefix.to_string(), - backend, - }; - self.routes.push(route); - self.routes.sort_by(|a, b| b.prefix.len().cmp(&a.prefix.len())); - self - } - - /// 경로에 맞는 백엔드와 변환된 경로 반환 - fn get_backend_and_path(&self, path: &str) -> (Arc, String) { - for route in &self.routes { - if path.starts_with(&route.prefix) { - let suffix = &path[route.prefix.len()..]; - let stripped = if suffix.is_empty() || suffix == "/" { - "/".to_string() - } else { - format!("/{}", suffix.trim_start_matches('/')) - }; - return (route.backend.clone(), stripped); - } - } - (self.default.clone(), path.to_string()) - } - - /// 결과 경로에 접두사 복원 - fn restore_prefix(&self, path: &str, original_path: &str) -> String { - for route in &self.routes { - if original_path.starts_with(&route.prefix) { - let prefix = route.prefix.trim_end_matches('/'); - return format!("{}{}", prefix, path); - } - } - path.to_string() - } -} - -#[async_trait] -impl Backend for CompositeBackend { - async fn ls(&self, path: &str) -> Result, BackendError> { - // 루트 경로면 모든 백엔드에서 수집 - if path == "/" { - let mut results = self.default.ls("/").await?; - - // 라우트된 디렉토리 추가 - for route in &self.routes { - results.push(FileInfo::dir(&route.prefix)); - } - - results.sort_by(|a, b| a.path.cmp(&b.path)); - return Ok(results); - } - - let (backend, stripped) = self.get_backend_and_path(path); - let mut results = backend.ls(&stripped).await?; - - // 경로 복원 - for info in &mut results { - info.path = self.restore_prefix(&info.path, path); - } - - Ok(results) - } - - async fn read(&self, path: &str, offset: usize, limit: usize) -> Result { - let (backend, stripped) = self.get_backend_and_path(path); - backend.read(&stripped, offset, limit).await - } - - async fn write(&self, path: &str, content: &str) -> Result { - let (backend, stripped) = self.get_backend_and_path(path); - let mut result = backend.write(&stripped, content).await?; - - if result.path.is_some() { - result.path = Some(path.to_string()); - } - - Ok(result) - } - - async fn edit( - &self, - path: &str, - old_string: &str, - new_string: &str, - replace_all: bool - ) -> Result { - let (backend, stripped) = self.get_backend_and_path(path); - let mut result = backend.edit(&stripped, old_string, new_string, replace_all).await?; - - if result.path.is_some() { - result.path = Some(path.to_string()); - } - - Ok(result) - } - - async fn glob(&self, pattern: &str, base_path: &str) -> Result, BackendError> { - let (backend, stripped) = self.get_backend_and_path(base_path); - let mut results = backend.glob(pattern, &stripped).await?; - - for info in &mut results { - info.path = self.restore_prefix(&info.path, base_path); - } - - Ok(results) - } - - async fn grep( - &self, - pattern: &str, - path: Option<&str>, - glob_filter: Option<&str>, - ) -> Result, BackendError> { - let search_path = path.unwrap_or("/"); - - // 특정 경로가 라우트에 매칭되면 해당 백엔드만 검색 - for route in &self.routes { - if search_path.starts_with(route.prefix.trim_end_matches('/')) { - let stripped = &search_path[route.prefix.len() - 1..]; - let search = if stripped.is_empty() { "/" } else { stripped }; - - let mut results = route.backend.grep(pattern, Some(search), glob_filter).await?; - - for m in &mut results { - m.path = self.restore_prefix(&m.path, search_path); - } - - return Ok(results); - } - } - - // 전체 검색 - let mut all_results = self.default.grep(pattern, path, glob_filter).await?; - - for route in &self.routes { - let mut route_results = route.backend.grep(pattern, Some("/"), glob_filter).await?; - for m in &mut route_results { - let prefix = route.prefix.trim_end_matches('/'); - m.path = format!("{}{}", prefix, m.path); - } - all_results.extend(route_results); - } - - Ok(all_results) - } - - async fn exists(&self, path: &str) -> Result { - let (backend, stripped) = self.get_backend_and_path(path); - backend.exists(&stripped).await - } - - async fn delete(&self, path: &str) -> Result<(), BackendError> { - let (backend, stripped) = self.get_backend_and_path(path); - backend.delete(&stripped).await - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - - #[tokio::test] - async fn test_composite_backend_routing() { - let default = Arc::new(MemoryBackend::new()); - let memories = Arc::new(MemoryBackend::new()); - - let composite = CompositeBackend::new(default.clone()) - .with_route("/memories/", memories.clone()); - - // 메모리 백엔드에 파일 쓰기 - composite.write("/memories/notes.txt", "my notes").await.unwrap(); - - // 읽기 - let content = composite.read("/memories/notes.txt", 0, 100).await.unwrap(); - assert!(content.contains("my notes")); - - // 기본 백엔드에 파일 쓰기 - composite.write("/other.txt", "other content").await.unwrap(); - - // 루트 ls - let files = composite.ls("/").await.unwrap(); - assert!(files.iter().any(|f| f.path.contains("memories"))); - assert!(files.iter().any(|f| f.path.contains("other"))); - } - - #[tokio::test] - async fn test_composite_backend_grep() { - let default = Arc::new(MemoryBackend::new()); - let docs = Arc::new(MemoryBackend::new()); - - let composite = CompositeBackend::new(default.clone()) - .with_route("/docs/", docs.clone()); - - composite.write("/docs/readme.txt", "hello world").await.unwrap(); - composite.write("/other.txt", "hello there").await.unwrap(); - - // 전체 검색 - let matches = composite.grep("hello", None, None).await.unwrap(); - assert_eq!(matches.len(), 2); - } -} -``` - -**Step 2: 테스트 실행** - -Run: `cargo test composite` -Expected: PASS - -**Step 3: Commit** - -```bash -git add -A && git commit -m "feat: implement CompositeBackend with route-based dispatch" -``` - ---- - -## Phase 4: ToolRuntime 및 AgentMiddleware 트레이트 - -### Task 4.1: ToolRuntime 정의 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/runtime.rs` - -**Python Reference:** `langchain/tools.py` - `ToolRuntime` - -**Step 1: runtime.rs 구현** - -```rust -// src/runtime.rs -//! 도구 실행 런타임 -//! -//! Python Reference: langchain/tools.py의 ToolRuntime -//! -//! 도구 실행 시 필요한 컨텍스트를 제공합니다. - -use std::sync::Arc; -use crate::state::AgentState; -use crate::backends::Backend; - -/// 도구 실행 런타임 -/// Python: ToolRuntime -/// -/// 도구가 실행될 때 필요한 컨텍스트를 제공합니다: -/// - 현재 에이전트 상태 -/// - 백엔드 접근 -/// - 도구 호출 ID -pub struct ToolRuntime { - /// 현재 에이전트 상태 (읽기 전용 스냅샷) - state: AgentState, - /// 백엔드 (파일 시스템 접근) - backend: Arc, - /// 현재 도구 호출 ID - tool_call_id: Option, - /// 추가 설정 - config: RuntimeConfig, -} - -/// 런타임 설정 -#[derive(Debug, Clone, Default)] -pub struct RuntimeConfig { - /// 디버그 모드 - pub debug: bool, - /// 최대 재귀 깊이 (SubAgent용) - pub max_recursion: usize, - /// 현재 재귀 깊이 - pub current_recursion: usize, -} - -impl RuntimeConfig { - pub fn new() -> Self { - Self { - debug: false, - max_recursion: 10, - current_recursion: 0, - } - } -} - -impl ToolRuntime { - pub fn new(state: AgentState, backend: Arc) -> Self { - Self { - state, - backend, - tool_call_id: None, - config: RuntimeConfig::new(), - } - } - - pub fn with_tool_call_id(mut self, id: &str) -> Self { - self.tool_call_id = Some(id.to_string()); - self - } - - pub fn with_config(mut self, config: RuntimeConfig) -> Self { - self.config = config; - self - } - - /// 현재 상태 참조 - pub fn state(&self) -> &AgentState { - &self.state - } - - /// 백엔드 참조 - pub fn backend(&self) -> &Arc { - &self.backend - } - - /// 도구 호출 ID - pub fn tool_call_id(&self) -> Option<&str> { - self.tool_call_id.as_deref() - } - - /// 설정 참조 - pub fn config(&self) -> &RuntimeConfig { - &self.config - } - - /// 재귀 깊이 증가한 새 런타임 생성 - pub fn with_increased_recursion(&self) -> Self { - let mut new_config = self.config.clone(); - new_config.current_recursion += 1; - - Self { - state: self.state.clone(), - backend: self.backend.clone(), - tool_call_id: None, - config: new_config, - } - } - - /// 재귀 한도 초과 확인 - pub fn is_recursion_limit_exceeded(&self) -> bool { - self.config.current_recursion >= self.config.max_recursion - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - - #[test] - fn test_tool_runtime_creation() { - let state = AgentState::new(); - let backend = Arc::new(MemoryBackend::new()); - - let runtime = ToolRuntime::new(state, backend) - .with_tool_call_id("call_123"); - - assert_eq!(runtime.tool_call_id(), Some("call_123")); - } - - #[test] - fn test_recursion_limit() { - let state = AgentState::new(); - let backend = Arc::new(MemoryBackend::new()); - - let mut runtime = ToolRuntime::new(state, backend); - - for _ in 0..10 { - runtime = runtime.with_increased_recursion(); - } - - assert!(runtime.is_recursion_limit_exceeded()); - } -} -``` - -**Step 2: Commit** - -```bash -git add -A && git commit -m "feat: add ToolRuntime for tool execution context" -``` - ---- - -### Task 4.2: AgentMiddleware 트레이트 정의 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/mod.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/traits.rs` - -**Python Reference:** `langchain/agents/middleware/types.py` - -**Step 1: 디렉토리 생성** - -```bash -mkdir -p rust-research-agent/crates/rig-deepagents/src/middleware -``` - -**Step 2: traits.rs 구현** - -```rust -// src/middleware/traits.rs -//! AgentMiddleware 트레이트 정의 -//! -//! Python Reference: langchain/agents/middleware/types.py - -use async_trait::async_trait; -use std::sync::Arc; -use crate::state::{AgentState, Message, Todo, FileData}; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; - -/// 상태 업데이트 커맨드 -/// Python: langgraph.types.Command -#[derive(Debug, Clone)] -pub enum StateUpdate { - /// 메시지 추가 - AddMessages(Vec), - /// Todo 업데이트 - SetTodos(Vec), - /// 파일 업데이트 (None = 삭제) - UpdateFiles(std::collections::HashMap>), - /// 복합 업데이트 - Batch(Vec), -} - -/// 도구 정의 -#[derive(Debug, Clone)] -pub struct ToolDefinition { - pub name: String, - pub description: String, - pub parameters: serde_json::Value, -} - -/// 도구 인터페이스 -#[async_trait] -pub trait Tool: Send + Sync { - /// 도구 정의 반환 - fn definition(&self) -> ToolDefinition; - - /// 도구 실행 - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result; -} - -/// 동적 도구 타입 -pub type DynTool = Arc; - -/// AgentMiddleware 트레이트 -/// -/// Python Reference: AgentMiddleware(Generic[StateT, ContextT]) -/// -/// 핵심 기능: -/// - tools(): 에이전트에 자동 주입할 도구 목록 반환 -/// - modify_system_prompt(): 시스템 프롬프트 수정 (체이닝) -/// - before_agent() / after_agent(): 라이프사이클 훅 -#[async_trait] -pub trait AgentMiddleware: Send + Sync { - /// 미들웨어 이름 - fn name(&self) -> &str; - - /// 이 미들웨어가 제공하는 도구 목록 - fn tools(&self) -> Vec { - vec![] - } - - /// 시스템 프롬프트 수정 - fn modify_system_prompt(&self, prompt: String) -> String { - prompt - } - - /// 에이전트 실행 전 훅 - /// Python: before_agent(self, state, runtime) -> dict | None - async fn before_agent( - &self, - _state: &mut AgentState, - _runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { - Ok(None) - } - - /// 에이전트 실행 후 훅 - /// Python: after_agent(self, state, runtime) -> dict | None - async fn after_agent( - &self, - _state: &mut AgentState, - _runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { - Ok(None) - } -} -``` - -**Step 3: mod.rs 생성** - -```rust -// src/middleware/mod.rs -//! 미들웨어 모듈 - -pub mod traits; -pub mod stack; -pub mod todo; -pub mod filesystem; -pub mod patch_tool_calls; -pub mod summarization; -pub mod subagent; - -pub use traits::{AgentMiddleware, DynTool, Tool, ToolDefinition, StateUpdate}; -pub use stack::MiddlewareStack; -pub use todo::TodoListMiddleware; -pub use filesystem::FilesystemMiddleware; -pub use patch_tool_calls::PatchToolCallsMiddleware; -pub use summarization::SummarizationMiddleware; -pub use subagent::SubAgentMiddleware; -``` - -**Step 4: Commit** - -```bash -git add -A && git commit -m "feat: add AgentMiddleware trait and Tool interface" -``` - ---- - -### Task 4.3: MiddlewareStack 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/stack.rs` - -**Step 1: stack.rs 구현** - -```rust -// src/middleware/stack.rs -//! 미들웨어 스택 -//! -//! 여러 미들웨어를 조합하여 순차적으로 실행합니다. - -use std::sync::Arc; -use crate::state::{AgentState, Message, FileData}; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use super::traits::{AgentMiddleware, DynTool, StateUpdate}; - -/// 미들웨어 스택 -pub struct MiddlewareStack { - middlewares: Vec>, -} - -impl MiddlewareStack { - pub fn new() -> Self { - Self { middlewares: vec![] } - } - - /// 미들웨어 추가 (빌더 패턴) - pub fn add(mut self, middleware: M) -> Self { - self.middlewares.push(Arc::new(middleware)); - self - } - - /// Arc로 래핑된 미들웨어 추가 - pub fn add_arc(mut self, middleware: Arc) -> Self { - self.middlewares.push(middleware); - self - } - - /// 미들웨어 개수 - pub fn len(&self) -> usize { - self.middlewares.len() - } - - pub fn is_empty(&self) -> bool { - self.middlewares.is_empty() - } - - /// 모든 미들웨어의 도구 수집 - pub fn collect_tools(&self) -> Vec { - self.middlewares - .iter() - .flat_map(|m| m.tools()) - .collect() - } - - /// 시스템 프롬프트 빌드 (체이닝) - pub fn build_system_prompt(&self, base: &str) -> String { - self.middlewares.iter().fold( - base.to_string(), - |acc, m| m.modify_system_prompt(acc) - ) - } - - /// before_agent 훅 실행 (순차) - pub async fn before_agent( - &self, - state: &mut AgentState, - runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { - let mut updates = vec![]; - - for middleware in &self.middlewares { - if let Some(update) = middleware.before_agent(state, runtime).await? { - Self::apply_update(state, &update); - updates.push(update); - } - } - - Ok(updates) - } - - /// after_agent 훅 실행 (역순) - pub async fn after_agent( - &self, - state: &mut AgentState, - runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { - let mut updates = vec![]; - - for middleware in self.middlewares.iter().rev() { - if let Some(update) = middleware.after_agent(state, runtime).await? { - Self::apply_update(state, &update); - updates.push(update); - } - } - - Ok(updates) - } - - /// 상태 업데이트 적용 - fn apply_update(state: &mut AgentState, update: &StateUpdate) { - match update { - StateUpdate::AddMessages(msgs) => { - state.messages.extend(msgs.clone()); - } - StateUpdate::SetTodos(todos) => { - state.todos = todos.clone(); - } - StateUpdate::UpdateFiles(files) => { - for (path, data) in files { - if let Some(d) = data { - state.files.insert(path.clone(), d.clone()); - } else { - state.files.remove(path); - } - } - } - StateUpdate::Batch(updates) => { - for u in updates { - Self::apply_update(state, u); - } - } - } - } -} - -impl Default for MiddlewareStack { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - use std::sync::Arc; - - struct TestMiddleware { - name: String, - prompt_addition: String, - } - - #[async_trait::async_trait] - impl AgentMiddleware for TestMiddleware { - fn name(&self) -> &str { - &self.name - } - - fn modify_system_prompt(&self, prompt: String) -> String { - format!("{}\n{}", prompt, self.prompt_addition) - } - } - - #[test] - fn test_middleware_stack_prompt_chaining() { - let stack = MiddlewareStack::new() - .add(TestMiddleware { - name: "First".to_string(), - prompt_addition: "First addition".to_string() - }) - .add(TestMiddleware { - name: "Second".to_string(), - prompt_addition: "Second addition".to_string() - }); - - let result = stack.build_system_prompt("Base prompt"); - assert!(result.contains("Base prompt")); - assert!(result.contains("First addition")); - assert!(result.contains("Second addition")); - } - - #[tokio::test] - async fn test_middleware_stack_hooks() { - let stack = MiddlewareStack::new() - .add(TestMiddleware { - name: "Test".to_string(), - prompt_addition: "Test".to_string() - }); - - let mut state = AgentState::new(); - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(state.clone(), backend); - - let updates = stack.before_agent(&mut state, &runtime).await.unwrap(); - assert!(updates.is_empty()); // 기본 미들웨어는 None 반환 - } -} -``` - -**Step 2: Commit** - -```bash -git add -A && git commit -m "feat: implement MiddlewareStack for middleware composition" -``` - ---- - -## Phase 5: 미들웨어 및 도구 구현 - -이 Phase에서는 모든 미들웨어와 실제 도구를 구현합니다. - -### Task 5.1: TodoListMiddleware 및 write_todos 도구 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/todo.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/mod.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/write_todos.rs` - -**Python Reference:** `langchain/agents/middleware/todo.py` - -**Step 1: tools/mod.rs 생성** - -```rust -// src/tools/mod.rs -//! 도구 모듈 - -pub mod write_todos; -pub mod filesystem; - -pub use write_todos::WriteTodosTool; -pub use filesystem::{LsTool, ReadFileTool, WriteFileTool, EditFileTool, GlobTool, GrepTool}; -``` - -**Step 2: write_todos.rs 구현** - -```rust -// src/tools/write_todos.rs -//! write_todos 도구 구현 - -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; -use crate::middleware::traits::{Tool, ToolDefinition}; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use crate::state::{Todo, TodoStatus}; - -/// write_todos 도구 인자 -#[derive(Debug, Deserialize)] -pub struct WriteTodosArgs { - pub todos: Vec, -} - -#[derive(Debug, Deserialize)] -pub struct TodoInput { - pub content: String, - pub status: String, -} - -/// write_todos 도구 -pub struct WriteTodosTool; - -impl WriteTodosTool { - pub fn new() -> Self { - Self - } -} - -impl Default for WriteTodosTool { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl Tool for WriteTodosTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "write_todos".to_string(), - description: "Update the todo list to track progress on complex tasks.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "todos": { - "type": "array", - "description": "The updated todo list", - "items": { - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "Description of the task" - }, - "status": { - "type": "string", - "enum": ["pending", "in_progress", "completed"], - "description": "Current status of the task" - } - }, - "required": ["content", "status"] - } - } - }, - "required": ["todos"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let parsed: WriteTodosArgs = serde_json::from_value(args)?; - - let todos: Vec = parsed.todos.into_iter().map(|t| { - let status = match t.status.as_str() { - "in_progress" => TodoStatus::InProgress, - "completed" => TodoStatus::Completed, - _ => TodoStatus::Pending, - }; - Todo::with_status(&t.content, status) - }).collect(); - - let count = todos.len(); - let completed = todos.iter().filter(|t| t.status == TodoStatus::Completed).count(); - - Ok(format!( - "Todo list updated: {} items total, {} completed", - count, completed - )) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - use crate::state::AgentState; - use std::sync::Arc; - - #[tokio::test] - async fn test_write_todos_tool() { - let tool = WriteTodosTool::new(); - let state = AgentState::new(); - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(state, backend); - - let args = serde_json::json!({ - "todos": [ - {"content": "Task 1", "status": "pending"}, - {"content": "Task 2", "status": "completed"} - ] - }); - - let result = tool.execute(args, &runtime).await.unwrap(); - assert!(result.contains("2 items")); - assert!(result.contains("1 completed")); - } -} -``` - -**Step 3: todo.rs 미들웨어 구현** - -```rust -// src/middleware/todo.rs -//! TodoListMiddleware 구현 -//! -//! Python Reference: langchain/agents/middleware/todo.py - -use async_trait::async_trait; -use std::sync::Arc; -use crate::state::AgentState; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use crate::tools::WriteTodosTool; -use super::traits::{AgentMiddleware, DynTool, StateUpdate}; - -const TODO_SYSTEM_PROMPT: &str = r#"## `write_todos` - -You have access to the `write_todos` tool to help you manage and plan complex objectives. -Use this tool for complex objectives to ensure that you are tracking each necessary step. - -It is critical that you mark todos as completed as soon as you are done with a step. - -## Task States -- pending: Task not yet started -- in_progress: Currently working on -- completed: Task finished successfully - -## When to Use -- Complex multi-step tasks (3+ steps) -- Tasks requiring careful planning -- User explicitly requests todo list - -## When NOT to Use -- Single, straightforward tasks -- Trivial tasks (< 3 steps) -- Purely conversational requests"#; - -/// TodoListMiddleware -pub struct TodoListMiddleware { - system_prompt: String, - tool: Arc, -} - -impl TodoListMiddleware { - pub fn new() -> Self { - Self { - system_prompt: TODO_SYSTEM_PROMPT.to_string(), - tool: Arc::new(WriteTodosTool::new()), - } - } -} - -impl Default for TodoListMiddleware { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl AgentMiddleware for TodoListMiddleware { - fn name(&self) -> &str { - "TodoListMiddleware" - } - - fn tools(&self) -> Vec { - vec![self.tool.clone()] - } - - fn modify_system_prompt(&self, prompt: String) -> String { - format!("{}\n\n{}", prompt, self.system_prompt) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_todo_middleware_prompt() { - let mw = TodoListMiddleware::new(); - let prompt = mw.modify_system_prompt("Base prompt".to_string()); - - assert!(prompt.contains("write_todos")); - assert!(prompt.contains("pending")); - assert!(prompt.contains("in_progress")); - assert!(prompt.contains("completed")); - } - - #[test] - fn test_todo_middleware_tools() { - let mw = TodoListMiddleware::new(); - let tools = mw.tools(); - - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].definition().name, "write_todos"); - } -} -``` - -**Step 4: Commit** - -```bash -git add -A && git commit -m "feat: implement TodoListMiddleware with write_todos tool" -``` - ---- - -### Task 5.2: FilesystemMiddleware 및 도구들 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/filesystem.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/filesystem.rs` - -**Python Reference:** `deepagents/middleware/filesystem.py` - -**Step 1: tools/filesystem.rs 구현** (도구 6개: ls, read_file, write_file, edit_file, glob, grep) - -```rust -// src/tools/filesystem.rs -//! 파일시스템 도구 구현 - -use async_trait::async_trait; -use serde::Deserialize; -use std::sync::Arc; -use crate::middleware::traits::{Tool, ToolDefinition}; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use crate::backends::Backend; - -// ============= ls 도구 ============= - -#[derive(Debug, Deserialize)] -pub struct LsArgs { - pub path: String, -} - -pub struct LsTool { - backend: Arc, -} - -impl LsTool { - pub fn new(backend: Arc) -> Self { - Self { backend } - } -} - -#[async_trait] -impl Tool for LsTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "ls".to_string(), - description: "List directory contents. Returns files and subdirectories.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Directory path to list (e.g., '/' or '/src')" - } - }, - "required": ["path"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let parsed: LsArgs = serde_json::from_value(args)?; - - let files = self.backend.ls(&parsed.path).await - .map_err(|e| MiddlewareError::ToolExecution(e.to_string()))?; - - if files.is_empty() { - return Ok(format!("Directory '{}' is empty or does not exist", parsed.path)); - } - - let output: Vec = files.iter().map(|f| { - if f.is_dir { - format!("{} (dir)", f.path) - } else { - format!("{} ({} bytes)", f.path, f.size.unwrap_or(0)) - } - }).collect(); - - Ok(output.join("\n")) - } -} - -// ============= read_file 도구 ============= - -#[derive(Debug, Deserialize)] -pub struct ReadFileArgs { - pub path: String, - #[serde(default)] - pub offset: usize, - #[serde(default = "default_limit")] - pub limit: usize, -} - -fn default_limit() -> usize { 500 } - -pub struct ReadFileTool { - backend: Arc, -} - -impl ReadFileTool { - pub fn new(backend: Arc) -> Self { - Self { backend } - } -} - -#[async_trait] -impl Tool for ReadFileTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "read_file".to_string(), - description: "Read file contents with line numbers (cat -n format). Use pagination for large files.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "File path to read" - }, - "offset": { - "type": "integer", - "description": "Line offset to start reading from (0-based)", - "default": 0 - }, - "limit": { - "type": "integer", - "description": "Maximum number of lines to read", - "default": 500 - } - }, - "required": ["path"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let parsed: ReadFileArgs = serde_json::from_value(args)?; - - self.backend.read(&parsed.path, parsed.offset, parsed.limit).await - .map_err(|e| MiddlewareError::ToolExecution(e.to_string())) - } -} - -// ============= write_file 도구 ============= - -#[derive(Debug, Deserialize)] -pub struct WriteFileArgs { - pub path: String, - pub content: String, -} - -pub struct WriteFileTool { - backend: Arc, -} - -impl WriteFileTool { - pub fn new(backend: Arc) -> Self { - Self { backend } - } -} - -#[async_trait] -impl Tool for WriteFileTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "write_file".to_string(), - description: "Create a new file. Fails if file already exists. Use edit_file for modifications.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Path for the new file" - }, - "content": { - "type": "string", - "description": "Content to write to the file" - } - }, - "required": ["path", "content"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let parsed: WriteFileArgs = serde_json::from_value(args)?; - - let result = self.backend.write(&parsed.path, &parsed.content).await - .map_err(|e| MiddlewareError::ToolExecution(e.to_string()))?; - - if let Some(error) = result.error { - Ok(format!("Error: {}", error)) - } else { - Ok(format!("Successfully created file: {}", parsed.path)) - } - } -} - -// ============= edit_file 도구 ============= - -#[derive(Debug, Deserialize)] -pub struct EditFileArgs { - pub path: String, - pub old_string: String, - pub new_string: String, - #[serde(default)] - pub replace_all: bool, -} - -pub struct EditFileTool { - backend: Arc, -} - -impl EditFileTool { - pub fn new(backend: Arc) -> Self { - Self { backend } - } -} - -#[async_trait] -impl Tool for EditFileTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "edit_file".to_string(), - description: "Edit a file by replacing old_string with new_string. Must read file before editing.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "Path to the file to edit" - }, - "old_string": { - "type": "string", - "description": "Exact string to find and replace" - }, - "new_string": { - "type": "string", - "description": "Replacement string" - }, - "replace_all": { - "type": "boolean", - "description": "Replace all occurrences (default: false)", - "default": false - } - }, - "required": ["path", "old_string", "new_string"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let parsed: EditFileArgs = serde_json::from_value(args)?; - - let result = self.backend.edit( - &parsed.path, - &parsed.old_string, - &parsed.new_string, - parsed.replace_all - ).await.map_err(|e| MiddlewareError::ToolExecution(e.to_string()))?; - - if let Some(error) = result.error { - Ok(format!("Error: {}", error)) - } else { - Ok(format!( - "Successfully edited {}: {} occurrence(s) replaced", - parsed.path, - result.occurrences.unwrap_or(0) - )) - } - } -} - -// ============= glob 도구 ============= - -#[derive(Debug, Deserialize)] -pub struct GlobArgs { - pub pattern: String, - #[serde(default = "default_path")] - pub path: String, -} - -fn default_path() -> String { "/".to_string() } - -pub struct GlobTool { - backend: Arc, -} - -impl GlobTool { - pub fn new(backend: Arc) -> Self { - Self { backend } - } -} - -#[async_trait] -impl Tool for GlobTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "glob".to_string(), - description: "Find files matching a glob pattern (e.g., '**/*.rs', 'src/*.py').".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Glob pattern to match files" - }, - "path": { - "type": "string", - "description": "Base directory to search from", - "default": "/" - } - }, - "required": ["pattern"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let parsed: GlobArgs = serde_json::from_value(args)?; - - let files = self.backend.glob(&parsed.pattern, &parsed.path).await - .map_err(|e| MiddlewareError::ToolExecution(e.to_string()))?; - - if files.is_empty() { - return Ok(format!("No files matching pattern '{}'", parsed.pattern)); - } - - let output: Vec = files.iter() - .map(|f| f.path.clone()) - .collect(); - - Ok(format!("Found {} files:\n{}", files.len(), output.join("\n"))) - } -} - -// ============= grep 도구 ============= - -#[derive(Debug, Deserialize)] -pub struct GrepArgs { - pub pattern: String, - pub path: Option, - pub glob: Option, -} - -pub struct GrepTool { - backend: Arc, -} - -impl GrepTool { - pub fn new(backend: Arc) -> Self { - Self { backend } - } -} - -#[async_trait] -impl Tool for GrepTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "grep".to_string(), - // **수정됨**: 리터럴 검색임을 명시 - description: "Search for a literal string pattern in files. NOT regex - searches for exact text match.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Literal string to search for (NOT regex)" - }, - "path": { - "type": "string", - "description": "Directory to search in (optional)" - }, - "glob": { - "type": "string", - "description": "File pattern filter (e.g., '*.rs')" - } - }, - "required": ["pattern"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let parsed: GrepArgs = serde_json::from_value(args)?; - - let matches = self.backend.grep( - &parsed.pattern, - parsed.path.as_deref(), - parsed.glob.as_deref() - ).await.map_err(|e| MiddlewareError::ToolExecution(e.to_string()))?; - - if matches.is_empty() { - return Ok(format!("No matches found for '{}'", parsed.pattern)); - } - - let output: Vec = matches.iter() - .take(50) // 결과 제한 - .map(|m| format!("{}:{}:{}", m.path, m.line, m.text)) - .collect(); - - let suffix = if matches.len() > 50 { - format!("\n... and {} more matches", matches.len() - 50) - } else { - String::new() - }; - - Ok(format!("Found {} matches:\n{}{}", matches.len(), output.join("\n"), suffix)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - use crate::state::AgentState; - - #[tokio::test] - async fn test_ls_tool() { - let backend = Arc::new(MemoryBackend::new()); - backend.write("/test.txt", "content").await.unwrap(); - - let tool = LsTool::new(backend.clone()); - let state = AgentState::new(); - let runtime = ToolRuntime::new(state, backend); - - let result = tool.execute(serde_json::json!({"path": "/"}), &runtime).await.unwrap(); - assert!(result.contains("test.txt")); - } - - #[tokio::test] - async fn test_read_write_edit_tools() { - let backend = Arc::new(MemoryBackend::new()); - let state = AgentState::new(); - let runtime = ToolRuntime::new(state, backend.clone()); - - // Write - let write_tool = WriteFileTool::new(backend.clone()); - let result = write_tool.execute( - serde_json::json!({"path": "/test.txt", "content": "hello world"}), - &runtime - ).await.unwrap(); - assert!(result.contains("Successfully")); - - // Read - let read_tool = ReadFileTool::new(backend.clone()); - let result = read_tool.execute( - serde_json::json!({"path": "/test.txt"}), - &runtime - ).await.unwrap(); - assert!(result.contains("hello world")); - - // Edit - let edit_tool = EditFileTool::new(backend.clone()); - let result = edit_tool.execute( - serde_json::json!({ - "path": "/test.txt", - "old_string": "hello", - "new_string": "goodbye" - }), - &runtime - ).await.unwrap(); - assert!(result.contains("Successfully edited")); - } - - #[tokio::test] - async fn test_grep_literal_search() { - let backend = Arc::new(MemoryBackend::new()); - backend.write("/test.rs", "fn main() {}").await.unwrap(); - - let tool = GrepTool::new(backend.clone()); - let state = AgentState::new(); - let runtime = ToolRuntime::new(state, backend); - - // 리터럴 검색 - 괄호가 그대로 검색됨 - let result = tool.execute( - serde_json::json!({"pattern": "()"}), - &runtime - ).await.unwrap(); - assert!(result.contains("Found")); - } -} -``` - -**Step 2: filesystem.rs 미들웨어 구현** - -```rust -// src/middleware/filesystem.rs -//! FilesystemMiddleware 구현 -//! -//! Python Reference: deepagents/middleware/filesystem.py - -use async_trait::async_trait; -use std::sync::Arc; -use crate::backends::Backend; -use crate::state::AgentState; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use crate::tools::{LsTool, ReadFileTool, WriteFileTool, EditFileTool, GlobTool, GrepTool}; -use super::traits::{AgentMiddleware, DynTool, StateUpdate}; - -/// 파일시스템 도구 시스템 프롬프트 -/// **수정됨**: grep은 리터럴 검색임을 명시 -const FILESYSTEM_SYSTEM_PROMPT: &str = r#"## Filesystem Tools - -You have access to filesystem tools for managing files: - -### `ls` - List directory contents -Usage: ls(path="/dir") - -### `read_file` - Read file contents -Usage: read_file(path="/file.txt", offset=0, limit=500) -- Returns content with line numbers (cat -n format) -- Use pagination for large files - -### `write_file` - Create new file -Usage: write_file(path="/file.txt", content="...") -- Use for creating new files only -- Fails if file already exists -- Prefer edit_file for modifications - -### `edit_file` - Edit existing file -Usage: edit_file(path="/file.txt", old_string="...", new_string="...", replace_all=false) -- Must read file before editing -- old_string must be unique unless replace_all=true - -### `glob` - Pattern matching -Usage: glob(pattern="**/*.rs", path="/") -- Standard glob: *, **, ? - -### `grep` - Text search -Usage: grep(pattern="search term", path="/src", glob="*.rs") -- **IMPORTANT**: Uses LITERAL string matching (NOT regex) -- Searches for exact text occurrences"#; - -/// FilesystemMiddleware -pub struct FilesystemMiddleware { - backend: Arc, - system_prompt: String, - tools: Vec, -} - -impl FilesystemMiddleware { - pub fn new(backend: Arc) -> Self { - let tools: Vec = vec![ - Arc::new(LsTool::new(backend.clone())), - Arc::new(ReadFileTool::new(backend.clone())), - Arc::new(WriteFileTool::new(backend.clone())), - Arc::new(EditFileTool::new(backend.clone())), - Arc::new(GlobTool::new(backend.clone())), - Arc::new(GrepTool::new(backend.clone())), - ]; - - Self { - backend, - system_prompt: FILESYSTEM_SYSTEM_PROMPT.to_string(), - tools, - } - } - - /// 백엔드 참조 반환 - pub fn backend(&self) -> &Arc { - &self.backend - } -} - -#[async_trait] -impl AgentMiddleware for FilesystemMiddleware { - fn name(&self) -> &str { - "FilesystemMiddleware" - } - - fn tools(&self) -> Vec { - self.tools.clone() - } - - fn modify_system_prompt(&self, prompt: String) -> String { - format!("{}\n\n{}", prompt, self.system_prompt) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - - #[test] - fn test_filesystem_middleware_prompt() { - let backend = Arc::new(MemoryBackend::new()); - let mw = FilesystemMiddleware::new(backend); - let prompt = mw.modify_system_prompt("Base".to_string()); - - assert!(prompt.contains("ls")); - assert!(prompt.contains("read_file")); - assert!(prompt.contains("write_file")); - assert!(prompt.contains("edit_file")); - assert!(prompt.contains("LITERAL")); // grep이 리터럴임을 확인 - } - - #[test] - fn test_filesystem_middleware_tools() { - let backend = Arc::new(MemoryBackend::new()); - let mw = FilesystemMiddleware::new(backend); - let tools = mw.tools(); - - assert_eq!(tools.len(), 6); - - let names: Vec<_> = tools.iter().map(|t| t.definition().name).collect(); - assert!(names.contains(&"ls".to_string())); - assert!(names.contains(&"read_file".to_string())); - assert!(names.contains(&"write_file".to_string())); - assert!(names.contains(&"edit_file".to_string())); - assert!(names.contains(&"glob".to_string())); - assert!(names.contains(&"grep".to_string())); - } -} -``` - -**Step 3: Commit** - -```bash -git add -A && git commit -m "feat: implement FilesystemMiddleware with 6 filesystem tools" -``` - ---- - -### Task 5.3: PatchToolCallsMiddleware 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/patch_tool_calls.rs` - -**Python Reference:** `deepagents/middleware/patch_tool_calls.py` - -**Step 1: patch_tool_calls.rs 구현** - -```rust -// src/middleware/patch_tool_calls.rs -//! PatchToolCallsMiddleware - Dangling tool call 패치 -//! -//! Python Reference: deepagents/middleware/patch_tool_calls.py -//! -//! AI 메시지의 tool_calls 중 대응하는 ToolMessage가 없는 것들을 -//! 자동으로 패치하여 API 에러를 방지합니다. - -use async_trait::async_trait; -use crate::state::{AgentState, Message, Role}; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use super::traits::{AgentMiddleware, StateUpdate}; - -/// PatchToolCallsMiddleware -/// -/// Python: PatchToolCallsMiddleware -/// -/// 메시지 기록에서 dangling tool call을 패치합니다. -/// - AIMessage가 tool_calls를 가지고 있지만 -/// - 대응하는 ToolMessage가 없는 경우 -/// - 취소 메시지를 자동 삽입 -pub struct PatchToolCallsMiddleware; - -impl PatchToolCallsMiddleware { - pub fn new() -> Self { - Self - } - - /// Dangling tool call 찾아서 패치 - fn patch_dangling_calls(messages: &[Message]) -> Vec { - let mut patched = Vec::new(); - - for (i, msg) in messages.iter().enumerate() { - patched.push(msg.clone()); - - // AI 메시지에 tool_calls가 있는 경우 - if msg.role == Role::Assistant { - if let Some(tool_calls) = &msg.tool_calls { - for tc in tool_calls { - // 나머지 메시지에서 대응하는 ToolMessage 찾기 - let has_response = messages[i..].iter().any(|m| { - m.role == Role::Tool && - m.tool_call_id.as_ref() == Some(&tc.id) - }); - - if !has_response { - // Dangling - 패치 메시지 삽입 - patched.push(Message::tool( - &format!( - "Tool call {} (ID: {}) was cancelled - \ - another message arrived before completion.", - tc.name, tc.id - ), - &tc.id - )); - } - } - } - } - } - - patched - } -} - -impl Default for PatchToolCallsMiddleware { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl AgentMiddleware for PatchToolCallsMiddleware { - fn name(&self) -> &str { - "PatchToolCallsMiddleware" - } - - /// 에이전트 실행 전에 dangling tool call 패치 - async fn before_agent( - &self, - state: &mut AgentState, - _runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { - if state.messages.is_empty() { - return Ok(None); - } - - let patched = Self::patch_dangling_calls(&state.messages); - - // 메시지가 변경되었으면 업데이트 - if patched.len() != state.messages.len() { - state.messages = patched; - // Note: StateUpdate로 반환하지 않고 직접 수정 - // (Overwrite 시맨틱) - } - - Ok(None) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::state::ToolCall; - use crate::backends::MemoryBackend; - use std::sync::Arc; - - #[tokio::test] - async fn test_patch_dangling_tool_calls() { - let mw = PatchToolCallsMiddleware::new(); - - // Dangling tool call이 있는 메시지 - let mut state = AgentState::new(); - state.messages = vec![ - Message::user("Hello"), - Message::assistant_with_tool_calls( - "", - vec![ToolCall { - id: "call_123".to_string(), - name: "read_file".to_string(), - arguments: serde_json::json!({"path": "/test.txt"}), - }] - ), - // ToolMessage 없음 - dangling! - Message::user("Continue"), - ]; - - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(AgentState::new(), backend); - - mw.before_agent(&mut state, &runtime).await.unwrap(); - - // 패치 메시지가 삽입되었는지 확인 - assert!(state.messages.len() > 3); - assert!(state.messages.iter().any(|m| - m.role == Role::Tool && m.content.contains("cancelled") - )); - } - - #[tokio::test] - async fn test_no_patch_when_response_exists() { - let mw = PatchToolCallsMiddleware::new(); - - // 정상적인 tool call/response 쌍 - let mut state = AgentState::new(); - state.messages = vec![ - Message::user("Hello"), - Message::assistant_with_tool_calls( - "", - vec![ToolCall { - id: "call_123".to_string(), - name: "read_file".to_string(), - arguments: serde_json::json!({"path": "/test.txt"}), - }] - ), - Message::tool("file contents", "call_123"), // 응답 있음 - Message::assistant("Here's the file"), - ]; - - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(AgentState::new(), backend); - let original_len = state.messages.len(); - - mw.before_agent(&mut state, &runtime).await.unwrap(); - - // 패치 없어야 함 - assert_eq!(state.messages.len(), original_len); - } -} -``` - -**Step 2: Commit** - -```bash -git add -A && git commit -m "feat: implement PatchToolCallsMiddleware for dangling tool calls" -``` - ---- - -### Task 5.4: SummarizationMiddleware 구현 (간소화) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/summarization.rs` - -**Python Reference:** `langchain/agents/middleware/summarization.py` - -**Step 1: summarization.rs 구현** (토큰 기반 요약 - 간소화 버전) - -```rust -// src/middleware/summarization.rs -//! SummarizationMiddleware - 컨텍스트 요약 -//! -//! Python Reference: langchain/agents/middleware/summarization.py -//! -//! 메시지 히스토리가 너무 길어지면 자동으로 요약합니다. -//! 이 구현은 간소화된 버전으로, 토큰 카운팅 대신 메시지 수 기반입니다. - -use async_trait::async_trait; -use crate::state::{AgentState, Message, Role}; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use super::traits::{AgentMiddleware, StateUpdate}; - -/// 요약 트리거 조건 -#[derive(Debug, Clone)] -pub enum SummarizationTrigger { - /// 메시지 수 기반 - MessageCount(usize), - /// 대략적인 토큰 수 기반 (문자 수 / 4) - ApproximateTokens(usize), -} - -/// SummarizationMiddleware -/// -/// 대화 기록이 임계치에 도달하면 오래된 메시지를 요약합니다. -pub struct SummarizationMiddleware { - trigger: SummarizationTrigger, - keep_messages: usize, -} - -impl SummarizationMiddleware { - pub fn new(trigger: SummarizationTrigger, keep_messages: usize) -> Self { - Self { trigger, keep_messages } - } - - /// 메시지 수 기반 트리거로 생성 - pub fn with_message_limit(max_messages: usize, keep: usize) -> Self { - Self::new(SummarizationTrigger::MessageCount(max_messages), keep) - } - - /// 요약이 필요한지 확인 - fn needs_summarization(&self, state: &AgentState) -> bool { - match self.trigger { - SummarizationTrigger::MessageCount(max) => { - state.messages.len() > max - } - SummarizationTrigger::ApproximateTokens(max_tokens) => { - let approx_tokens: usize = state.messages.iter() - .map(|m| m.content.len() / 4) - .sum(); - approx_tokens > max_tokens - } - } - } - - /// 메시지 요약 생성 - fn create_summary(messages: &[Message]) -> String { - let mut summary_parts = Vec::new(); - - for msg in messages { - let prefix = match msg.role { - Role::User => "User", - Role::Assistant => "Assistant", - Role::System => "System", - Role::Tool => "Tool", - }; - - // 긴 메시지는 잘라서 요약 - let content = if msg.content.len() > 200 { - format!("{}...", &msg.content[..200]) - } else { - msg.content.clone() - }; - - summary_parts.push(format!("{}: {}", prefix, content)); - } - - format!( - "[Previous conversation summary]\n{}", - summary_parts.join("\n") - ) - } - - /// 요약 적용 - fn apply_summarization(&self, state: &mut AgentState) { - let total = state.messages.len(); - let keep_count = self.keep_messages.min(total); - let summarize_count = total.saturating_sub(keep_count); - - if summarize_count == 0 { - return; - } - - // 요약할 메시지들 - let to_summarize: Vec<_> = state.messages.drain(..summarize_count).collect(); - - // 요약 메시지 생성 - let summary = Self::create_summary(&to_summarize); - - // 요약 메시지를 맨 앞에 삽입 - state.messages.insert(0, Message::system(&summary)); - } -} - -impl Default for SummarizationMiddleware { - fn default() -> Self { - Self::with_message_limit(50, 10) - } -} - -#[async_trait] -impl AgentMiddleware for SummarizationMiddleware { - fn name(&self) -> &str { - "SummarizationMiddleware" - } - - async fn before_agent( - &self, - state: &mut AgentState, - _runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { - if self.needs_summarization(state) { - self.apply_summarization(state); - } - Ok(None) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - use std::sync::Arc; - - #[tokio::test] - async fn test_summarization_trigger() { - let mw = SummarizationMiddleware::with_message_limit(5, 2); - - // 6개 메시지 - 트리거됨 - let mut state = AgentState::new(); - for i in 0..6 { - state.add_message(Message::user(&format!("Message {}", i))); - } - - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(AgentState::new(), backend); - - mw.before_agent(&mut state, &runtime).await.unwrap(); - - // 요약 후: 요약 메시지 1개 + 유지된 메시지 2개 = 3개 - assert_eq!(state.messages.len(), 3); - assert!(state.messages[0].content.contains("summary")); - } - - #[tokio::test] - async fn test_no_summarization_below_limit() { - let mw = SummarizationMiddleware::with_message_limit(10, 2); - - let mut state = AgentState::new(); - for i in 0..5 { - state.add_message(Message::user(&format!("Message {}", i))); - } - - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(AgentState::new(), backend); - - mw.before_agent(&mut state, &runtime).await.unwrap(); - - // 변경 없음 - assert_eq!(state.messages.len(), 5); - } -} -``` - -**Step 2: Commit** - -```bash -git add -A && git commit -m "feat: implement SummarizationMiddleware for context management" -``` - ---- - -### Task 5.5: SubAgentMiddleware 구현 (간소화) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/middleware/subagent.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/task.rs` - -**Python Reference:** `deepagents/middleware/subagents.py` - -이 태스크는 가장 복잡한 부분으로, 핵심 기능만 구현합니다. - -**Step 1: tools/task.rs 구현** - -```rust -// src/tools/task.rs -//! task 도구 - 서브에이전트 실행 - -use async_trait::async_trait; -use serde::Deserialize; -use std::collections::HashMap; -use std::sync::Arc; -use crate::middleware::traits::{Tool, ToolDefinition}; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; - -/// SubAgent 정의 -pub struct SubAgentDef { - pub name: String, - pub description: String, - pub system_prompt: String, -} - -/// task 도구 인자 -#[derive(Debug, Deserialize)] -pub struct TaskArgs { - pub description: String, - pub subagent_type: String, -} - -/// task 도구 -pub struct TaskTool { - subagents: HashMap, - task_description: String, -} - -impl TaskTool { - pub fn new(subagents: Vec) -> Self { - let mut map = HashMap::new(); - let mut descriptions = Vec::new(); - - for sa in subagents { - descriptions.push(format!("- {}: {}", sa.name, sa.description)); - map.insert(sa.name.clone(), sa); - } - - let task_description = format!( - "Launch a subagent to handle complex, multi-step tasks.\n\n\ - Available subagent types:\n{}\n\n\ - Use subagent_type parameter to select the agent type.", - descriptions.join("\n") - ); - - Self { - subagents: map, - task_description, - } - } - - /// 범용 에이전트만 있는 기본 설정 - pub fn with_general_purpose() -> Self { - Self::new(vec![SubAgentDef { - name: "general-purpose".to_string(), - description: "General-purpose agent for complex tasks with access to all tools.".to_string(), - system_prompt: "You are a helpful assistant that completes tasks autonomously.".to_string(), - }]) - } -} - -#[async_trait] -impl Tool for TaskTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "task".to_string(), - description: self.task_description.clone(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "description": { - "type": "string", - "description": "Detailed description of the task for the subagent" - }, - "subagent_type": { - "type": "string", - "description": "Type of subagent to use" - } - }, - "required": ["description", "subagent_type"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let parsed: TaskArgs = serde_json::from_value(args)?; - - // 재귀 한도 확인 - if runtime.is_recursion_limit_exceeded() { - return Ok("Error: Maximum subagent recursion depth exceeded.".to_string()); - } - - // 서브에이전트 유효성 확인 - let subagent = match self.subagents.get(&parsed.subagent_type) { - Some(sa) => sa, - None => { - let valid_types: Vec<_> = self.subagents.keys().collect(); - return Ok(format!( - "Error: Unknown subagent type '{}'. Valid types: {:?}", - parsed.subagent_type, valid_types - )); - } - }; - - // 실제 서브에이전트 실행은 AgentExecutor에서 처리됨 - // 여기서는 메시지만 반환 - Ok(format!( - "[SubAgent '{}' would execute task: {}]\n\ - Note: Full subagent execution requires AgentExecutor integration.", - subagent.name, parsed.description - )) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - use crate::state::AgentState; - - #[tokio::test] - async fn test_task_tool_validation() { - let tool = TaskTool::with_general_purpose(); - let state = AgentState::new(); - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(state, backend); - - // 유효한 타입 - let result = tool.execute( - serde_json::json!({ - "description": "Research topic X", - "subagent_type": "general-purpose" - }), - &runtime - ).await.unwrap(); - assert!(result.contains("SubAgent")); - - // 잘못된 타입 - let result = tool.execute( - serde_json::json!({ - "description": "Task", - "subagent_type": "invalid-type" - }), - &runtime - ).await.unwrap(); - assert!(result.contains("Error")); - } -} -``` - -**Step 2: tools/mod.rs 업데이트** - -```rust -// src/tools/mod.rs (업데이트) -pub mod write_todos; -pub mod filesystem; -pub mod task; - -pub use write_todos::WriteTodosTool; -pub use filesystem::{LsTool, ReadFileTool, WriteFileTool, EditFileTool, GlobTool, GrepTool}; -pub use task::{TaskTool, SubAgentDef}; -``` - -**Step 3: subagent.rs 미들웨어 구현** - -```rust -// src/middleware/subagent.rs -//! SubAgentMiddleware - 서브에이전트 지원 -//! -//! Python Reference: deepagents/middleware/subagents.py - -use async_trait::async_trait; -use std::sync::Arc; -use crate::state::AgentState; -use crate::error::MiddlewareError; -use crate::runtime::ToolRuntime; -use crate::tools::{TaskTool, SubAgentDef}; -use super::traits::{AgentMiddleware, DynTool, StateUpdate}; - -const SUBAGENT_SYSTEM_PROMPT: &str = r#"## `task` (subagent spawner) - -You have access to the `task` tool to spawn ephemeral subagents for isolated tasks. - -When to use task tool: -- Complex multi-step tasks that can be fully delegated -- Independent tasks that can run in parallel -- Tasks requiring intensive reasoning that would bloat the main thread - -When NOT to use task tool: -- Need to verify intermediate reasoning -- Trivial tasks (few tool calls) -- Delegation doesn't reduce complexity - -Subagent lifecycle: -1. Spawn → provide clear role and instructions -2. Run → subagent works autonomously -3. Return → subagent provides single result -4. Reconcile → integrate results"#; - -/// SubAgentMiddleware -pub struct SubAgentMiddleware { - tool: Arc, - system_prompt: String, -} - -impl SubAgentMiddleware { - pub fn new(subagents: Vec) -> Self { - Self { - tool: Arc::new(TaskTool::new(subagents)), - system_prompt: SUBAGENT_SYSTEM_PROMPT.to_string(), - } - } - - /// 범용 에이전트만 포함하는 기본 설정 - pub fn with_general_purpose() -> Self { - Self { - tool: Arc::new(TaskTool::with_general_purpose()), - system_prompt: SUBAGENT_SYSTEM_PROMPT.to_string(), - } - } -} - -#[async_trait] -impl AgentMiddleware for SubAgentMiddleware { - fn name(&self) -> &str { - "SubAgentMiddleware" - } - - fn tools(&self) -> Vec { - vec![self.tool.clone()] - } - - fn modify_system_prompt(&self, prompt: String) -> String { - format!("{}\n\n{}", prompt, self.system_prompt) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_subagent_middleware() { - let mw = SubAgentMiddleware::with_general_purpose(); - - let prompt = mw.modify_system_prompt("Base".to_string()); - assert!(prompt.contains("task")); - assert!(prompt.contains("subagent")); - - let tools = mw.tools(); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].definition().name, "task"); - } -} -``` - -**Step 4: Commit** - -```bash -git add -A && git commit -m "feat: implement SubAgentMiddleware with task tool" -``` - ---- - -## Phase 6: Agent Execution Loop - -### Task 6.1: AgentExecutor 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/executor.rs` - -이것은 가장 핵심적인 부분으로, LLM 호출과 도구 실행 루프를 구현합니다. - -**Step 1: executor.rs 구현** - -```rust -// src/executor.rs -//! Agent Execution Loop -//! -//! LLM 호출 및 도구 실행을 관리하는 핵심 실행기입니다. - -use std::collections::HashMap; -use std::sync::Arc; -use crate::state::{AgentState, Message, Role, ToolCall}; -use crate::error::DeepAgentError; -use crate::middleware::{MiddlewareStack, DynTool, Tool}; -use crate::runtime::{ToolRuntime, RuntimeConfig}; -use crate::backends::Backend; - -/// LLM 응답 타입 -pub struct LlmResponse { - pub content: String, - pub tool_calls: Vec, - pub finish_reason: String, -} - -/// LLM 인터페이스 트레이트 -#[async_trait::async_trait] -pub trait LlmClient: Send + Sync { - async fn chat( - &self, - messages: &[Message], - system_prompt: &str, - tools: &[serde_json::Value], - ) -> Result; -} - -/// Agent Executor 설정 -pub struct ExecutorConfig { - pub max_iterations: usize, - pub debug: bool, -} - -impl Default for ExecutorConfig { - fn default() -> Self { - Self { - max_iterations: 50, - debug: false, - } - } -} - -/// Agent Executor -/// -/// LLM 호출 및 도구 실행 루프를 관리합니다. -pub struct AgentExecutor { - llm: Arc, - middleware: MiddlewareStack, - backend: Arc, - config: ExecutorConfig, - tools: HashMap, -} - -impl AgentExecutor { - pub fn new( - llm: Arc, - middleware: MiddlewareStack, - backend: Arc, - ) -> Self { - Self::with_config(llm, middleware, backend, ExecutorConfig::default()) - } - - pub fn with_config( - llm: Arc, - middleware: MiddlewareStack, - backend: Arc, - config: ExecutorConfig, - ) -> Self { - // 모든 미들웨어에서 도구 수집 - let tools_list = middleware.collect_tools(); - let mut tools = HashMap::new(); - for tool in tools_list { - tools.insert(tool.definition().name.clone(), tool); - } - - Self { - llm, - middleware, - backend, - config, - tools, - } - } - - /// 에이전트 실행 - pub async fn run(&self, state: &mut AgentState) -> Result { - let runtime = ToolRuntime::new(state.clone(), self.backend.clone()) - .with_config(RuntimeConfig { - debug: self.config.debug, - max_recursion: 10, - current_recursion: 0, - }); - - // before_agent 훅 실행 - self.middleware.before_agent(state, &runtime).await - .map_err(|e| DeepAgentError::Middleware(e))?; - - // 시스템 프롬프트 빌드 - let system_prompt = self.middleware.build_system_prompt( - "You are a helpful assistant with access to various tools." - ); - - // 도구 스키마 준비 - let tool_schemas: Vec<_> = self.tools.values() - .map(|t| { - let def = t.definition(); - serde_json::json!({ - "type": "function", - "function": { - "name": def.name, - "description": def.description, - "parameters": def.parameters - } - }) - }) - .collect(); - - // 실행 루프 - for iteration in 0..self.config.max_iterations { - if self.config.debug { - eprintln!("[Executor] Iteration {}", iteration); - } - - // LLM 호출 - let response = self.llm.chat( - &state.messages, - &system_prompt, - &tool_schemas, - ).await?; - - // 어시스턴트 메시지 추가 - if response.tool_calls.is_empty() { - // 도구 호출 없음 - 최종 응답 - state.add_message(Message::assistant(&response.content)); - - // after_agent 훅 실행 - self.middleware.after_agent(state, &runtime).await - .map_err(|e| DeepAgentError::Middleware(e))?; - - return Ok(response.content); - } - - // 도구 호출이 있는 경우 - state.add_message(Message::assistant_with_tool_calls( - &response.content, - response.tool_calls.clone(), - )); - - // 각 도구 실행 - for tool_call in &response.tool_calls { - let result = self.execute_tool(tool_call, &runtime).await?; - state.add_message(Message::tool(&result, &tool_call.id)); - } - } - - Err(DeepAgentError::AgentExecution( - "Maximum iterations exceeded".to_string() - )) - } - - /// 단일 도구 실행 - async fn execute_tool( - &self, - tool_call: &ToolCall, - runtime: &ToolRuntime, - ) -> Result { - let tool = self.tools.get(&tool_call.name) - .ok_or_else(|| DeepAgentError::ToolNotFound(tool_call.name.clone()))?; - - if self.config.debug { - eprintln!("[Executor] Executing tool: {} with args: {}", - tool_call.name, tool_call.arguments); - } - - tool.execute(tool_call.arguments.clone(), runtime).await - .map_err(|e| DeepAgentError::Middleware(e)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - use crate::middleware::{TodoListMiddleware, FilesystemMiddleware}; - - /// 테스트용 Mock LLM - struct MockLlm { - responses: Vec, - } - - impl MockLlm { - fn new(responses: Vec) -> Self { - Self { responses } - } - } - - #[async_trait::async_trait] - impl LlmClient for MockLlm { - async fn chat( - &self, - _messages: &[Message], - _system_prompt: &str, - _tools: &[serde_json::Value], - ) -> Result { - // 간단히 첫 번째 응답 반환 (실제로는 상태 관리 필요) - Ok(self.responses[0].clone()) - } - } - - impl Clone for LlmResponse { - fn clone(&self) -> Self { - Self { - content: self.content.clone(), - tool_calls: self.tool_calls.clone(), - finish_reason: self.finish_reason.clone(), - } - } - } - - #[tokio::test] - async fn test_executor_simple_response() { - let llm = Arc::new(MockLlm::new(vec![ - LlmResponse { - content: "Hello! I'm here to help.".to_string(), - tool_calls: vec![], - finish_reason: "stop".to_string(), - } - ])); - - let backend = Arc::new(MemoryBackend::new()); - let middleware = MiddlewareStack::new() - .add(TodoListMiddleware::new()) - .add(FilesystemMiddleware::new(backend.clone())); - - let executor = AgentExecutor::new(llm, middleware, backend); - - let mut state = AgentState::with_messages(vec![ - Message::user("Hello!"), - ]); - - let result = executor.run(&mut state).await.unwrap(); - assert_eq!(result, "Hello! I'm here to help."); - } - - #[tokio::test] - async fn test_executor_with_tool_call() { - let backend = Arc::new(MemoryBackend::new()); - - // 파일 미리 생성 - backend.write("/test.txt", "Hello from file!").await.unwrap(); - - let llm = Arc::new(MockLlm::new(vec![ - LlmResponse { - content: "".to_string(), - tool_calls: vec![ToolCall { - id: "call_1".to_string(), - name: "read_file".to_string(), - arguments: serde_json::json!({"path": "/test.txt"}), - }], - finish_reason: "tool_calls".to_string(), - } - ])); - - let middleware = MiddlewareStack::new() - .add(FilesystemMiddleware::new(backend.clone())); - - let executor = AgentExecutor::new(llm, middleware, backend); - - let mut state = AgentState::with_messages(vec![ - Message::user("Read the file"), - ]); - - // 첫 반복만 실행하면 도구 호출 후 다시 LLM 호출 필요 - // Mock이 단순하므로 에러 발생 예상 - let result = executor.run(&mut state).await; - // 도구는 실행되었을 것 - assert!(state.messages.len() > 1); - } -} -``` - -**Step 2: Commit** - -```bash -git add -A && git commit -m "feat: implement AgentExecutor with LLM and tool execution loop" -``` - ---- - -## Phase 7: OpenAI 통합 테스트 - -### Task 7.1: OpenAI LLM Client 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/tests/openai_client.rs` -- Create: `rust-research-agent/crates/rig-deepagents/tests/integration_openai.rs` - -**Step 1: openai_client.rs 구현** (테스트 헬퍼) - -```rust -// tests/openai_client.rs -//! OpenAI Client for integration tests - -use async_trait::async_trait; -use rig_deepagents::executor::{LlmClient, LlmResponse}; -use rig_deepagents::state::{Message, Role, ToolCall}; -use rig_deepagents::error::DeepAgentError; - -/// OpenAI API Client (테스트용) -pub struct OpenAiClient { - api_key: String, - model: String, -} - -impl OpenAiClient { - pub fn new(api_key: &str) -> Self { - Self { - api_key: api_key.to_string(), - model: "gpt-4o-mini".to_string(), - } - } - - pub fn with_model(api_key: &str, model: &str) -> Self { - Self { - api_key: api_key.to_string(), - model: model.to_string(), - } - } -} - -#[async_trait] -impl LlmClient for OpenAiClient { - async fn chat( - &self, - messages: &[Message], - system_prompt: &str, - tools: &[serde_json::Value], - ) -> Result { - let client = reqwest::Client::new(); - - // 메시지 변환 - let mut api_messages: Vec = vec![ - serde_json::json!({ - "role": "system", - "content": system_prompt - }) - ]; - - for msg in messages { - let role = match msg.role { - Role::User => "user", - Role::Assistant => "assistant", - Role::System => "system", - Role::Tool => "tool", - }; - - let mut api_msg = serde_json::json!({ - "role": role, - "content": msg.content - }); - - if let Some(ref id) = msg.tool_call_id { - api_msg["tool_call_id"] = serde_json::json!(id); - } - - if let Some(ref tcs) = msg.tool_calls { - let tool_calls: Vec<_> = tcs.iter().map(|tc| { - serde_json::json!({ - "id": tc.id, - "type": "function", - "function": { - "name": tc.name, - "arguments": tc.arguments.to_string() - } - }) - }).collect(); - api_msg["tool_calls"] = serde_json::json!(tool_calls); - } - - api_messages.push(api_msg); - } - - let mut body = serde_json::json!({ - "model": self.model, - "messages": api_messages, - }); - - if !tools.is_empty() { - body["tools"] = serde_json::json!(tools); - } - - let response = client - .post("https://api.openai.com/v1/chat/completions") - .header("Authorization", format!("Bearer {}", self.api_key)) - .header("Content-Type", "application/json") - .json(&body) - .send() - .await - .map_err(|e| DeepAgentError::LlmError(e.to_string()))?; - - let json: serde_json::Value = response.json().await - .map_err(|e| DeepAgentError::LlmError(e.to_string()))?; - - // 응답 파싱 - let choice = &json["choices"][0]; - let message = &choice["message"]; - - let content = message["content"].as_str().unwrap_or("").to_string(); - let finish_reason = choice["finish_reason"].as_str().unwrap_or("stop").to_string(); - - let tool_calls: Vec = message["tool_calls"] - .as_array() - .map(|arr| { - arr.iter().filter_map(|tc| { - Some(ToolCall { - id: tc["id"].as_str()?.to_string(), - name: tc["function"]["name"].as_str()?.to_string(), - arguments: serde_json::from_str( - tc["function"]["arguments"].as_str()? - ).ok()?, - }) - }).collect() - }) - .unwrap_or_default(); - - Ok(LlmResponse { - content, - tool_calls, - finish_reason, - }) - } -} -``` - -**Step 2: integration_openai.rs 구현** - -```rust -// tests/integration_openai.rs -//! OpenAI 통합 테스트 - -mod openai_client; - -use std::sync::Arc; -use std::time::Instant; -use rig_deepagents::middleware::*; -use rig_deepagents::backends::MemoryBackend; -use rig_deepagents::executor::{AgentExecutor, ExecutorConfig}; -use rig_deepagents::state::{AgentState, Message}; -use openai_client::OpenAiClient; - -fn get_openai_key() -> Option { - dotenv::dotenv().ok(); - std::env::var("OPENAI_API_KEY").ok() -} - -#[tokio::test] -#[ignore] // cargo test -- --ignored -async fn test_real_openai_simple() { - let api_key = match get_openai_key() { - Some(k) => k, - None => { - eprintln!("OPENAI_API_KEY not set, skipping test"); - return; - } - }; - - let llm = Arc::new(OpenAiClient::new(&api_key)); - let backend = Arc::new(MemoryBackend::new()); - - let middleware = MiddlewareStack::new() - .add(TodoListMiddleware::new()) - .add(FilesystemMiddleware::new(backend.clone())); - - let executor = AgentExecutor::with_config( - llm, - middleware, - backend, - ExecutorConfig { max_iterations: 10, debug: true }, - ); - - let mut state = AgentState::with_messages(vec![ - Message::user("Hello! Can you tell me what tools you have access to?"), - ]); - - let start = Instant::now(); - let result = executor.run(&mut state).await.unwrap(); - let elapsed = start.elapsed(); - - println!("Response: {}", result); - println!("Time: {:?}", elapsed); - - assert!(!result.is_empty()); -} - -#[tokio::test] -#[ignore] -async fn test_real_openai_with_tool_use() { - let api_key = match get_openai_key() { - Some(k) => k, - None => { - eprintln!("OPENAI_API_KEY not set, skipping test"); - return; - } - }; - - let llm = Arc::new(OpenAiClient::new(&api_key)); - let backend = Arc::new(MemoryBackend::new()); - - // 파일 미리 생성 - backend.write("/readme.txt", "# Welcome\nThis is a test file.").await.unwrap(); - - let middleware = MiddlewareStack::new() - .add(FilesystemMiddleware::new(backend.clone())); - - let executor = AgentExecutor::with_config( - llm, - middleware, - backend, - ExecutorConfig { max_iterations: 10, debug: true }, - ); - - let mut state = AgentState::with_messages(vec![ - Message::user("Please read the file /readme.txt and tell me what it contains."), - ]); - - let start = Instant::now(); - let result = executor.run(&mut state).await.unwrap(); - let elapsed = start.elapsed(); - - println!("Response: {}", result); - println!("Time: {:?}", elapsed); - - assert!(result.contains("Welcome") || result.contains("test")); -} - -#[tokio::test] -#[ignore] -async fn benchmark_middleware_stack() { - let backend = Arc::new(MemoryBackend::new()); - - let start = Instant::now(); - for _ in 0..1000 { - let _ = MiddlewareStack::new() - .add(TodoListMiddleware::new()) - .add(FilesystemMiddleware::new(backend.clone())) - .add(PatchToolCallsMiddleware::new()) - .add(SummarizationMiddleware::default()) - .add(SubAgentMiddleware::with_general_purpose()); - } - let elapsed = start.elapsed(); - - println!("1000 middleware stack creations: {:?}", elapsed); - println!("Average: {:?}", elapsed / 1000); - - assert!(elapsed.as_millis() < 100, "Should be < 100ms total"); -} - -#[tokio::test] -#[ignore] -async fn benchmark_memory_backend() { - let backend = MemoryBackend::new(); - - let start = Instant::now(); - for i in 0..1000 { - backend.write(&format!("/test/file{}.txt", i), &format!("Content {}", i)) - .await.unwrap(); - } - let write_time = start.elapsed(); - - let start = Instant::now(); - let files = backend.glob("**/*.txt", "/").await.unwrap(); - let glob_time = start.elapsed(); - - let start = Instant::now(); - let matches = backend.grep("Content 5", None, None).await.unwrap(); - let grep_time = start.elapsed(); - - println!("Write 1000 files: {:?}", write_time); - println!("Glob search: {:?} ({} files)", glob_time, files.len()); - println!("Grep search: {:?} ({} matches)", grep_time, matches.len()); - - assert_eq!(files.len(), 1000); -} -``` - -**Step 3: Cargo.toml에 reqwest 추가** - -```toml -# Cargo.toml [dev-dependencies]에 추가 -reqwest = { version = "0.11", features = ["json"] } -``` - -**Step 4: Commit** - -```bash -git add -A && git commit -m "test: add OpenAI integration tests" -``` - ---- - -## Phase 8: Criterion 벤치마크 - -### Task 8.1: Criterion 벤치마크 구현 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/benches/middleware_benchmark.rs` - -**Step 1: 벤치마크 구현** - -```rust -// benches/middleware_benchmark.rs -//! Criterion 벤치마크 - -use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; -use std::sync::Arc; -use tokio::runtime::Runtime; - -use rig_deepagents::middleware::*; -use rig_deepagents::backends::MemoryBackend; -use rig_deepagents::state::AgentState; -use rig_deepagents::runtime::ToolRuntime; - -fn bench_middleware_stack_creation(c: &mut Criterion) { - c.bench_function("middleware_stack_creation", |b| { - let backend = Arc::new(MemoryBackend::new()); - - b.iter(|| { - black_box( - MiddlewareStack::new() - .add(TodoListMiddleware::new()) - .add(FilesystemMiddleware::new(backend.clone())) - .add(PatchToolCallsMiddleware::new()) - .add(SummarizationMiddleware::default()) - .add(SubAgentMiddleware::with_general_purpose()) - ) - }); - }); -} - -fn bench_prompt_building(c: &mut Criterion) { - let backend = Arc::new(MemoryBackend::new()); - let stack = MiddlewareStack::new() - .add(TodoListMiddleware::new()) - .add(FilesystemMiddleware::new(backend.clone())) - .add(PatchToolCallsMiddleware::new()) - .add(SubAgentMiddleware::with_general_purpose()); - - c.bench_function("prompt_building", |b| { - b.iter(|| { - black_box(stack.build_system_prompt("You are a helpful assistant.")) - }); - }); -} - -fn bench_memory_backend_write(c: &mut Criterion) { - let rt = Runtime::new().unwrap(); - - c.bench_function("memory_backend_write_100", |b| { - b.iter(|| { - let backend = MemoryBackend::new(); - rt.block_on(async { - for i in 0..100 { - black_box( - backend.write(&format!("/file{}.txt", i), "content").await.unwrap() - ); - } - }); - }); - }); -} - -fn bench_memory_backend_glob(c: &mut Criterion) { - let rt = Runtime::new().unwrap(); - let backend = MemoryBackend::new(); - - // Setup: 1000 files - rt.block_on(async { - for i in 0..1000 { - backend.write(&format!("/test/file{}.txt", i), "content").await.unwrap(); - } - }); - - c.bench_function("memory_backend_glob_1000", |b| { - b.iter(|| { - rt.block_on(async { - black_box(backend.glob("**/*.txt", "/").await.unwrap()) - }) - }); - }); -} - -fn bench_memory_backend_grep(c: &mut Criterion) { - let rt = Runtime::new().unwrap(); - let backend = MemoryBackend::new(); - - // Setup - rt.block_on(async { - for i in 0..100 { - backend.write( - &format!("/src/file{}.rs", i), - &format!("fn main() {{\n println!(\"hello {}\");\n}}", i) - ).await.unwrap(); - } - }); - - c.bench_function("memory_backend_grep_100", |b| { - b.iter(|| { - rt.block_on(async { - black_box(backend.grep("println", None, Some("*.rs")).await.unwrap()) - }) - }); - }); -} - -fn bench_tool_collection(c: &mut Criterion) { - let backend = Arc::new(MemoryBackend::new()); - let stack = MiddlewareStack::new() - .add(TodoListMiddleware::new()) - .add(FilesystemMiddleware::new(backend.clone())) - .add(SubAgentMiddleware::with_general_purpose()); - - c.bench_function("tool_collection", |b| { - b.iter(|| { - black_box(stack.collect_tools()) - }); - }); -} - -criterion_group!( - benches, - bench_middleware_stack_creation, - bench_prompt_building, - bench_memory_backend_write, - bench_memory_backend_glob, - bench_memory_backend_grep, - bench_tool_collection, -); - -criterion_main!(benches); -``` - -**Step 2: 벤치마크 실행** - -Run: `cargo bench` -Expected: Benchmark results with statistical analysis - -**Step 3: Commit** - -```bash -git add -A && git commit -m "bench: add Criterion benchmarks for performance validation" -``` - ---- - -## Summary - -이 계획은 LangChain DeepAgents의 **전체 기능**을 Rust/Rig로 구현합니다: - -### Phases Overview - -| Phase | 내용 | 주요 파일 | -|-------|------|----------| -| 1 | 프로젝트 초기화 | `Cargo.toml`, `lib.rs` | -| 2 | 에러 타입 및 상태 | `error.rs`, `state.rs` | -| 3 | Backend 트레이트 | `protocol.rs`, `memory.rs`, `filesystem.rs`, `composite.rs` | -| 4 | ToolRuntime & Middleware | `runtime.rs`, `traits.rs`, `stack.rs` | -| 5 | 미들웨어 구현 | `todo.rs`, `filesystem.rs`, `patch_tool_calls.rs`, `summarization.rs`, `subagent.rs` | -| 6 | Agent Executor | `executor.rs` | -| 7 | OpenAI 통합 | `integration_openai.rs` | -| 8 | Criterion 벤치마크 | `middleware_benchmark.rs` | - -### 이전 계획 대비 주요 개선사항 - -1. ✅ **누락된 미들웨어 추가**: SubAgentMiddleware, SummarizationMiddleware, PatchToolCallsMiddleware -2. ✅ **실제 도구 구현**: 각 미들웨어에 실제 동작하는 도구 포함 -3. ✅ **Agent Execution Loop**: LlmClient 트레이트 + AgentExecutor -4. ✅ **FilesystemBackend & CompositeBackend**: 경로 기반 라우팅 -5. ✅ **ToolRuntime**: 도구 실행 컨텍스트 -6. ✅ **HashMap import 수정**: error.rs -7. ✅ **FileData 통합**: state.rs에서 정의 -8. ✅ **grep 프롬프트 수정**: 리터럴 검색 명시 -9. ✅ **Criterion 벤치마크**: 실제 코드 포함 - -### Python Reference Files - -- `deepagents/graph.py` - create_deep_agent -- `langchain/agents/middleware/types.py` - AgentMiddleware -- `deepagents/backends/protocol.py` - BackendProtocol -- `deepagents/backends/composite.py` - CompositeBackend -- `deepagents/backends/filesystem.py` - FilesystemBackend -- `deepagents/middleware/filesystem.py` - FilesystemMiddleware -- `deepagents/middleware/subagents.py` - SubAgentMiddleware -- `deepagents/middleware/patch_tool_calls.py` - PatchToolCallsMiddleware -- `langchain/agents/middleware/todo.py` - TodoListMiddleware -- `langchain/agents/middleware/summarization.py` - SummarizationMiddleware diff --git a/docs/plans/2026-01-01-rig-deepagents-fixes.md b/docs/plans/2026-01-01-rig-deepagents-fixes.md deleted file mode 100644 index 67d8946..0000000 --- a/docs/plans/2026-01-01-rig-deepagents-fixes.md +++ /dev/null @@ -1,804 +0,0 @@ -# Rig-DeepAgents Code Review 수정 계획 - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** Codex 및 Code Review에서 식별된 HIGH/MEDIUM 이슈를 수정하여 Python DeepAgents와의 패리티를 확보하고 프로덕션 준비 상태를 달성한다. - -**Architecture:** -- MemoryBackend와 AgentState.files의 이중 소스 문제는 `files_update` 반환 패턴을 통해 해결 (Python 패턴 유지) -- 경로 처리는 중앙화된 `normalize_path` 유틸리티로 통일 -- CompositeBackend의 glob/write/edit 결과 집계 및 경로 복원 로직 보강 - -**Tech Stack:** Rust 1.75+, tokio, async-trait, glob crate - -**검증 피드백 출처:** -- Codex CLI (gpt-5.2-codex): 55,431 토큰 분석 -- Code Reviewer Subagent: 독립 검증 - ---- - -## Phase 1: HIGH Severity 수정 (필수) - -### Task 1.1: MemoryBackend glob() base_path 필터 적용 - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/memory.rs:191-213` -- Test: 동일 파일 하단 tests 모듈 - -**문제:** `glob()`이 `base_path`를 검증만 하고 실제 필터링에 사용하지 않아 모든 파일이 검색됨 - -**Step 1: 실패하는 테스트 작성** - -`memory.rs` tests 모듈에 추가: - -```rust -#[tokio::test] -async fn test_memory_backend_glob_respects_base_path() { - let backend = MemoryBackend::new(); - backend.write("/src/main.rs", "fn main()").await.unwrap(); - backend.write("/src/lib.rs", "pub mod").await.unwrap(); - backend.write("/docs/readme.md", "# Readme").await.unwrap(); - backend.write("/tests/test.rs", "test code").await.unwrap(); - - // /src 하위에서만 검색해야 함 - let files = backend.glob("**/*.rs", "/src").await.unwrap(); - - // /src 하위의 .rs 파일만 포함되어야 함 - assert_eq!(files.len(), 2); - assert!(files.iter().all(|f| f.path.starts_with("/src"))); - - // /tests/test.rs는 포함되면 안 됨 - assert!(!files.iter().any(|f| f.path.contains("/tests/"))); -} -``` - -**Step 2: 테스트 실패 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_memory_backend_glob_respects_base_path` -Expected: FAIL - 현재 구현은 모든 .rs 파일(4개)을 반환 - -**Step 3: glob() 구현 수정** - -`memory.rs` 191-213 라인을 다음으로 교체: - -```rust -async fn glob(&self, pattern: &str, base_path: &str) -> Result, BackendError> { - let base = Self::validate_path(base_path)?; - let files = self.files.read().await; - - let glob_pattern = Pattern::new(pattern) - .map_err(|e| BackendError::Pattern(e.to_string()))?; - - let mut results = Vec::new(); - for (file_path, data) in files.iter() { - // base_path 하위 파일만 검색 - let normalized_base = base.trim_end_matches('/'); - if !file_path.starts_with(normalized_base) { - continue; - } - // base_path와 정확히 같거나 base_path/ 로 시작해야 함 - if file_path.len() > normalized_base.len() - && !file_path[normalized_base.len()..].starts_with('/') { - continue; - } - - let match_path = file_path.trim_start_matches('/'); - if glob_pattern.matches(match_path) { - let size = data.content.iter().map(|s| s.len()).sum::() as u64; - results.push(FileInfo::file_with_time( - file_path, - size, - &data.modified_at, - )); - } - } - - results.sort_by(|a, b| a.path.cmp(&b.path)); - Ok(results) -} -``` - -**Step 4: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_memory_backend_glob_respects_base_path` -Expected: PASS - -**Step 5: 전체 테스트 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test` -Expected: 모든 테스트 PASS - ---- - -### Task 1.2: CompositeBackend files_update 키 경로 복원 - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/composite.rs:107-133` -- Test: 동일 파일 하단 tests 모듈 - -**문제:** routed backend에서 반환된 `files_update`의 키가 stripped 경로로 남아있어 상태 업데이트 시 경로 오류 발생 - -**Step 1: 실패하는 테스트 작성** - -`composite.rs` tests 모듈에 추가: - -```rust -#[tokio::test] -async fn test_composite_backend_write_files_update_path_restoration() { - let default = Arc::new(MemoryBackend::new()); - let memories = Arc::new(MemoryBackend::new()); - - let composite = CompositeBackend::new(default.clone()) - .with_route("/memories/", memories.clone()); - - // routed backend에 파일 쓰기 - let result = composite.write("/memories/notes.txt", "my notes").await.unwrap(); - - assert!(result.is_ok()); - assert_eq!(result.path, Some("/memories/notes.txt".to_string())); - - // files_update가 있다면 키도 복원되어야 함 - if let Some(files_update) = &result.files_update { - // 키가 /memories/notes.txt 여야 함 (stripped된 /notes.txt가 아님) - assert!(files_update.contains_key("/memories/notes.txt"), - "files_update key should be '/memories/notes.txt', got keys: {:?}", - files_update.keys().collect::>()); - } -} - -#[tokio::test] -async fn test_composite_backend_edit_files_update_path_restoration() { - let default = Arc::new(MemoryBackend::new()); - let memories = Arc::new(MemoryBackend::new()); - - let composite = CompositeBackend::new(default.clone()) - .with_route("/memories/", memories.clone()); - - // 먼저 파일 생성 - composite.write("/memories/notes.txt", "hello world").await.unwrap(); - - // 편집 - let result = composite.edit("/memories/notes.txt", "hello", "goodbye", false).await.unwrap(); - - assert!(result.is_ok()); - - // files_update가 있다면 키도 복원되어야 함 - if let Some(files_update) = &result.files_update { - assert!(files_update.contains_key("/memories/notes.txt"), - "files_update key should be '/memories/notes.txt', got keys: {:?}", - files_update.keys().collect::>()); - } -} -``` - -**Step 2: 테스트 실패 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_composite_backend_write_files_update` -Expected: FAIL - 키가 `/notes.txt`로 되어있음 - -**Step 3: write() 및 edit() 구현 수정** - -`composite.rs`의 write() 메서드 (107-116 라인)를 다음으로 교체: - -```rust -async fn write(&self, path: &str, content: &str) -> Result { - let (backend, stripped) = self.get_backend_and_path(path); - let mut result = backend.write(&stripped, content).await?; - - // 경로 복원 - if result.path.is_some() { - result.path = Some(path.to_string()); - } - - // files_update 키도 복원 - if let Some(ref mut files_update) = result.files_update { - let restored: std::collections::HashMap = files_update - .drain() - .map(|(k, v)| (self.restore_prefix(&k, path), v)) - .collect(); - *files_update = restored; - } - - Ok(result) -} -``` - -`composite.rs`의 edit() 메서드 (118-133 라인)를 다음으로 교체: - -```rust -async fn edit( - &self, - path: &str, - old_string: &str, - new_string: &str, - replace_all: bool -) -> Result { - let (backend, stripped) = self.get_backend_and_path(path); - let mut result = backend.edit(&stripped, old_string, new_string, replace_all).await?; - - // 경로 복원 - if result.path.is_some() { - result.path = Some(path.to_string()); - } - - // files_update 키도 복원 - if let Some(ref mut files_update) = result.files_update { - let restored: std::collections::HashMap = files_update - .drain() - .map(|(k, v)| (self.restore_prefix(&k, path), v)) - .collect(); - *files_update = restored; - } - - Ok(result) -} -``` - -**Step 4: 상단에 import 추가 확인** - -`composite.rs` 상단에 `use crate::state::FileData;` 가 있는지 확인. 없으면 추가. - -**Step 5: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_composite_backend` -Expected: 모든 composite 테스트 PASS - ---- - -### Task 1.3: CompositeBackend glob() 집계 로직 추가 - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/composite.rs:135-144` -- Test: 동일 파일 하단 tests 모듈 - -**문제:** glob()이 단일 백엔드만 쿼리하여 Python처럼 모든 백엔드 결과를 집계하지 않음 - -**Step 1: 실패하는 테스트 작성** - -```rust -#[tokio::test] -async fn test_composite_backend_glob_aggregates_all_backends() { - let default = Arc::new(MemoryBackend::new()); - let docs = Arc::new(MemoryBackend::new()); - - let composite = CompositeBackend::new(default.clone()) - .with_route("/docs/", docs.clone()); - - // 각 백엔드에 파일 생성 - composite.write("/src/main.rs", "fn main()").await.unwrap(); - composite.write("/docs/guide.md", "# Guide").await.unwrap(); - composite.write("/docs/api.md", "# API").await.unwrap(); - - // 루트에서 모든 .md 파일 검색 - 모든 백엔드에서 집계해야 함 - let files = composite.glob("**/*.md", "/").await.unwrap(); - - // docs 백엔드의 2개 파일이 모두 포함되어야 함 - assert_eq!(files.len(), 2, "Expected 2 .md files, got: {:?}", files); - assert!(files.iter().any(|f| f.path.contains("guide.md"))); - assert!(files.iter().any(|f| f.path.contains("api.md"))); -} -``` - -**Step 2: 테스트 실패 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_composite_backend_glob_aggregates` -Expected: FAIL - 현재는 default 백엔드만 검색 - -**Step 3: glob() 구현 수정** - -`composite.rs`의 glob() 메서드 (135-144 라인)를 다음으로 교체: - -```rust -async fn glob(&self, pattern: &str, base_path: &str) -> Result, BackendError> { - // 특정 라우트 경로인 경우 해당 백엔드만 검색 - for route in &self.routes { - let route_prefix = route.prefix.trim_end_matches('/'); - if base_path.starts_with(route_prefix) && - (base_path.len() == route_prefix.len() || base_path[route_prefix.len()..].starts_with('/')) { - let (backend, stripped) = self.get_backend_and_path(base_path); - let mut results = backend.glob(pattern, &stripped).await?; - - for info in &mut results { - info.path = self.restore_prefix(&info.path, base_path); - } - return Ok(results); - } - } - - // 루트 또는 라우트되지 않은 경로 - 모든 백엔드에서 집계 - let mut all_results = self.default.glob(pattern, base_path).await?; - - for route in &self.routes { - let mut route_results = route.backend.glob(pattern, "/").await?; - for info in &mut route_results { - let prefix = route.prefix.trim_end_matches('/'); - info.path = format!("{}{}", prefix, info.path); - } - all_results.extend(route_results); - } - - all_results.sort_by(|a, b| a.path.cmp(&b.path)); - Ok(all_results) -} -``` - -**Step 4: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_composite_backend_glob` -Expected: PASS - ---- - -### Task 1.4: AgentState.clone() extensions 경고 로그 추가 - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/state.rs:169-180` -- Modify: `rust-research-agent/crates/rig-deepagents/Cargo.toml` (tracing 의존성 확인) - -**문제:** clone() 시 extensions가 빈 HashMap으로 초기화되어 미들웨어 상태 손실 - -**Step 1: 현재 Clone 구현 확인** - -`state.rs`의 Clone 구현 확인: - -```rust -impl Clone for AgentState { - fn clone(&self) -> Self { - Self { - messages: self.messages.clone(), - todos: self.todos.clone(), - files: self.files.clone(), - structured_response: self.structured_response.clone(), - extensions: HashMap::new(), // 데이터 손실! - } - } -} -``` - -**Step 2: tracing import 추가** - -`state.rs` 상단에 추가: - -```rust -use tracing::warn; -``` - -**Step 3: Clone 구현에 경고 로그 추가** - -`state.rs`의 Clone 구현을 다음으로 교체: - -```rust -impl Clone for AgentState { - fn clone(&self) -> Self { - // extensions가 비어있지 않으면 경고 - if !self.extensions.is_empty() { - warn!( - extension_count = self.extensions.len(), - extension_keys = ?self.extensions.keys().collect::>(), - "AgentState.clone() called with non-empty extensions - extensions will be lost" - ); - } - - Self { - messages: self.messages.clone(), - todos: self.todos.clone(), - files: self.files.clone(), - structured_response: self.structured_response.clone(), - // extensions는 Box를 clone할 수 없어서 빈 상태로 시작 - // 향후 Arc> 패턴으로 개선 고려 - extensions: HashMap::new(), - } - } -} -``` - -**Step 4: 컴파일 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo check` -Expected: 성공 (경고만 있을 수 있음) - ---- - -## Phase 2: MEDIUM Severity 수정 - -### Task 2.1: 경로 정규화 유틸리티 추가 및 적용 - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/backends/path_utils.rs` -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/mod.rs` -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/memory.rs` - -**Step 1: path_utils.rs 생성** - -```rust -// src/backends/path_utils.rs -//! 경로 정규화 유틸리티 -//! -//! 모든 백엔드에서 일관된 경로 처리를 위한 헬퍼 함수들 - -use crate::error::BackendError; - -/// 경로 정규화 -/// - 앞에 `/` 추가 -/// - 연속된 슬래시 제거 (`//` -> `/`) -/// - 후행 슬래시 제거 (루트 제외) -/// - 경로 순회 공격 방지 -pub fn normalize_path(path: &str) -> Result { - // 경로 순회 공격 방지 - if path.contains("..") || path.starts_with("~") { - return Err(BackendError::PathTraversal(path.to_string())); - } - - // 빈 경로는 루트로 - if path.is_empty() { - return Ok("/".to_string()); - } - - // 연속된 슬래시 제거 - let parts: Vec<&str> = path.split('/') - .filter(|p| !p.is_empty()) - .collect(); - - if parts.is_empty() { - return Ok("/".to_string()); - } - - Ok(format!("/{}", parts.join("/"))) -} - -/// 경로가 base_path 하위에 있는지 확인 -/// `/dir`은 `/dir2`와 매칭되지 않음 (정확한 디렉토리 경계 확인) -pub fn is_under_path(path: &str, base_path: &str) -> bool { - let normalized_base = base_path.trim_end_matches('/'); - - if normalized_base.is_empty() || normalized_base == "/" { - return true; // 루트는 모든 경로 포함 - } - - if path == normalized_base { - return true; // 정확히 같은 경로 - } - - // base_path + "/" 로 시작해야 함 - path.starts_with(&format!("{}/", normalized_base)) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_normalize_path_basic() { - assert_eq!(normalize_path("/test.txt").unwrap(), "/test.txt"); - assert_eq!(normalize_path("test.txt").unwrap(), "/test.txt"); - assert_eq!(normalize_path("/dir/file.txt").unwrap(), "/dir/file.txt"); - } - - #[test] - fn test_normalize_path_double_slashes() { - assert_eq!(normalize_path("/dir//file.txt").unwrap(), "/dir/file.txt"); - assert_eq!(normalize_path("//dir///file.txt").unwrap(), "/dir/file.txt"); - } - - #[test] - fn test_normalize_path_trailing_slash() { - assert_eq!(normalize_path("/dir/").unwrap(), "/dir"); - assert_eq!(normalize_path("/").unwrap(), "/"); - } - - #[test] - fn test_normalize_path_traversal_attack() { - assert!(normalize_path("../etc/passwd").is_err()); - assert!(normalize_path("/dir/../etc/passwd").is_err()); - assert!(normalize_path("~/.ssh/id_rsa").is_err()); - } - - #[test] - fn test_is_under_path() { - assert!(is_under_path("/dir/file.txt", "/dir")); - assert!(is_under_path("/dir/sub/file.txt", "/dir")); - assert!(is_under_path("/dir", "/dir")); - assert!(is_under_path("/anything", "/")); - - // /dir 은 /dir2 에 포함되지 않음 - assert!(!is_under_path("/dir2/file.txt", "/dir")); - assert!(!is_under_path("/directory/file.txt", "/dir")); - } -} -``` - -**Step 2: mod.rs에 모듈 추가** - -`backends/mod.rs` 수정: - -```rust -// src/backends/mod.rs -//! 백엔드 모듈 -//! -//! 파일시스템 추상화를 제공합니다. - -pub mod protocol; -pub mod memory; -pub mod filesystem; -pub mod composite; -pub mod path_utils; - -pub use protocol::{Backend, FileInfo, GrepMatch}; -pub use memory::MemoryBackend; -pub use filesystem::FilesystemBackend; -pub use composite::CompositeBackend; -pub use path_utils::{normalize_path, is_under_path}; -``` - -**Step 3: MemoryBackend에서 path_utils 사용** - -`memory.rs` 상단 import 추가: - -```rust -use super::path_utils::{normalize_path, is_under_path}; -``` - -`memory.rs`의 `validate_path` 함수를 제거하고 `normalize_path` 사용으로 교체: - -기존: -```rust -fn validate_path(path: &str) -> Result { ... } -``` - -모든 `Self::validate_path(path)?` 호출을 `normalize_path(path)?`로 교체. - -**Step 4: 테스트 실행** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test path_utils` -Expected: PASS - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test` -Expected: 모든 테스트 PASS - ---- - -### Task 2.2: CompositeBackend 라우트 매칭 후행 슬래시 정규화 - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/composite.rs:48-61` - -**문제:** `/memories` 경로가 `/memories/` 라우트와 매칭되지 않음 - -**Step 1: 실패하는 테스트 작성** - -```rust -#[tokio::test] -async fn test_composite_backend_route_matching_without_trailing_slash() { - let default = Arc::new(MemoryBackend::new()); - let memories = Arc::new(MemoryBackend::new()); - - let composite = CompositeBackend::new(default.clone()) - .with_route("/memories/", memories.clone()); - - // 후행 슬래시 없이도 라우트되어야 함 - composite.write("/memories/notes.txt", "my notes").await.unwrap(); - - // /memories 경로로 읽기 (후행 슬래시 없음) - let files = composite.ls("/memories").await.unwrap(); - assert!(!files.is_empty(), "Should find files under /memories route"); -} -``` - -**Step 2: get_backend_and_path() 수정** - -`composite.rs`의 `get_backend_and_path` 메서드를 다음으로 교체: - -```rust -/// 경로에 맞는 백엔드와 변환된 경로 반환 -fn get_backend_and_path(&self, path: &str) -> (Arc, String) { - // 경로 정규화 (후행 슬래시 제거) - let normalized_path = path.trim_end_matches('/'); - - for route in &self.routes { - let route_prefix = route.prefix.trim_end_matches('/'); - - // 정확히 일치하거나 route_prefix/ 로 시작하는 경우 - if normalized_path == route_prefix || - normalized_path.starts_with(&format!("{}/", route_prefix)) { - let suffix = if normalized_path == route_prefix { - "" - } else { - &normalized_path[route_prefix.len()..] - }; - - let stripped = if suffix.is_empty() { - "/".to_string() - } else { - suffix.to_string() - }; - return (route.backend.clone(), stripped); - } - } - (self.default.clone(), path.to_string()) -} -``` - -**Step 3: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_composite_backend_route_matching` -Expected: PASS - ---- - -### Task 2.3: FilesystemBackend grep에서 tokio::fs 사용 - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/filesystem.rs:283-300` - -**문제:** async 함수 내에서 sync `std::fs::read_to_string` 사용으로 런타임 블로킹 - -**Step 1: grep() 내부 파일 읽기를 async로 변경** - -`filesystem.rs`의 grep 메서드에서 파일 읽기 부분을 수정. - -기존: -```rust -let content = match std::fs::read_to_string(entry.path()) { - Ok(c) => c, - Err(_) => continue, -}; -``` - -수정: -```rust -let content = match fs::read_to_string(entry.path()).await { - Ok(c) => c, - Err(e) => { - tracing::debug!(path = ?entry.path(), error = %e, "Skipping file in grep due to read error"); - continue; - } -}; -``` - -**Step 2: filesystem.rs 상단에 tracing import 추가** - -```rust -use tracing; -``` - -**Step 3: 컴파일 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo check` -Expected: 성공 - ---- - -### Task 2.4: max_recursion 설정 가능하게 변경 - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/runtime.rs:41-48` - -**Step 1: RuntimeConfig::new()에 기본값을 Python과 동일하게 수정** - -```rust -impl RuntimeConfig { - pub fn new() -> Self { - Self { - debug: false, - max_recursion: 100, // Python 기본값에 가깝게 조정 (1000은 너무 높음) - current_recursion: 0, - } - } - - /// 커스텀 재귀 제한으로 생성 - pub fn with_max_recursion(max_recursion: usize) -> Self { - Self { - debug: false, - max_recursion, - current_recursion: 0, - } - } -} -``` - -**Step 2: 테스트 수정** - -`runtime.rs`의 기존 테스트 `test_recursion_limit`를 수정: - -```rust -#[test] -fn test_recursion_limit() { - let state = AgentState::new(); - let backend = Arc::new(MemoryBackend::new()); - - // 커스텀 재귀 제한 사용 - let config = RuntimeConfig::with_max_recursion(10); - let mut runtime = ToolRuntime::new(state, backend).with_config(config); - - for _ in 0..10 { - runtime = runtime.with_increased_recursion(); - } - - assert!(runtime.is_recursion_limit_exceeded()); -} - -#[test] -fn test_default_recursion_limit() { - let state = AgentState::new(); - let backend = Arc::new(MemoryBackend::new()); - let runtime = ToolRuntime::new(state, backend); - - // 기본 제한은 100 - assert_eq!(runtime.config().max_recursion, 100); -} -``` - -**Step 3: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test runtime` -Expected: PASS - ---- - -## Phase 3: 최종 검증 - -### Task 3.1: 전체 테스트 실행 및 검증 - -**Step 1: 전체 테스트 실행** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test` -Expected: 모든 테스트 PASS (30개 이상) - -**Step 2: Clippy 린트 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo clippy -- -D warnings` -Expected: 경고 없음 (또는 허용 가능한 경고만) - -**Step 3: 문서 주석 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo doc --no-deps` -Expected: 문서 생성 성공 - ---- - -### Task 3.2: 변경사항 커밋 - -**Step 1: 변경 파일 확인** - -Run: `git status` -Expected: 수정된 Rust 파일들 표시 - -**Step 2: 커밋** - -```bash -git add rust-research-agent/crates/rig-deepagents/ -git commit -m "fix: address HIGH and MEDIUM severity issues from Codex/Code Review - -HIGH fixes: -- glob() now respects base_path filtering (security fix) -- CompositeBackend write/edit restores files_update key paths -- CompositeBackend glob aggregates results from all backends -- AgentState.clone() logs warning when extensions lost - -MEDIUM fixes: -- Add centralized path_utils for consistent path normalization -- CompositeBackend route matching normalizes trailing slashes -- FilesystemBackend grep uses async tokio::fs -- RuntimeConfig.max_recursion increased to 100 (was 10) - -Verified by: Codex CLI (gpt-5.2-codex), Code Reviewer subagent" -``` - ---- - -## 수정 우선순위 요약 - -| 우선순위 | Task | 예상 시간 | 위험도 | -|---------|------|----------|--------| -| 1 | Task 1.1: glob base_path 필터 | 5분 | Low | -| 2 | Task 1.2: files_update 키 복원 | 5분 | Low | -| 3 | Task 1.3: glob 집계 로직 | 10분 | Medium | -| 4 | Task 1.4: clone extensions 경고 | 5분 | Low | -| 5 | Task 2.1: path_utils 유틸리티 | 15분 | Low | -| 6 | Task 2.2: 라우트 매칭 정규화 | 5분 | Low | -| 7 | Task 2.3: async grep | 5분 | Low | -| 8 | Task 2.4: max_recursion 설정 | 5분 | Low | -| 9 | Task 3.1: 최종 검증 | 5분 | N/A | -| 10 | Task 3.2: 커밋 | 2분 | N/A | - -**총 예상 시간:** 약 60분 diff --git a/docs/plans/2026-01-01-rig-deepagents-phase5-6.md b/docs/plans/2026-01-01-rig-deepagents-phase5-6.md deleted file mode 100644 index 345ebae..0000000 --- a/docs/plans/2026-01-01-rig-deepagents-phase5-6.md +++ /dev/null @@ -1,1526 +0,0 @@ -# Rig-DeepAgents Phase 5-6 Implementation Plan - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task. - -**Goal:** Complete the Rig-DeepAgents Rust port by fixing security issues, aligning backend behaviors with Python, implementing all tools, and building the executor loop for full feature parity. - -**Architecture:** -- Phase 1 fixes security vulnerabilities identified by Codex CLI (symlink traversal) -- Phase 2 aligns backend behaviors with Python reference (ls boundary, glob patterns, grep) -- Phase 3 implements tool layer connecting Backend trait methods to AgentMiddleware -- Phase 4 builds executor loop with LLM integration via Rig framework -- Phase 5 adds SubAgent system for task delegation - -**Tech Stack:** Rust 1.75+, tokio, async-trait, rig-core, serde_json, tracing - -**Verification Sources:** -- Codex CLI (gpt-5.2-codex): 164,210 tokens of analysis -- Qwen CLI: SubAgent/Tools gap analysis -- Python Reference: `deepagents_sourcecode/libs/deepagents/deepagents/backends/` - ---- - -## Phase 1: Security Fixes (HIGH Priority) - -### Task 1.1: Fix Symlink Path Traversal in FilesystemBackend - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/filesystem.rs:42-66` -- Test: 동일 파일 하단 tests 모듈 - -**문제:** `resolve_path`가 존재하지 않는 경로에 대해 canonicalize 없이 `target`을 반환하여, 부모 디렉토리가 루트 외부로의 심볼릭 링크인 경우 경로 탈출이 가능함. - -**Step 1: 실패하는 테스트 작성** - -`filesystem.rs` tests 모듈에 추가: - -```rust -#[tokio::test] -async fn test_filesystem_backend_symlink_traversal_prevention() { - use std::os::unix::fs::symlink; - use tempfile::tempdir; - - // 테스트용 디렉토리 생성 - let root = tempdir().unwrap(); - let outside = tempdir().unwrap(); - - // 루트 외부에 파일 생성 - let outside_file = outside.path().join("secret.txt"); - std::fs::write(&outside_file, "secret data").unwrap(); - - // 루트 내부에 외부를 가리키는 심볼릭 링크 생성 - let symlink_path = root.path().join("escape"); - symlink(outside.path(), &symlink_path).unwrap(); - - let backend = FilesystemBackend::new(root.path()); - - // 심볼릭 링크를 통한 접근 시도 - 차단되어야 함 - let result = backend.read("/escape/secret.txt", 0, 100).await; - assert!(result.is_err(), "Should block symlink traversal"); - - // 심볼릭 링크를 통한 쓰기 시도 - 차단되어야 함 - let result = backend.write("/escape/malicious.txt", "pwned").await; - assert!(result.is_err() || result.unwrap().is_err(), "Should block write via symlink"); -} -``` - -**Step 2: 테스트 실패 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_filesystem_backend_symlink_traversal` -Expected: FAIL - 현재는 symlink를 통한 접근이 허용됨 - -**Step 3: Cargo.toml에 tempfile 의존성 추가** - -`Cargo.toml`의 `[dev-dependencies]` 섹션에: - -```toml -[dev-dependencies] -tempfile = "3" -``` - -**Step 4: resolve_path 구현 수정** - -`filesystem.rs`의 `resolve_path` 메서드를 다음으로 교체: - -```rust -/// 경로 검증 및 해결 -/// Security: 심볼릭 링크를 통한 루트 탈출 방지 -fn resolve_path(&self, path: &str) -> Result { - if self.virtual_mode { - // 경로 탐색 방지 - if path.contains("..") || path.starts_with("~") { - return Err(BackendError::PathTraversal(path.to_string())); - } - - let clean_path = path.trim_start_matches('/'); - let target = self.root.join(clean_path); - - // 루트를 canonicalize - let canonical_root = self.root.canonicalize() - .unwrap_or_else(|_| self.root.clone()); - - // 부모 디렉토리가 존재하면 canonicalize하여 symlink 해석 - if let Some(parent) = target.parent() { - if parent.exists() { - let canonical_parent = parent.canonicalize() - .map_err(|e| BackendError::IoError(e.to_string()))?; - - // 부모가 루트 외부이면 차단 - if !canonical_parent.starts_with(&canonical_root) { - return Err(BackendError::PathTraversal( - format!("Symlink escape detected: {}", path) - )); - } - } - } - - // 존재하는 경로는 canonicalize해서 최종 확인 - if target.exists() { - let resolved = target.canonicalize() - .map_err(|e| BackendError::IoError(e.to_string()))?; - - if !resolved.starts_with(&canonical_root) { - return Err(BackendError::PathTraversal(path.to_string())); - } - } - - Ok(target) - } else { - Ok(PathBuf::from(path)) - } -} -``` - -**Step 5: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_filesystem_backend_symlink` -Expected: PASS - -**Step 6: 전체 테스트 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test` -Expected: 모든 테스트 PASS - ---- - -## Phase 2: Backend Behavioral Fixes (MEDIUM Priority) - -### Task 2.1: Fix MemoryBackend::ls Boundary Check - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/memory.rs:61-94` -- Test: 동일 파일 하단 tests 모듈 - -**문제:** `starts_with` 검사가 경계를 확인하지 않아 `/dir`이 `/directory`의 파일도 매칭함 - -**Step 1: 실패하는 테스트 작성** - -```rust -#[tokio::test] -async fn test_memory_backend_ls_boundary_check() { - let backend = MemoryBackend::new(); - backend.write("/dir/file.txt", "in dir").await.unwrap(); - backend.write("/directory/other.txt", "in directory").await.unwrap(); - - // /dir 에서 ls 하면 /directory 파일은 보이지 않아야 함 - let files = backend.ls("/dir").await.unwrap(); - - assert_eq!(files.len(), 1, "Should only find files under /dir"); - assert!(files[0].path.contains("/dir/"), "File should be under /dir"); - assert!(!files.iter().any(|f| f.path.contains("/directory")), - "Should not include /directory files"); -} -``` - -**Step 2: 테스트 실패 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_memory_backend_ls_boundary_check` -Expected: FAIL - -**Step 3: ls() 구현 수정** - -`memory.rs`의 ls() 메서드에서 매칭 로직 수정: - -```rust -async fn ls(&self, path: &str) -> Result, BackendError> { - let path = normalize_path(path)?; - let files = self.files.read().await; - - let normalized_prefix = if path == "/" { - "/".to_string() - } else { - format!("{}/", path.trim_end_matches('/')) - }; - - let mut results = Vec::new(); - let mut dirs_seen = HashSet::new(); - - for (file_path, data) in files.iter() { - // 정확한 디렉토리 경계 확인 - let matches = if path == "/" { - true - } else { - file_path.starts_with(&normalized_prefix) || file_path == &path - }; - - if matches { - let prefix_to_strip = if path == "/" { "/" } else { &normalized_prefix }; - let relative = file_path.strip_prefix(prefix_to_strip) - .unwrap_or(file_path.strip_prefix(&path).unwrap_or(file_path)); - - if let Some(slash_pos) = relative.find('/') { - // 서브디렉토리 - let dir_name = &relative[..slash_pos]; - let dir_path = format!("{}/{}", path.trim_end_matches('/'), dir_name); - if dirs_seen.insert(dir_path.clone()) { - results.push(FileInfo::dir(&format!("{}/", dir_path))); - } - } else if !relative.is_empty() { - // 파일 - let size = data.content.iter().map(|s| s.len()).sum::() as u64; - results.push(FileInfo::file_with_time( - file_path, - size, - &data.modified_at, - )); - } - } - } - - results.sort_by(|a, b| a.path.cmp(&b.path)); - Ok(results) -} -``` - -**Step 4: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_memory_backend_ls_boundary` -Expected: PASS - ---- - -### Task 2.2: Fix normalize_path to Handle "." Segments - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/path_utils.rs:32-52` -- Test: 동일 파일 하단 tests 모듈 - -**문제:** `/./file`이 `/file`과 다르게 처리됨 - -**Step 1: 실패하는 테스트 작성** - -```rust -#[test] -fn test_normalize_path_dot_segments() { - assert_eq!(normalize_path("/./file.txt").unwrap(), "/file.txt"); - assert_eq!(normalize_path("/dir/./sub/file.txt").unwrap(), "/dir/sub/file.txt"); - assert_eq!(normalize_path("./file.txt").unwrap(), "/file.txt"); - assert_eq!(normalize_path("/dir/.").unwrap(), "/dir"); -} -``` - -**Step 2: normalize_path 수정** - -```rust -pub fn normalize_path(path: &str) -> Result { - // 경로 순회 공격 방지 (..는 차단, .은 허용) - if path.contains("..") || path.starts_with("~") { - return Err(BackendError::PathTraversal(path.to_string())); - } - - // 빈 경로는 루트로 - if path.is_empty() { - return Ok("/".to_string()); - } - - // 연속된 슬래시 제거 및 "." 세그먼트 필터링 - let parts: Vec<&str> = path.split('/') - .filter(|p| !p.is_empty() && *p != ".") - .collect(); - - if parts.is_empty() { - return Ok("/".to_string()); - } - - Ok(format!("/{}", parts.join("/"))) -} -``` - -**Step 3: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_normalize_path_dot` -Expected: PASS - ---- - -### Task 2.3: Fix FilesystemBackend::grep glob_filter to Match Full Path - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/filesystem.rs:280-310` - -**문제:** `glob_filter`가 파일명만 매칭하여 `**/*.rs` 같은 패턴이 작동하지 않음 - -**Step 1: 실패하는 테스트 작성** - -```rust -#[tokio::test] -async fn test_filesystem_backend_grep_path_glob() { - use tempfile::tempdir; - - let root = tempdir().unwrap(); - - // 중첩 디렉토리에 파일 생성 - let src_dir = root.path().join("src"); - std::fs::create_dir_all(&src_dir).unwrap(); - std::fs::write(src_dir.join("main.rs"), "fn main() { println!(\"hello\"); }").unwrap(); - std::fs::write(src_dir.join("lib.rs"), "pub fn hello() {}").unwrap(); - std::fs::write(root.path().join("README.md"), "# Hello").unwrap(); - - let backend = FilesystemBackend::new(root.path()); - - // **/*.rs 패턴으로 검색 - let results = backend.grep("fn", None, Some("**/*.rs")).await.unwrap(); - - assert!(!results.is_empty(), "Should find matches in .rs files"); - assert!(results.iter().all(|m| m.path.ends_with(".rs")), - "All matches should be from .rs files"); -} -``` - -**Step 2: grep 구현 수정** - -```rust -async fn grep( - &self, - pattern: &str, - path: Option<&str>, - glob_filter: Option<&str>, -) -> Result, BackendError> { - let search_path = path.unwrap_or("/"); - let resolved = self.resolve_path(search_path)?; - - if !resolved.exists() { - return Ok(vec![]); - } - - let glob_pattern = glob_filter.map(|g| { - // 패턴이 **로 시작하면 그대로, 아니면 앞에 **/ 추가 - let normalized = if g.starts_with("**/") || g.starts_with("/") { - g.to_string() - } else { - format!("**/{}", g) - }; - glob::Pattern::new(&normalized) - }).transpose() - .map_err(|e| BackendError::Pattern(e.to_string()))?; - - let mut results = Vec::new(); - let walker = walkdir::WalkDir::new(&resolved); - - for entry in walker.into_iter().filter_map(|e| e.ok()) { - if !entry.file_type().is_file() { - continue; - } - - // Glob filter - 전체 경로에 대해 매칭 - if let Some(ref gp) = glob_pattern { - let relative_path = entry.path() - .strip_prefix(&resolved) - .map(|p| p.to_string_lossy().to_string()) - .unwrap_or_else(|_| entry.path().to_string_lossy().to_string()); - - if !gp.matches(&relative_path) && !gp.matches(&entry.file_name().to_string_lossy()) { - continue; - } - } - - // 파일 읽기 (async) - let content = match fs::read_to_string(entry.path()).await { - Ok(c) => c, - Err(e) => { - tracing::debug!(path = ?entry.path(), error = %e, "Skipping file in grep due to read error"); - continue; - } - }; - - let virt_path = self.to_virtual_path(entry.path()); - - // 리터럴 검색 - for (line_num, line) in content.lines().enumerate() { - if line.contains(pattern) { - results.push(GrepMatch::new(&virt_path, line_num + 1, line)); - } - } - } - - Ok(results) -} -``` - -**Step 3: 테스트 통과 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test test_filesystem_backend_grep_path_glob` -Expected: PASS - ---- - -### Task 2.4: Document Grep Literal vs Regex Design Decision - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/protocol.rs:60-70` (문서 주석) -- Modify: `rust-research-agent/crates/rig-deepagents/src/backends/memory.rs:1-10` (모듈 문서) - -**Step 1: protocol.rs에 grep 설계 결정 문서화** - -`protocol.rs`의 `grep` 메서드 문서에 추가: - -```rust -/// 파일 내용에서 패턴 검색 -/// -/// # Design Decision: Literal Search -/// -/// Rust 구현은 **리터럴 문자열 검색**을 사용합니다 (Python의 regex와 다름). -/// 이유: -/// - 보안: 정규식 패턴 주입 공격 방지 -/// - 성능: 정규식 컴파일 오버헤드 없음 -/// - 단순성: LLM 에이전트가 이해하기 쉬움 -/// -/// 정규식이 필요한 경우: -/// - `regex` crate를 사용하는 `grep_regex` 메서드 추가 고려 -/// - 또는 Backend 구현체에서 regex 기능 확장 -async fn grep( - &self, - pattern: &str, - path: Option<&str>, - glob_filter: Option<&str>, -) -> Result, BackendError>; -``` - -**Step 2: 커밋** - -```bash -git add rust-research-agent/crates/rig-deepagents/ -git commit -m "docs: document grep literal search design decision - -Rust grep uses literal substring matching (not regex like Python). -This is intentional for security, performance, and simplicity." -``` - ---- - -## Phase 3: Tool Implementations (CRITICAL) - -### Task 3.1: Create Tool Module Structure - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/read_file.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/write_file.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/edit_file.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/ls.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/glob.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/grep.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/write_todos.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/task.rs` -- Modify: `rust-research-agent/crates/rig-deepagents/src/tools/mod.rs` - -**Step 1: mod.rs 업데이트** - -```rust -//! Tool implementations for DeepAgents -//! -//! This module provides the 8 core tools auto-injected by middleware: -//! - File operations: read_file, write_file, edit_file, ls, glob, grep -//! - Planning: write_todos -//! - Delegation: task (SubAgent) - -mod read_file; -mod write_file; -mod edit_file; -mod ls; -mod glob; -mod grep; -mod write_todos; -mod task; - -pub use read_file::ReadFileTool; -pub use write_file::WriteFileTool; -pub use edit_file::EditFileTool; -pub use ls::LsTool; -pub use glob::GlobTool; -pub use grep::GrepTool; -pub use write_todos::WriteTodosTool; -pub use task::TaskTool; - -use crate::middleware::DynTool; -use std::sync::Arc; - -/// 모든 기본 도구 반환 -pub fn default_tools() -> Vec { - vec![ - Arc::new(ReadFileTool), - Arc::new(WriteFileTool), - Arc::new(EditFileTool), - Arc::new(LsTool), - Arc::new(GlobTool), - Arc::new(GrepTool), - Arc::new(WriteTodosTool), - ] -} - -/// SubAgent task 도구 포함하여 모든 도구 반환 -pub fn all_tools() -> Vec { - let mut tools = default_tools(); - tools.push(Arc::new(TaskTool)); - tools -} -``` - ---- - -### Task 3.2: Implement ReadFileTool - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/read_file.rs` - -**Step 1: 파일 생성** - -```rust -//! read_file 도구 구현 - -use async_trait::async_trait; -use serde::{Deserialize, Serialize}; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; - -/// read_file 도구 -pub struct ReadFileTool; - -#[derive(Debug, Deserialize)] -struct ReadFileArgs { - file_path: String, - #[serde(default)] - offset: usize, - #[serde(default = "default_limit")] - limit: usize, -} - -fn default_limit() -> usize { - 2000 -} - -#[async_trait] -impl Tool for ReadFileTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "read_file".to_string(), - description: "Read content from a file with optional line offset and limit.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "The absolute path to the file to read" - }, - "offset": { - "type": "integer", - "description": "Line number to start reading from (0-indexed)", - "default": 0 - }, - "limit": { - "type": "integer", - "description": "Maximum number of lines to read", - "default": 2000 - } - }, - "required": ["file_path"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let args: ReadFileArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - runtime.backend() - .read(&args.file_path, args.offset, args.limit) - .await - .map_err(|e| MiddlewareError::Backend(e)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - use crate::state::AgentState; - use std::sync::Arc; - - #[tokio::test] - async fn test_read_file_tool() { - let backend = Arc::new(MemoryBackend::new()); - backend.write("/test.txt", "line1\nline2\nline3").await.unwrap(); - - let state = AgentState::new(); - let runtime = ToolRuntime::new(state, backend); - let tool = ReadFileTool; - - let result = tool.execute( - serde_json::json!({"file_path": "/test.txt"}), - &runtime, - ).await.unwrap(); - - assert!(result.contains("line1")); - assert!(result.contains("line2")); - } -} -``` - ---- - -### Task 3.3: Implement WriteFileTool - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/write_file.rs` - -**Step 1: 파일 생성** - -```rust -//! write_file 도구 구현 - -use async_trait::async_trait; -use serde::Deserialize; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; - -/// write_file 도구 -pub struct WriteFileTool; - -#[derive(Debug, Deserialize)] -struct WriteFileArgs { - file_path: String, - content: String, -} - -#[async_trait] -impl Tool for WriteFileTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "write_file".to_string(), - description: "Write content to a file, creating it if it doesn't exist.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "The absolute path to the file to write" - }, - "content": { - "type": "string", - "description": "The content to write to the file" - } - }, - "required": ["file_path", "content"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let args: WriteFileArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - let result = runtime.backend() - .write(&args.file_path, &args.content) - .await - .map_err(|e| MiddlewareError::Backend(e))?; - - if result.is_ok() { - Ok(format!("Successfully wrote to {}", args.file_path)) - } else { - Err(MiddlewareError::ToolExecution( - result.error.unwrap_or_else(|| "Unknown error".to_string()) - )) - } - } -} -``` - ---- - -### Task 3.4: Implement EditFileTool - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/edit_file.rs` - -**Step 1: 파일 생성** - -```rust -//! edit_file 도구 구현 - -use async_trait::async_trait; -use serde::Deserialize; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; - -/// edit_file 도구 -pub struct EditFileTool; - -#[derive(Debug, Deserialize)] -struct EditFileArgs { - file_path: String, - old_string: String, - new_string: String, - #[serde(default)] - replace_all: bool, -} - -#[async_trait] -impl Tool for EditFileTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "edit_file".to_string(), - description: "Edit a file by replacing old_string with new_string.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "The absolute path to the file to edit" - }, - "old_string": { - "type": "string", - "description": "The string to find and replace" - }, - "new_string": { - "type": "string", - "description": "The replacement string" - }, - "replace_all": { - "type": "boolean", - "description": "Replace all occurrences (default: false)", - "default": false - } - }, - "required": ["file_path", "old_string", "new_string"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let args: EditFileArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - let result = runtime.backend() - .edit(&args.file_path, &args.old_string, &args.new_string, args.replace_all) - .await - .map_err(|e| MiddlewareError::Backend(e))?; - - if result.is_ok() { - let occurrences = result.occurrences.unwrap_or(1); - Ok(format!("Replaced {} occurrence(s) in {}", occurrences, args.file_path)) - } else { - Err(MiddlewareError::ToolExecution( - result.error.unwrap_or_else(|| "Unknown error".to_string()) - )) - } - } -} -``` - ---- - -### Task 3.5: Implement LsTool, GlobTool, GrepTool - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/ls.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/glob.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/grep.rs` - -**Step 1: ls.rs 생성** - -```rust -//! ls 도구 구현 - -use async_trait::async_trait; -use serde::Deserialize; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; - -pub struct LsTool; - -#[derive(Debug, Deserialize)] -struct LsArgs { - #[serde(default = "default_path")] - path: String, -} - -fn default_path() -> String { - "/".to_string() -} - -#[async_trait] -impl Tool for LsTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "ls".to_string(), - description: "List files and directories at the given path.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "path": { - "type": "string", - "description": "The directory path to list", - "default": "/" - } - } - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let args: LsArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - let files = runtime.backend() - .ls(&args.path) - .await - .map_err(|e| MiddlewareError::Backend(e))?; - - let output: Vec = files.iter() - .map(|f| { - if f.is_dir { - format!("{}/ (dir)", f.path) - } else { - format!("{} ({} bytes)", f.path, f.size.unwrap_or(0)) - } - }) - .collect(); - - Ok(output.join("\n")) - } -} -``` - -**Step 2: glob.rs 생성** - -```rust -//! glob 도구 구현 - -use async_trait::async_trait; -use serde::Deserialize; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; - -pub struct GlobTool; - -#[derive(Debug, Deserialize)] -struct GlobArgs { - pattern: String, - #[serde(default = "default_path")] - base_path: String, -} - -fn default_path() -> String { - "/".to_string() -} - -#[async_trait] -impl Tool for GlobTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "glob".to_string(), - description: "Find files matching a glob pattern.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Glob pattern (e.g., '**/*.rs', '*.txt')" - }, - "base_path": { - "type": "string", - "description": "Base path to search from", - "default": "/" - } - }, - "required": ["pattern"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let args: GlobArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - let files = runtime.backend() - .glob(&args.pattern, &args.base_path) - .await - .map_err(|e| MiddlewareError::Backend(e))?; - - let paths: Vec = files.iter().map(|f| f.path.clone()).collect(); - - if paths.is_empty() { - Ok("No files found matching pattern.".to_string()) - } else { - Ok(format!("Found {} files:\n{}", paths.len(), paths.join("\n"))) - } - } -} -``` - -**Step 3: grep.rs 생성** - -```rust -//! grep 도구 구현 - -use async_trait::async_trait; -use serde::Deserialize; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; - -pub struct GrepTool; - -#[derive(Debug, Deserialize)] -struct GrepArgs { - pattern: String, - #[serde(default)] - path: Option, - #[serde(default)] - glob_filter: Option, -} - -#[async_trait] -impl Tool for GrepTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "grep".to_string(), - description: "Search for a literal text pattern in files.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Literal text pattern to search for" - }, - "path": { - "type": "string", - "description": "Directory to search in (default: /)" - }, - "glob_filter": { - "type": "string", - "description": "Glob pattern to filter files (e.g., '**/*.rs')" - } - }, - "required": ["pattern"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let args: GrepArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - let matches = runtime.backend() - .grep(&args.pattern, args.path.as_deref(), args.glob_filter.as_deref()) - .await - .map_err(|e| MiddlewareError::Backend(e))?; - - if matches.is_empty() { - Ok("No matches found.".to_string()) - } else { - let output: Vec = matches.iter() - .map(|m| format!("{}:{}: {}", m.path, m.line_number, m.content)) - .collect(); - Ok(format!("Found {} matches:\n{}", matches.len(), output.join("\n"))) - } - } -} -``` - ---- - -### Task 3.6: Implement WriteTodosTool - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/write_todos.rs` - -**Step 1: 파일 생성** - -```rust -//! write_todos 도구 구현 - -use async_trait::async_trait; -use serde::Deserialize; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; -use crate::state::{Todo, TodoStatus}; - -pub struct WriteTodosTool; - -#[derive(Debug, Deserialize)] -struct TodoItem { - content: String, - #[serde(default)] - status: String, -} - -#[derive(Debug, Deserialize)] -struct WriteTodosArgs { - todos: Vec, -} - -#[async_trait] -impl Tool for WriteTodosTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "write_todos".to_string(), - description: "Update the todo list with new items.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "todos": { - "type": "array", - "items": { - "type": "object", - "properties": { - "content": { - "type": "string", - "description": "The todo item content" - }, - "status": { - "type": "string", - "enum": ["pending", "in_progress", "completed"], - "default": "pending" - } - }, - "required": ["content"] - }, - "description": "List of todo items" - } - }, - "required": ["todos"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let args: WriteTodosArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - let todos: Vec = args.todos.iter() - .map(|t| { - let status = match t.status.as_str() { - "in_progress" => TodoStatus::InProgress, - "completed" => TodoStatus::Completed, - _ => TodoStatus::Pending, - }; - Todo::with_status(&t.content, status) - }) - .collect(); - - // Note: 실제 상태 업데이트는 미들웨어 레벨에서 처리 - // 여기서는 검증 및 포맷만 수행 - Ok(format!("Updated {} todo items", todos.len())) - } -} -``` - ---- - -### Task 3.7: Implement TaskTool (SubAgent Delegation) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/task.rs` - -**Step 1: 파일 생성** - -```rust -//! task 도구 구현 (SubAgent 위임) - -use async_trait::async_trait; -use serde::Deserialize; - -use crate::error::MiddlewareError; -use crate::middleware::{Tool, ToolDefinition}; -use crate::runtime::ToolRuntime; - -pub struct TaskTool; - -#[derive(Debug, Deserialize)] -struct TaskArgs { - subagent_type: String, - prompt: String, - #[serde(default)] - description: Option, -} - -#[async_trait] -impl Tool for TaskTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "task".to_string(), - description: "Delegate a task to a sub-agent for specialized processing.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "subagent_type": { - "type": "string", - "description": "The type of sub-agent to use (e.g., 'researcher', 'explorer', 'synthesizer')" - }, - "prompt": { - "type": "string", - "description": "The task prompt for the sub-agent" - }, - "description": { - "type": "string", - "description": "A short description of the task" - } - }, - "required": ["subagent_type", "prompt"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - runtime: &ToolRuntime, - ) -> Result { - let args: TaskArgs = serde_json::from_value(args) - .map_err(|e| MiddlewareError::ToolExecution(format!("Invalid arguments: {}", e)))?; - - // 재귀 한도 확인 - if runtime.is_recursion_limit_exceeded() { - return Err(MiddlewareError::RecursionLimit( - format!("Recursion limit exceeded. Cannot delegate to subagent '{}'", args.subagent_type) - )); - } - - // Note: 실제 SubAgent 실행은 executor에서 처리 - // 이 도구는 요청을 구조화하고 검증만 수행 - Ok(format!( - "Task delegation requested:\n- Agent: {}\n- Description: {}\n- Prompt: {}", - args.subagent_type, - args.description.unwrap_or_else(|| "N/A".to_string()), - args.prompt - )) - } -} -``` - ---- - -### Task 3.8: Update lib.rs Exports - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/lib.rs` - -**Step 1: tools 모듈 export 추가** - -```rust -// lib.rs 상단에 추가 -pub mod tools; - -// pub use 추가 -pub use tools::{ - ReadFileTool, WriteFileTool, EditFileTool, - LsTool, GlobTool, GrepTool, - WriteTodosTool, TaskTool, - default_tools, all_tools, -}; -``` - ---- - -### Task 3.9: 전체 테스트 실행 - -**Step 1: 컴파일 확인** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo check` -Expected: 성공 - -**Step 2: 테스트 실행** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test` -Expected: 모든 테스트 PASS - -**Step 3: 커밋** - -```bash -git add rust-research-agent/crates/rig-deepagents/ -git commit -m "feat(tools): implement all 8 core tools - -Tools implemented: -- read_file: Read file content with offset/limit -- write_file: Create/overwrite files -- edit_file: String replacement in files -- ls: List directory contents -- glob: Pattern-based file search -- grep: Content search (literal) -- write_todos: Todo list management -- task: SubAgent delegation request - -All tools connect to Backend trait for operations." -``` - ---- - -## Phase 4: Executor Loop (CRITICAL) - -### Task 4.1: Create Executor Module Structure - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/executor.rs` - -**Step 1: Executor 구조체 정의** - -```rust -//! Agent executor - 메시지 처리 및 도구 실행 루프 -//! -//! Python Reference: deepagents/graph.py - -use std::sync::Arc; -use async_trait::async_trait; - -use crate::backends::Backend; -use crate::error::{DeepAgentError, MiddlewareError}; -use crate::middleware::{MiddlewareStack, DynTool, StateUpdate}; -use crate::runtime::ToolRuntime; -use crate::state::{AgentState, Message, Role, ToolCall}; - -/// LLM 인터페이스 트레이트 -#[async_trait] -pub trait LLMProvider: Send + Sync { - /// 메시지로부터 응답 생성 - async fn generate( - &self, - messages: &[Message], - tools: &[crate::middleware::ToolDefinition], - ) -> Result; -} - -/// Agent Executor -pub struct AgentExecutor { - llm: L, - middleware: MiddlewareStack, - backend: Arc, - max_iterations: usize, -} - -impl AgentExecutor { - pub fn new( - llm: L, - middleware: MiddlewareStack, - backend: Arc, - ) -> Self { - Self { - llm, - middleware, - backend, - max_iterations: 50, - } - } - - pub fn with_max_iterations(mut self, max: usize) -> Self { - self.max_iterations = max; - self - } - - /// 에이전트 실행 - pub async fn run(&self, initial_state: AgentState) -> Result { - let mut state = initial_state; - let runtime = ToolRuntime::new(state.clone(), self.backend.clone()); - - // Before hooks 실행 - if let Some(update) = self.middleware.before_agent(&mut state, &runtime).await? { - self.apply_update(&mut state, update); - } - - // 도구 수집 - let tools = self.middleware.collect_tools(); - let tool_definitions: Vec<_> = tools.iter() - .map(|t| t.definition()) - .collect(); - - // 메인 실행 루프 - for iteration in 0..self.max_iterations { - tracing::debug!(iteration, "Agent iteration"); - - // LLM 호출 - let response = self.llm.generate(&state.messages, &tool_definitions).await?; - state.add_message(response.clone()); - - // 도구 호출이 없으면 종료 - if !response.has_tool_calls() { - tracing::debug!("No tool calls, finishing"); - break; - } - - // 도구 호출 처리 - if let Some(tool_calls) = &response.tool_calls { - for call in tool_calls { - let result = self.execute_tool_call(call, &tools, &runtime).await; - let tool_message = Message::tool(&result, &call.id); - state.add_message(tool_message); - } - } - } - - // After hooks 실행 - if let Some(update) = self.middleware.after_agent(&mut state, &runtime).await? { - self.apply_update(&mut state, update); - } - - Ok(state) - } - - /// 도구 호출 실행 - async fn execute_tool_call( - &self, - call: &ToolCall, - tools: &[DynTool], - runtime: &ToolRuntime, - ) -> String { - let tool = tools.iter().find(|t| t.definition().name == call.name); - - match tool { - Some(t) => { - match t.execute(call.arguments.clone(), runtime).await { - Ok(result) => result, - Err(e) => format!("Tool error: {}", e), - } - } - None => format!("Unknown tool: {}", call.name), - } - } - - /// 상태 업데이트 적용 - fn apply_update(&self, state: &mut AgentState, update: StateUpdate) { - match update { - StateUpdate::AddMessages(msgs) => { - for msg in msgs { - state.add_message(msg); - } - } - StateUpdate::SetTodos(todos) => { - state.todos = todos; - } - StateUpdate::UpdateFiles(files) => { - for (path, data) in files { - if let Some(file_data) = data { - state.files.insert(path, file_data); - } else { - state.files.remove(&path); - } - } - } - StateUpdate::Batch(updates) => { - for u in updates { - self.apply_update(state, u); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::backends::MemoryBackend; - - struct MockLLM; - - #[async_trait] - impl LLMProvider for MockLLM { - async fn generate( - &self, - _messages: &[Message], - _tools: &[crate::middleware::ToolDefinition], - ) -> Result { - Ok(Message::assistant("Hello! I'm a mock assistant.")) - } - } - - #[tokio::test] - async fn test_executor_basic() { - let llm = MockLLM; - let backend = Arc::new(MemoryBackend::new()); - let middleware = MiddlewareStack::new(); - - let executor = AgentExecutor::new(llm, middleware, backend); - - let initial_state = AgentState::with_messages(vec![ - Message::user("Hello!") - ]); - - let result = executor.run(initial_state).await.unwrap(); - - assert!(result.messages.len() >= 2); - assert!(result.last_assistant_message().is_some()); - } -} -``` - ---- - -## Phase 5: 최종 검증 및 문서화 - -### Task 5.1: 전체 테스트 및 Clippy - -**Step 1: 전체 테스트** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test` -Expected: 40+ 테스트 PASS - -**Step 2: Clippy** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo clippy -- -D warnings` -Expected: 경고 없음 - -**Step 3: 문서 생성** - -Run: `source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo doc --no-deps --open` -Expected: 문서 생성 성공 - ---- - -### Task 5.2: 최종 커밋 및 태그 - -```bash -git add rust-research-agent/ -git commit -m "feat: complete Phase 5-6 implementation - -Phase 1 (Security): -- Fix symlink path traversal in FilesystemBackend - -Phase 2 (Behavioral): -- Fix MemoryBackend ls boundary check -- Fix normalize_path to handle . segments -- Fix grep glob_filter to match full paths -- Document grep literal vs regex decision - -Phase 3 (Tools): -- Implement all 8 core tools -- Add default_tools() and all_tools() helpers - -Phase 4 (Executor): -- Add LLMProvider trait -- Implement AgentExecutor with tool execution loop -- Add state update application - -Test coverage: 45+ tests passing -All Clippy warnings resolved" - -git tag v0.2.0-alpha -``` - ---- - -## 수정 우선순위 요약 - -| 우선순위 | Phase | Task 수 | 예상 시간 | -|----------|-------|---------|-----------| -| 🔴 HIGH | Phase 1 (Security) | 1 | 30분 | -| 🟡 MEDIUM | Phase 2 (Behavioral) | 4 | 1시간 | -| 🔴 CRITICAL | Phase 3 (Tools) | 9 | 2시간 | -| 🔴 CRITICAL | Phase 4 (Executor) | 1 | 1시간 | -| 🟢 LOW | Phase 5 (Verification) | 2 | 30분 | - -**총 예상 시간:** 약 5시간 - ---- - -## 의존성 그래프 - -``` -Phase 1 (Security) - ↓ -Phase 2 (Behavioral) ──→ Phase 3 (Tools) - ↓ - Phase 4 (Executor) - ↓ - Phase 5 (Verification) -``` - -Phase 1과 2는 독립적으로 진행 가능. Phase 3는 2 완료 후, Phase 4는 3 완료 후 진행. diff --git a/docs/plans/2026-01-01-rig-deepagents-phase7-9.md b/docs/plans/2026-01-01-rig-deepagents-phase7-9.md deleted file mode 100644 index e48165c..0000000 --- a/docs/plans/2026-01-01-rig-deepagents-phase7-9.md +++ /dev/null @@ -1,1307 +0,0 @@ -# Rig-DeepAgents Phase 7-9 Implementation Plan - -> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task with TDD approach. - -**Goal:** Complete the Rig-DeepAgents Rust port by implementing LLM provider abstraction, SubAgent execution system, Skills middleware, and domain tools. - -**Architecture Principles:** -- **Abstraction First**: Design provider-agnostic interfaces before concrete implementations -- **TDD Approach**: Write failing tests first, then implement to make them pass -- **LangChain Patterns**: Reference langchain-openai/langchain-anthropic for proven patterns -- **Rig Integration**: Leverage rig-core's CompletionModel trait internally - -**Tech Stack:** Rust 1.75+, rig-core 0.27, tokio, async-trait, serde_json, tracing - -**Reference Sources:** -- LangChain Python: `langchain_openai`, `langchain_anthropic` packages -- Rig Core: `rig-core/src/completion/request.rs`, `rig-core/src/agent/completion.rs` -- Python DeepAgents: `research_agent/skills/`, `research_agent/subagents/` - ---- - -## Dependency Graph - -``` -Phase 7: LLM Provider Abstraction - │ - ▼ -Phase 8: SubAgent Execution System (requires Phase 7) - │ - ▼ -Phase 9a: Skills Middleware ──┬── Phase 9b: Domain Tools - (independent) │ (independent) - │ - ▼ - Integration Testing -``` - ---- - -## Phase 7: LLM Provider Abstraction (CRITICAL) - -### Overview - -Create a provider-agnostic LLM interface that bridges DeepAgents with Rig's CompletionModel. - -**Priority:** -1. OpenAI (gpt-4.1) - Primary target, matches Python reference -2. Anthropic (Claude) - Secondary target - -### Task 7.1: Create LLM Module Structure - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/mod.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/provider.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/config.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/message.rs` -- Modify: `rust-research-agent/crates/rig-deepagents/src/lib.rs` - -**Step 1: Create mod.rs** - -```rust -//! LLM Provider abstractions for DeepAgents -//! -//! This module provides provider-agnostic interfaces for LLM completion, -//! bridging DeepAgents with various LLM providers via Rig framework. - -mod provider; -mod config; -mod message; - -pub use provider::{LLMProvider, LLMResponse, LLMResponseStream}; -pub use config::{LLMConfig, TokenUsage}; -pub use message::{MessageConverter, ToolConverter}; -``` - ---- - -### Task 7.2: Define Core Types (TDD) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/config.rs` - -**Step 1: Write failing tests first** - -```rust -//! LLM configuration types - -use serde::{Deserialize, Serialize}; - -/// Token usage statistics -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] -pub struct TokenUsage { - pub input_tokens: u64, - pub output_tokens: u64, - pub total_tokens: u64, -} - -impl TokenUsage { - pub fn new(input: u64, output: u64) -> Self { - Self { - input_tokens: input, - output_tokens: output, - total_tokens: input + output, - } - } -} - -impl std::ops::Add for TokenUsage { - type Output = Self; - - fn add(self, other: Self) -> Self::Output { - Self { - input_tokens: self.input_tokens + other.input_tokens, - output_tokens: self.output_tokens + other.output_tokens, - total_tokens: self.total_tokens + other.total_tokens, - } - } -} - -impl std::ops::AddAssign for TokenUsage { - fn add_assign(&mut self, other: Self) { - self.input_tokens += other.input_tokens; - self.output_tokens += other.output_tokens; - self.total_tokens += other.total_tokens; - } -} - -/// LLM Provider configuration -#[derive(Clone, Debug, Default, Serialize, Deserialize)] -pub struct LLMConfig { - /// Model identifier (e.g., "gpt-4.1", "claude-3-opus") - pub model: String, - /// Sampling temperature (0.0 - 2.0) - pub temperature: Option, - /// Maximum tokens to generate - pub max_tokens: Option, - /// API key (optional, can use environment variable) - pub api_key: Option, - /// API base URL (optional, for custom endpoints) - pub api_base: Option, -} - -impl LLMConfig { - pub fn new(model: impl Into) -> Self { - Self { - model: model.into(), - ..Default::default() - } - } - - pub fn with_temperature(mut self, temp: f64) -> Self { - self.temperature = Some(temp); - self - } - - pub fn with_max_tokens(mut self, tokens: u64) -> Self { - self.max_tokens = Some(tokens); - self - } - - pub fn with_api_key(mut self, key: impl Into) -> Self { - self.api_key = Some(key.into()); - self - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_token_usage_add() { - let a = TokenUsage::new(100, 50); - let b = TokenUsage::new(200, 100); - let c = a + b; - - assert_eq!(c.input_tokens, 300); - assert_eq!(c.output_tokens, 150); - assert_eq!(c.total_tokens, 450); - } - - #[test] - fn test_llm_config_builder() { - let config = LLMConfig::new("gpt-4.1") - .with_temperature(0.7) - .with_max_tokens(4096); - - assert_eq!(config.model, "gpt-4.1"); - assert_eq!(config.temperature, Some(0.7)); - assert_eq!(config.max_tokens, Some(4096)); - } -} -``` - ---- - -### Task 7.3: Define LLMProvider Trait (TDD) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/provider.rs` - -**Step 1: Write trait definition with tests** - -```rust -//! LLM Provider trait definition - -use async_trait::async_trait; -use std::pin::Pin; -use futures::Stream; - -use crate::error::DeepAgentError; -use crate::state::Message; -use crate::middleware::ToolDefinition; -use super::config::{LLMConfig, TokenUsage}; - -/// LLM completion response -#[derive(Debug, Clone)] -pub struct LLMResponse { - /// The assistant's response message - pub message: Message, - /// Token usage statistics (if available) - pub usage: Option, -} - -impl LLMResponse { - pub fn new(message: Message) -> Self { - Self { message, usage: None } - } - - pub fn with_usage(mut self, usage: TokenUsage) -> Self { - self.usage = Some(usage); - self - } -} - -/// Streaming response chunk -#[derive(Debug, Clone)] -pub struct MessageChunk { - pub content: String, - pub is_final: bool, - pub usage: Option, -} - -/// Streaming response wrapper -pub struct LLMResponseStream { - inner: Pin> + Send>>, -} - -impl LLMResponseStream { - pub fn new(stream: S) -> Self - where - S: Stream> + Send + 'static, - { - Self { - inner: Box::pin(stream), - } - } - - /// Create from a complete (non-streaming) response - pub fn from_complete(response: LLMResponse) -> Self { - let content = response.message.content().unwrap_or_default().to_string(); - let chunk = MessageChunk { - content, - is_final: true, - usage: response.usage, - }; - Self::new(futures::stream::once(async move { Ok(chunk) })) - } -} - -/// Core LLM Provider trait -/// -/// Provides a provider-agnostic interface for LLM completion. -/// Implementations should bridge to specific providers (OpenAI, Anthropic, etc.) -/// via Rig's CompletionModel trait. -#[async_trait] -pub trait LLMProvider: Send + Sync { - /// Generate a completion response (non-streaming) - /// - /// # Arguments - /// * `messages` - Conversation history including the current prompt - /// * `tools` - Available tools for the model to call - /// * `config` - Optional runtime configuration overrides - async fn complete( - &self, - messages: &[Message], - tools: &[ToolDefinition], - config: Option<&LLMConfig>, - ) -> Result; - - /// Generate a streaming completion response - /// - /// Default implementation falls back to non-streaming. - async fn stream( - &self, - messages: &[Message], - tools: &[ToolDefinition], - config: Option<&LLMConfig>, - ) -> Result { - let response = self.complete(messages, tools, config).await?; - Ok(LLMResponseStream::from_complete(response)) - } - - /// Provider name for logging/debugging - fn name(&self) -> &str; - - /// Default model for this provider - fn default_model(&self) -> &str; -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::state::Role; - - struct MockProvider; - - #[async_trait] - impl LLMProvider for MockProvider { - async fn complete( - &self, - messages: &[Message], - _tools: &[ToolDefinition], - _config: Option<&LLMConfig>, - ) -> Result { - let last_content = messages.last() - .and_then(|m| m.content()) - .unwrap_or("Hello"); - - Ok(LLMResponse::new(Message::assistant(&format!("Echo: {}", last_content)))) - } - - fn name(&self) -> &str { - "mock" - } - - fn default_model(&self) -> &str { - "mock-model" - } - } - - #[tokio::test] - async fn test_mock_provider_complete() { - let provider = MockProvider; - let messages = vec![Message::user("Hello, world!")]; - - let response = provider.complete(&messages, &[], None).await.unwrap(); - - assert!(response.message.content().unwrap().contains("Echo:")); - assert!(response.message.content().unwrap().contains("Hello, world!")); - } - - #[tokio::test] - async fn test_stream_fallback() { - let provider = MockProvider; - let messages = vec![Message::user("Test")]; - - let stream = provider.stream(&messages, &[], None).await.unwrap(); - // Stream should work via fallback - assert!(true); // Stream created successfully - } -} -``` - ---- - -### Task 7.4: Message Conversion Layer (TDD) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/message.rs` - -**Step 1: Define conversion traits** - -```rust -//! Message conversion between DeepAgents and Rig formats - -use crate::state::{Message, Role, ToolCall}; -use crate::middleware::ToolDefinition; -use crate::error::DeepAgentError; -use rig::completion::{Message as RigMessage, ToolDefinition as RigToolDefinition}; - -/// Converts DeepAgents messages to Rig format -pub trait ToRigMessage { - fn to_rig_message(&self) -> Result; -} - -/// Converts Rig messages to DeepAgents format -pub trait FromRigMessage { - fn from_rig_message(msg: &RigMessage) -> Result - where - Self: Sized; -} - -/// Converts DeepAgents tool definitions to Rig format -pub trait ToRigTool { - fn to_rig_tool(&self) -> RigToolDefinition; -} - -impl ToRigMessage for Message { - fn to_rig_message(&self) -> Result { - match self.role { - Role::User => Ok(RigMessage::user(self.content().unwrap_or(""))), - Role::Assistant => { - if let Some(tool_calls) = &self.tool_calls { - // Handle assistant message with tool calls - // Rig uses AssistantContent::ToolCall for this - Ok(RigMessage::assistant(self.content().unwrap_or(""))) - } else { - Ok(RigMessage::assistant(self.content().unwrap_or(""))) - } - } - Role::System => { - // Rig handles system as preamble, not a message - // Return as user message with system prefix for compatibility - Ok(RigMessage::user(format!("[System]: {}", self.content().unwrap_or("")))) - } - Role::Tool => { - // Tool results need special handling - Ok(RigMessage::user(format!("[Tool Result]: {}", self.content().unwrap_or("")))) - } - } - } -} - -impl ToRigTool for ToolDefinition { - fn to_rig_tool(&self) -> RigToolDefinition { - RigToolDefinition { - name: self.name.clone(), - description: self.description.clone(), - parameters: self.parameters.clone(), - } - } -} - -/// Convert a slice of DeepAgents messages to Rig format -pub fn convert_messages(messages: &[Message]) -> Result, DeepAgentError> { - messages.iter().map(|m| m.to_rig_message()).collect() -} - -/// Convert a slice of tool definitions to Rig format -pub fn convert_tools(tools: &[ToolDefinition]) -> Vec { - tools.iter().map(|t| t.to_rig_tool()).collect() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_user_message_conversion() { - let msg = Message::user("Hello!"); - let rig_msg = msg.to_rig_message().unwrap(); - // Verify conversion worked (exact format depends on Rig internals) - assert!(true); - } - - #[test] - fn test_tool_definition_conversion() { - let tool = ToolDefinition { - name: "read_file".to_string(), - description: "Read a file".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "path": {"type": "string"} - } - }), - }; - - let rig_tool = tool.to_rig_tool(); - assert_eq!(rig_tool.name, "read_file"); - assert_eq!(rig_tool.description, "Read a file"); - } -} -``` - ---- - -### Task 7.5: OpenAI Provider Implementation (TDD) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/openai.rs` - -**Step 1: Write failing test first** - -```rust -//! OpenAI LLM Provider implementation via Rig - -use async_trait::async_trait; -use rig::providers::openai::{Client, CompletionModel}; -use std::sync::Arc; - -use super::provider::{LLMProvider, LLMResponse, LLMResponseStream}; -use super::config::{LLMConfig, TokenUsage}; -use super::message::{convert_messages, convert_tools}; -use crate::error::DeepAgentError; -use crate::state::Message; -use crate::middleware::ToolDefinition; - -/// OpenAI LLM Provider -pub struct OpenAIProvider { - client: Client, - default_model: String, - default_config: LLMConfig, -} - -impl OpenAIProvider { - /// Create a new OpenAI provider with API key from environment - pub fn from_env() -> Result { - Self::from_env_with_model("gpt-4.1") - } - - /// Create with specific model - pub fn from_env_with_model(model: impl Into) -> Result { - let client = Client::from_env(); - let model = model.into(); - - Ok(Self { - client, - default_model: model.clone(), - default_config: LLMConfig::new(model), - }) - } - - /// Create with explicit API key - pub fn new(api_key: impl Into, model: impl Into) -> Self { - let client = Client::new(&api_key.into()); - let model = model.into(); - - Self { - client, - default_model: model.clone(), - default_config: LLMConfig::new(model), - } - } - - /// Get effective config (override with runtime config if provided) - fn effective_config<'a>(&'a self, runtime: Option<&'a LLMConfig>) -> &'a LLMConfig { - runtime.unwrap_or(&self.default_config) - } -} - -#[async_trait] -impl LLMProvider for OpenAIProvider { - async fn complete( - &self, - messages: &[Message], - tools: &[ToolDefinition], - config: Option<&LLMConfig>, - ) -> Result { - let config = self.effective_config(config); - - // Get completion model from Rig - let model = self.client.completion_model(&config.model); - - // Convert messages and tools - let rig_messages = convert_messages(messages)?; - let rig_tools = convert_tools(tools); - - // Build completion request - let mut request_builder = model.completion_request( - rig_messages.last().cloned().unwrap_or_else(|| rig::message::Message::user("")) - ); - - // Add chat history (all but last message) - if rig_messages.len() > 1 { - request_builder = request_builder.messages(rig_messages[..rig_messages.len()-1].to_vec()); - } - - // Add tools - request_builder = request_builder.tools(rig_tools); - - // Apply config - if let Some(temp) = config.temperature { - request_builder = request_builder.temperature(temp); - } - if let Some(max) = config.max_tokens { - request_builder = request_builder.max_tokens(max); - } - - // Execute - let response = request_builder.send().await - .map_err(|e| DeepAgentError::LLMError(e.to_string()))?; - - // Convert response - let content = response.choice.first() - .map(|c| c.to_string()) - .unwrap_or_default(); - - let usage = TokenUsage { - input_tokens: response.usage.input_tokens, - output_tokens: response.usage.output_tokens, - total_tokens: response.usage.total_tokens, - }; - - Ok(LLMResponse::new(Message::assistant(&content)).with_usage(usage)) - } - - fn name(&self) -> &str { - "openai" - } - - fn default_model(&self) -> &str { - &self.default_model - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore] // Requires OPENAI_API_KEY - async fn test_openai_provider_complete() { - let provider = OpenAIProvider::from_env().unwrap(); - let messages = vec![Message::user("Say hello in exactly 3 words.")]; - - let response = provider.complete(&messages, &[], None).await.unwrap(); - - assert!(!response.message.content().unwrap().is_empty()); - assert!(response.usage.is_some()); - } - - #[test] - fn test_openai_provider_creation() { - // This test doesn't make API calls, just verifies construction - let provider = OpenAIProvider::new("test-key", "gpt-4.1"); - assert_eq!(provider.name(), "openai"); - assert_eq!(provider.default_model(), "gpt-4.1"); - } -} -``` - ---- - -### Task 7.6: Anthropic Provider Implementation (TDD) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/llm/anthropic.rs` - -**Step 1: Similar structure to OpenAI** - -```rust -//! Anthropic (Claude) LLM Provider implementation via Rig - -use async_trait::async_trait; -use rig::providers::anthropic::{Client, CompletionModel, CLAUDE_3_5_SONNET}; -use std::sync::Arc; - -use super::provider::{LLMProvider, LLMResponse, LLMResponseStream}; -use super::config::{LLMConfig, TokenUsage}; -use super::message::{convert_messages, convert_tools}; -use crate::error::DeepAgentError; -use crate::state::Message; -use crate::middleware::ToolDefinition; - -/// Anthropic (Claude) LLM Provider -pub struct AnthropicProvider { - client: Client, - default_model: String, - default_config: LLMConfig, -} - -impl AnthropicProvider { - pub fn from_env() -> Result { - Self::from_env_with_model(CLAUDE_3_5_SONNET) - } - - pub fn from_env_with_model(model: impl Into) -> Result { - let client = Client::from_env(); - let model = model.into(); - - Ok(Self { - client, - default_model: model.clone(), - default_config: LLMConfig::new(model).with_max_tokens(4096), // Anthropic requires max_tokens - }) - } - - pub fn new(api_key: impl Into, model: impl Into) -> Self { - let client = Client::new(&api_key.into()); - let model = model.into(); - - Self { - client, - default_model: model.clone(), - default_config: LLMConfig::new(model).with_max_tokens(4096), - } - } -} - -#[async_trait] -impl LLMProvider for AnthropicProvider { - async fn complete( - &self, - messages: &[Message], - tools: &[ToolDefinition], - config: Option<&LLMConfig>, - ) -> Result { - let config = config.unwrap_or(&self.default_config); - - let model = self.client.completion_model(&config.model); - let rig_messages = convert_messages(messages)?; - let rig_tools = convert_tools(tools); - - let mut request_builder = model.completion_request( - rig_messages.last().cloned().unwrap_or_else(|| rig::message::Message::user("")) - ); - - if rig_messages.len() > 1 { - request_builder = request_builder.messages(rig_messages[..rig_messages.len()-1].to_vec()); - } - - request_builder = request_builder.tools(rig_tools); - - if let Some(temp) = config.temperature { - request_builder = request_builder.temperature(temp); - } - - // Anthropic requires max_tokens - let max_tokens = config.max_tokens.unwrap_or(4096); - request_builder = request_builder.max_tokens(max_tokens); - - let response = request_builder.send().await - .map_err(|e| DeepAgentError::LLMError(e.to_string()))?; - - let content = response.choice.first() - .map(|c| c.to_string()) - .unwrap_or_default(); - - let usage = TokenUsage { - input_tokens: response.usage.input_tokens, - output_tokens: response.usage.output_tokens, - total_tokens: response.usage.total_tokens, - }; - - Ok(LLMResponse::new(Message::assistant(&content)).with_usage(usage)) - } - - fn name(&self) -> &str { - "anthropic" - } - - fn default_model(&self) -> &str { - &self.default_model - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[ignore] // Requires ANTHROPIC_API_KEY - async fn test_anthropic_provider_complete() { - let provider = AnthropicProvider::from_env().unwrap(); - let messages = vec![Message::user("Say hello in exactly 3 words.")]; - - let response = provider.complete(&messages, &[], None).await.unwrap(); - - assert!(!response.message.content().unwrap().is_empty()); - } -} -``` - ---- - -### Task 7.7: Update Executor to Use LLMProvider - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/executor.rs` - -**Step 1: Replace old LLMProvider with new one** - -```rust -// Update imports -use crate::llm::{LLMProvider, LLMConfig}; - -// Update AgentExecutor to use Arc -pub struct AgentExecutor { - llm: Arc, - middleware: MiddlewareStack, - backend: Arc, - max_iterations: usize, - config: Option, -} - -impl AgentExecutor { - pub fn new( - llm: Arc, - middleware: MiddlewareStack, - backend: Arc, - ) -> Self { - Self { - llm, - middleware, - backend, - max_iterations: 50, - config: None, - } - } - - pub fn with_config(mut self, config: LLMConfig) -> Self { - self.config = Some(config); - self - } - - // Update run() to use new LLMProvider interface - pub async fn run(&self, initial_state: AgentState) -> Result { - // ... existing logic ... - - // LLM call changes to: - let response = self.llm.complete( - &state.messages, - &tool_definitions, - self.config.as_ref(), - ).await?; - - state.add_message(response.message); - - // ... rest of loop ... - } -} -``` - ---- - -### Task 7.8: Update lib.rs Exports - -**Files:** -- Modify: `rust-research-agent/crates/rig-deepagents/src/lib.rs` - -```rust -// Add llm module -pub mod llm; - -// Re-exports -pub use llm::{ - LLMProvider, LLMResponse, LLMConfig, TokenUsage, - OpenAIProvider, AnthropicProvider, -}; -``` - ---- - -### Task 7.9: Verification - -**Step 1: Run all tests** - -```bash -source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo test -``` - -**Step 2: Run clippy** - -```bash -source ~/.cargo/env && cd rust-research-agent/crates/rig-deepagents && cargo clippy -- -D warnings -``` - -**Step 3: Commit** - -```bash -git add rust-research-agent/crates/rig-deepagents/ -git commit -m "feat(llm): implement LLMProvider abstraction with OpenAI and Anthropic support - -Phase 7 implementation: -- Add LLMProvider trait for provider-agnostic LLM access -- Implement OpenAIProvider via rig-core -- Implement AnthropicProvider via rig-core -- Add message/tool conversion layer -- Update AgentExecutor to use new LLMProvider - -All providers tested with TDD approach." -``` - ---- - -## Phase 8: SubAgent Execution System - -### Overview - -Implement SubAgent registration, execution, and isolated context management. - -**Python Reference:** `research_agent/subagents/registry.py`, `research_agent/subagents/definitions.py` - -### Task 8.1: Create SubAgent Module Structure - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/subagent/mod.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/subagent/definition.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/subagent/registry.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/subagent/executor.rs` - -### Task 8.2: Define SubAgent Types (TDD) - -```rust -//! SubAgent definition types - -use serde::{Deserialize, Serialize}; - -/// SubAgent type classification -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub enum SubAgentType { - /// Simple single-response agent (prompt-based) - Simple, - /// Compiled multi-turn agent (StateGraph-based) - Compiled, -} - -/// SubAgent definition -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SubAgentDefinition { - pub name: String, - pub description: String, - pub agent_type: SubAgentType, - pub system_prompt: Option, - pub tools: Vec, // Tool names to include - pub max_iterations: Option, -} - -/// SubAgent execution context -pub struct SubAgentContext { - pub parent_state: AgentState, - pub recursion_depth: usize, - pub isolated_backend: Arc, -} -``` - -### Task 8.3: Implement SubAgent Registry (TDD) - -```rust -//! SubAgent registry for dynamic agent lookup - -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; - -pub struct SubAgentRegistry { - agents: RwLock>, -} - -impl SubAgentRegistry { - pub fn new() -> Self { - Self { - agents: RwLock::new(HashMap::new()), - } - } - - pub async fn register(&self, definition: SubAgentDefinition) { - self.agents.write().await.insert(definition.name.clone(), definition); - } - - pub async fn get(&self, name: &str) -> Option { - self.agents.read().await.get(name).cloned() - } - - pub async fn list(&self) -> Vec { - self.agents.read().await.values().cloned().collect() - } -} -``` - -### Task 8.4: Implement SubAgent Executor (TDD) - -```rust -//! SubAgent execution logic - -pub struct SubAgentExecutor { - registry: Arc, - llm: Arc, - max_recursion_depth: usize, -} - -impl SubAgentExecutor { - pub async fn execute( - &self, - agent_name: &str, - prompt: &str, - context: SubAgentContext, - ) -> Result { - // Check recursion limit - if context.recursion_depth >= self.max_recursion_depth { - return Err(DeepAgentError::RecursionLimit( - format!("Max recursion depth {} exceeded", self.max_recursion_depth) - )); - } - - // Get agent definition - let definition = self.registry.get(agent_name).await - .ok_or_else(|| DeepAgentError::SubAgent( - format!("Unknown SubAgent: {}", agent_name) - ))?; - - match definition.agent_type { - SubAgentType::Simple => self.execute_simple(&definition, prompt, context).await, - SubAgentType::Compiled => self.execute_compiled(&definition, prompt, context).await, - } - } - - async fn execute_simple( - &self, - definition: &SubAgentDefinition, - prompt: &str, - context: SubAgentContext, - ) -> Result { - // Single LLM call with system prompt - let messages = vec![ - Message::system(definition.system_prompt.as_deref().unwrap_or("")), - Message::user(prompt), - ]; - - let response = self.llm.complete(&messages, &[], None).await?; - Ok(response.message.content().unwrap_or("").to_string()) - } - - async fn execute_compiled( - &self, - definition: &SubAgentDefinition, - prompt: &str, - context: SubAgentContext, - ) -> Result { - // Multi-turn execution with tools - // Create a child AgentExecutor with isolated context - // ... implementation details ... - todo!("Compiled SubAgent execution") - } -} -``` - -### Task 8.5: Update TaskTool to Use SubAgentExecutor - -Update `tools/task.rs` to delegate to `SubAgentExecutor`. - -### Task 8.6: Verification & Commit - ---- - -## Phase 9a: Skills Middleware - -### Overview - -Implement Skills middleware with progressive disclosure pattern from Python. - -**Python Reference:** `research_agent/skills/middleware.py`, `research_agent/skills/load.py` - -### Task 9a.1: Create Skills Module Structure - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/skills/mod.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/skills/loader.rs` -- Create: `rust-research-agent/crates/rig-deepagents/src/skills/middleware.rs` - -### Task 9a.2: Define Skill Types (TDD) - -```rust -//! Skill definition types - -use serde::{Deserialize, Serialize}; -use std::path::PathBuf; - -/// Skill metadata parsed from SKILL.md frontmatter -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SkillMetadata { - pub name: String, - pub description: String, - pub path: PathBuf, -} - -/// Full skill content (loaded on demand) -#[derive(Debug, Clone)] -pub struct SkillContent { - pub metadata: SkillMetadata, - pub instructions: String, -} -``` - -### Task 9a.3: Implement Skill Loader (TDD) - -```rust -//! Skill loading from filesystem - -pub struct SkillLoader { - skills_dir: PathBuf, -} - -impl SkillLoader { - pub fn new(skills_dir: impl Into) -> Self { - Self { skills_dir: skills_dir.into() } - } - - /// List all skills (metadata only - progressive disclosure) - pub async fn list_skills(&self) -> Result, DeepAgentError> { - // Scan skills_dir for SKILL.md files - // Parse YAML frontmatter for name + description - todo!() - } - - /// Load full skill content by name - pub async fn load_skill(&self, name: &str) -> Result { - // Read full SKILL.md content - todo!() - } -} -``` - -### Task 9a.4: Implement SkillsMiddleware (TDD) - -```rust -//! Skills middleware with progressive disclosure - -pub struct SkillsMiddleware { - loader: SkillLoader, - cached_metadata: RwLock>, -} - -#[async_trait] -impl AgentMiddleware for SkillsMiddleware { - async fn modify_prompt(&self, prompt: &str, _state: &AgentState) -> String { - // Inject skill metadata into system prompt - let skills = self.cached_metadata.read().await; - let skills_section = format_skills_prompt(&skills); - format!("{}\n\n{}", prompt, skills_section) - } - - async fn before_agent( - &self, - state: &mut AgentState, - _runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { - // Load skill metadata on first call - if self.cached_metadata.read().await.is_empty() { - let metadata = self.loader.list_skills().await?; - *self.cached_metadata.write().await = metadata; - } - Ok(None) - } -} -``` - ---- - -## Phase 9b: Domain Tools - -### Overview - -Implement research-specific tools matching Python reference. - -**Python Reference:** `research_agent/tools.py` - -### Task 9b.1: Implement TavilySearchTool (TDD) - -**Files:** -- Create: `rust-research-agent/crates/rig-deepagents/src/tools/tavily.rs` - -```rust -//! Tavily web search tool - -use reqwest::Client; -use serde::{Deserialize, Serialize}; - -pub struct TavilySearchTool { - client: Client, - api_key: String, -} - -#[derive(Debug, Serialize)] -struct TavilyRequest { - query: String, - max_results: usize, - search_depth: String, - include_raw_content: bool, -} - -#[derive(Debug, Deserialize)] -struct TavilyResponse { - results: Vec, -} - -#[derive(Debug, Deserialize)] -struct TavilyResult { - title: String, - url: String, - content: String, - raw_content: Option, -} - -#[async_trait] -impl Tool for TavilySearchTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "tavily_search".to_string(), - description: "Search the web using Tavily API and retrieve relevant content.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "Search query" - }, - "max_results": { - "type": "integer", - "description": "Maximum number of results", - "default": 5 - }, - "topic": { - "type": "string", - "enum": ["general", "news"], - "default": "general" - } - }, - "required": ["query"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - // Implement Tavily API call - todo!() - } -} -``` - -### Task 9b.2: Implement ThinkTool (TDD) - -```rust -//! Think tool for explicit reasoning - -pub struct ThinkTool; - -#[async_trait] -impl Tool for ThinkTool { - fn definition(&self) -> ToolDefinition { - ToolDefinition { - name: "think".to_string(), - description: "Use this tool for explicit reflection and reasoning before making decisions.".to_string(), - parameters: serde_json::json!({ - "type": "object", - "properties": { - "reflection": { - "type": "string", - "description": "Your current thinking and reasoning" - } - }, - "required": ["reflection"] - }), - } - } - - async fn execute( - &self, - args: serde_json::Value, - _runtime: &ToolRuntime, - ) -> Result { - let reflection: String = serde_json::from_value(args["reflection"].clone()) - .unwrap_or_default(); - - // Think tool just returns the reflection - it's for prompting explicit reasoning - Ok(format!("Reflection recorded: {}", reflection)) - } -} -``` - -### Task 9b.3: Update tools/mod.rs - -Add new tools to exports and `all_tools()` function. - ---- - -## Summary - -| Phase | Tasks | Priority | Est. Time | -|-------|-------|----------|-----------| -| **Phase 7** | LLM Provider Abstraction | 🔴 CRITICAL | 3-4 hours | -| **Phase 8** | SubAgent Execution | 🔴 CRITICAL | 3-4 hours | -| **Phase 9a** | Skills Middleware | 🟡 HIGH | 2-3 hours | -| **Phase 9b** | Domain Tools | 🟡 HIGH | 2-3 hours | - -**Total Estimated Time:** 10-14 hours - -**TDD Verification at Each Phase:** -1. Write failing test -2. Implement minimum code to pass -3. Refactor if needed -4. Run full test suite -5. Run clippy -6. Commit - ---- - -## Error Type Additions - -Add to `src/error.rs`: - -```rust -#[error("LLM error: {0}")] -LLMError(String), - -#[error("Recursion limit exceeded: {0}")] -RecursionLimit(String), - -#[error("Skill error: {0}")] -SkillError(String), -``` diff --git a/rust-research-agent/crates/rig-deepagents/Cargo.toml b/rust-research-agent/crates/rig-deepagents/Cargo.toml index 0147580..b25ba85 100644 --- a/rust-research-agent/crates/rig-deepagents/Cargo.toml +++ b/rust-research-agent/crates/rig-deepagents/Cargo.toml @@ -5,6 +5,12 @@ edition = "2021" description = "DeepAgents-style middleware system for Rig framework" license = "MIT" +[features] +default = [] +checkpointer-sqlite = [] +checkpointer-redis = [] +checkpointer-postgres = [] + [dependencies] rig-core = { version = "0.27", features = ["derive"] } tokio = { version = "1", features = ["full", "sync"] } @@ -20,12 +26,18 @@ uuid = { version = "1", features = ["v4"] } walkdir = "2" # Added: needed for FilesystemBackend recursive traversal futures = "0.3" # Added: needed for LLMProvider streaming support +# Pregel runtime dependencies +num_cpus = "1" # For default parallelism configuration +humantime-serde = "1" # For Duration serialization in configs +zstd = "0.13" # For checkpoint compression + [dev-dependencies] # OpenAI support is built into rig-core tokio-test = "0.4" dotenv = "0.15" criterion = { version = "0.5", features = ["async_tokio"] } tempfile = "3" # For filesystem tests +static_assertions = "1" # For compile-time trait checks [[bench]] name = "middleware_benchmark" diff --git a/rust-research-agent/crates/rig-deepagents/src/lib.rs b/rust-research-agent/crates/rig-deepagents/src/lib.rs index 5594409..30e567a 100644 --- a/rust-research-agent/crates/rig-deepagents/src/lib.rs +++ b/rust-research-agent/crates/rig-deepagents/src/lib.rs @@ -32,6 +32,8 @@ pub mod runtime; pub mod executor; pub mod tools; pub mod llm; +pub mod pregel; +pub mod workflow; // Re-exports for convenience pub use error::{BackendError, MiddlewareError, DeepAgentError, WriteResult, EditResult}; diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/file.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/file.rs new file mode 100644 index 0000000..85240fc --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/file.rs @@ -0,0 +1,438 @@ +//! File-based Checkpointer Implementation +//! +//! Stores checkpoints as JSON files in a directory structure. +//! Supports optional compression via zstd for reduced storage. +//! +//! # Directory Structure +//! +//! ```text +//! checkpoints/ +//! └── {workflow_id}/ +//! ├── checkpoint_00001.json[.zst] +//! ├── checkpoint_00005.json[.zst] +//! └── checkpoint_00010.json[.zst] +//! ``` + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::io::Write; +use std::path::{Path, PathBuf}; +use tokio::fs; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +use super::{Checkpoint, Checkpointer}; +use crate::pregel::error::PregelError; +use crate::pregel::state::WorkflowState; + +/// File-based checkpointer that stores checkpoints as JSON files. +/// +/// Each checkpoint is stored in a separate file, named by superstep number. +/// Atomic writes are ensured via temporary file + rename pattern. +#[derive(Debug)] +pub struct FileCheckpointer { + /// Workflow-specific subdirectory + workflow_path: PathBuf, + /// Whether to compress checkpoints with zstd + compression: bool, +} + +impl FileCheckpointer { + /// Create a new file-based checkpointer. + /// + /// # Arguments + /// + /// * `base_path` - Base directory for storing checkpoints + /// * `workflow_id` - Unique identifier for this workflow + /// * `compression` - Whether to compress checkpoint data + pub fn new(base_path: impl Into, workflow_id: impl AsRef, compression: bool) -> Self { + let base_path = base_path.into(); + let workflow_path = base_path.join(workflow_id.as_ref()); + + Self { + workflow_path, + compression, + } + } + + /// Get the file path for a checkpoint at a given superstep + fn checkpoint_path(&self, superstep: usize) -> PathBuf { + let filename = if self.compression { + format!("checkpoint_{:05}.json.zst", superstep) + } else { + format!("checkpoint_{:05}.json", superstep) + }; + self.workflow_path.join(filename) + } + + /// Get the temporary file path for atomic writes + fn temp_path(&self, superstep: usize) -> PathBuf { + let filename = format!("checkpoint_{:05}.tmp", superstep); + self.workflow_path.join(filename) + } + + /// Ensure the checkpoint directory exists + async fn ensure_dir(&self) -> Result<(), PregelError> { + fs::create_dir_all(&self.workflow_path) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to create directory: {}", e))) + } + + /// Compress data using zstd + fn compress(data: &[u8]) -> Result, PregelError> { + let mut encoder = zstd::stream::Encoder::new(Vec::new(), 3) + .map_err(|e| PregelError::checkpoint_error(format!("Compression init failed: {}", e)))?; + encoder + .write_all(data) + .map_err(|e| PregelError::checkpoint_error(format!("Compression write failed: {}", e)))?; + encoder + .finish() + .map_err(|e| PregelError::checkpoint_error(format!("Compression finish failed: {}", e))) + } + + /// Decompress data using zstd + fn decompress(data: &[u8]) -> Result, PregelError> { + zstd::stream::decode_all(data) + .map_err(|e| PregelError::checkpoint_error(format!("Decompression failed: {}", e))) + } + + /// Parse superstep number from filename + fn parse_superstep(path: &Path) -> Option { + let filename = path.file_name()?.to_str()?; + if !filename.starts_with("checkpoint_") { + return None; + } + + // Extract the number between "checkpoint_" and the extension + let num_part = filename + .strip_prefix("checkpoint_")? + .split('.') + .next()?; + + num_part.parse().ok() + } + + /// List all superstep numbers (non-generic helper method) + async fn list_supersteps(&self) -> Result, PregelError> { + if !self.workflow_path.exists() { + return Ok(Vec::new()); + } + + let mut entries = fs::read_dir(&self.workflow_path) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to read directory: {}", e)))?; + + let mut supersteps = Vec::new(); + + while let Some(entry) = entries + .next_entry() + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to read entry: {}", e)))? + { + if let Some(superstep) = Self::parse_superstep(&entry.path()) { + supersteps.push(superstep); + } + } + + supersteps.sort(); + Ok(supersteps) + } +} + +#[async_trait] +impl Checkpointer for FileCheckpointer +where + S: WorkflowState + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>, +{ + async fn save(&self, checkpoint: &Checkpoint) -> Result<(), PregelError> { + self.ensure_dir().await?; + + // Serialize checkpoint + let json = serde_json::to_vec_pretty(checkpoint) + .map_err(|e| PregelError::checkpoint_error(format!("Serialization failed: {}", e)))?; + + // Optionally compress + let data = if self.compression { + Self::compress(&json)? + } else { + json + }; + + // Write to temp file first (atomic write pattern) + let temp_path = self.temp_path(checkpoint.superstep); + let final_path = self.checkpoint_path(checkpoint.superstep); + + let mut file = fs::File::create(&temp_path) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to create temp file: {}", e)))?; + + file.write_all(&data) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to write data: {}", e)))?; + + file.sync_all() + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to sync file: {}", e)))?; + + // Atomic rename + fs::rename(&temp_path, &final_path) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to rename file: {}", e)))?; + + Ok(()) + } + + async fn load(&self, superstep: usize) -> Result>, PregelError> { + let path = self.checkpoint_path(superstep); + + if !path.exists() { + return Ok(None); + } + + let mut file = fs::File::open(&path) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to open file: {}", e)))?; + + let mut data = Vec::new(); + file.read_to_end(&mut data) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to read file: {}", e)))?; + + // Decompress if needed + let json = if self.compression { + Self::decompress(&data)? + } else { + data + }; + + let checkpoint: Checkpoint = serde_json::from_slice(&json) + .map_err(|e| PregelError::checkpoint_error(format!("Deserialization failed: {}", e)))?; + + Ok(Some(checkpoint)) + } + + async fn latest(&self) -> Result>, PregelError> { + // Use our own list_supersteps to avoid type inference issues + let supersteps = self.list_supersteps().await?; + match supersteps.last() { + Some(&superstep) => self.load(superstep).await, + None => Ok(None), + } + } + + async fn list(&self) -> Result, PregelError> { + self.list_supersteps().await + } + + async fn delete(&self, superstep: usize) -> Result<(), PregelError> { + let path = self.checkpoint_path(superstep); + + if path.exists() { + fs::remove_file(&path) + .await + .map_err(|e| PregelError::checkpoint_error(format!("Failed to delete file: {}", e)))?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::pregel::state::UnitState; + use crate::pregel::vertex::{VertexId, VertexState}; + use std::collections::HashMap; + use tempfile::tempdir; + + #[tokio::test] + async fn test_file_checkpointer_save_load() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "test-workflow", false); + + let checkpoint = Checkpoint::new( + "test-workflow", + 5, + UnitState, + HashMap::new(), + HashMap::new(), + ); + + checkpointer.save(&checkpoint).await.unwrap(); + let loaded: Checkpoint = checkpointer.load(5).await.unwrap().unwrap(); + + assert_eq!(loaded.superstep, 5); + assert_eq!(loaded.workflow_id, "test-workflow"); + } + + #[tokio::test] + async fn test_file_checkpointer_with_compression() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "compressed-workflow", true); + + // Create a checkpoint with some data + let mut vertex_states = HashMap::new(); + vertex_states.insert(VertexId::new("vertex1"), VertexState::Active); + vertex_states.insert(VertexId::new("vertex2"), VertexState::Halted); + + let checkpoint = Checkpoint::new( + "compressed-workflow", + 10, + UnitState, + vertex_states, + HashMap::new(), + ); + + checkpointer.save(&checkpoint).await.unwrap(); + + // Verify the file is compressed (has .zst extension) + let path = temp_dir.path().join("compressed-workflow/checkpoint_00010.json.zst"); + assert!(path.exists()); + + // Load and verify + let loaded: Checkpoint = checkpointer.load(10).await.unwrap().unwrap(); + assert_eq!(loaded.superstep, 10); + assert_eq!(loaded.vertex_states.len(), 2); + } + + #[tokio::test] + async fn test_file_checkpointer_load_nonexistent() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "test-workflow", false); + + let result: Option> = checkpointer.load(999).await.unwrap(); + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_file_checkpointer_list() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "test-workflow", false); + + // Save checkpoints at supersteps 5, 1, 10 (out of order) + for superstep in [5, 1, 10] { + let checkpoint = Checkpoint::new( + "test-workflow", + superstep, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + } + + let list = >::list(&checkpointer).await.unwrap(); + assert_eq!(list, vec![1, 5, 10]); // Should be sorted + } + + #[tokio::test] + async fn test_file_checkpointer_latest() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "test-workflow", false); + + // Save checkpoints + for superstep in [1, 5, 3] { + let checkpoint = Checkpoint::new( + "test-workflow", + superstep, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + } + + let latest: Checkpoint = checkpointer.latest().await.unwrap().unwrap(); + assert_eq!(latest.superstep, 5); + } + + #[tokio::test] + async fn test_file_checkpointer_delete() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "test-workflow", false); + + let checkpoint = Checkpoint::new( + "test-workflow", + 5, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + + // Verify exists + let exists: Option> = checkpointer.load(5).await.unwrap(); + assert!(exists.is_some()); + + // Delete + >::delete(&checkpointer, 5).await.unwrap(); + + // Verify gone + let gone: Option> = checkpointer.load(5).await.unwrap(); + assert!(gone.is_none()); + } + + #[tokio::test] + async fn test_file_checkpointer_prune() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "test-workflow", false); + + // Create 5 checkpoints + for superstep in 1..=5 { + let checkpoint = Checkpoint::new( + "test-workflow", + superstep, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + } + + // Prune to keep only 2 + let deleted = >::prune(&checkpointer, 2).await.unwrap(); + assert_eq!(deleted, 3); + + let remaining = >::list(&checkpointer).await.unwrap(); + assert_eq!(remaining, vec![4, 5]); + } + + #[tokio::test] + async fn test_file_checkpointer_atomic_write() { + let temp_dir = tempdir().unwrap(); + let checkpointer = FileCheckpointer::new(temp_dir.path(), "test-workflow", false); + + let checkpoint = Checkpoint::new( + "test-workflow", + 7, + UnitState, + HashMap::new(), + HashMap::new(), + ); + + checkpointer.save(&checkpoint).await.unwrap(); + + // Verify no temp file remains + let temp_path = temp_dir.path().join("test-workflow/checkpoint_00007.tmp"); + assert!(!temp_path.exists()); + + // Verify final file exists + let final_path = temp_dir.path().join("test-workflow/checkpoint_00007.json"); + assert!(final_path.exists()); + } + + #[test] + fn test_parse_superstep() { + assert_eq!( + FileCheckpointer::parse_superstep(Path::new("checkpoint_00005.json")), + Some(5) + ); + assert_eq!( + FileCheckpointer::parse_superstep(Path::new("checkpoint_00123.json.zst")), + Some(123) + ); + assert_eq!( + FileCheckpointer::parse_superstep(Path::new("other_file.json")), + None + ); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/mod.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/mod.rs new file mode 100644 index 0000000..0255537 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/checkpoint/mod.rs @@ -0,0 +1,551 @@ +//! Checkpointing System for Pregel Runtime +//! +//! Provides durable state persistence for fault-tolerant workflow execution. +//! Checkpoints capture the complete workflow state at superstep boundaries, +//! enabling recovery from failures without losing progress. +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────┐ +//! │ Checkpointer │ +//! │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ +//! │ │ File │ │ SQLite │ │ Redis │ │ Postgres │ │ +//! │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ +//! │ │ │ │ │ │ +//! │ └────────────┴────────────┴────────────┘ │ +//! │ │ │ +//! │ ▼ │ +//! │ Checkpoint │ +//! └─────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Usage +//! +//! ```ignore +//! use rig_deepagents::pregel::checkpoint::{Checkpointer, CheckpointerConfig, create_checkpointer}; +//! +//! // Create a file-based checkpointer +//! let config = CheckpointerConfig::File { +//! path: PathBuf::from("./checkpoints"), +//! compression: true, +//! }; +//! let checkpointer = create_checkpointer::(config)?; +//! +//! // Save a checkpoint +//! checkpointer.save(&checkpoint).await?; +//! +//! // Load the latest checkpoint +//! if let Some(checkpoint) = checkpointer.latest().await? { +//! // Resume from checkpoint +//! } +//! ``` + +mod file; + +pub use file::FileCheckpointer; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt::Debug; +use std::path::PathBuf; + +use super::error::PregelError; +use super::message::WorkflowMessage; +use super::state::WorkflowState; +use super::vertex::{VertexId, VertexState}; + +/// A checkpoint captures the complete workflow state at a superstep boundary. +/// +/// Checkpoints are the foundation of fault tolerance in the Pregel runtime. +/// They capture: +/// - The workflow state (user-defined data) +/// - The state of each vertex (Active, Halted, Completed) +/// - Pending messages that haven't been delivered yet +/// - Metadata for debugging and recovery +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Checkpoint +where + S: WorkflowState, +{ + /// Unique identifier for the workflow instance + pub workflow_id: String, + + /// The superstep number when this checkpoint was created + pub superstep: usize, + + /// The workflow state at this superstep + pub state: S, + + /// The state of each vertex (Active, Halted, Completed) + pub vertex_states: HashMap, + + /// Pending messages waiting to be delivered in the next superstep + pub pending_messages: HashMap>, + + /// When this checkpoint was created + pub timestamp: DateTime, + + /// Optional metadata for debugging or external tools + #[serde(default)] + pub metadata: HashMap, +} + +impl Checkpoint +where + S: WorkflowState, +{ + /// Create a new checkpoint + pub fn new( + workflow_id: impl Into, + superstep: usize, + state: S, + vertex_states: HashMap, + pending_messages: HashMap>, + ) -> Self { + Self { + workflow_id: workflow_id.into(), + superstep, + state, + vertex_states, + pending_messages, + timestamp: Utc::now(), + metadata: HashMap::new(), + } + } + + /// Add metadata to this checkpoint + pub fn with_metadata(mut self, key: impl Into, value: impl Into) -> Self { + self.metadata.insert(key.into(), value.into()); + self + } + + /// Check if this checkpoint is empty (no vertex states or messages) + pub fn is_empty(&self) -> bool { + self.vertex_states.is_empty() && self.pending_messages.is_empty() + } + + /// Get the total number of pending messages across all vertices + pub fn pending_message_count(&self) -> usize { + self.pending_messages.values().map(|v| v.len()).sum() + } +} + +/// Trait for checkpointing workflow state. +/// +/// Implementations provide durable storage for checkpoints, enabling +/// recovery from failures and inspection of workflow history. +#[async_trait] +pub trait Checkpointer: Send + Sync +where + S: WorkflowState + Send + Sync, +{ + /// Save a checkpoint. + /// + /// Implementations should ensure atomic writes to prevent corruption. + async fn save(&self, checkpoint: &Checkpoint) -> Result<(), PregelError>; + + /// Load a checkpoint by superstep number. + /// + /// Returns `None` if no checkpoint exists for that superstep. + async fn load(&self, superstep: usize) -> Result>, PregelError>; + + /// Load the latest checkpoint. + /// + /// Returns `None` if no checkpoints exist. + async fn latest(&self) -> Result>, PregelError>; + + /// List all available checkpoint superstep numbers, sorted ascending. + async fn list(&self) -> Result, PregelError>; + + /// Delete a specific checkpoint. + async fn delete(&self, superstep: usize) -> Result<(), PregelError>; + + /// Prune checkpoints, keeping only the most recent `keep` checkpoints. + /// + /// This is useful for managing storage space in long-running workflows. + async fn prune(&self, keep: usize) -> Result { + let checkpoints = self.list().await?; + let to_delete = checkpoints.len().saturating_sub(keep); + let mut deleted = 0; + + for superstep in checkpoints.into_iter().take(to_delete) { + self.delete(superstep).await?; + deleted += 1; + } + + Ok(deleted) + } + + /// Clear all checkpoints for this workflow. + async fn clear(&self) -> Result<(), PregelError> { + for superstep in self.list().await? { + self.delete(superstep).await?; + } + Ok(()) + } +} + +/// Configuration for creating checkpointers. +/// +/// Use with `create_checkpointer()` to instantiate the appropriate backend. +#[derive(Debug, Clone, Default)] +pub enum CheckpointerConfig { + /// In-memory checkpointing (for testing only, not durable) + #[default] + Memory, + + /// File-based checkpointing + File { + /// Directory to store checkpoint files + path: PathBuf, + /// Whether to compress checkpoint data (uses zstd) + compression: bool, + }, + + /// SQLite-based checkpointing (requires `checkpointer-sqlite` feature) + #[cfg(feature = "checkpointer-sqlite")] + Sqlite { + /// Path to the SQLite database file, or `:memory:` for in-memory + path: String, + }, + + /// Redis-based checkpointing (requires `checkpointer-redis` feature) + #[cfg(feature = "checkpointer-redis")] + Redis { + /// Redis connection URL + url: String, + /// TTL for checkpoint keys (optional) + ttl_seconds: Option, + }, + + /// PostgreSQL-based checkpointing (requires `checkpointer-postgres` feature) + #[cfg(feature = "checkpointer-postgres")] + Postgres { + /// PostgreSQL connection URL + url: String, + }, +} + + +/// In-memory checkpointer for testing. +/// +/// This implementation stores checkpoints in memory and is not durable. +/// Use only for testing or development. +#[derive(Debug, Default)] +pub struct MemoryCheckpointer +where + S: WorkflowState, +{ + checkpoints: tokio::sync::RwLock>>, +} + +impl MemoryCheckpointer +where + S: WorkflowState, +{ + /// Create a new in-memory checkpointer + pub fn new() -> Self { + Self { + checkpoints: tokio::sync::RwLock::new(HashMap::new()), + } + } +} + +#[async_trait] +impl Checkpointer for MemoryCheckpointer +where + S: WorkflowState + Clone + Send + Sync, +{ + async fn save(&self, checkpoint: &Checkpoint) -> Result<(), PregelError> { + let mut checkpoints = self.checkpoints.write().await; + checkpoints.insert(checkpoint.superstep, checkpoint.clone()); + Ok(()) + } + + async fn load(&self, superstep: usize) -> Result>, PregelError> { + let checkpoints = self.checkpoints.read().await; + Ok(checkpoints.get(&superstep).cloned()) + } + + async fn latest(&self) -> Result>, PregelError> { + let checkpoints = self.checkpoints.read().await; + let max_superstep = checkpoints.keys().max().copied(); + match max_superstep { + Some(superstep) => Ok(checkpoints.get(&superstep).cloned()), + None => Ok(None), + } + } + + async fn list(&self) -> Result, PregelError> { + let checkpoints = self.checkpoints.read().await; + let mut supersteps: Vec = checkpoints.keys().copied().collect(); + supersteps.sort(); + Ok(supersteps) + } + + async fn delete(&self, superstep: usize) -> Result<(), PregelError> { + let mut checkpoints = self.checkpoints.write().await; + checkpoints.remove(&superstep); + Ok(()) + } +} + +/// Create a checkpointer from configuration. +/// +/// This factory function creates the appropriate checkpointer backend +/// based on the provided configuration. +/// +/// # Example +/// +/// ```ignore +/// let config = CheckpointerConfig::File { +/// path: PathBuf::from("./checkpoints"), +/// compression: true, +/// }; +/// let checkpointer = create_checkpointer::(config, "workflow-123")?; +/// ``` +pub fn create_checkpointer( + config: CheckpointerConfig, + workflow_id: impl Into, +) -> Result>, PregelError> +where + S: WorkflowState + Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + 'static, +{ + let workflow_id = workflow_id.into(); + + match config { + CheckpointerConfig::Memory => Ok(Box::new(MemoryCheckpointer::::new())), + + CheckpointerConfig::File { path, compression } => { + let checkpointer = FileCheckpointer::new(path, workflow_id, compression); + Ok(Box::new(checkpointer)) + } + + #[cfg(feature = "checkpointer-sqlite")] + CheckpointerConfig::Sqlite { path } => { + // SQLite checkpointer will be implemented in Task 8.2.3 + Err(PregelError::not_implemented("SQLite checkpointer")) + } + + #[cfg(feature = "checkpointer-redis")] + CheckpointerConfig::Redis { url, ttl_seconds } => { + // Redis checkpointer will be implemented in Task 8.2.4 + Err(PregelError::not_implemented("Redis checkpointer")) + } + + #[cfg(feature = "checkpointer-postgres")] + CheckpointerConfig::Postgres { url } => { + // PostgreSQL checkpointer will be implemented in Task 8.2.5 + Err(PregelError::not_implemented("PostgreSQL checkpointer")) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::pregel::state::UnitState; + + #[test] + fn test_checkpoint_creation() { + let checkpoint = Checkpoint::new( + "test-workflow", + 5, + UnitState, + HashMap::new(), + HashMap::new(), + ); + + assert_eq!(checkpoint.workflow_id, "test-workflow"); + assert_eq!(checkpoint.superstep, 5); + assert!(checkpoint.is_empty()); + assert_eq!(checkpoint.pending_message_count(), 0); + } + + #[test] + fn test_checkpoint_with_metadata() { + let checkpoint = Checkpoint::new( + "test-workflow", + 10, + UnitState, + HashMap::new(), + HashMap::new(), + ) + .with_metadata("version", "1.0") + .with_metadata("creator", "test"); + + assert_eq!(checkpoint.metadata.get("version"), Some(&"1.0".to_string())); + assert_eq!(checkpoint.metadata.get("creator"), Some(&"test".to_string())); + } + + #[test] + fn test_checkpoint_with_vertex_states() { + let mut vertex_states = HashMap::new(); + vertex_states.insert(VertexId::new("a"), VertexState::Active); + vertex_states.insert(VertexId::new("b"), VertexState::Halted); + + let checkpoint = Checkpoint::new( + "test-workflow", + 3, + UnitState, + vertex_states, + HashMap::new(), + ); + + assert!(!checkpoint.is_empty()); + assert_eq!(checkpoint.vertex_states.len(), 2); + } + + #[test] + fn test_checkpoint_pending_message_count() { + let mut pending_messages = HashMap::new(); + pending_messages.insert( + VertexId::new("a"), + vec![WorkflowMessage::Activate, WorkflowMessage::Activate], + ); + pending_messages.insert( + VertexId::new("b"), + vec![WorkflowMessage::Activate], + ); + + let checkpoint = Checkpoint::new( + "test-workflow", + 7, + UnitState, + HashMap::new(), + pending_messages, + ); + + assert_eq!(checkpoint.pending_message_count(), 3); + } + + #[tokio::test] + async fn test_memory_checkpointer_save_load() { + let checkpointer = MemoryCheckpointer::::new(); + + let checkpoint = Checkpoint::new( + "test-workflow", + 5, + UnitState, + HashMap::new(), + HashMap::new(), + ); + + checkpointer.save(&checkpoint).await.unwrap(); + let loaded = checkpointer.load(5).await.unwrap().unwrap(); + + assert_eq!(loaded.superstep, 5); + assert_eq!(loaded.workflow_id, "test-workflow"); + } + + #[tokio::test] + async fn test_memory_checkpointer_latest() { + let checkpointer = MemoryCheckpointer::::new(); + + // Save checkpoints at supersteps 1, 3, 5 + for superstep in [1, 3, 5] { + let checkpoint = Checkpoint::new( + "test-workflow", + superstep, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + } + + let latest = checkpointer.latest().await.unwrap().unwrap(); + assert_eq!(latest.superstep, 5); + } + + #[tokio::test] + async fn test_memory_checkpointer_list() { + let checkpointer = MemoryCheckpointer::::new(); + + // Save checkpoints at supersteps 5, 1, 3 (out of order) + for superstep in [5, 1, 3] { + let checkpoint = Checkpoint::new( + "test-workflow", + superstep, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + } + + let list = checkpointer.list().await.unwrap(); + assert_eq!(list, vec![1, 3, 5]); // Should be sorted + } + + #[tokio::test] + async fn test_memory_checkpointer_delete() { + let checkpointer = MemoryCheckpointer::::new(); + + let checkpoint = Checkpoint::new( + "test-workflow", + 5, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + + checkpointer.delete(5).await.unwrap(); + let loaded = checkpointer.load(5).await.unwrap(); + assert!(loaded.is_none()); + } + + #[tokio::test] + async fn test_memory_checkpointer_prune() { + let checkpointer = MemoryCheckpointer::::new(); + + // Save 5 checkpoints + for superstep in 1..=5 { + let checkpoint = Checkpoint::new( + "test-workflow", + superstep, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + } + + // Prune to keep only 2 + let deleted = checkpointer.prune(2).await.unwrap(); + assert_eq!(deleted, 3); + + let remaining = checkpointer.list().await.unwrap(); + assert_eq!(remaining, vec![4, 5]); + } + + #[tokio::test] + async fn test_memory_checkpointer_clear() { + let checkpointer = MemoryCheckpointer::::new(); + + for superstep in 1..=3 { + let checkpoint = Checkpoint::new( + "test-workflow", + superstep, + UnitState, + HashMap::new(), + HashMap::new(), + ); + checkpointer.save(&checkpoint).await.unwrap(); + } + + checkpointer.clear().await.unwrap(); + let list = checkpointer.list().await.unwrap(); + assert!(list.is_empty()); + } + + #[test] + fn test_checkpointer_config_default() { + let config = CheckpointerConfig::default(); + assert!(matches!(config, CheckpointerConfig::Memory)); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/config.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/config.rs new file mode 100644 index 0000000..abe5cee --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/config.rs @@ -0,0 +1,273 @@ +//! Pregel runtime configuration +//! +//! Configuration for the Pregel execution engine including +//! parallelism, timeouts, checkpointing, and retry policies. + +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +/// Pregel runtime configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PregelConfig { + /// Maximum supersteps before forced termination + pub max_supersteps: usize, + + /// Maximum concurrent vertex computations + pub parallelism: usize, + + /// Checkpoint frequency (every N supersteps, 0 = disabled) + pub checkpoint_interval: usize, + + /// Timeout for individual vertex computation + #[serde(with = "humantime_serde")] + pub vertex_timeout: Duration, + + /// Timeout for entire workflow + #[serde(with = "humantime_serde")] + pub workflow_timeout: Duration, + + /// Enable detailed tracing + pub tracing_enabled: bool, + + /// Retry policy for failed vertices + pub retry_policy: RetryPolicy, +} + +impl Default for PregelConfig { + fn default() -> Self { + Self { + max_supersteps: 100, + parallelism: num_cpus::get(), + checkpoint_interval: 10, + vertex_timeout: Duration::from_secs(300), // 5 min per vertex + workflow_timeout: Duration::from_secs(3600), // 1 hour total + tracing_enabled: true, + retry_policy: RetryPolicy::default(), + } + } +} + +impl PregelConfig { + /// Create a new config with defaults + pub fn new() -> Self { + Self::default() + } + + /// Set maximum supersteps + pub fn with_max_supersteps(mut self, max: usize) -> Self { + self.max_supersteps = max; + self + } + + /// Set parallelism level + pub fn with_parallelism(mut self, parallelism: usize) -> Self { + self.parallelism = parallelism.max(1); + self + } + + /// Set checkpoint interval (0 to disable) + pub fn with_checkpoint_interval(mut self, interval: usize) -> Self { + self.checkpoint_interval = interval; + self + } + + /// Set vertex timeout + pub fn with_vertex_timeout(mut self, timeout: Duration) -> Self { + self.vertex_timeout = timeout; + self + } + + /// Set workflow timeout + pub fn with_workflow_timeout(mut self, timeout: Duration) -> Self { + self.workflow_timeout = timeout; + self + } + + /// Enable or disable tracing + pub fn with_tracing(mut self, enabled: bool) -> Self { + self.tracing_enabled = enabled; + self + } + + /// Set retry policy + pub fn with_retry_policy(mut self, policy: RetryPolicy) -> Self { + self.retry_policy = policy; + self + } + + /// Check if checkpointing is enabled + pub fn checkpointing_enabled(&self) -> bool { + self.checkpoint_interval > 0 + } + + /// Check if a checkpoint should be taken at this superstep + #[allow(clippy::manual_is_multiple_of)] // Using % for compatibility with older Rust versions + pub fn should_checkpoint(&self, superstep: usize) -> bool { + self.checkpointing_enabled() && superstep > 0 && superstep % self.checkpoint_interval == 0 + } +} + +/// Retry policy for failed vertex computations +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RetryPolicy { + /// Maximum retry attempts + pub max_retries: usize, + + /// Base delay for exponential backoff + #[serde(with = "humantime_serde")] + pub backoff_base: Duration, + + /// Maximum delay between retries + #[serde(with = "humantime_serde")] + pub backoff_max: Duration, +} + +impl Default for RetryPolicy { + fn default() -> Self { + Self { + max_retries: 3, + backoff_base: Duration::from_millis(100), + backoff_max: Duration::from_secs(10), + } + } +} + +impl RetryPolicy { + /// Create a new retry policy + pub fn new(max_retries: usize) -> Self { + Self { + max_retries, + ..Default::default() + } + } + + /// Set backoff base duration + pub fn with_backoff_base(mut self, base: Duration) -> Self { + self.backoff_base = base; + self + } + + /// Set maximum backoff duration + pub fn with_backoff_max(mut self, max: Duration) -> Self { + self.backoff_max = max; + self + } + + /// Calculate delay for a given retry attempt (exponential backoff) + pub fn delay_for_attempt(&self, attempt: usize) -> Duration { + let multiplier = 2u32.saturating_pow(attempt as u32); + let delay = self.backoff_base.saturating_mul(multiplier); + delay.min(self.backoff_max) + } + + /// Check if more retries are allowed + pub fn should_retry(&self, attempts: usize) -> bool { + attempts < self.max_retries + } + + /// Create a no-retry policy + pub fn no_retry() -> Self { + Self { + max_retries: 0, + ..Default::default() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = PregelConfig::default(); + assert_eq!(config.max_supersteps, 100); + assert!(config.parallelism > 0); + assert_eq!(config.checkpoint_interval, 10); + assert!(config.tracing_enabled); + } + + #[test] + fn test_config_builder() { + let config = PregelConfig::default() + .with_max_supersteps(50) + .with_parallelism(4) + .with_checkpoint_interval(5); + + assert_eq!(config.max_supersteps, 50); + assert_eq!(config.parallelism, 4); + assert_eq!(config.checkpoint_interval, 5); + } + + #[test] + fn test_parallelism_minimum() { + let config = PregelConfig::default().with_parallelism(0); + assert_eq!(config.parallelism, 1); + } + + #[test] + fn test_checkpointing_enabled() { + let config = PregelConfig::default(); + assert!(config.checkpointing_enabled()); + + let disabled = config.with_checkpoint_interval(0); + assert!(!disabled.checkpointing_enabled()); + } + + #[test] + fn test_should_checkpoint() { + let config = PregelConfig::default().with_checkpoint_interval(5); + + assert!(!config.should_checkpoint(0)); + assert!(!config.should_checkpoint(1)); + assert!(config.should_checkpoint(5)); + assert!(config.should_checkpoint(10)); + assert!(!config.should_checkpoint(7)); + } + + #[test] + fn test_retry_policy_default() { + let policy = RetryPolicy::default(); + assert_eq!(policy.max_retries, 3); + assert!(policy.should_retry(0)); + assert!(policy.should_retry(2)); + assert!(!policy.should_retry(3)); + } + + #[test] + fn test_retry_backoff() { + let policy = RetryPolicy::default(); + + let delay0 = policy.delay_for_attempt(0); + let delay1 = policy.delay_for_attempt(1); + let delay2 = policy.delay_for_attempt(2); + + assert_eq!(delay0, Duration::from_millis(100)); + assert_eq!(delay1, Duration::from_millis(200)); + assert_eq!(delay2, Duration::from_millis(400)); + } + + #[test] + fn test_retry_backoff_max() { + let policy = RetryPolicy::default().with_backoff_max(Duration::from_millis(300)); + + let delay_high = policy.delay_for_attempt(10); + assert_eq!(delay_high, Duration::from_millis(300)); + } + + #[test] + fn test_no_retry_policy() { + let policy = RetryPolicy::no_retry(); + assert!(!policy.should_retry(0)); + } + + #[test] + fn test_config_with_timeouts() { + let config = PregelConfig::default() + .with_vertex_timeout(Duration::from_secs(60)) + .with_workflow_timeout(Duration::from_secs(120)); + + assert_eq!(config.vertex_timeout, Duration::from_secs(60)); + assert_eq!(config.workflow_timeout, Duration::from_secs(120)); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/error.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/error.rs new file mode 100644 index 0000000..652a38c --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/error.rs @@ -0,0 +1,243 @@ +//! Error types for Pregel runtime +//! +//! Comprehensive error handling for the Pregel execution engine. + +use super::vertex::VertexId; +use thiserror::Error; + +/// Errors that can occur during Pregel runtime execution +#[derive(Debug, Error)] +pub enum PregelError { + /// Maximum supersteps exceeded + #[error("Max supersteps exceeded: {0}")] + MaxSuperstepsExceeded(usize), + + /// Vertex computation timed out + #[error("Vertex timeout: {0:?}")] + VertexTimeout(VertexId), + + /// Error during vertex computation + #[error("Vertex error in {vertex_id:?}: {message}")] + VertexError { + vertex_id: VertexId, + message: String, + #[source] + source: Option>, + }, + + /// Error during routing decision + #[error("Routing error in {vertex_id:?}: {decision}")] + RoutingError { vertex_id: VertexId, decision: String }, + + /// Recursion depth limit exceeded + #[error("Recursion limit in {vertex_id:?}: depth {depth}, limit {limit}")] + RecursionLimit { + vertex_id: VertexId, + depth: usize, + limit: usize, + }, + + /// Error in workflow state management + #[error("State error: {0}")] + StateError(String), + + /// Error in checkpointing + #[error("Checkpoint error: {0}")] + CheckpointError(String), + + /// Feature not yet implemented + #[error("Not implemented: {0}")] + NotImplemented(String), + + /// Invalid workflow configuration + #[error("Configuration error: {0}")] + ConfigError(String), + + /// Message delivery failed + #[error("Message delivery failed: {0}")] + MessageDeliveryError(String), + + /// Workflow terminated by user + #[error("Workflow cancelled")] + Cancelled, + + /// Workflow execution timed out + #[error("Workflow timeout after {0:?}")] + WorkflowTimeout(std::time::Duration), + + /// Maximum retry attempts exceeded for a vertex + #[error("Max retries exceeded for vertex {vertex_id:?}: {attempts} attempts")] + MaxRetriesExceeded { vertex_id: VertexId, attempts: usize }, +} + +impl PregelError { + /// Create a vertex error with a message + pub fn vertex_error(vertex_id: impl Into, message: impl Into) -> Self { + Self::VertexError { + vertex_id: vertex_id.into(), + message: message.into(), + source: None, + } + } + + /// Create a vertex error with source + pub fn vertex_error_with_source( + vertex_id: impl Into, + message: impl Into, + source: impl std::error::Error + Send + Sync + 'static, + ) -> Self { + Self::VertexError { + vertex_id: vertex_id.into(), + message: message.into(), + source: Some(Box::new(source)), + } + } + + /// Create a routing error + pub fn routing_error(vertex_id: impl Into, decision: impl Into) -> Self { + Self::RoutingError { + vertex_id: vertex_id.into(), + decision: decision.into(), + } + } + + /// Create a recursion limit error + pub fn recursion_limit(vertex_id: impl Into, depth: usize, limit: usize) -> Self { + Self::RecursionLimit { + vertex_id: vertex_id.into(), + depth, + limit, + } + } + + /// Check if the error is recoverable + pub fn is_recoverable(&self) -> bool { + matches!( + self, + PregelError::VertexTimeout(_) + | PregelError::VertexError { .. } + | PregelError::MessageDeliveryError(_) + ) + } + + /// Check if the error is a timeout + pub fn is_timeout(&self) -> bool { + matches!(self, PregelError::VertexTimeout(_)) + } + + /// Create a checkpoint error + pub fn checkpoint_error(message: impl Into) -> Self { + Self::CheckpointError(message.into()) + } + + /// Create a not implemented error + pub fn not_implemented(feature: impl Into) -> Self { + Self::NotImplemented(feature.into()) + } + + /// Create a state error + pub fn state_error(message: impl Into) -> Self { + Self::StateError(message.into()) + } + + /// Create a config error + pub fn config_error(message: impl Into) -> Self { + Self::ConfigError(message.into()) + } + + /// Create a workflow timeout error + pub fn workflow_timeout(duration: std::time::Duration) -> Self { + Self::WorkflowTimeout(duration) + } + + /// Create a max retries exceeded error + pub fn max_retries_exceeded(vertex_id: impl Into, attempts: usize) -> Self { + Self::MaxRetriesExceeded { + vertex_id: vertex_id.into(), + attempts, + } + } +} + +#[cfg(test)] +mod tests { + // Ensure errors are Send + Sync (compile-time check) + static_assertions::assert_impl_all!(super::PregelError: Send, Sync); + use super::*; + + #[test] + fn test_error_display() { + let err = PregelError::MaxSuperstepsExceeded(100); + assert_eq!(format!("{}", err), "Max supersteps exceeded: 100"); + } + + #[test] + fn test_vertex_error() { + let err = PregelError::vertex_error("node1", "computation failed"); + match err { + PregelError::VertexError { + vertex_id, + message, + source, + } => { + assert_eq!(vertex_id.0, "node1"); + assert_eq!(message, "computation failed"); + assert!(source.is_none()); + } + _ => panic!("Wrong error type"), + } + } + + #[test] + fn test_vertex_timeout() { + let err = PregelError::VertexTimeout(VertexId::from("slow_node")); + assert!(format!("{}", err).contains("slow_node")); + assert!(err.is_timeout()); + } + + #[test] + fn test_routing_error() { + let err = PregelError::routing_error("router", "no matching branch"); + match err { + PregelError::RoutingError { vertex_id, decision } => { + assert_eq!(vertex_id.0, "router"); + assert_eq!(decision, "no matching branch"); + } + _ => panic!("Wrong error type"), + } + } + + #[test] + fn test_recursion_limit() { + let err = PregelError::recursion_limit("nested_agent", 6, 5); + match err { + PregelError::RecursionLimit { + vertex_id, + depth, + limit, + } => { + assert_eq!(vertex_id.0, "nested_agent"); + assert_eq!(depth, 6); + assert_eq!(limit, 5); + } + _ => panic!("Wrong error type"), + } + } + + #[test] + fn test_is_recoverable() { + assert!(PregelError::VertexTimeout(VertexId::from("x")).is_recoverable()); + assert!(PregelError::vertex_error("x", "err").is_recoverable()); + assert!(PregelError::MessageDeliveryError("err".into()).is_recoverable()); + + assert!(!PregelError::MaxSuperstepsExceeded(100).is_recoverable()); + assert!(!PregelError::Cancelled.is_recoverable()); + assert!(!PregelError::recursion_limit("x", 5, 3).is_recoverable()); + } + + #[test] + fn test_errors_are_send_sync() { + fn assert_send_sync() {} + assert_send_sync::(); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/message.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/message.rs new file mode 100644 index 0000000..1ca46f6 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/message.rs @@ -0,0 +1,230 @@ +//! Message types for Pregel vertex communication +//! +//! Vertices communicate by sending messages to each other. +//! Messages are delivered at the start of each superstep. + +use serde::{Deserialize, Serialize}; +use super::vertex::VertexId; + +/// Trait bound for vertex messages +pub trait VertexMessage: Clone + Send + Sync + 'static {} + +/// Priority level for research directions +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +pub enum Priority { + High, + #[default] + Medium, + Low, +} + +/// Source information for research findings +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Source { + /// URL of the source + pub url: String, + /// Title of the source + pub title: String, + /// Relevance score (0.0 to 1.0) + pub relevance: f32, +} + +impl Source { + /// Create a new source + pub fn new(url: impl Into, title: impl Into, relevance: f32) -> Self { + Self { + url: url.into(), + title: title.into(), + relevance: relevance.clamp(0.0, 1.0), + } + } +} + +/// Standard message types for workflow coordination +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum WorkflowMessage { + /// Trigger vertex activation + Activate, + + /// Pass data between vertices + Data { + key: String, + value: serde_json::Value, + }, + + /// Signal completion of upstream work + Completed { + source: VertexId, + result: Option, + }, + + /// Request vertex to halt + Halt, + + /// Research-specific: share findings + ResearchFinding { + query: String, + sources: Vec, + summary: String, + }, + + /// Research-specific: suggest new direction + ResearchDirection { + topic: String, + priority: Priority, + rationale: String, + }, +} + +impl VertexMessage for WorkflowMessage {} + +impl WorkflowMessage { + /// Create a Data message + pub fn data(key: impl Into, value: impl Serialize) -> Self { + Self::Data { + key: key.into(), + value: serde_json::to_value(value).unwrap_or(serde_json::Value::Null), + } + } + + /// Create a Completed message + pub fn completed(source: impl Into, result: Option) -> Self { + Self::Completed { + source: source.into(), + result, + } + } + + /// Create a ResearchFinding message + pub fn research_finding( + query: impl Into, + sources: Vec, + summary: impl Into, + ) -> Self { + Self::ResearchFinding { + query: query.into(), + sources, + summary: summary.into(), + } + } + + /// Create a ResearchDirection message + pub fn research_direction( + topic: impl Into, + priority: Priority, + rationale: impl Into, + ) -> Self { + Self::ResearchDirection { + topic: topic.into(), + priority, + rationale: rationale.into(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_workflow_message_serialization() { + let msg = WorkflowMessage::Data { + key: "query".into(), + value: json!("test query"), + }; + let json_str = serde_json::to_string(&msg).unwrap(); + let deserialized: WorkflowMessage = serde_json::from_str(&json_str).unwrap(); + + // Verify roundtrip + match deserialized { + WorkflowMessage::Data { key, value } => { + assert_eq!(key, "query"); + assert_eq!(value, json!("test query")); + } + _ => panic!("Wrong variant"), + } + } + + #[test] + fn test_research_finding_message() { + let msg = WorkflowMessage::ResearchFinding { + query: "rust async".into(), + sources: vec![Source { + url: "https://example.com".into(), + title: "Example".into(), + relevance: 0.95, + }], + summary: "Rust async is great".into(), + }; + assert!(matches!(msg, WorkflowMessage::ResearchFinding { .. })); + } + + #[test] + fn test_research_direction_message() { + let msg = WorkflowMessage::research_direction( + "async runtimes", + Priority::High, + "Important for understanding concurrency", + ); + match msg { + WorkflowMessage::ResearchDirection { + topic, + priority, + rationale, + } => { + assert_eq!(topic, "async runtimes"); + assert_eq!(priority, Priority::High); + assert!(!rationale.is_empty()); + } + _ => panic!("Wrong variant"), + } + } + + #[test] + fn test_source_relevance_clamping() { + let source = Source::new("https://test.com", "Test", 1.5); + assert_eq!(source.relevance, 1.0); + + let source = Source::new("https://test.com", "Test", -0.5); + assert_eq!(source.relevance, 0.0); + } + + #[test] + fn test_completed_message() { + let msg = WorkflowMessage::completed("planner", Some("Plan complete".to_string())); + match msg { + WorkflowMessage::Completed { source, result } => { + assert_eq!(source.0, "planner"); + assert_eq!(result, Some("Plan complete".to_string())); + } + _ => panic!("Wrong variant"), + } + } + + #[test] + fn test_data_message_helper() { + let msg = WorkflowMessage::data("count", 42); + match msg { + WorkflowMessage::Data { key, value } => { + assert_eq!(key, "count"); + assert_eq!(value, json!(42)); + } + _ => panic!("Wrong variant"), + } + } + + #[test] + fn test_priority_default() { + assert_eq!(Priority::default(), Priority::Medium); + } + + #[test] + fn test_activate_and_halt() { + let activate = WorkflowMessage::Activate; + let halt = WorkflowMessage::Halt; + + assert!(matches!(activate, WorkflowMessage::Activate)); + assert!(matches!(halt, WorkflowMessage::Halt)); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/mod.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/mod.rs new file mode 100644 index 0000000..276cee6 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/mod.rs @@ -0,0 +1,45 @@ +//! Pregel Runtime for Graph-Based Agent Orchestration +//! +//! This module implements a Pregel-inspired runtime for executing agent workflows. +//! Key concepts: +//! +//! - **Vertex**: Computation unit (Agent, Tool, Router, etc.) +//! - **Edge**: Connection between vertices (Direct, Conditional) +//! - **Superstep**: Synchronized execution phase +//! - **Message**: Communication between vertices +//! +//! # Architecture +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────┐ +//! │ PregelRuntime │ +//! │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ +//! │ │Superstep│→ │Superstep│→ │Superstep│→ ... │ +//! │ │ 0 │ │ 1 │ │ 2 │ │ +//! │ └─────────┘ └─────────┘ └─────────┘ │ +//! │ │ │ │ │ +//! │ ▼ ▼ ▼ │ +//! │ ┌─────────────────────────────────────────────────────┐ │ +//! │ │ Per-Superstep: Deliver → Compute → Collect → Route │ │ +//! │ └─────────────────────────────────────────────────────┘ │ +//! └─────────────────────────────────────────────────────────────┘ +//! ``` + +pub mod vertex; +pub mod message; +pub mod config; +pub mod error; +pub mod state; +pub mod runtime; +pub mod checkpoint; + +// Re-exports +pub use vertex::{ + BoxedVertex, ComputeContext, ComputeResult, StateUpdate, Vertex, VertexId, VertexState, +}; +pub use message::{Priority, Source, VertexMessage, WorkflowMessage}; +pub use config::{PregelConfig, RetryPolicy}; +pub use error::PregelError; +pub use state::{UnitState, UnitUpdate, WorkflowState}; +pub use runtime::{PregelRuntime, WorkflowResult}; +pub use checkpoint::{Checkpoint, Checkpointer, CheckpointerConfig, MemoryCheckpointer, FileCheckpointer, create_checkpointer}; diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/runtime.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/runtime.rs new file mode 100644 index 0000000..0d9fea1 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/runtime.rs @@ -0,0 +1,871 @@ +//! Pregel Runtime - Core execution engine for workflow graphs +//! +//! The runtime executes workflows through synchronized supersteps. +//! Each superstep follows the sequence: Deliver → Compute → Collect → Route. + +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{Mutex, Semaphore}; +use tokio::time::timeout; + +use super::config::PregelConfig; +use super::error::PregelError; +use super::message::VertexMessage; +use super::state::WorkflowState; +use super::vertex::{BoxedVertex, ComputeContext, ComputeResult, VertexId, VertexState}; + +/// Result of a workflow execution +#[derive(Debug, Clone)] +pub struct WorkflowResult { + /// Final workflow state + pub state: S, + /// Number of supersteps executed + pub supersteps: usize, + /// Whether the workflow completed successfully + pub completed: bool, + /// Final states of all vertices + pub vertex_states: HashMap, +} + +/// Pregel Runtime for executing workflow graphs +/// +/// Manages the execution of vertices through synchronized supersteps, +/// handling message passing, state updates, and termination detection. +pub struct PregelRuntime +where + S: WorkflowState, + M: VertexMessage, +{ + /// Configuration for the runtime + config: PregelConfig, + /// Vertices in the workflow graph + vertices: HashMap>, + /// Current state of each vertex + vertex_states: HashMap, + /// Pending messages for each vertex (delivered at start of next superstep) + message_queues: HashMap>, + /// Edges defining message routing (source -> targets) + edges: HashMap>, + /// Retry attempt counts per vertex (for retry policy enforcement) + retry_counts: HashMap, +} + +impl PregelRuntime +where + S: WorkflowState, + M: VertexMessage, +{ + /// Create a new runtime with default configuration + pub fn new() -> Self { + Self::with_config(PregelConfig::default()) + } + + /// Create a new runtime with custom configuration + pub fn with_config(config: PregelConfig) -> Self { + Self { + config, + vertices: HashMap::new(), + vertex_states: HashMap::new(), + message_queues: HashMap::new(), + edges: HashMap::new(), + retry_counts: HashMap::new(), + } + } + + /// Add a vertex to the runtime + pub fn add_vertex(&mut self, vertex: BoxedVertex) -> &mut Self { + let id = vertex.id().clone(); + self.vertex_states.insert(id.clone(), VertexState::Active); + self.message_queues.insert(id.clone(), Vec::new()); + self.vertices.insert(id, vertex); + self + } + + /// Add an edge between vertices + pub fn add_edge(&mut self, from: impl Into, to: impl Into) -> &mut Self { + let from = from.into(); + let to = to.into(); + self.edges.entry(from).or_default().push(to); + self + } + + /// Set the entry point (activate this vertex on start) + pub fn set_entry(&mut self, entry: impl Into) -> &mut Self { + let entry = entry.into(); + if let Some(state) = self.vertex_states.get_mut(&entry) { + *state = VertexState::Active; + } + self + } + + /// Get the configuration + pub fn config(&self) -> &PregelConfig { + &self.config + } + + /// Run the workflow to completion + /// + /// Enforces the configured `workflow_timeout` - if the workflow takes longer + /// than this duration, it will return a `WorkflowTimeout` error. + pub async fn run(&mut self, initial_state: S) -> Result, PregelError> { + let workflow_timeout = self.config.workflow_timeout; + + // C2 Fix: Wrap entire run loop with workflow timeout + match timeout(workflow_timeout, self.run_inner(initial_state)).await { + Ok(result) => result, + Err(_) => Err(PregelError::WorkflowTimeout(workflow_timeout)), + } + } + + /// Internal run loop (extracted for timeout wrapping) + async fn run_inner(&mut self, initial_state: S) -> Result, PregelError> { + let mut state = initial_state; + let mut superstep = 0; + + loop { + // Check max supersteps limit + if superstep >= self.config.max_supersteps { + return Err(PregelError::MaxSuperstepsExceeded(superstep)); + } + + // Check if workflow should terminate + if self.should_terminate(&state) { + return Ok(WorkflowResult { + state, + supersteps: superstep, + completed: true, + vertex_states: self.vertex_states.clone(), + }); + } + + // Execute one superstep + let updates = self.execute_superstep(superstep, &state).await?; + + // Apply state updates + state = state.apply_updates(updates); + + superstep += 1; + } + } + + /// Check if the workflow should terminate + fn should_terminate(&self, state: &S) -> bool { + // Terminal state check + if state.is_terminal() { + return true; + } + + // All vertices halted or completed AND no pending messages + let all_inactive = self + .vertex_states + .values() + .all(|s| !s.is_active()); + + let no_pending_messages = self + .message_queues + .values() + .all(|q| q.is_empty()); + + all_inactive && no_pending_messages + } + + /// Execute a single superstep + async fn execute_superstep( + &mut self, + superstep: usize, + state: &S, + ) -> Result, PregelError> { + // 1. Deliver messages - move pending messages to vertex inboxes + let inboxes = self.deliver_messages(); + + // 2. Reactivate halted vertices that received messages + for (vertex_id, messages) in &inboxes { + if !messages.is_empty() { + if let Some(vertex_state) = self.vertex_states.get_mut(vertex_id) { + if vertex_state.is_halted() { + if let Some(vertex) = self.vertices.get(vertex_id) { + *vertex_state = vertex.on_reactivation(messages); + } + } + } + } + } + + // 3. Compute active vertices in parallel + let (updates, outboxes) = self.compute_vertices(superstep, state, &inboxes).await?; + + // 4. Route messages to next superstep queues + self.route_messages(outboxes); + + Ok(updates) + } + + /// Deliver pending messages to vertex inboxes + fn deliver_messages(&mut self) -> HashMap> { + let mut inboxes = HashMap::new(); + for (vertex_id, queue) in &mut self.message_queues { + if !queue.is_empty() { + inboxes.insert(vertex_id.clone(), std::mem::take(queue)); + } else { + inboxes.insert(vertex_id.clone(), Vec::new()); + } + } + inboxes + } + + /// Compute all active vertices in parallel + async fn compute_vertices( + &mut self, + superstep: usize, + state: &S, + inboxes: &HashMap>, + ) -> Result<(Vec, HashMap>>), PregelError> { + let semaphore = Arc::new(Semaphore::new(self.config.parallelism)); + let updates = Arc::new(Mutex::new(Vec::new())); + let outboxes = Arc::new(Mutex::new(HashMap::new())); + let vertex_timeout = self.config.vertex_timeout; + + // Collect active vertices to compute + let active_vertices: Vec<_> = self + .vertex_states + .iter() + .filter(|(_, state)| state.is_active()) + .map(|(id, _)| id.clone()) + .collect(); + + // Execute vertices in parallel + let mut handles = Vec::new(); + + for vertex_id in active_vertices { + let vertex = match self.vertices.get(&vertex_id) { + Some(v) => Arc::clone(v), + None => continue, + }; + let messages = inboxes.get(&vertex_id).cloned().unwrap_or_default(); + let state_clone = state.clone(); + let sem_clone = Arc::clone(&semaphore); + let vid = vertex_id.clone(); + + let handle = tokio::spawn(async move { + // Acquire semaphore permit for parallelism control + let _permit = sem_clone.acquire().await.unwrap(); + + // Create compute context + let mut ctx = ComputeContext::new(vid.clone(), &messages, superstep, &state_clone); + + // Execute with timeout + let result: Result, PregelError> = match timeout( + vertex_timeout, + vertex.compute(&mut ctx), + ) + .await + { + Ok(result) => result, + Err(_) => Err(PregelError::VertexTimeout(vid.clone())), + }; + + let outbox = ctx.into_outbox(); + + (vid, result, outbox) + }); + + handles.push(handle); + } + + // Collect results + let mut new_vertex_states = HashMap::new(); + + for handle in handles { + let (vid, result, outbox) = handle.await.map_err(|e| { + PregelError::vertex_error_with_source( + "unknown", + "task join error", + std::io::Error::other(e.to_string()), + ) + })?; + + match result { + Ok(compute_result) => { + // Success: reset retry count for this vertex + self.retry_counts.remove(&vid); + updates.lock().await.push(compute_result.update); + new_vertex_states.insert(vid.clone(), compute_result.state); + outboxes.lock().await.insert(vid, outbox); + } + Err(e) => { + if e.is_recoverable() { + // C3 Fix: Track retry attempts and enforce max_retries + // retry_count tracks how many retries we've already attempted + let retry_count = self.retry_counts.entry(vid.clone()).or_insert(0); + + // Check if we can retry BEFORE incrementing + if self.config.retry_policy.should_retry(*retry_count) { + // Apply backoff delay before next retry + let delay = self.config.retry_policy.delay_for_attempt(*retry_count); + tokio::time::sleep(delay).await; + // Track this retry attempt + *retry_count += 1; + // Keep vertex active for retry + new_vertex_states.insert(vid, VertexState::Active); + } else { + // Max retries exceeded (current attempt is retry_count + 1 total) + return Err(PregelError::MaxRetriesExceeded { + vertex_id: vid, + attempts: *retry_count + 1, // +1 for the current failed attempt + }); + } + } else { + return Err(e); + } + } + } + } + + // Update vertex states + for (vid, new_state) in new_vertex_states { + self.vertex_states.insert(vid, new_state); + } + + // C1 Fix: Use async-safe lock instead of blocking_lock + let final_updates = match Arc::try_unwrap(updates) { + Ok(mutex) => mutex.into_inner(), + Err(arc) => arc.lock().await.clone(), + }; + + let final_outboxes = match Arc::try_unwrap(outboxes) { + Ok(mutex) => mutex.into_inner(), + Err(arc) => arc.lock().await.clone(), + }; + + Ok((final_updates, final_outboxes)) + } + + /// Route outgoing messages to target vertex queues + fn route_messages(&mut self, outboxes: HashMap>>) { + for (_source, outbox) in outboxes { + for (target, messages) in outbox { + if let Some(queue) = self.message_queues.get_mut(&target) { + queue.extend(messages); + } + } + } + } +} + +impl Default for PregelRuntime +where + S: WorkflowState, + M: VertexMessage, +{ + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use super::super::message::WorkflowMessage; + use super::super::vertex::{StateUpdate, Vertex}; + use async_trait::async_trait; + use tokio::time::Duration; + + #[allow(unused_imports)] + use super::super::state::WorkflowState as _; + + // Test state + #[derive(Clone, Default, Debug)] + struct TestState { + counter: i32, + messages_received: i32, + } + + #[derive(Clone, Debug)] + struct TestUpdate { + counter_delta: i32, + messages_delta: i32, + } + + impl StateUpdate for TestUpdate { + fn empty() -> Self { + TestUpdate { + counter_delta: 0, + messages_delta: 0, + } + } + + fn is_empty(&self) -> bool { + self.counter_delta == 0 && self.messages_delta == 0 + } + } + + impl WorkflowState for TestState { + type Update = TestUpdate; + + fn apply_update(&self, update: Self::Update) -> Self { + TestState { + counter: self.counter + update.counter_delta, + messages_received: self.messages_received + update.messages_delta, + } + } + + fn merge_updates(updates: Vec) -> Self::Update { + TestUpdate { + counter_delta: updates.iter().map(|u| u.counter_delta).sum(), + messages_delta: updates.iter().map(|u| u.messages_delta).sum(), + } + } + + fn is_terminal(&self) -> bool { + self.counter >= 10 + } + } + + // Simple vertex that increments counter and halts + struct IncrementVertex { + id: VertexId, + #[allow(dead_code)] + increment: i32, + } + + #[async_trait] + impl Vertex for IncrementVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + _ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + // Just halt immediately + Ok(ComputeResult::halt(TestUpdate::empty())) + } + } + + // Vertex that sends a message then halts + struct MessageSenderVertex { + id: VertexId, + target: VertexId, + } + + #[async_trait] + impl Vertex for MessageSenderVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + if ctx.is_first_superstep() { + ctx.send_message(self.target.clone(), WorkflowMessage::Activate); + } + Ok(ComputeResult::halt(TestUpdate::empty())) + } + } + + // Vertex that counts messages received + struct MessageReceiverVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for MessageReceiverVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + _ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + // Just halt after receiving messages + Ok(ComputeResult::halt(TestUpdate::empty())) + } + } + + #[tokio::test] + async fn test_runtime_creation() { + let runtime: PregelRuntime = PregelRuntime::new(); + assert_eq!(runtime.config().max_supersteps, 100); + } + + #[tokio::test] + async fn test_runtime_single_vertex_halts() { + let mut runtime: PregelRuntime = PregelRuntime::new(); + + runtime.add_vertex(Arc::new(IncrementVertex { + id: VertexId::new("a"), + increment: 1, + })); + + let result = runtime.run(TestState::default()).await; + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.completed); + // Single vertex computes once then halts, workflow terminates + assert!(result.supersteps <= 2); + } + + #[tokio::test] + async fn test_runtime_message_delivery() { + let mut runtime: PregelRuntime = PregelRuntime::new(); + + runtime.add_vertex(Arc::new(MessageSenderVertex { + id: VertexId::new("sender"), + target: VertexId::new("receiver"), + })); + + runtime.add_vertex(Arc::new(MessageReceiverVertex { + id: VertexId::new("receiver"), + })); + + let result = runtime.run(TestState::default()).await.unwrap(); + assert!(result.completed); + // Sender sends in superstep 0, receiver gets it in superstep 1 + assert!(result.supersteps >= 1); + } + + #[tokio::test] + async fn test_runtime_termination_all_halted() { + let mut runtime: PregelRuntime = PregelRuntime::new(); + + // Add two vertices that halt immediately + runtime.add_vertex(Arc::new(IncrementVertex { + id: VertexId::new("a"), + increment: 1, + })); + + runtime.add_vertex(Arc::new(IncrementVertex { + id: VertexId::new("b"), + increment: 1, + })); + + let result = runtime.run(TestState::default()).await.unwrap(); + assert!(result.completed); + // All vertices halt, no messages pending -> terminate + assert!(result.vertex_states.values().all(|s| !s.is_active())); + } + + #[tokio::test] + async fn test_runtime_max_supersteps_exceeded() { + struct InfiniteLoopVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for InfiniteLoopVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + // Always stay active + ctx.send_message(self.id.clone(), WorkflowMessage::Activate); + Ok(ComputeResult::active(TestUpdate::empty())) + } + } + + let config = PregelConfig::default().with_max_supersteps(5); + let mut runtime: PregelRuntime = + PregelRuntime::with_config(config); + + runtime.add_vertex(Arc::new(InfiniteLoopVertex { + id: VertexId::new("loop"), + })); + + let result = runtime.run(TestState::default()).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + PregelError::MaxSuperstepsExceeded(5) + )); + } + + #[tokio::test] + async fn test_runtime_terminal_state() { + struct CounterVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for CounterVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + // Keep running until terminal state + ctx.send_message(self.id.clone(), WorkflowMessage::Activate); + Ok(ComputeResult::active(TestUpdate::empty())) + } + } + + let mut runtime: PregelRuntime = PregelRuntime::new(); + + runtime.add_vertex(Arc::new(CounterVertex { + id: VertexId::new("counter"), + })); + + // Start with counter at 10, which is terminal + let result = runtime + .run(TestState { + counter: 10, + messages_received: 0, + }) + .await + .unwrap(); + + assert!(result.completed); + assert_eq!(result.supersteps, 0); // Terminates immediately + } + + #[tokio::test] + async fn test_runtime_parallel_execution() { + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::time::Instant; + + static EXECUTION_COUNT: AtomicUsize = AtomicUsize::new(0); + + struct SlowVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for SlowVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + _ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + EXECUTION_COUNT.fetch_add(1, Ordering::SeqCst); + tokio::time::sleep(Duration::from_millis(50)).await; + Ok(ComputeResult::halt(TestUpdate::empty())) + } + } + + EXECUTION_COUNT.store(0, Ordering::SeqCst); + + let config = PregelConfig::default().with_parallelism(4); + let mut runtime: PregelRuntime = + PregelRuntime::with_config(config); + + // Add 4 slow vertices + for i in 0..4 { + runtime.add_vertex(Arc::new(SlowVertex { + id: VertexId::new(format!("slow_{}", i)), + })); + } + + let start = Instant::now(); + let result = runtime.run(TestState::default()).await.unwrap(); + let elapsed = start.elapsed(); + + assert!(result.completed); + assert_eq!(EXECUTION_COUNT.load(Ordering::SeqCst), 4); + // With parallelism=4, should take ~50ms, not ~200ms + assert!(elapsed < Duration::from_millis(150)); + } + + #[tokio::test] + async fn test_runtime_add_edge() { + let mut runtime: PregelRuntime = PregelRuntime::new(); + + runtime + .add_vertex(Arc::new(IncrementVertex { + id: VertexId::new("a"), + increment: 1, + })) + .add_edge("a", "b"); + + assert!(runtime.edges.contains_key(&VertexId::new("a"))); + } + + // ============================================ + // C2: Workflow Timeout Tests (RED - should fail) + // ============================================ + + #[tokio::test] + async fn test_workflow_timeout_enforced() { + // Vertex that runs forever (simulates slow LLM calls) + struct SlowForeverVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for SlowForeverVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + // Sleep for a long time but stay active + tokio::time::sleep(Duration::from_secs(10)).await; + ctx.send_message(self.id.clone(), WorkflowMessage::Activate); + Ok(ComputeResult::active(TestUpdate::empty())) + } + } + + // Set a very short workflow timeout (100ms) + let config = PregelConfig::default() + .with_workflow_timeout(Duration::from_millis(100)) + .with_vertex_timeout(Duration::from_secs(60)) // vertex timeout is longer + .with_max_supersteps(1000); // high limit so it doesn't hit this first + + let mut runtime: PregelRuntime = + PregelRuntime::with_config(config); + + runtime.add_vertex(Arc::new(SlowForeverVertex { + id: VertexId::new("slow"), + })); + + let start = std::time::Instant::now(); + let result = runtime.run(TestState::default()).await; + let elapsed = start.elapsed(); + + // Should timeout within ~200ms (some tolerance) + assert!(elapsed < Duration::from_millis(500), "Took too long: {:?}", elapsed); + + // Should return WorkflowTimeout error + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, PregelError::WorkflowTimeout(_)), + "Expected WorkflowTimeout, got {:?}", + err + ); + } + + // ============================================ + // C3: Retry Policy Tests (RED - should fail) + // ============================================ + + #[tokio::test] + async fn test_retry_policy_with_backoff() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + static ATTEMPT_COUNT: AtomicUsize = AtomicUsize::new(0); + + // Vertex that fails first 2 times, then succeeds + struct FailingThenSuccessVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for FailingThenSuccessVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + _ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + let attempt = ATTEMPT_COUNT.fetch_add(1, Ordering::SeqCst); + if attempt < 2 { + // Fail with recoverable error + Err(PregelError::vertex_error(self.id.clone(), format!("transient failure {}", attempt))) + } else { + // Succeed on 3rd attempt + Ok(ComputeResult::halt(TestUpdate { + counter_delta: 1, + messages_delta: 0, + })) + } + } + } + + ATTEMPT_COUNT.store(0, Ordering::SeqCst); + + let config = PregelConfig::default() + .with_retry_policy( + super::super::config::RetryPolicy::new(3) + .with_backoff_base(Duration::from_millis(10)) + ) + .with_max_supersteps(20); + + let mut runtime: PregelRuntime = + PregelRuntime::with_config(config); + + runtime.add_vertex(Arc::new(FailingThenSuccessVertex { + id: VertexId::new("flaky"), + })); + + let result = runtime.run(TestState::default()).await; + + // Should succeed after retries + assert!(result.is_ok(), "Expected success after retries, got {:?}", result); + + // Should have attempted 3 times (2 failures + 1 success) + assert_eq!(ATTEMPT_COUNT.load(Ordering::SeqCst), 3); + } + + #[tokio::test] + async fn test_retry_policy_max_exceeded() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + static FAIL_COUNT: AtomicUsize = AtomicUsize::new(0); + + // Vertex that always fails + struct AlwaysFailsVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for AlwaysFailsVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + _ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + FAIL_COUNT.fetch_add(1, Ordering::SeqCst); + Err(PregelError::vertex_error(self.id.clone(), "always fails")) + } + } + + FAIL_COUNT.store(0, Ordering::SeqCst); + + let config = PregelConfig::default() + .with_retry_policy(super::super::config::RetryPolicy::new(3)) + .with_max_supersteps(100); + + let mut runtime: PregelRuntime = + PregelRuntime::with_config(config); + + runtime.add_vertex(Arc::new(AlwaysFailsVertex { + id: VertexId::new("failing"), + })); + + let result = runtime.run(TestState::default()).await; + + // Should fail with MaxRetriesExceeded + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + matches!(err, PregelError::MaxRetriesExceeded { .. }), + "Expected MaxRetriesExceeded, got {:?}", + err + ); + + // Should have tried exactly max_retries + 1 times (initial + retries) + assert_eq!(FAIL_COUNT.load(Ordering::SeqCst), 4); // 1 initial + 3 retries + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/state.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/state.rs new file mode 100644 index 0000000..1b015df --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/state.rs @@ -0,0 +1,297 @@ +//! Workflow state abstraction for Pregel runtime +//! +//! Defines how workflow state is updated and merged during supersteps. +//! The runtime collects updates from all vertices and applies them atomically +//! at the end of each superstep. + +use super::vertex::StateUpdate; + +/// Trait for workflow state managed by the Pregel runtime +/// +/// The workflow state represents the shared data that vertices can read +/// and update during computation. Updates are collected and merged at the +/// end of each superstep. +/// +/// # Example +/// +/// ```ignore +/// #[derive(Clone, Default)] +/// struct ResearchState { +/// findings: Vec, +/// phase: ResearchPhase, +/// completed_topics: HashSet, +/// } +/// +/// impl WorkflowState for ResearchState { +/// type Update = ResearchUpdate; +/// +/// fn apply_update(&self, update: Self::Update) -> Self { +/// let mut new = self.clone(); +/// new.findings.extend(update.new_findings); +/// new.completed_topics.extend(update.completed); +/// if let Some(phase) = update.phase_transition { +/// new.phase = phase; +/// } +/// new +/// } +/// +/// fn merge_updates(updates: Vec) -> Self::Update { +/// ResearchUpdate { +/// new_findings: updates.iter().flat_map(|u| u.new_findings.clone()).collect(), +/// completed: updates.iter().flat_map(|u| u.completed.clone()).collect(), +/// phase_transition: updates.iter().find_map(|u| u.phase_transition.clone()), +/// } +/// } +/// +/// fn is_terminal(&self) -> bool { +/// self.phase == ResearchPhase::Complete +/// } +/// } +/// ``` +pub trait WorkflowState: Clone + Send + Sync + 'static { + /// The update type produced by vertices + type Update: StateUpdate; + + /// Apply an update to produce a new state + /// + /// This should be a pure function - the original state is not modified. + fn apply_update(&self, update: Self::Update) -> Self; + + /// Merge multiple updates into a single update + /// + /// Called when multiple vertices produce updates in the same superstep. + /// The merge should be deterministic (order-independent for correctness). + fn merge_updates(updates: Vec) -> Self::Update; + + /// Check if the state represents a terminal condition + /// + /// When true, the workflow will terminate regardless of vertex states. + fn is_terminal(&self) -> bool { + false + } + + /// Apply multiple updates in sequence + /// + /// Default implementation merges updates then applies the result. + fn apply_updates(&self, updates: Vec) -> Self { + if updates.is_empty() { + return self.clone(); + } + let merged = Self::merge_updates(updates); + self.apply_update(merged) + } +} + +/// A simple unit state for workflows that don't need shared state +/// +/// Useful for workflows where all communication is via messages. +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct UnitState; + +/// Unit update that has no effect +#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)] +pub struct UnitUpdate; + +impl StateUpdate for UnitUpdate { + fn empty() -> Self { + UnitUpdate + } + + fn is_empty(&self) -> bool { + true + } +} + +impl WorkflowState for UnitState { + type Update = UnitUpdate; + + fn apply_update(&self, _update: Self::Update) -> Self { + UnitState + } + + fn merge_updates(_updates: Vec) -> Self::Update { + UnitUpdate + } + + fn is_terminal(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashSet; + + // Counter-based state for testing + #[derive(Clone, Default, Debug, PartialEq)] + struct CounterState { + count: i32, + } + + #[derive(Clone, Debug)] + struct CounterUpdate { + delta: i32, + } + + impl StateUpdate for CounterUpdate { + fn empty() -> Self { + CounterUpdate { delta: 0 } + } + + fn is_empty(&self) -> bool { + self.delta == 0 + } + } + + impl WorkflowState for CounterState { + type Update = CounterUpdate; + + fn apply_update(&self, update: Self::Update) -> Self { + CounterState { + count: self.count + update.delta, + } + } + + fn merge_updates(updates: Vec) -> Self::Update { + CounterUpdate { + delta: updates.iter().map(|u| u.delta).sum(), + } + } + + fn is_terminal(&self) -> bool { + self.count >= 100 + } + } + + #[test] + fn test_state_update_merge() { + let updates = vec![ + CounterUpdate { delta: 5 }, + CounterUpdate { delta: 3 }, + CounterUpdate { delta: -2 }, + ]; + let merged = CounterState::merge_updates(updates); + assert_eq!(merged.delta, 6); + } + + #[test] + fn test_state_apply_update() { + let state = CounterState { count: 10 }; + let update = CounterUpdate { delta: 5 }; + let new_state = state.apply_update(update); + assert_eq!(new_state.count, 15); + // Original state unchanged (immutable) + assert_eq!(state.count, 10); + } + + #[test] + fn test_state_apply_updates() { + let state = CounterState { count: 0 }; + let updates = vec![ + CounterUpdate { delta: 10 }, + CounterUpdate { delta: 20 }, + CounterUpdate { delta: 5 }, + ]; + let new_state = state.apply_updates(updates); + assert_eq!(new_state.count, 35); + } + + #[test] + fn test_state_terminal_condition() { + let non_terminal = CounterState { count: 50 }; + assert!(!non_terminal.is_terminal()); + + let terminal = CounterState { count: 100 }; + assert!(terminal.is_terminal()); + + let over_terminal = CounterState { count: 150 }; + assert!(over_terminal.is_terminal()); + } + + #[test] + fn test_empty_updates() { + let state = CounterState { count: 42 }; + let new_state = state.apply_updates(vec![]); + assert_eq!(new_state.count, 42); + } + + // More complex state for testing + #[derive(Clone, Default, Debug)] + struct CollectionState { + items: Vec, + seen: HashSet, + } + + #[derive(Clone, Debug)] + struct CollectionUpdate { + new_items: Vec, + } + + impl StateUpdate for CollectionUpdate { + fn empty() -> Self { + CollectionUpdate { new_items: vec![] } + } + + fn is_empty(&self) -> bool { + self.new_items.is_empty() + } + } + + impl WorkflowState for CollectionState { + type Update = CollectionUpdate; + + fn apply_update(&self, update: Self::Update) -> Self { + let mut items = self.items.clone(); + let mut seen = self.seen.clone(); + + for item in update.new_items { + if !seen.contains(&item) { + seen.insert(item.clone()); + items.push(item); + } + } + + CollectionState { items, seen } + } + + fn merge_updates(updates: Vec) -> Self::Update { + CollectionUpdate { + new_items: updates.into_iter().flat_map(|u| u.new_items).collect(), + } + } + } + + #[test] + fn test_collection_state_dedup() { + let state = CollectionState::default(); + let updates = vec![ + CollectionUpdate { + new_items: vec!["a".to_string(), "b".to_string()], + }, + CollectionUpdate { + new_items: vec!["b".to_string(), "c".to_string()], + }, + ]; + + let new_state = state.apply_updates(updates); + // "b" should only appear once due to dedup in apply_update + assert_eq!(new_state.items.len(), 3); + assert_eq!(new_state.seen.len(), 3); + } + + #[test] + fn test_unit_state() { + let state = UnitState; + let update = UnitUpdate; + + assert!(update.is_empty()); + assert!(UnitUpdate::empty().is_empty()); + + let new_state = state.apply_update(update); + assert!(!new_state.is_terminal()); + + let merged = UnitState::merge_updates(vec![UnitUpdate, UnitUpdate]); + assert!(merged.is_empty()); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/pregel/vertex.rs b/rust-research-agent/crates/rig-deepagents/src/pregel/vertex.rs new file mode 100644 index 0000000..67b2f45 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/pregel/vertex.rs @@ -0,0 +1,562 @@ +//! Vertex (Node) abstractions for Pregel runtime +//! +//! A Vertex represents a computation unit in the workflow graph. +//! Vertices communicate via messages and execute in synchronized supersteps. + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::hash::Hash; +use std::sync::Arc; + +use super::error::PregelError; +use super::message::VertexMessage; + +/// Unique identifier for a vertex in the workflow graph +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct VertexId(pub String); + +impl VertexId { + /// Create a new VertexId + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } +} + +impl From<&str> for VertexId { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl From for VertexId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl std::fmt::Display for VertexId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +/// Vertex execution state (Pregel's "vote to halt" mechanism) +/// +/// - `Active`: Vertex will compute in the next superstep +/// - `Halted`: Vertex has voted to halt (will reactivate on message receipt) +/// - `Completed`: Vertex has finished and will not reactivate +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +pub enum VertexState { + /// Vertex is active and will compute in next superstep + #[default] + Active, + /// Vertex has voted to halt (will reactivate on message receipt) + Halted, + /// Vertex has completed and will not reactivate + Completed, +} + +impl VertexState { + /// Check if the vertex is active + pub fn is_active(&self) -> bool { + matches!(self, VertexState::Active) + } + + /// Check if the vertex is halted (can be reactivated) + pub fn is_halted(&self) -> bool { + matches!(self, VertexState::Halted) + } + + /// Check if the vertex is completed (cannot be reactivated) + pub fn is_completed(&self) -> bool { + matches!(self, VertexState::Completed) + } +} + +/// Trait for state updates produced by vertex computation +/// +/// State updates are collected from all vertices and merged at the end of each superstep. +pub trait StateUpdate: Clone + Send + Sync + 'static { + /// Create an empty (no-op) update + fn empty() -> Self; + + /// Check if this update has no effect + fn is_empty(&self) -> bool; +} + +/// Context provided to a vertex during computation +/// +/// Provides access to: +/// - Incoming messages from other vertices +/// - Outbox for sending messages +/// - Current superstep number +/// - Workflow state (read-only) +pub struct ComputeContext<'a, S, M: VertexMessage> { + /// Messages received from other vertices + pub messages: &'a [M], + /// Current superstep number (0-indexed) + pub superstep: usize, + /// Read-only access to workflow state + pub state: &'a S, + /// Outgoing messages (target vertex -> messages) + outbox: HashMap>, + /// Current vertex ID + vertex_id: VertexId, +} + +impl<'a, S, M: VertexMessage> ComputeContext<'a, S, M> { + /// Create a new compute context + pub fn new(vertex_id: VertexId, messages: &'a [M], superstep: usize, state: &'a S) -> Self { + Self { + messages, + superstep, + state, + outbox: HashMap::new(), + vertex_id, + } + } + + /// Get the current vertex ID + pub fn id(&self) -> &VertexId { + &self.vertex_id + } + + /// Send a message to another vertex + /// + /// Messages will be delivered at the start of the next superstep. + pub fn send_message(&mut self, target: impl Into, message: M) { + let target = target.into(); + self.outbox.entry(target).or_default().push(message); + } + + /// Send a message to multiple targets + pub fn broadcast(&mut self, targets: impl IntoIterator>, message: M) { + for target in targets { + self.send_message(target.into(), message.clone()); + } + } + + /// Check if this is the first superstep + pub fn is_first_superstep(&self) -> bool { + self.superstep == 0 + } + + /// Check if any messages were received + pub fn has_messages(&self) -> bool { + !self.messages.is_empty() + } + + /// Get the count of received messages + pub fn message_count(&self) -> usize { + self.messages.len() + } + + /// Consume the context and return the outbox + pub fn into_outbox(self) -> HashMap> { + self.outbox + } +} + +use super::state::WorkflowState; + +/// The core vertex trait for Pregel computation +/// +/// Each vertex in the workflow graph implements this trait. +/// During each superstep, active vertices have their `compute` method called. +/// +/// # Type Parameters +/// +/// - `S`: The workflow state type (must implement WorkflowState) +/// - `M`: The message type used for vertex communication +/// +/// # Example +/// +/// ```ignore +/// struct EchoVertex { +/// id: VertexId, +/// } +/// +/// #[async_trait] +/// impl Vertex for EchoVertex { +/// fn id(&self) -> &VertexId { +/// &self.id +/// } +/// +/// async fn compute( +/// &self, +/// ctx: &mut ComputeContext<'_, MyState, WorkflowMessage>, +/// ) -> Result, PregelError> { +/// for msg in ctx.messages { +/// if let WorkflowMessage::Data { key, value } = msg { +/// ctx.send_message("output", WorkflowMessage::data(format!("echo_{}", key), value.clone())); +/// } +/// } +/// Ok(ComputeResult::halt(MyUpdate::empty())) +/// } +/// } +/// ``` +#[async_trait] +pub trait Vertex: Send + Sync +where + S: WorkflowState, + M: VertexMessage, +{ + /// Get the vertex's unique identifier + fn id(&self) -> &VertexId; + + /// Execute the vertex's computation + /// + /// Called during each superstep for active vertices. + /// Returns a state update and the next vertex state. + async fn compute( + &self, + ctx: &mut ComputeContext<'_, S, M>, + ) -> Result, PregelError>; + + /// Combine multiple messages into one (optional optimization) + /// + /// Default implementation returns messages unchanged. + /// Override to reduce message traffic for commutative/associative operations. + fn combine_messages(&self, messages: Vec) -> Vec { + messages + } + + /// Called when the vertex receives messages while halted + /// + /// By default, returns `Active` to reactivate the vertex. + fn on_reactivation(&self, _messages: &[M]) -> VertexState { + VertexState::Active + } +} + +/// Result of a vertex computation +#[derive(Debug, Clone)] +pub struct ComputeResult { + /// State update to apply + pub update: U, + /// New vertex state + pub state: VertexState, +} + +impl ComputeResult { + /// Create a result that keeps the vertex active + pub fn active(update: U) -> Self { + Self { + update, + state: VertexState::Active, + } + } + + /// Create a result that halts the vertex + pub fn halt(update: U) -> Self { + Self { + update, + state: VertexState::Halted, + } + } + + /// Create a result that completes the vertex + pub fn complete(update: U) -> Self { + Self { + update, + state: VertexState::Completed, + } + } + + /// Create a result with a specific state + pub fn with_state(update: U, state: VertexState) -> Self { + Self { update, state } + } +} + +/// Boxed vertex for dynamic dispatch +pub type BoxedVertex = Arc>; + +#[cfg(test)] +mod tests { + use super::*; + use super::super::message::WorkflowMessage; + + // Test StateUpdate implementation for tests + #[derive(Clone, Debug, Default)] + struct TestUpdate { + delta: i32, + } + + impl StateUpdate for TestUpdate { + fn empty() -> Self { + TestUpdate { delta: 0 } + } + + fn is_empty(&self) -> bool { + self.delta == 0 + } + } + + // Test state for tests + #[derive(Clone, Debug, Default)] + #[allow(dead_code)] + struct TestState { + value: i32, + } + + impl super::super::state::WorkflowState for TestState { + type Update = TestUpdate; + + fn apply_update(&self, update: Self::Update) -> Self { + TestState { + value: self.value + update.delta, + } + } + + fn merge_updates(updates: Vec) -> Self::Update { + TestUpdate { + delta: updates.iter().map(|u| u.delta).sum(), + } + } + } + + // Mock vertex for testing + struct EchoVertex { + id: VertexId, + } + + #[async_trait] + impl Vertex for EchoVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + ctx: &mut ComputeContext<'_, TestState, WorkflowMessage>, + ) -> Result, PregelError> { + // Echo back any data messages + for msg in ctx.messages { + if let WorkflowMessage::Data { key, value } = msg { + ctx.send_message( + "output", + WorkflowMessage::data(format!("echo_{}", key), value.clone()), + ); + } + } + Ok(ComputeResult::halt(TestUpdate::empty())) + } + } + + #[tokio::test] + async fn test_mock_vertex_compute() { + let vertex = EchoVertex { + id: VertexId::new("echo"), + }; + + let state = TestState { value: 42 }; + let messages = vec![WorkflowMessage::data("test", "hello")]; + + let mut ctx = ComputeContext::new( + VertexId::new("echo"), + &messages, + 0, + &state, + ); + + let result = vertex.compute(&mut ctx).await; + assert!(result.is_ok()); + + let result = result.unwrap(); + assert!(result.state.is_halted()); + + let outbox = ctx.into_outbox(); + assert!(outbox.contains_key(&VertexId::new("output"))); + assert_eq!(outbox.get(&VertexId::new("output")).unwrap().len(), 1); + } + + #[test] + fn test_compute_context_send_message() { + let state = TestState { value: 0 }; + let messages: Vec = vec![]; + + let mut ctx = ComputeContext::new( + VertexId::new("test"), + &messages, + 0, + &state, + ); + + ctx.send_message("target1", WorkflowMessage::Activate); + ctx.send_message("target1", WorkflowMessage::Halt); + ctx.send_message("target2", WorkflowMessage::Activate); + + let outbox = ctx.into_outbox(); + assert_eq!(outbox.get(&VertexId::new("target1")).unwrap().len(), 2); + assert_eq!(outbox.get(&VertexId::new("target2")).unwrap().len(), 1); + } + + #[test] + fn test_compute_context_broadcast() { + let state = TestState { value: 0 }; + let messages: Vec = vec![]; + + let mut ctx = ComputeContext::new( + VertexId::new("broadcaster"), + &messages, + 0, + &state, + ); + + let targets = vec!["a", "b", "c"]; + ctx.broadcast(targets, WorkflowMessage::Activate); + + let outbox = ctx.into_outbox(); + assert_eq!(outbox.len(), 3); + assert!(outbox.contains_key(&VertexId::new("a"))); + assert!(outbox.contains_key(&VertexId::new("b"))); + assert!(outbox.contains_key(&VertexId::new("c"))); + } + + #[test] + fn test_compute_context_helpers() { + let state = TestState { value: 0 }; + let messages = vec![WorkflowMessage::Activate, WorkflowMessage::Halt]; + + let ctx = ComputeContext::::new( + VertexId::new("test"), + &messages, + 0, + &state, + ); + + assert!(ctx.is_first_superstep()); + assert!(ctx.has_messages()); + assert_eq!(ctx.message_count(), 2); + assert_eq!(ctx.id(), &VertexId::new("test")); + + let ctx2 = ComputeContext::::new( + VertexId::new("test2"), + &[], + 5, + &state, + ); + + assert!(!ctx2.is_first_superstep()); + assert!(!ctx2.has_messages()); + assert_eq!(ctx2.message_count(), 0); + } + + #[test] + fn test_compute_result_constructors() { + let update = TestUpdate { delta: 5 }; + + let active = ComputeResult::active(update.clone()); + assert!(active.state.is_active()); + + let halted = ComputeResult::halt(update.clone()); + assert!(halted.state.is_halted()); + + let completed = ComputeResult::complete(update.clone()); + assert!(completed.state.is_completed()); + + let custom = ComputeResult::with_state(update, VertexState::Halted); + assert!(custom.state.is_halted()); + } + + #[test] + fn test_state_update_trait() { + let empty = TestUpdate::empty(); + assert!(empty.is_empty()); + + let non_empty = TestUpdate { delta: 10 }; + assert!(!non_empty.is_empty()); + } + + #[test] + fn test_vertex_state_helpers() { + assert!(VertexState::Active.is_active()); + assert!(!VertexState::Active.is_halted()); + assert!(!VertexState::Active.is_completed()); + + assert!(!VertexState::Halted.is_active()); + assert!(VertexState::Halted.is_halted()); + assert!(!VertexState::Halted.is_completed()); + + assert!(!VertexState::Completed.is_active()); + assert!(!VertexState::Completed.is_halted()); + assert!(VertexState::Completed.is_completed()); + } + + #[test] + fn test_vertex_id_from_str() { + let id: VertexId = "planner".into(); + assert_eq!(id.0, "planner"); + } + + #[test] + fn test_vertex_id_from_string() { + let id: VertexId = String::from("router").into(); + assert_eq!(id.0, "router"); + } + + #[test] + fn test_vertex_id_new() { + let id = VertexId::new("explorer"); + assert_eq!(id.0, "explorer"); + } + + #[test] + fn test_vertex_id_equality() { + let id1: VertexId = "node1".into(); + let id2: VertexId = "node1".into(); + let id3: VertexId = "node2".into(); + assert_eq!(id1, id2); + assert_ne!(id1, id3); + } + + #[test] + fn test_vertex_id_hash() { + use std::collections::HashSet; + let mut set = HashSet::new(); + set.insert(VertexId::from("a")); + set.insert(VertexId::from("b")); + set.insert(VertexId::from("a")); // duplicate + assert_eq!(set.len(), 2); + } + + #[test] + fn test_vertex_id_display() { + let id = VertexId::new("test_node"); + assert_eq!(format!("{}", id), "test_node"); + } + + #[test] + fn test_vertex_state_default_is_active() { + assert_eq!(VertexState::default(), VertexState::Active); + } + + #[test] + fn test_vertex_state_variants() { + let active = VertexState::Active; + let halted = VertexState::Halted; + let completed = VertexState::Completed; + + assert_ne!(active, halted); + assert_ne!(halted, completed); + assert_ne!(active, completed); + } + + #[test] + fn test_vertex_id_serialization() { + let id = VertexId::new("test"); + let json = serde_json::to_string(&id).unwrap(); + let deserialized: VertexId = serde_json::from_str(&json).unwrap(); + assert_eq!(id, deserialized); + } + + #[test] + fn test_vertex_state_serialization() { + let state = VertexState::Halted; + let json = serde_json::to_string(&state).unwrap(); + let deserialized: VertexState = serde_json::from_str(&json).unwrap(); + assert_eq!(state, deserialized); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/workflow/mod.rs b/rust-research-agent/crates/rig-deepagents/src/workflow/mod.rs new file mode 100644 index 0000000..6499b04 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/workflow/mod.rs @@ -0,0 +1,57 @@ +//! Workflow Graph System for Pregel-Based Agent Orchestration +//! +//! This module provides the building blocks for constructing and executing +//! agent workflows using a Pregel-inspired graph execution model. +//! +//! # Overview +//! +//! ```text +//! ┌─────────────────────────────────────────────────────────────┐ +//! │ WorkflowGraph │ +//! │ ┌─────────────────────────────────────────────────────┐ │ +//! │ │ Nodes │ │ +//! │ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │ +//! │ │ │ Agent │→ │ Router │→ │ SubAgent│ │ │ +//! │ │ └─────────┘ └────┬────┘ └─────────┘ │ │ +//! │ │ │ │ │ +//! │ │ ┌─────▼─────┐ │ │ +//! │ │ │ Tool │ │ │ +//! │ │ └───────────┘ │ │ +//! │ └─────────────────────────────────────────────────────┘ │ +//! │ │ +//! │ Compile via WorkflowBuilder → CompiledWorkflow │ +//! │ Execute via PregelRuntime │ +//! └─────────────────────────────────────────────────────────────┘ +//! ``` +//! +//! # Usage +//! +//! ```ignore +//! use rig_deepagents::workflow::{WorkflowGraph, NodeKind, AgentNodeConfig}; +//! +//! let workflow = WorkflowGraph::::new() +//! .name("research_agent") +//! .node("planner", NodeKind::Agent(AgentNodeConfig { +//! system_prompt: "Plan the research...".into(), +//! ..Default::default() +//! })) +//! .node("researcher", NodeKind::Agent(AgentNodeConfig { +//! system_prompt: "Execute research...".into(), +//! ..Default::default() +//! })) +//! .entry("planner") +//! .edge("planner", "researcher") +//! .edge("researcher", END) +//! .build()?; +//! ``` + +pub mod node; +pub mod vertices; + +pub use node::{ + AgentNodeConfig, Branch, BranchCondition, FanInNodeConfig, FanOutNodeConfig, MergeStrategy, + NodeKind, RouterNodeConfig, RoutingStrategy, SplitStrategy, StopCondition, SubAgentNodeConfig, + ToolNodeConfig, +}; + +pub use vertices::agent::AgentVertex; diff --git a/rust-research-agent/crates/rig-deepagents/src/workflow/node.rs b/rust-research-agent/crates/rig-deepagents/src/workflow/node.rs new file mode 100644 index 0000000..5bdff6d --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/workflow/node.rs @@ -0,0 +1,513 @@ +//! Node Types and Configuration for Workflow Graphs +//! +//! Defines the different types of nodes that can exist in a workflow graph. +//! Each node type has specific behavior and configuration options. +//! +//! # Node Types +//! +//! - **Agent**: LLM-based processing with tool calling capabilities +//! - **Tool**: Single tool execution with static or dynamic arguments +//! - **Router**: Conditional branching based on state or LLM decisions +//! - **SubAgent**: Delegation to nested workflows with recursion protection +//! - **FanOut**: Parallel dispatch to multiple targets +//! - **FanIn**: Synchronization point waiting for multiple sources +//! - **Passthrough**: Simple data forwarding (identity transformation) + +use serde::{Deserialize, Serialize}; +use std::collections::{HashMap, HashSet}; +use std::time::Duration; + +/// The kind of node in a workflow graph. +/// +/// Each variant represents a different computation pattern. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum NodeKind { + /// An LLM-based agent that can process messages and call tools + Agent(AgentNodeConfig), + + /// A single tool execution node + Tool(ToolNodeConfig), + + /// Conditional routing based on state or LLM decisions + Router(RouterNodeConfig), + + /// Delegation to a sub-workflow + SubAgent(SubAgentNodeConfig), + + /// Parallel dispatch to multiple targets + FanOut(FanOutNodeConfig), + + /// Synchronization point waiting for multiple sources + FanIn(FanInNodeConfig), + + /// Simple passthrough (identity transformation) + #[default] + Passthrough, +} + +/// Configuration for an Agent node. +/// +/// Agents use LLMs to process messages and can optionally call tools. +/// They iterate until a stop condition is met. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentNodeConfig { + /// System prompt for the agent + pub system_prompt: String, + + /// Maximum iterations before forcing termination + #[serde(default = "default_max_iterations")] + pub max_iterations: usize, + + /// Conditions that cause the agent to stop iterating + #[serde(default)] + pub stop_conditions: Vec, + + /// Tools the agent is allowed to use (None = all tools) + #[serde(default)] + pub allowed_tools: Option>, + + /// Timeout for each LLM call + #[serde(default, with = "humantime_serde")] + pub llm_timeout: Option, + + /// Temperature for LLM calls + #[serde(default)] + pub temperature: Option, +} + +impl Default for AgentNodeConfig { + fn default() -> Self { + Self { + system_prompt: String::new(), + max_iterations: 10, + stop_conditions: vec![StopCondition::NoToolCalls], + allowed_tools: None, + llm_timeout: None, + temperature: None, + } + } +} + +fn default_max_iterations() -> usize { + 10 +} + +/// Conditions that cause an agent to stop iterating. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum StopCondition { + /// Stop when the LLM produces no tool calls + NoToolCalls, + + /// Stop when a specific tool is called + OnTool { tool_name: String }, + + /// Stop when the message contains specific text + ContainsText { pattern: String }, + + /// Stop when a state field matches a condition + StateMatch { field: String, value: serde_json::Value }, + + /// Stop after a certain number of iterations + MaxIterations { count: usize }, +} + +/// Configuration for a Tool node. +/// +/// Executes a single tool with arguments from static config or state. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ToolNodeConfig { + /// Name of the tool to execute + #[serde(default)] + pub tool_name: String, + + /// Static arguments (can be overridden by state paths) + #[serde(default)] + pub static_args: HashMap, + + /// Map from argument name to state path for dynamic arguments + #[serde(default)] + pub state_arg_paths: HashMap, + + /// Path in state where the result should be stored + #[serde(default)] + pub result_path: Option, + + /// Timeout for tool execution + #[serde(default, with = "humantime_serde")] + pub timeout: Option, +} + +/// Configuration for a Router node. +/// +/// Determines next node based on state inspection or LLM decision. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RouterNodeConfig { + /// How to make the routing decision + pub strategy: RoutingStrategy, + + /// Branches to evaluate (in order for StateField strategy) + pub branches: Vec, + + /// Default branch if no conditions match (required for StateField) + #[serde(default)] + pub default: Option, +} + +impl Default for RouterNodeConfig { + fn default() -> Self { + Self { + strategy: RoutingStrategy::StateField { + field: String::new(), + }, + branches: Vec::new(), + default: None, + } + } +} + +/// Strategy for making routing decisions. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum RoutingStrategy { + /// Route based on a state field value + StateField { + /// Path to the state field to inspect + field: String, + }, + + /// Route based on LLM classification + LLMDecision { + /// Prompt describing the options to the LLM + prompt: String, + /// Model to use (optional, uses default if not specified) + model: Option, + }, +} + +/// A branch in a routing decision. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Branch { + /// Target node for this branch + pub target: String, + + /// Condition that must be true for this branch + pub condition: BranchCondition, +} + +/// Condition for a routing branch. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "op", rename_all = "snake_case")] +pub enum BranchCondition { + /// Value equals expected + Equals { value: serde_json::Value }, + + /// Value is in set of options + In { values: Vec }, + + /// Value matches regex pattern + Matches { pattern: String }, + + /// Value is truthy (non-null, non-empty, non-false) + IsTruthy, + + /// Value is falsy + IsFalsy, + + /// Always true (used for catch-all branches) + Always, +} + +/// Configuration for a SubAgent node. +/// +/// Delegates work to a nested workflow with recursion protection. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SubAgentNodeConfig { + /// Name of the sub-agent to invoke + pub agent_name: String, + + /// Maximum recursion depth (prevents infinite nesting) + #[serde(default = "default_max_recursion")] + pub max_recursion: usize, + + /// Mapping from parent state paths to sub-agent input + #[serde(default)] + pub input_mapping: HashMap, + + /// Mapping from sub-agent output to parent state paths + #[serde(default)] + pub output_mapping: HashMap, + + /// Timeout for the entire sub-agent execution + #[serde(default, with = "humantime_serde")] + pub timeout: Option, +} + +impl Default for SubAgentNodeConfig { + fn default() -> Self { + Self { + agent_name: String::new(), + max_recursion: 5, + input_mapping: HashMap::new(), + output_mapping: HashMap::new(), + timeout: None, + } + } +} + +fn default_max_recursion() -> usize { + 5 +} + +/// Configuration for a FanOut node. +/// +/// Broadcasts messages to multiple targets in parallel. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FanOutNodeConfig { + /// Target nodes to send to + pub targets: Vec, + + /// How to split the work among targets + #[serde(default)] + pub split_strategy: SplitStrategy, + + /// Path to array in state to split (for Split strategy) + #[serde(default)] + pub split_path: Option, +} + +impl Default for FanOutNodeConfig { + fn default() -> Self { + Self { + targets: Vec::new(), + split_strategy: SplitStrategy::Broadcast, + split_path: None, + } + } +} + +/// Strategy for splitting work in a FanOut node. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SplitStrategy { + /// Send the same message to all targets + #[default] + Broadcast, + + /// Split an array and send one element to each target + Split, + + /// Round-robin distribution + RoundRobin, +} + +/// Configuration for a FanIn node. +/// +/// Waits for messages from multiple sources and merges them. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FanInNodeConfig { + /// Source nodes to wait for + pub sources: Vec, + + /// How to merge results from sources + #[serde(default)] + pub merge_strategy: MergeStrategy, + + /// Path in state where merged results are stored + #[serde(default)] + pub result_path: Option, + + /// Timeout for waiting for all sources + #[serde(default, with = "humantime_serde")] + pub timeout: Option, +} + +impl Default for FanInNodeConfig { + fn default() -> Self { + Self { + sources: Vec::new(), + merge_strategy: MergeStrategy::Collect, + result_path: None, + timeout: None, + } + } +} + +/// Strategy for merging results in a FanIn node. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum MergeStrategy { + /// Collect all results into an array + #[default] + Collect, + + /// Use first result that arrives + First, + + /// Use last result (all must complete) + Last, + + /// Concatenate string results + Concat, + + /// Merge object results (later values overwrite) + Merge, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_kind_serialization() { + let agent = NodeKind::Agent(AgentNodeConfig { + system_prompt: "You are helpful.".into(), + ..Default::default() + }); + + let json = serde_json::to_string(&agent).unwrap(); + let deserialized: NodeKind = serde_json::from_str(&json).unwrap(); + + match deserialized { + NodeKind::Agent(config) => { + assert_eq!(config.system_prompt, "You are helpful."); + } + _ => panic!("Wrong variant"), + } + } + + #[test] + fn test_tool_node_config() { + let tool = ToolNodeConfig { + tool_name: "search".into(), + static_args: [("query".into(), serde_json::json!("test"))].into(), + state_arg_paths: [("max_results".into(), "config.limit".into())].into(), + result_path: Some("search_results".into()), + timeout: Some(Duration::from_secs(30)), + }; + + let json = serde_json::to_string(&tool).unwrap(); + assert!(json.contains("search")); + assert!(json.contains("query")); + } + + #[test] + fn test_router_config_with_branches() { + let router = RouterNodeConfig { + strategy: RoutingStrategy::StateField { + field: "phase".into(), + }, + branches: vec![ + Branch { + target: "explore".into(), + condition: BranchCondition::Equals { + value: serde_json::json!("exploratory"), + }, + }, + Branch { + target: "synthesize".into(), + condition: BranchCondition::Equals { + value: serde_json::json!("synthesis"), + }, + }, + ], + default: Some("done".into()), + }; + + let json = serde_json::to_string(&router).unwrap(); + let deserialized: RouterNodeConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.branches.len(), 2); + assert_eq!(deserialized.default, Some("done".into())); + } + + #[test] + fn test_stop_conditions() { + let conditions = vec![ + StopCondition::NoToolCalls, + StopCondition::OnTool { + tool_name: "submit".into(), + }, + StopCondition::ContainsText { + pattern: "DONE".into(), + }, + StopCondition::MaxIterations { count: 5 }, + ]; + + let json = serde_json::to_string(&conditions).unwrap(); + let deserialized: Vec = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.len(), 4); + assert_eq!(deserialized[0], StopCondition::NoToolCalls); + } + + #[test] + fn test_subagent_config() { + let config = SubAgentNodeConfig { + agent_name: "researcher".into(), + max_recursion: 3, + input_mapping: [("query".into(), "research.topic".into())].into(), + output_mapping: [("findings".into(), "research.findings".into())].into(), + timeout: Some(Duration::from_secs(300)), + }; + + assert_eq!(config.agent_name, "researcher"); + assert_eq!(config.max_recursion, 3); + } + + #[test] + fn test_fanout_fanin_config() { + let fanout = FanOutNodeConfig { + targets: vec!["a".into(), "b".into(), "c".into()], + split_strategy: SplitStrategy::Broadcast, + split_path: None, + }; + + let fanin = FanInNodeConfig { + sources: vec!["a".into(), "b".into(), "c".into()], + merge_strategy: MergeStrategy::Collect, + result_path: Some("results".into()), + timeout: Some(Duration::from_secs(60)), + }; + + assert_eq!(fanout.targets, fanin.sources); + } + + #[test] + fn test_branch_conditions() { + let conditions = vec![ + BranchCondition::Equals { + value: serde_json::json!("active"), + }, + BranchCondition::In { + values: vec![serde_json::json!(1), serde_json::json!(2)], + }, + BranchCondition::Matches { + pattern: "^done.*".into(), + }, + BranchCondition::IsTruthy, + BranchCondition::Always, + ]; + + for condition in &conditions { + let json = serde_json::to_string(condition).unwrap(); + let _: BranchCondition = serde_json::from_str(&json).unwrap(); + } + } + + #[test] + fn test_node_kind_variants() { + // Ensure all 7 variants can be created + let _agent = NodeKind::Agent(Default::default()); + let _tool = NodeKind::Tool(Default::default()); + let _router = NodeKind::Router(Default::default()); + let _subagent = NodeKind::SubAgent(Default::default()); + let _fanout = NodeKind::FanOut(Default::default()); + let _fanin = NodeKind::FanIn(Default::default()); + let _passthrough = NodeKind::Passthrough; + + // Ensure default is Passthrough + assert!(matches!(NodeKind::default(), NodeKind::Passthrough)); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/workflow/vertices/agent.rs b/rust-research-agent/crates/rig-deepagents/src/workflow/vertices/agent.rs new file mode 100644 index 0000000..d69bac0 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/workflow/vertices/agent.rs @@ -0,0 +1,354 @@ +//! AgentVertex: LLM-based agent node with tool calling capabilities +//! +//! Implements the Vertex trait for agent nodes that use LLMs to process +//! messages and can iteratively call tools until a stop condition is met. + +use async_trait::async_trait; +use std::sync::Arc; + +use crate::llm::{LLMConfig, LLMProvider}; +use crate::middleware::ToolDefinition; +use crate::pregel::error::PregelError; +use crate::pregel::message::WorkflowMessage; +use crate::pregel::state::WorkflowState; +use crate::pregel::vertex::{ComputeContext, ComputeResult, StateUpdate, Vertex, VertexId}; +use crate::state::{Message, Role}; +use crate::workflow::node::{AgentNodeConfig, StopCondition}; + +/// An agent vertex that uses an LLM to process messages and call tools +pub struct AgentVertex { + id: VertexId, + config: AgentNodeConfig, + llm: Arc, + tools: Vec, + _phantom: std::marker::PhantomData, +} + +impl AgentVertex { + /// Create a new agent vertex + pub fn new( + id: impl Into, + config: AgentNodeConfig, + llm: Arc, + tools: Vec, + ) -> Self { + Self { + id: id.into(), + config, + llm, + tools, + _phantom: std::marker::PhantomData, + } + } + + /// Check if any stop condition is met + fn check_stop_conditions(&self, message: &Message, iteration: usize) -> bool { + for condition in &self.config.stop_conditions { + match condition { + StopCondition::NoToolCalls => { + if message.tool_calls.is_none() || message.tool_calls.as_ref().unwrap().is_empty() { + return true; + } + } + StopCondition::OnTool { tool_name } => { + if let Some(tool_calls) = &message.tool_calls { + if tool_calls.iter().any(|tc| &tc.name == tool_name) { + return true; + } + } + } + StopCondition::ContainsText { pattern } => { + if message.content.contains(pattern) { + return true; + } + } + StopCondition::MaxIterations { count } => { + if iteration >= *count { + return true; + } + } + StopCondition::StateMatch { .. } => { + // TODO: Implement state matching + continue; + } + } + } + false + } + + /// Filter tools based on allowed list + fn filter_tools(&self) -> Vec { + if let Some(allowed) = &self.config.allowed_tools { + self.tools + .iter() + .filter(|t| allowed.contains(&t.name)) + .cloned() + .collect() + } else { + self.tools.clone() + } + } + + /// Build LLM config from agent config + fn build_llm_config(&self) -> Option { + self.config.temperature.map(|temp| LLMConfig::new("").with_temperature(temp as f64)) + } +} + +#[async_trait] +impl Vertex for AgentVertex { + fn id(&self) -> &VertexId { + &self.id + } + + async fn compute( + &self, + ctx: &mut ComputeContext<'_, S, WorkflowMessage>, + ) -> Result, PregelError> { + // Build message history starting with system prompt + let mut messages = vec![Message { + role: Role::System, + content: self.config.system_prompt.clone(), + tool_calls: None, + tool_call_id: None, + }]; + + // Add any incoming workflow messages as user messages + for msg in ctx.messages { + if let WorkflowMessage::Data { key: _, value } = msg { + messages.push(Message { + role: Role::User, + content: value.to_string(), + tool_calls: None, + tool_call_id: None, + }); + } + } + + // If no user messages, add a default activation message + if messages.len() == 1 { + messages.push(Message { + role: Role::User, + content: "Begin processing.".to_string(), + tool_calls: None, + tool_call_id: None, + }); + } + + let filtered_tools = self.filter_tools(); + let llm_config = self.build_llm_config(); + + // Agent loop: iterate until stop condition or max iterations + for iteration in 0..self.config.max_iterations { + // Call LLM + let response = self + .llm + .complete(&messages, &filtered_tools, llm_config.as_ref()) + .await + .map_err(|e| PregelError::VertexError { + vertex_id: self.id.clone(), + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let assistant_message = response.message.clone(); + messages.push(assistant_message.clone()); + + // Check stop conditions + if self.check_stop_conditions(&assistant_message, iteration) { + // Send final response as output message + ctx.send_message( + "output", + WorkflowMessage::Data { + key: "response".to_string(), + value: serde_json::Value::String(assistant_message.content), + }, + ); + return Ok(ComputeResult::halt(S::Update::empty())); + } + + // If there are tool calls, execute them + if let Some(tool_calls) = &assistant_message.tool_calls { + for tool_call in tool_calls { + // TODO: Execute tool calls + // For now, just add a mock tool result + messages.push(Message::tool( + "Tool executed successfully", + &tool_call.id, + )); + } + } else { + // No tool calls and no stop condition matched, halt anyway + ctx.send_message( + "output", + WorkflowMessage::Data { + key: "response".to_string(), + value: serde_json::Value::String(assistant_message.content), + }, + ); + return Ok(ComputeResult::halt(S::Update::empty())); + } + } + + // Max iterations reached + Err(PregelError::vertex_error( + self.id.clone(), + "Max iterations reached", + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::DeepAgentError; + use crate::llm::LLMResponse; + use crate::pregel::state::UnitState; + use crate::pregel::vertex::VertexState; + use crate::state::ToolCall; + use std::sync::Mutex; + + // Mock LLM provider for testing + struct MockLLMProvider { + responses: Arc>>, + } + + impl MockLLMProvider { + fn new() -> Self { + Self { + responses: Arc::new(Mutex::new(Vec::new())), + } + } + + fn with_response(self, content: impl Into) -> Self { + let message = Message { + role: Role::Assistant, + content: content.into(), + tool_calls: None, + tool_call_id: None, + }; + self.responses.lock().unwrap().push(message); + self + } + + fn with_tool_call(self, content: impl Into, tool_name: impl Into) -> Self { + let message = Message { + role: Role::Assistant, + content: content.into(), + tool_calls: Some(vec![ToolCall { + id: "test_call_1".to_string(), + name: tool_name.into(), + arguments: serde_json::json!({}), + }]), + tool_call_id: None, + }; + self.responses.lock().unwrap().push(message); + self + } + } + + #[async_trait] + impl LLMProvider for MockLLMProvider { + async fn complete( + &self, + _messages: &[Message], + _tools: &[ToolDefinition], + _config: Option<&LLMConfig>, + ) -> Result { + let mut responses = self.responses.lock().unwrap(); + if responses.is_empty() { + return Err(DeepAgentError::AgentExecution( + "No more mock responses".to_string(), + )); + } + let message = responses.remove(0); + Ok(LLMResponse::new(message)) + } + + fn name(&self) -> &str { + "mock" + } + + fn default_model(&self) -> &str { + "mock-model" + } + } + + #[tokio::test] + async fn test_agent_vertex_single_response() { + let mock_llm = MockLLMProvider::new().with_response("Hello! How can I help?"); + + let vertex = AgentVertex::::new( + "agent", + AgentNodeConfig { + system_prompt: "You are helpful.".into(), + stop_conditions: vec![StopCondition::NoToolCalls], + ..Default::default() + }, + Arc::new(mock_llm), + vec![], + ); + + let mut ctx = + ComputeContext::::new("agent".into(), &[], 0, &UnitState); + + let result = vertex.compute(&mut ctx).await.unwrap(); + + assert_eq!(result.state, VertexState::Halted); + assert!(ctx.has_messages() || !ctx.into_outbox().is_empty()); + } + + #[tokio::test] + async fn test_agent_vertex_stop_on_tool() { + let mock_llm = MockLLMProvider::new().with_tool_call("Let me search for that", "search"); + + let vertex = AgentVertex::::new( + "agent", + AgentNodeConfig { + system_prompt: "You are a researcher.".into(), + stop_conditions: vec![StopCondition::OnTool { + tool_name: "search".to_string(), + }], + ..Default::default() + }, + Arc::new(mock_llm), + vec![], + ); + + let mut ctx = + ComputeContext::::new("agent".into(), &[], 0, &UnitState); + + let result = vertex.compute(&mut ctx).await.unwrap(); + + assert_eq!(result.state, VertexState::Halted); + } + + #[tokio::test] + async fn test_agent_vertex_max_iterations() { + // Mock LLM that always returns tool calls (would loop forever without limit) + let mut mock_llm = MockLLMProvider::new(); + for _ in 0..15 { + mock_llm = mock_llm.with_tool_call("Still thinking...", "think"); + } + + let vertex = AgentVertex::::new( + "agent", + AgentNodeConfig { + system_prompt: "You are helpful.".into(), + max_iterations: 3, + stop_conditions: vec![], // No stop conditions, relies on max_iterations + ..Default::default() + }, + Arc::new(mock_llm), + vec![], + ); + + let mut ctx = + ComputeContext::::new("agent".into(), &[], 0, &UnitState); + + let result = vertex.compute(&mut ctx).await; + + // Should hit max iterations and return error + assert!(result.is_err()); + } +} diff --git a/rust-research-agent/crates/rig-deepagents/src/workflow/vertices/mod.rs b/rust-research-agent/crates/rig-deepagents/src/workflow/vertices/mod.rs new file mode 100644 index 0000000..bc4e195 --- /dev/null +++ b/rust-research-agent/crates/rig-deepagents/src/workflow/vertices/mod.rs @@ -0,0 +1,11 @@ +//! Vertex implementations for workflow nodes +//! +//! Each vertex type implements the Vertex trait and corresponds to a NodeKind variant. + +pub mod agent; + +// Future vertex implementations: +// pub mod tool; +// pub mod router; +// pub mod subagent; +// pub mod parallel;