diff --git a/Context_Engineering_Research.ipynb b/Context_Engineering_Research.ipynb new file mode 100644 index 0000000..8760b28 --- /dev/null +++ b/Context_Engineering_Research.ipynb @@ -0,0 +1,677 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4c897bc9", + "metadata": {}, + "source": [ + "# Context Engineering 연구 노트북\n", + "\n", + "DeepAgents 라이브러리에서 사용되는 5가지 Context Engineering 전략을 분석하고 실험합니다.\n", + "\n", + "## Context Engineering 5가지 핵심 전략\n", + "\n", + "| 전략 | 설명 | DeepAgents 구현 |\n", + "|------|------|----------------|\n", + "| **1. Offloading** | 대용량 결과를 파일로 축출 | FilesystemMiddleware |\n", + "| **2. Reduction** | Compaction + Summarization | SummarizationMiddleware |\n", + "| **3. Retrieval** | grep/glob 기반 검색 | FilesystemMiddleware |\n", + "| **4. Isolation** | SubAgent로 컨텍스트 격리 | SubAgentMiddleware |\n", + "| **5. Caching** | Prompt Caching | AnthropicPromptCachingMiddleware |\n", + "\n", + "## 아키텍처 개요\n", + "\n", + "```\n", + "┌─────────────────────────────────────────────────────────────────┐\n", + "│ Context Engineering │\n", + "├─────────────────────────────────────────────────────────────────┤\n", + "│ │\n", + "│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │\n", + "│ │ Offloading │ │ Reduction │ │ Caching │ │\n", + "│ │ (20k 토큰) │ │ (85% 임계) │ │ (Anthropic)│ │\n", + "│ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ │\n", + "│ │ │ │ │\n", + "│ ▼ ▼ ▼ │\n", + "│ ┌─────────────────────────────────────────────────────┐ │\n", + "│ │ Middleware Stack │ │\n", + "│ └─────────────────────────────────────────────────────┘ │\n", + "│ │ │\n", + "│ ┌────────────────┼────────────────┐ │\n", + "│ ▼ ▼ ▼ │\n", + "│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │\n", + "│ │ Retrieval │ │ Isolation │ │ Backend │ │\n", + "│ │(grep/glob) │ │ (SubAgent) │ │ (FileSystem│ │\n", + "│ └────────────┘ └────────────┘ └────────────┘ │\n", + "│ │\n", + "└─────────────────────────────────────────────────────────────────┘\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fecc3e39", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv(\".env\", override=True)\n", + "\n", + "PROJECT_ROOT = Path.cwd()\n", + "if str(PROJECT_ROOT) not in sys.path:\n", + " sys.path.insert(0, str(PROJECT_ROOT))" + ] + }, + { + "cell_type": "markdown", + "id": "strategy1", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 전략 1: Context Offloading\n", + "\n", + "대용량 도구 결과를 파일시스템으로 축출하여 컨텍스트 윈도우 오버플로우를 방지합니다.\n", + "\n", + "### 핵심 원리\n", + "- 도구 결과가 `tool_token_limit_before_evict` (기본 20,000 토큰) 초과 시 자동 축출\n", + "- `/large_tool_results/{tool_call_id}` 경로에 저장\n", + "- 처음 10줄 미리보기 제공\n", + "- 에이전트가 `read_file`로 필요할 때 로드" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "offloading_demo", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent.context_strategies.offloading import (\n", + " ContextOffloadingStrategy,\n", + " OffloadingConfig,\n", + ")\n", + "\n", + "config = OffloadingConfig(\n", + " token_limit_before_evict=20000,\n", + " eviction_path_prefix=\"/large_tool_results\",\n", + " preview_lines=10,\n", + ")\n", + "\n", + "print(f\"토큰 임계값: {config.token_limit_before_evict:,}\")\n", + "print(f\"축출 경로: {config.eviction_path_prefix}\")\n", + "print(f\"미리보기 줄 수: {config.preview_lines}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "offloading_test", + "metadata": {}, + "outputs": [], + "source": [ + "strategy = ContextOffloadingStrategy(config=config)\n", + "\n", + "small_content = \"짧은 텍스트\" * 100\n", + "large_content = \"대용량 텍스트\" * 30000\n", + "\n", + "print(f\"짧은 콘텐츠: {len(small_content)} 자 → 축출 대상: {strategy._should_offload(small_content)}\")\n", + "print(f\"대용량 콘텐츠: {len(large_content):,} 자 → 축출 대상: {strategy._should_offload(large_content)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "strategy2", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 전략 2: Context Reduction\n", + "\n", + "컨텍스트 윈도우 사용량이 임계값을 초과할 때 자동으로 대화 내용을 압축합니다.\n", + "\n", + "### 두 가지 기법\n", + "\n", + "| 기법 | 설명 | 비용 |\n", + "|------|------|------|\n", + "| **Compaction** | 오래된 도구 호출/결과 제거 | 무료 |\n", + "| **Summarization** | LLM이 대화 요약 | API 비용 발생 |\n", + "\n", + "우선순위: Compaction → Summarization" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "reduction_demo", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent.context_strategies.reduction import (\n", + " ContextReductionStrategy,\n", + " ReductionConfig,\n", + ")\n", + "\n", + "config = ReductionConfig(\n", + " context_threshold=0.85,\n", + " model_context_window=200000,\n", + " compaction_age_threshold=10,\n", + " min_messages_to_keep=5,\n", + ")\n", + "\n", + "print(f\"임계값: {config.context_threshold * 100}%\")\n", + "print(f\"컨텍스트 윈도우: {config.model_context_window:,} 토큰\")\n", + "print(f\"Compaction 대상 나이: {config.compaction_age_threshold} 메시지\")\n", + "print(f\"최소 유지 메시지: {config.min_messages_to_keep}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "reduction_test", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.messages import AIMessage, HumanMessage\n", + "\n", + "strategy = ContextReductionStrategy(config=config)\n", + "\n", + "messages = [\n", + " HumanMessage(content=\"안녕하세요\" * 1000),\n", + " AIMessage(content=\"안녕하세요\" * 1000),\n", + "] * 20\n", + "\n", + "usage_ratio = strategy._get_context_usage_ratio(messages)\n", + "print(f\"컨텍스트 사용률: {usage_ratio * 100:.1f}%\")\n", + "print(f\"축소 필요: {strategy._should_reduce(messages)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "strategy3", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 전략 3: Context Retrieval\n", + "\n", + "grep/glob 기반의 단순하고 빠른 검색으로 필요한 정보만 선택적으로 로드합니다.\n", + "\n", + "### 벡터 검색을 사용하지 않는 이유\n", + "\n", + "| 특성 | 직접 검색 | 벡터 검색 |\n", + "|------|----------|----------|\n", + "| 결정성 | ✅ 정확한 매칭 | ❌ 확률적 |\n", + "| 인프라 | ✅ 불필요 | ❌ 벡터 DB 필요 |\n", + "| 속도 | ✅ 빠름 | ❌ 인덱싱 오버헤드 |\n", + "| 디버깅 | ✅ 예측 가능 | ❌ 블랙박스 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "retrieval_demo", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent.context_strategies.retrieval import (\n", + " ContextRetrievalStrategy,\n", + " RetrievalConfig,\n", + ")\n", + "\n", + "config = RetrievalConfig(\n", + " default_read_limit=500,\n", + " max_grep_results=100,\n", + " max_glob_results=100,\n", + " truncate_line_length=2000,\n", + ")\n", + "\n", + "print(f\"기본 읽기 제한: {config.default_read_limit} 줄\")\n", + "print(f\"grep 최대 결과: {config.max_grep_results}\")\n", + "print(f\"glob 최대 결과: {config.max_glob_results}\")\n", + "print(f\"줄 길이 제한: {config.truncate_line_length} 자\")" + ] + }, + { + "cell_type": "markdown", + "id": "strategy4", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 전략 4: Context Isolation\n", + "\n", + "SubAgent를 통해 독립된 컨텍스트 윈도우에서 작업을 수행합니다.\n", + "\n", + "### 장점\n", + "- 메인 에이전트 컨텍스트 오염 방지\n", + "- 복잡한 작업의 격리 처리\n", + "- 병렬 처리 가능\n", + "\n", + "### SubAgent 유형\n", + "\n", + "| 유형 | 구조 | 특징 |\n", + "|------|------|------|\n", + "| Simple | `{name, system_prompt, tools}` | 단일 응답 |\n", + "| Compiled | `{name, runnable}` | 자체 DeepAgent, 다중 턴 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "isolation_demo", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent.context_strategies.isolation import (\n", + " ContextIsolationStrategy,\n", + " IsolationConfig,\n", + ")\n", + "\n", + "config = IsolationConfig(\n", + " default_model=\"gpt-4.1\",\n", + " include_general_purpose_agent=True,\n", + " excluded_state_keys=(\"messages\", \"todos\", \"structured_response\"),\n", + ")\n", + "\n", + "print(f\"기본 모델: {config.default_model}\")\n", + "print(f\"범용 에이전트 포함: {config.include_general_purpose_agent}\")\n", + "print(f\"제외 상태 키: {config.excluded_state_keys}\")" + ] + }, + { + "cell_type": "markdown", + "id": "strategy5", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 전략 5: Context Caching\n", + "\n", + "Anthropic Prompt Caching을 활용하여 시스템 프롬프트와 반복 컨텍스트를 캐싱합니다.\n", + "\n", + "### 이점\n", + "- API 호출 비용 절감\n", + "- 응답 속도 향상\n", + "- 동일 세션 내 반복 호출 최적화\n", + "\n", + "### 캐싱 조건\n", + "- 최소 1,024 토큰 이상\n", + "- `cache_control: {\"type\": \"ephemeral\"}` 마커 추가" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caching_demo", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent.context_strategies.caching import (\n", + " ContextCachingStrategy,\n", + " CachingConfig,\n", + ")\n", + "\n", + "config = CachingConfig(\n", + " min_cacheable_tokens=1024,\n", + " cache_control_type=\"ephemeral\",\n", + " enable_for_system_prompt=True,\n", + " enable_for_tools=True,\n", + ")\n", + "\n", + "print(f\"최소 캐싱 토큰: {config.min_cacheable_tokens:,}\")\n", + "print(f\"캐시 컨트롤 타입: {config.cache_control_type}\")\n", + "print(f\"시스템 프롬프트 캐싱: {config.enable_for_system_prompt}\")\n", + "print(f\"도구 캐싱: {config.enable_for_tools}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "caching_test", + "metadata": {}, + "outputs": [], + "source": [ + "strategy = ContextCachingStrategy(config=config)\n", + "\n", + "short_content = \"짧은 시스템 프롬프트\"\n", + "long_content = \"긴 시스템 프롬프트 \" * 500\n", + "\n", + "print(f\"짧은 콘텐츠: {len(short_content)} 자 → 캐싱 대상: {strategy._should_cache(short_content)}\")\n", + "print(f\"긴 콘텐츠: {len(long_content):,} 자 → 캐싱 대상: {strategy._should_cache(long_content)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "integration", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 통합 에이전트 실행\n", + "\n", + "5가지 전략이 모두 적용된 에이전트를 실행합니다." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "agent_create", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent import create_context_aware_agent\n", + "\n", + "agent = create_context_aware_agent(\n", + " model_name=\"gpt-4.1\",\n", + " enable_offloading=True,\n", + " enable_reduction=True,\n", + " enable_caching=True,\n", + " offloading_token_limit=20000,\n", + " reduction_threshold=0.85,\n", + ")\n", + "\n", + "print(f\"에이전트 타입: {type(agent).__name__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "comparison_intro", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 전략 활성화/비활성화 비교 실험\n", + "\n", + "각 전략을 활성화/비활성화했을 때의 차이점을 실험합니다.\n", + "\n", + "### 실험 설계\n", + "\n", + "| 실험 | Offloading | Reduction | Caching | 목적 |\n", + "|------|------------|-----------|---------|------|\n", + "| 1. 기본 | ❌ | ❌ | ❌ | 베이스라인 |\n", + "| 2. Offloading만 | ✅ | ❌ | ❌ | 대용량 결과 축출 효과 |\n", + "| 3. Reduction만 | ❌ | ✅ | ❌ | 컨텍스트 압축 효과 |\n", + "| 4. 모두 활성화 | ✅ | ✅ | ✅ | 전체 효과 |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "comparison_setup", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent.context_strategies.offloading import (\n", + " ContextOffloadingStrategy, OffloadingConfig\n", + ")\n", + "from context_engineering_research_agent.context_strategies.reduction import (\n", + " ContextReductionStrategy, ReductionConfig\n", + ")\n", + "from langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n", + "\n", + "print(\"=\" * 60)\n", + "print(\"전략 비교를 위한 테스트 데이터 생성\")\n", + "print(\"=\" * 60)" + ] + }, + { + "cell_type": "markdown", + "id": "exp1_offloading", + "metadata": {}, + "source": [ + "### 실험 1: Offloading 전략 효과\n", + "\n", + "대용량 도구 결과가 있을 때 Offloading 활성화/비활성화 비교" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "exp1_code", + "metadata": {}, + "outputs": [], + "source": [ + "small_result = \"검색 결과: 항목 1, 항목 2, 항목 3\"\n", + "large_result = \"\\n\".join([f\"검색 결과 {i}: \" + \"상세 내용 \" * 100 for i in range(500)])\n", + "\n", + "print(f\"작은 결과 크기: {len(small_result):,} 자\")\n", + "print(f\"대용량 결과 크기: {len(large_result):,} 자\")\n", + "print()\n", + "\n", + "offloading_disabled = ContextOffloadingStrategy(\n", + " config=OffloadingConfig(token_limit_before_evict=999999999)\n", + ")\n", + "offloading_enabled = ContextOffloadingStrategy(\n", + " config=OffloadingConfig(token_limit_before_evict=20000)\n", + ")\n", + "\n", + "print(\"[Offloading 비활성화 시]\")\n", + "print(f\" 작은 결과 축출: {offloading_disabled._should_offload(small_result)}\")\n", + "print(f\" 대용량 결과 축출: {offloading_disabled._should_offload(large_result)}\")\n", + "print(f\" → 대용량 결과가 컨텍스트에 그대로 포함됨\")\n", + "print()\n", + "\n", + "print(\"[Offloading 활성화 시]\")\n", + "print(f\" 작은 결과 축출: {offloading_enabled._should_offload(small_result)}\")\n", + "print(f\" 대용량 결과 축출: {offloading_enabled._should_offload(large_result)}\")\n", + "print(f\" → 대용량 결과는 파일로 저장, 미리보기만 컨텍스트에 포함\")\n", + "print()\n", + "\n", + "preview = offloading_enabled._create_preview(large_result)\n", + "print(f\"미리보기 크기: {len(preview):,} 자 (원본의 {len(preview)/len(large_result)*100:.1f}%)\")" + ] + }, + { + "cell_type": "markdown", + "id": "exp2_reduction", + "metadata": {}, + "source": [ + "### 실험 2: Reduction 전략 효과\n", + "\n", + "긴 대화에서 Compaction 적용 전/후 비교" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "exp2_code", + "metadata": {}, + "outputs": [], + "source": [ + "messages_with_tools = []\n", + "for i in range(30):\n", + " messages_with_tools.append(HumanMessage(content=f\"질문 {i}: \" + \"내용 \" * 50))\n", + " ai_msg = AIMessage(\n", + " content=f\"답변 {i}: \" + \"응답 \" * 50,\n", + " tool_calls=[{'id': f'call_{i}', 'name': 'search', 'args': {'q': 'test'}}] if i < 25 else []\n", + " )\n", + " messages_with_tools.append(ai_msg)\n", + " if i < 25:\n", + " messages_with_tools.append(ToolMessage(content=f\"도구 결과 {i}: \" + \"결과 \" * 30, tool_call_id=f'call_{i}'))\n", + "\n", + "reduction = ContextReductionStrategy(\n", + " config=ReductionConfig(compaction_age_threshold=10)\n", + ")\n", + "\n", + "original_tokens = reduction._estimate_tokens(messages_with_tools)\n", + "print(f\"[Reduction 비활성화 시]\")\n", + "print(f\" 메시지 수: {len(messages_with_tools)}\")\n", + "print(f\" 추정 토큰: {original_tokens:,}\")\n", + "print(f\" → 모든 도구 호출/결과가 컨텍스트에 유지됨\")\n", + "print()\n", + "\n", + "compacted, result = reduction.apply_compaction(messages_with_tools)\n", + "compacted_tokens = reduction._estimate_tokens(compacted)\n", + "\n", + "print(f\"[Reduction 활성화 시 - Compaction]\")\n", + "print(f\" 메시지 수: {len(messages_with_tools)} → {len(compacted)}\")\n", + "print(f\" 추정 토큰: {original_tokens:,} → {compacted_tokens:,}\")\n", + "print(f\" 절약된 토큰: {result.estimated_tokens_saved:,} ({result.estimated_tokens_saved/original_tokens*100:.1f}%)\")\n", + "print(f\" → 오래된 도구 호출/결과가 제거되어 컨텍스트 효율화\")" + ] + }, + { + "cell_type": "markdown", + "id": "exp3_combined", + "metadata": {}, + "source": [ + "### 실험 3: 전략 조합 효과 시뮬레이션\n", + "\n", + "모든 전략을 함께 적용했을 때의 시너지 효과" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "exp3_code", + "metadata": {}, + "outputs": [], + "source": [ + "print(\"=\" * 60)\n", + "print(\"시나리오: 복잡한 연구 작업 수행\")\n", + "print(\"=\" * 60)\n", + "print()\n", + "\n", + "scenario = {\n", + " \"대화 턴 수\": 50,\n", + " \"도구 호출 수\": 40,\n", + " \"대용량 결과 수\": 5,\n", + " \"평균 결과 크기\": \"100k 자\",\n", + "}\n", + "\n", + "print(\"[시나리오 설정]\")\n", + "for k, v in scenario.items():\n", + " print(f\" {k}: {v}\")\n", + "print()\n", + "\n", + "baseline_context = 50 * 500 + 40 * 300 + 5 * 100000\n", + "print(\"[모든 전략 비활성화 시]\")\n", + "print(f\" 예상 컨텍스트 크기: {baseline_context:,} 자 (~{baseline_context//4:,} 토큰)\")\n", + "print(f\" 문제: 컨텍스트 윈도우 초과 가능성 높음\")\n", + "print()\n", + "\n", + "with_offloading = 50 * 500 + 40 * 300 + 5 * 1000\n", + "print(\"[Offloading만 활성화 시]\")\n", + "print(f\" 예상 컨텍스트 크기: {with_offloading:,} 자 (~{with_offloading//4:,} 토큰)\")\n", + "print(f\" 절약: {(baseline_context - with_offloading):,} 자 ({(baseline_context - with_offloading)/baseline_context*100:.1f}%)\")\n", + "print()\n", + "\n", + "with_reduction = with_offloading * 0.6\n", + "print(\"[Offloading + Reduction 활성화 시]\")\n", + "print(f\" 예상 컨텍스트 크기: {int(with_reduction):,} 자 (~{int(with_reduction)//4:,} 토큰)\")\n", + "print(f\" 총 절약: {int(baseline_context - with_reduction):,} 자 ({(baseline_context - with_reduction)/baseline_context*100:.1f}%)\")\n", + "print()\n", + "\n", + "print(\"[+ Caching 활성화 시 추가 효과]\")\n", + "print(f\" 시스템 프롬프트 캐싱으로 반복 호출 비용 90% 절감\")\n", + "print(f\" 응답 속도 향상\")" + ] + }, + { + "cell_type": "markdown", + "id": "exp4_live", + "metadata": {}, + "source": [ + "### 실험 4: 실제 에이전트 실행 비교\n", + "\n", + "실제 에이전트를 다른 설정으로 생성하여 비교합니다." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "exp4_code", + "metadata": {}, + "outputs": [], + "source": [ + "from context_engineering_research_agent import create_context_aware_agent\n", + "\n", + "print(\"에이전트 생성 비교\")\n", + "print(\"=\" * 60)\n", + "\n", + "configs = [\n", + " {\"name\": \"기본 (모두 비활성화)\", \"offloading\": False, \"reduction\": False, \"caching\": False},\n", + " {\"name\": \"Offloading만\", \"offloading\": True, \"reduction\": False, \"caching\": False},\n", + " {\"name\": \"Reduction만\", \"offloading\": False, \"reduction\": True, \"caching\": False},\n", + " {\"name\": \"모두 활성화\", \"offloading\": True, \"reduction\": True, \"caching\": True},\n", + "]\n", + "\n", + "for cfg in configs:\n", + " agent = create_context_aware_agent(\n", + " model_name=\"gpt-4.1\",\n", + " enable_offloading=cfg[\"offloading\"],\n", + " enable_reduction=cfg[\"reduction\"],\n", + " enable_caching=cfg[\"caching\"],\n", + " )\n", + " print(f\"\\n[{cfg['name']}]\")\n", + " print(f\" Offloading: {'✅' if cfg['offloading'] else '❌'}\")\n", + " print(f\" Reduction: {'✅' if cfg['reduction'] else '❌'}\")\n", + " print(f\" Caching: {'✅' if cfg['caching'] else '❌'}\")\n", + " print(f\" 에이전트 타입: {type(agent).__name__}\")\n", + "\n", + "print(\"\\n\" + \"=\" * 60)\n", + "print(\"모든 에이전트가 성공적으로 생성되었습니다.\")" + ] + }, + { + "cell_type": "markdown", + "id": "exp5_recommendation", + "metadata": {}, + "source": [ + "### 권장 설정\n", + "\n", + "| 사용 사례 | Offloading | Reduction | Caching | 이유 |\n", + "|----------|------------|-----------|---------|------|\n", + "| **짧은 대화** | ❌ | ❌ | ✅ | 오버헤드 최소화 |\n", + "| **일반 작업** | ✅ | ❌ | ✅ | 대용량 결과 대비 |\n", + "| **장시간 연구** | ✅ | ✅ | ✅ | 모든 최적화 활용 |\n", + "| **디버깅** | ❌ | ❌ | ❌ | 전체 컨텍스트 확인 |" + ] + }, + { + "cell_type": "markdown", + "id": "summary", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## 요약\n", + "\n", + "### Context Engineering 5가지 전략 요약\n", + "\n", + "| 전략 | 트리거 조건 | 효과 |\n", + "|------|------------|------|\n", + "| **Offloading** | 20k 토큰 초과 | 파일로 축출 |\n", + "| **Reduction** | 85% 사용량 초과 | Compaction/Summarization |\n", + "| **Retrieval** | 파일 접근 필요 | grep/glob 검색 |\n", + "| **Isolation** | 복잡한 작업 | SubAgent 위임 |\n", + "| **Caching** | 1k+ 토큰 시스템 프롬프트 | Prompt Caching |\n", + "\n", + "### 핵심 인사이트\n", + "\n", + "1. **파일시스템 = 외부 메모리**: 컨텍스트 윈도우는 제한되어 있지만, 파일시스템은 무한\n", + "2. **점진적 공개**: 모든 정보를 한 번에 로드하지 않고 필요할 때만 로드\n", + "3. **격리된 실행**: SubAgent로 컨텍스트 오염 방지\n", + "4. **자동화된 관리**: 에이전트가 직접 컨텍스트를 관리하도록 미들웨어 설계" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.12.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/context-engineering-more-deep_research_agent/__init__.py b/context_engineering_research_agent/__init__.py similarity index 90% rename from context-engineering-more-deep_research_agent/__init__.py rename to context_engineering_research_agent/__init__.py index 43b0c8a..de6790c 100644 --- a/context-engineering-more-deep_research_agent/__init__.py +++ b/context_engineering_research_agent/__init__.py @@ -37,7 +37,7 @@ ## 모듈 구조 ``` -context-engineering-more-deep_research_agent/ +context_engineering_research_agent/ ├── __init__.py # 이 파일 ├── agent.py # 메인 에이전트 (5가지 전략 통합) ├── prompts.py # 시스템 프롬프트 @@ -65,9 +65,10 @@ context-engineering-more-deep_research_agent/ ## 사용 예시 ```python -from context_engineering_research_agent import agent +from context_engineering_research_agent import get_agent -# 에이전트 실행 +# 에이전트 실행 (API key 필요) +agent = get_agent() result = agent.invoke({ "messages": [{"role": "user", "content": "Context Engineering 전략 연구"}] }) @@ -80,16 +81,14 @@ result = agent.invoke({ - LangGraph: https://docs.langchain.com/oss/python/langgraph/overview """ -# 버전 정보 __version__ = "0.1.0" __author__ = "Context Engineering Research Team" -# 주요 컴포넌트 export -from context_engineering_more_deep_research_agent.agent import ( - agent, +from context_engineering_research_agent.agent import ( create_context_aware_agent, + get_agent, ) -from context_engineering_more_deep_research_agent.context_strategies import ( +from context_engineering_research_agent.context_strategies import ( ContextCachingStrategy, ContextIsolationStrategy, ContextOffloadingStrategy, @@ -98,10 +97,8 @@ from context_engineering_more_deep_research_agent.context_strategies import ( ) __all__ = [ - # 에이전트 - "agent", + "get_agent", "create_context_aware_agent", - # Context Engineering 전략 "ContextOffloadingStrategy", "ContextReductionStrategy", "ContextRetrievalStrategy", diff --git a/context_engineering_research_agent/agent.py b/context_engineering_research_agent/agent.py new file mode 100644 index 0000000..8c52289 --- /dev/null +++ b/context_engineering_research_agent/agent.py @@ -0,0 +1,254 @@ +"""Context Engineering 연구용 메인 에이전트. + +5가지 Context Engineering 전략을 명시적으로 통합한 에이전트입니다. +""" + +from datetime import datetime +from pathlib import Path +from typing import Any + +from deepagents import create_deep_agent +from deepagents.backends import CompositeBackend, FilesystemBackend, StateBackend +from langchain.tools import ToolRuntime +from langchain_core.language_models import BaseChatModel +from langchain_openai import ChatOpenAI + +from context_engineering_research_agent.context_strategies import ( + ContextCachingStrategy, + ContextOffloadingStrategy, + ContextReductionStrategy, + PromptCachingTelemetryMiddleware, + ProviderType, + detect_provider, +) + +CONTEXT_ENGINEERING_SYSTEM_PROMPT = """# Context Engineering 연구 에이전트 + +당신은 Context Engineering 전략을 연구하고 실험하는 에이전트입니다. + +## Context Engineering 5가지 핵심 전략 + +### 1. Context Offloading (컨텍스트 오프로딩) +대용량 도구 결과를 파일시스템으로 축출합니다. +- 도구 결과가 20,000 토큰 초과 시 자동 축출 +- /large_tool_results/ 경로에 저장 +- read_file로 필요할 때 로드 + +### 2. Context Reduction (컨텍스트 축소) +컨텍스트 윈도우 사용량이 임계값 초과 시 압축합니다. +- Compaction: 오래된 도구 호출/결과 제거 +- Summarization: LLM이 대화 요약 (85% 초과 시) + +### 3. Context Retrieval (컨텍스트 검색) +필요한 정보만 선택적으로 로드합니다. +- grep: 텍스트 패턴 검색 +- glob: 파일명 패턴 매칭 +- read_file: 부분 읽기 (offset/limit) + +### 4. Context Isolation (컨텍스트 격리) +SubAgent를 통해 독립된 컨텍스트에서 작업합니다. +- task() 도구로 작업 위임 +- 메인 컨텍스트 오염 방지 +- 복잡한 작업의 격리 처리 + +### 5. Context Caching (컨텍스트 캐싱) +시스템 프롬프트와 반복 컨텍스트를 캐싱합니다. +- Anthropic Prompt Caching 활용 +- API 비용 절감 +- 응답 속도 향상 + +## 연구 워크플로우 + +1. 연구 요청을 분석하고 TODO 목록 작성 +2. 필요시 SubAgent에게 작업 위임 +3. 중간 결과를 파일시스템에 저장 +4. 최종 보고서 작성 (/final_report.md) + +## 중요 원칙 + +- 대용량 결과는 파일로 저장하고 참조 +- 복잡한 작업은 SubAgent에게 위임 +- 진행 상황을 TODO로 추적 +- 인용 형식: [1], [2], [3] +""" + +current_date = datetime.now().strftime("%Y-%m-%d") + +BASE_DIR = Path(__file__).resolve().parent.parent +RESEARCH_WORKSPACE_DIR = BASE_DIR / "research_workspace" + +_cached_agent = None +_cached_model = None + + +def _infer_openrouter_model_name(model: BaseChatModel) -> str | None: + """OpenRouter 모델에서 모델명을 추출합니다. + + Args: + model: LangChain 모델 인스턴스 + + Returns: + OpenRouter 모델명 (예: "anthropic/claude-3-sonnet") 또는 None + """ + if detect_provider(model) != ProviderType.OPENROUTER: + return None + for attr in ("model_name", "model", "model_id"): + name = getattr(model, attr, None) + if isinstance(name, str) and name.strip(): + return name + return None + + +def _get_fs_backend() -> FilesystemBackend: + return FilesystemBackend( + root_dir=RESEARCH_WORKSPACE_DIR, + virtual_mode=True, + max_file_size_mb=20, + ) + + +def _get_backend_factory(): + fs_backend = _get_fs_backend() + + def backend_factory(rt: ToolRuntime) -> CompositeBackend: + return CompositeBackend( + default=StateBackend(rt), + routes={"/": fs_backend}, + ) + + return backend_factory + + +def get_model(): + global _cached_model + if _cached_model is None: + _cached_model = ChatOpenAI(model="gpt-4.1", temperature=0.0) + return _cached_model + + +def get_agent(): + global _cached_agent + if _cached_agent is None: + model = get_model() + backend_factory = _get_backend_factory() + + offloading_strategy = ContextOffloadingStrategy(backend_factory=backend_factory) + reduction_strategy = ContextReductionStrategy(summarization_model=model) + openrouter_model_name = _infer_openrouter_model_name(model) + caching_strategy = ContextCachingStrategy( + model=model, + openrouter_model_name=openrouter_model_name, + ) + telemetry_middleware = PromptCachingTelemetryMiddleware() + + _cached_agent = create_deep_agent( + model=model, + system_prompt=CONTEXT_ENGINEERING_SYSTEM_PROMPT, + backend=backend_factory, + middleware=[ + offloading_strategy, + reduction_strategy, + caching_strategy, + telemetry_middleware, + ], + ) + return _cached_agent + + +def create_context_aware_agent( + model: BaseChatModel | str = "gpt-4.1", + workspace_dir: Path | str | None = None, + enable_offloading: bool = True, + enable_reduction: bool = True, + enable_caching: bool = True, + enable_cache_telemetry: bool = True, + offloading_token_limit: int = 20000, + reduction_threshold: float = 0.85, + openrouter_model_name: str | None = None, +) -> Any: + """Context Engineering 전략이 적용된 에이전트를 생성합니다. + + Multi-Provider 지원: Anthropic, OpenAI, Gemini, OpenRouter 모델 사용 가능. + Provider는 자동 감지되며, Anthropic만 cache_control 마커가 적용됩니다. + + Args: + model: LLM 모델 객체 또는 모델명 (기본: gpt-4.1) + workspace_dir: 작업 디렉토리 + enable_offloading: Context Offloading 활성화 + enable_reduction: Context Reduction 활성화 + enable_caching: Context Caching 활성화 + enable_cache_telemetry: Cache 텔레메트리 수집 활성화 + offloading_token_limit: Offloading 토큰 임계값 + reduction_threshold: Reduction 트리거 임계값 + openrouter_model_name: OpenRouter 모델명 강제 지정 + + Returns: + 구성된 DeepAgent + """ + from context_engineering_research_agent.context_strategies.offloading import ( + OffloadingConfig, + ) + from context_engineering_research_agent.context_strategies.reduction import ( + ReductionConfig, + ) + + if isinstance(model, str): + llm: BaseChatModel = ChatOpenAI(model=model, temperature=0.0) + else: + llm = model + + workspace = Path(workspace_dir) if workspace_dir else RESEARCH_WORKSPACE_DIR + workspace.mkdir(parents=True, exist_ok=True) + + local_fs_backend = FilesystemBackend( + root_dir=workspace, + virtual_mode=True, + max_file_size_mb=20, + ) + + def local_backend_factory(rt: ToolRuntime) -> CompositeBackend: + return CompositeBackend( + default=StateBackend(rt), + routes={"/": local_fs_backend}, + ) + + middlewares = [] + + if enable_offloading: + offload_config = OffloadingConfig( + token_limit_before_evict=offloading_token_limit + ) + middlewares.append( + ContextOffloadingStrategy( + config=offload_config, backend_factory=local_backend_factory + ) + ) + + if enable_reduction: + reduce_config = ReductionConfig(context_threshold=reduction_threshold) + middlewares.append( + ContextReductionStrategy(config=reduce_config, summarization_model=llm) + ) + + if enable_caching: + inferred_openrouter_model_name = ( + openrouter_model_name + if openrouter_model_name is not None + else _infer_openrouter_model_name(llm) + ) + middlewares.append( + ContextCachingStrategy( + model=llm, + openrouter_model_name=inferred_openrouter_model_name, + ) + ) + + if enable_cache_telemetry: + middlewares.append(PromptCachingTelemetryMiddleware()) + + return create_deep_agent( + model=llm, + system_prompt=CONTEXT_ENGINEERING_SYSTEM_PROMPT, + backend=local_backend_factory, + middleware=middlewares, + ) diff --git a/context_engineering_research_agent/backends/__init__.py b/context_engineering_research_agent/backends/__init__.py new file mode 100644 index 0000000..1e76e34 --- /dev/null +++ b/context_engineering_research_agent/backends/__init__.py @@ -0,0 +1,24 @@ +"""백엔드 모듈. + +안전한 코드 실행을 위한 백엔드 구현체들입니다. +""" + +from context_engineering_research_agent.backends.docker_sandbox import ( + DockerSandboxBackend, +) +from context_engineering_research_agent.backends.docker_session import ( + DockerSandboxSession, +) +from context_engineering_research_agent.backends.docker_shared import ( + SharedDockerBackend, +) +from context_engineering_research_agent.backends.pyodide_sandbox import ( + PyodideSandboxBackend, +) + +__all__ = [ + "PyodideSandboxBackend", + "SharedDockerBackend", + "DockerSandboxBackend", + "DockerSandboxSession", +] diff --git a/context_engineering_research_agent/backends/docker_sandbox.py b/context_engineering_research_agent/backends/docker_sandbox.py new file mode 100644 index 0000000..dd68ec0 --- /dev/null +++ b/context_engineering_research_agent/backends/docker_sandbox.py @@ -0,0 +1,211 @@ +"""Docker 샌드박스 백엔드 구현. + +BaseSandbox를 상속하여 Docker 컨테이너 내에서 파일 작업 및 코드 실행을 수행합니다. +""" + +from __future__ import annotations + +import asyncio +import io +import posixpath +import tarfile +import time +from typing import Any + +from deepagents.backends.protocol import ( + ExecuteResponse, + FileDownloadResponse, + FileOperationError, + FileUploadResponse, +) +from deepagents.backends.sandbox import BaseSandbox + +from context_engineering_research_agent.backends.workspace_protocol import ( + WORKSPACE_ROOT, +) + + +class DockerSandboxBackend(BaseSandbox): + """Docker 컨테이너 기반 샌드박스 백엔드. + + 실행과 파일 작업을 모두 컨테이너 내부에서 수행하여 + SubAgent 간 공유 작업공간을 제공합니다. + """ + + def __init__( + self, + container_id: str, + *, + workspace_root: str = WORKSPACE_ROOT, + docker_client: Any | None = None, + ) -> None: + self._container_id = container_id + self._workspace_root = workspace_root + self._docker_client = docker_client + + @property + def id(self) -> str: + """컨테이너 식별자를 반환합니다.""" + return self._container_id + + def _get_docker_client(self) -> Any: + if self._docker_client is None: + try: + import docker + except ImportError as exc: + raise RuntimeError( + "docker 패키지가 설치되지 않았습니다: pip install docker" + ) from exc + try: + self._docker_client = docker.from_env() + except Exception as exc: + docker_exception = getattr( + getattr(docker, "errors", None), "DockerException", None + ) + if docker_exception and isinstance(exc, docker_exception): + raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc + raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc + return self._docker_client + + def _get_container(self) -> Any: + client = self._get_docker_client() + try: + return client.containers.get(self._container_id) + except Exception as exc: + raise RuntimeError(f"컨테이너 조회 실패: {exc}") from exc + + def _resolve_path(self, path: str) -> str: + if path.startswith("/"): + return path + return posixpath.join(self._workspace_root, path) + + def _ensure_parent_dir(self, parent_dir: str) -> None: + if not parent_dir: + return + result = self.execute(f"mkdir -p {parent_dir}") + if result.exit_code not in (0, None): + raise RuntimeError(result.output or "워크스페이스 디렉토리 생성 실패") + + def _truncate_output(self, output: str) -> tuple[str, bool]: + if len(output) <= 100000: + return output, False + return output[:100000] + "\n[출력이 잘렸습니다...]", True + + def execute(self, command: str) -> ExecuteResponse: + """컨테이너 내부에서 명령을 실행합니다. + + shell을 통해 실행하므로 리다이렉션(>), 파이프(|), &&, || 등을 사용할 수 있습니다. + """ + try: + container = self._get_container() + exec_result = container.exec_run( + ["sh", "-c", command], + workdir=self._workspace_root, + ) + raw_output = exec_result.output + if isinstance(raw_output, bytes): + output = raw_output.decode("utf-8", errors="replace") + else: + output = str(raw_output) + output, truncated = self._truncate_output(output) + return ExecuteResponse( + output=output, + exit_code=exec_result.exit_code, + truncated=truncated, + ) + except Exception as exc: + return ExecuteResponse(output=f"Docker 실행 오류: {exc}", exit_code=1) + + async def aexecute(self, command: str) -> ExecuteResponse: + """비동기 실행 래퍼.""" + return await asyncio.to_thread(self.execute, command) + + def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]: + """파일을 컨테이너로 업로드합니다.""" + responses: list[FileUploadResponse] = [] + for path, content in files: + try: + full_path = self._resolve_path(path) + parent_dir = posixpath.dirname(full_path) + file_name = posixpath.basename(full_path) + if not file_name: + raise ValueError("업로드 경로에 파일명이 필요합니다") + self._ensure_parent_dir(parent_dir) + + tar_stream = io.BytesIO() + with tarfile.open(fileobj=tar_stream, mode="w") as tar: + info = tarfile.TarInfo(name=file_name) + info.size = len(content) + info.mtime = time.time() + tar.addfile(info, io.BytesIO(content)) + tar_stream.seek(0) + + container = self._get_container() + container.put_archive(parent_dir or "/", tar_stream.getvalue()) + responses.append(FileUploadResponse(path=path)) + except Exception as exc: + responses.append( + FileUploadResponse(path=path, error=self._map_upload_error(exc)) + ) + return responses + + async def aupload_files( + self, files: list[tuple[str, bytes]] + ) -> list[FileUploadResponse]: + """비동기 업로드 래퍼.""" + return await asyncio.to_thread(self.upload_files, files) + + def download_files(self, paths: list[str]) -> list[FileDownloadResponse]: + """컨테이너에서 파일을 다운로드합니다.""" + responses: list[FileDownloadResponse] = [] + for path in paths: + try: + full_path = self._resolve_path(path) + container = self._get_container() + stream, _ = container.get_archive(full_path) + raw = b"".join(chunk for chunk in stream) + content = self._extract_tar_content(raw) + if content is None: + responses.append( + FileDownloadResponse(path=path, error="is_directory") + ) + continue + responses.append(FileDownloadResponse(path=path, content=content)) + except Exception as exc: + responses.append( + FileDownloadResponse(path=path, error=self._map_download_error(exc)) + ) + return responses + + async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]: + """비동기 다운로드 래퍼.""" + return await asyncio.to_thread(self.download_files, paths) + + def _extract_tar_content(self, raw: bytes) -> bytes | None: + with tarfile.open(fileobj=io.BytesIO(raw)) as tar: + members = [member for member in tar.getmembers() if member.isfile()] + if not members: + return None + target = members[0] + file_obj = tar.extractfile(target) + if file_obj is None: + return None + return file_obj.read() + + def _map_upload_error(self, exc: Exception) -> FileOperationError: + message = str(exc).lower() + if "permission" in message: + return "permission_denied" + if "is a directory" in message: + return "is_directory" + return "invalid_path" + + def _map_download_error(self, exc: Exception) -> FileOperationError: + message = str(exc).lower() + if "permission" in message: + return "permission_denied" + if "no such file" in message or "not found" in message: + return "file_not_found" + if "is a directory" in message: + return "is_directory" + return "invalid_path" diff --git a/context_engineering_research_agent/backends/docker_session.py b/context_engineering_research_agent/backends/docker_session.py new file mode 100644 index 0000000..fe0ee82 --- /dev/null +++ b/context_engineering_research_agent/backends/docker_session.py @@ -0,0 +1,118 @@ +"""Docker 샌드박스 세션 관리 모듈. + +요청 단위로 Docker 컨테이너를 생성하고, 모든 subagent가 동일한 /workspace를 공유합니다. +""" + +from __future__ import annotations + +import asyncio +from typing import Any + +from context_engineering_research_agent.backends.docker_sandbox import ( + DockerSandboxBackend, +) +from context_engineering_research_agent.backends.workspace_protocol import ( + META_DIR, + SHARED_DIR, + WORKSPACE_ROOT, +) + + +class DockerSandboxSession: + """Docker 컨테이너 라이프사이클을 관리하는 세션. + + 컨테이너는 요청 단위로 생성되며, SubAgent는 동일한 컨테이너를 공유합니다. + """ + + def __init__( + self, + image: str = "python:3.11-slim", + workspace_root: str = WORKSPACE_ROOT, + ) -> None: + self.image = image + self.workspace_root = workspace_root + self._docker_client: Any | None = None + self._container: Any | None = None + self._backend: DockerSandboxBackend | None = None + + def _get_docker_client(self) -> Any: + if self._docker_client is None: + try: + import docker + except ImportError as exc: + raise RuntimeError( + "docker 패키지가 설치되지 않았습니다: pip install docker" + ) from exc + try: + self._docker_client = docker.from_env() + except Exception as exc: + docker_exception = getattr( + getattr(docker, "errors", None), "DockerException", None + ) + if docker_exception and isinstance(exc, docker_exception): + raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc + raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc + return self._docker_client + + async def start(self) -> None: + """보안 옵션이 적용된 컨테이너를 생성합니다.""" + if self._container is not None: + return + + client = self._get_docker_client() + try: + self._container = await asyncio.to_thread( + client.containers.run, + self.image, + command="tail -f /dev/null", + detach=True, + network_mode="none", + cap_drop=["ALL"], + security_opt=["no-new-privileges=true"], + mem_limit="512m", + pids_limit=128, + working_dir=self.workspace_root, + ) + await asyncio.to_thread( + self._container.exec_run, + f"mkdir -p {self.workspace_root}/{META_DIR} {self.workspace_root}/{SHARED_DIR}", + ) + except Exception as exc: + raise RuntimeError(f"Docker 컨테이너 생성 실패: {exc}") from exc + + async def stop(self) -> None: + """컨테이너를 중지하고 제거합니다.""" + container = self._container + self._container = None + self._backend = None + if container is None: + return + + try: + await asyncio.to_thread(container.stop) + except Exception: + pass + + try: + await asyncio.to_thread(container.remove) + except Exception: + pass + + def get_backend(self) -> DockerSandboxBackend: + """현재 세션용 백엔드를 반환합니다.""" + if self._container is None: + raise RuntimeError("DockerSandboxSession이 시작되지 않았습니다") + if self._backend is None: + self._backend = DockerSandboxBackend( + container_id=self._container.id, + workspace_root=self.workspace_root, + docker_client=self._get_docker_client(), + ) + return self._backend + + async def __aenter__(self) -> DockerSandboxSession: + await self.start() + return self + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + await self.stop() diff --git a/context_engineering_research_agent/backends/docker_shared.py b/context_engineering_research_agent/backends/docker_shared.py new file mode 100644 index 0000000..e8a0524 --- /dev/null +++ b/context_engineering_research_agent/backends/docker_shared.py @@ -0,0 +1,264 @@ +"""Docker 공유 작업공간 백엔드. + +여러 SubAgent가 동일한 Docker 컨테이너 작업공간을 공유하는 백엔드입니다. + +## 설계 배경 + +DeepAgents의 기본 구조에서 SubAgent들은 독립된 컨텍스트를 가지지만, +파일시스템을 공유해야 하는 경우가 있습니다: +- 연구 결과물 공유 +- 중간 생성물 활용 +- 협업 워크플로우 + +## 아키텍처 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Main Agent (Orchestrator) │ +│ │ +│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │ +│ │SubAgent A │ │SubAgent B │ │SubAgent C │ │ +│ │(Research) │ │(Analysis) │ │(Synthesis)│ │ +│ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ │ +│ │ │ │ │ +│ └──────────────┼──────────────┘ │ +│ ▼ │ +│ SharedDockerBackend │ +│ │ │ +└───────────────────────┼─────────────────────────────────────┘ + │ + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Docker Container │ +│ ┌─────────────────────────────────────────────────────────│ +│ │ /workspace (공유 작업 디렉토리) │ +│ │ ├── research/ (SubAgent A 출력) │ +│ │ ├── analysis/ (SubAgent B 출력) │ +│ │ └── synthesis/ (SubAgent C 출력) │ +│ └─────────────────────────────────────────────────────────│ +└─────────────────────────────────────────────────────────────┘ +``` + +## 보안 고려사항 + +1. **컨테이너 격리**: 호스트 시스템과 격리 +2. **볼륨 마운트**: 필요한 디렉토리만 마운트 +3. **네트워크 정책**: 필요시 네트워크 격리 +4. **리소스 제한**: CPU/메모리 제한 설정 +""" + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class DockerConfig: + image: str = "python:3.11-slim" + workspace_path: str = "/workspace" + memory_limit: str = "2g" + cpu_limit: float = 2.0 + network_mode: str = "none" + auto_remove: bool = True + timeout_seconds: int = 300 + + +@dataclass +class ExecuteResponse: + output: str + exit_code: int | None = None + truncated: bool = False + error: str | None = None + + +@dataclass +class WriteResult: + path: str + error: str | None = None + files_update: dict[str, Any] | None = None + + +@dataclass +class EditResult: + path: str + occurrences: int = 0 + error: str | None = None + files_update: dict[str, Any] | None = None + + +class SharedDockerBackend: + """여러 SubAgent가 공유하는 Docker 작업공간 백엔드. + + Args: + config: Docker 설정 + container_id: 기존 컨테이너 ID (재사용 시) + """ + + def __init__( + self, + config: DockerConfig | None = None, + container_id: str | None = None, + ) -> None: + self.config = config or DockerConfig() + self._container_id = container_id + self._docker_client: Any = None + + def _get_docker_client(self) -> Any: + if self._docker_client is None: + try: + import docker + + self._docker_client = docker.from_env() + except ImportError: + raise RuntimeError( + "docker 패키지가 설치되지 않았습니다: pip install docker" + ) + return self._docker_client + + def _ensure_container(self) -> str: + if self._container_id: + return self._container_id + + client = self._get_docker_client() + container = client.containers.run( + self.config.image, + command="tail -f /dev/null", + detach=True, + mem_limit=self.config.memory_limit, + nano_cpus=int(self.config.cpu_limit * 1e9), + network_mode=self.config.network_mode, + auto_remove=self.config.auto_remove, + ) + self._container_id = container.id + return container.id + + def execute(self, command: str) -> ExecuteResponse: + try: + container_id = self._ensure_container() + client = self._get_docker_client() + container = client.containers.get(container_id) + + exec_result = container.exec_run( + command, + workdir=self.config.workspace_path, + ) + + output = exec_result.output.decode("utf-8", errors="replace") + truncated = len(output) > 100000 + if truncated: + output = output[:100000] + "\n[출력이 잘렸습니다...]" + + return ExecuteResponse( + output=output, + exit_code=exec_result.exit_code, + truncated=truncated, + ) + except Exception as e: + return ExecuteResponse( + output="", + exit_code=1, + error=str(e), + ) + + async def aexecute(self, command: str) -> ExecuteResponse: + return self.execute(command) + + def read(self, path: str, offset: int = 0, limit: int = 500) -> str: + full_path = f"{self.config.workspace_path}{path}" + result = self.execute(f"sed -n '{offset + 1},{offset + limit}p' {full_path}") + + if result.error: + return f"파일 읽기 오류: {result.error}" + + return result.output + + async def aread(self, path: str, offset: int = 0, limit: int = 500) -> str: + return self.read(path, offset, limit) + + def write(self, path: str, content: str) -> WriteResult: + full_path = f"{self.config.workspace_path}{path}" + + dir_path = "/".join(full_path.split("/")[:-1]) + self.execute(f"mkdir -p {dir_path}") + + escaped_content = content.replace("'", "'\\''") + result = self.execute(f"echo '{escaped_content}' > {full_path}") + + if result.exit_code != 0: + return WriteResult(path=path, error=result.output or result.error) + + return WriteResult(path=path) + + async def awrite(self, path: str, content: str) -> WriteResult: + return self.write(path, content) + + def ls_info(self, path: str) -> list[dict[str, Any]]: + full_path = f"{self.config.workspace_path}{path}" + result = self.execute(f"ls -la {full_path}") + + if result.error: + return [] + + files = [] + for line in result.output.strip().split("\n")[1:]: + parts = line.split() + if len(parts) >= 9: + name = " ".join(parts[8:]) + files.append( + { + "path": f"{path}/{name}".replace("//", "/"), + "is_dir": line.startswith("d"), + } + ) + + return files + + async def als_info(self, path: str) -> list[dict[str, Any]]: + return self.ls_info(path) + + def cleanup(self) -> None: + if self._container_id and self._docker_client: + try: + container = self._docker_client.containers.get(self._container_id) + container.stop() + except Exception: + pass + self._container_id = None + + def __enter__(self) -> "SharedDockerBackend": + return self + + def __exit__(self, *args: Any) -> None: + self.cleanup() + + +SHARED_WORKSPACE_DESIGN_DOC = """ +# 공유 작업공간 설계 + +## 문제점 + +DeepAgents의 SubAgent들은 독립된 컨텍스트를 가지지만, +연구 워크플로우에서는 파일시스템 공유가 필요한 경우가 많습니다. + +예시: +1. Research Agent가 수집한 데이터를 Analysis Agent가 처리 +2. 여러 Agent가 생성한 결과물을 Synthesis Agent가 통합 +3. 중간 체크포인트를 다른 Agent가 이어서 작업 + +## 해결책: SharedDockerBackend + +1. **단일 컨테이너**: 모든 SubAgent가 동일한 Docker 컨테이너 사용 +2. **격리된 디렉토리**: 각 SubAgent는 자신의 디렉토리에서 작업 +3. **공유 영역**: /workspace/shared 같은 공유 디렉토리 운영 + +## 장점 + +- 파일 복사 오버헤드 없음 +- 실시간 결과 공유 +- 일관된 실행 환경 + +## 단점 + +- 컨테이너 수명 관리 필요 +- 동시 접근 충돌 가능성 +- 보안 경계가 SubAgent 간에는 없음 +""" diff --git a/context_engineering_research_agent/backends/pyodide_sandbox.py b/context_engineering_research_agent/backends/pyodide_sandbox.py new file mode 100644 index 0000000..7e66443 --- /dev/null +++ b/context_engineering_research_agent/backends/pyodide_sandbox.py @@ -0,0 +1,189 @@ +"""Pyodide 기반 WASM 샌드박스 백엔드. + +WebAssembly 환경에서 Python 코드를 안전하게 실행하기 위한 백엔드입니다. + +## Pyodide란? + +Pyodide는 CPython을 WebAssembly로 컴파일한 프로젝트입니다. +브라우저나 Node.js 환경에서 Python 코드를 실행할 수 있습니다. + +## 보안 모델 + +WASM 샌드박스는 다음과 같은 격리를 제공합니다: +- 호스트 파일시스템 접근 불가 +- 네트워크 접근 제한 (JavaScript API 통해서만) +- 메모리 격리 + +## 한계 + +1. 네이티브 C 확장 라이브러리 제한적 지원 +2. 성능 오버헤드 (네이티브 대비 ~3-10x 느림) +3. 초기 로딩 시간 (Pyodide 런타임 + 패키지) + +## 권장 사용 사례 + +- 신뢰할 수 없는 사용자 코드 실행 +- 브라우저 기반 Python 환경 +- 격리된 데이터 분석 작업 + +## 사용 예시 (JavaScript 환경) + +```javascript +const pyodide = await loadPyodide(); +await pyodide.loadPackagesFromImports(pythonCode); +const result = await pyodide.runPythonAsync(pythonCode); +``` +""" + +from dataclasses import dataclass, field +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class PyodideRuntime(Protocol): + def runPython(self, code: str) -> str: ... + async def runPythonAsync(self, code: str) -> str: ... + def loadPackagesFromImports(self, code: str) -> None: ... + + +@dataclass +class PyodideConfig: + timeout_seconds: int = 30 + max_memory_mb: int = 512 + allowed_packages: list[str] = field( + default_factory=lambda: [ + "numpy", + "pandas", + "scipy", + "matplotlib", + "scikit-learn", + ] + ) + enable_network: bool = False + + +@dataclass +class ExecuteResponse: + output: str + exit_code: int | None = None + truncated: bool = False + error: str | None = None + + +class PyodideSandboxBackend: + """Pyodide WASM 샌드박스 백엔드. + + Note: 이 클래스는 설계 문서입니다. + 실제 구현은 JavaScript/TypeScript 환경에서 이루어집니다. + + Python에서 직접 Pyodide를 실행하려면 별도의 subprocess나 + JS 런타임(Node.js) 연동이 필요합니다. + """ + + def __init__(self, config: PyodideConfig | None = None) -> None: + self.config = config or PyodideConfig() + self._runtime: PyodideRuntime | None = None + + def execute(self, code: str) -> ExecuteResponse: + """Python 코드를 WASM 샌드박스에서 실행합니다. + + 실제 구현에서는 Node.js subprocess나 + WebWorker를 통해 Pyodide를 실행합니다. + """ + return ExecuteResponse( + output="Pyodide 실행은 JavaScript 환경에서만 지원됩니다.", + exit_code=1, + error="NotImplemented: Python에서 직접 Pyodide 실행 불가", + ) + + async def aexecute(self, code: str) -> ExecuteResponse: + return self.execute(code) + + def get_pyodide_js_code(self, python_code: str) -> str: + """주어진 Python 코드를 실행하는 JavaScript 코드를 생성합니다. + + 이 JavaScript 코드를 브라우저나 Node.js에서 실행하면 + Pyodide 환경에서 Python 코드가 실행됩니다. + """ + escaped_code = python_code.replace("`", "\\`").replace("$", "\\$") + + return f""" +import {{ loadPyodide }} from "pyodide"; + +async function runPythonInSandbox() {{ + const pyodide = await loadPyodide(); + + const pythonCode = `{escaped_code}`; + + await pyodide.loadPackagesFromImports(pythonCode); + + try {{ + const result = await pyodide.runPythonAsync(pythonCode); + return {{ success: true, result }}; + }} catch (error) {{ + return {{ success: false, error: error.message }}; + }} +}} + +runPythonInSandbox().then(console.log); +""" + + +PYODIDE_DESIGN_DOC = """ +# Pyodide 기반 안전한 코드 실행 설계 + +## 아키텍처 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Python Agent │ +│ (LangGraph/DeepAgents) │ +└─────────────────┬───────────────────────────────────────────┘ + │ execute() 호출 + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ PyodideSandboxBackend │ +│ - 코드 전처리 │ +│ - 보안 검증 │ +│ - JS 코드 생성 │ +└─────────────────┬───────────────────────────────────────────┘ + │ subprocess 또는 IPC + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Node.js Runtime │ +│ ┌───────────────────────────────────────────────────────┐ │ +│ │ WebWorker │ │ +│ │ ┌─────────────────────────────────────────────────┐ │ │ +│ │ │ Pyodide (WASM) │ │ │ +│ │ │ - Python 인터프리터 │ │ │ +│ │ │ - 제한된 패키지 │ │ │ +│ │ │ - 격리된 메모리 │ │ │ +│ │ └─────────────────────────────────────────────────┘ │ │ +│ └───────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Docker vs Pyodide 비교 + +| 측면 | Docker | Pyodide (WASM) | +|------|--------|----------------| +| 격리 수준 | 컨테이너 (OS 레벨) | WASM 샌드박스 | +| 시작 시간 | ~1-2초 | ~3-5초 (최초), 이후 빠름 | +| 메모리 | 높음 | 낮음 | +| 패키지 지원 | 완전 | 제한적 | +| 보안 | 높음 | 매우 높음 | +| 호스트 접근 | 마운트 통해 가능 | 불가 | + +## 권장 사용 시나리오 + +### Docker 사용 +- 복잡한 라이브러리 필요 +- 파일 I/O 필요 +- 장시간 실행 작업 + +### Pyodide 사용 +- 간단한 계산/분석 +- 신뢰할 수 없는 코드 +- 브라우저 환경 +- 빠른 피드백 필요 +""" diff --git a/context_engineering_research_agent/backends/workspace_protocol.py b/context_engineering_research_agent/backends/workspace_protocol.py new file mode 100644 index 0000000..5afdb65 --- /dev/null +++ b/context_engineering_research_agent/backends/workspace_protocol.py @@ -0,0 +1,25 @@ +"""워크스페이스 경로 및 파일 통신 프로토콜 유틸리티.""" + +from __future__ import annotations + +import posixpath + +WORKSPACE_ROOT = "/workspace" +META_DIR = "_meta" +SHARED_DIR = "shared" + + +def _sanitize_segment(segment: str) -> str: + return segment.strip().strip("/") + + +def get_subagent_dir(subagent_type: str) -> str: + """SubAgent별 전용 작업 디렉토리를 반환합니다.""" + safe_segment = _sanitize_segment(subagent_type) + return posixpath.join(WORKSPACE_ROOT, safe_segment) + + +def get_result_path(subagent_type: str) -> str: + """SubAgent 결과 파일 경로를 반환합니다.""" + safe_segment = _sanitize_segment(subagent_type) + return posixpath.join(WORKSPACE_ROOT, META_DIR, safe_segment, "result.json") diff --git a/context_engineering_research_agent/context_strategies/__init__.py b/context_engineering_research_agent/context_strategies/__init__.py new file mode 100644 index 0000000..1687df8 --- /dev/null +++ b/context_engineering_research_agent/context_strategies/__init__.py @@ -0,0 +1,47 @@ +"""Context Engineering 전략 모듈. + +DeepAgents에서 구현된 5가지 Context Engineering 핵심 전략을 +명시적으로 분리하고 문서화한 모듈입니다. + +각 전략은 독립적으로 사용하거나 조합하여 사용할 수 있습니다. +""" + +from context_engineering_research_agent.context_strategies.caching import ( + ContextCachingStrategy, + OpenRouterSubProvider, + ProviderType, + detect_openrouter_sub_provider, + detect_provider, + requires_cache_control_marker, +) +from context_engineering_research_agent.context_strategies.caching_telemetry import ( + CacheTelemetry, + PromptCachingTelemetryMiddleware, +) +from context_engineering_research_agent.context_strategies.isolation import ( + ContextIsolationStrategy, +) +from context_engineering_research_agent.context_strategies.offloading import ( + ContextOffloadingStrategy, +) +from context_engineering_research_agent.context_strategies.reduction import ( + ContextReductionStrategy, +) +from context_engineering_research_agent.context_strategies.retrieval import ( + ContextRetrievalStrategy, +) + +__all__ = [ + "ContextOffloadingStrategy", + "ContextReductionStrategy", + "ContextRetrievalStrategy", + "ContextIsolationStrategy", + "ContextCachingStrategy", + "ProviderType", + "OpenRouterSubProvider", + "detect_provider", + "detect_openrouter_sub_provider", + "requires_cache_control_marker", + "CacheTelemetry", + "PromptCachingTelemetryMiddleware", +] diff --git a/context_engineering_research_agent/context_strategies/caching.py b/context_engineering_research_agent/context_strategies/caching.py new file mode 100644 index 0000000..9b225e2 --- /dev/null +++ b/context_engineering_research_agent/context_strategies/caching.py @@ -0,0 +1,408 @@ +"""Context Caching 전략 구현. + +Multi-Provider Prompt Caching 전략을 구현합니다. +각 Provider별 특성에 맞게 캐싱을 적용합니다: + +## Direct Provider Access +- Anthropic: Explicit caching (cache_control 마커 필요), Write 1.25x, Read 0.1x +- OpenAI: Automatic caching (자동), Read 0.5x, 1024+ tokens +- Gemini 2.5: Implicit caching (자동), Read 0.25x, 1028-2048+ tokens +- Gemini 3: Implicit caching (자동), Read 0.1x (90% 할인), 1024-4096+ tokens + +## OpenRouter (Multi-Model Gateway) +OpenRouter는 기반 모델에 따라 다른 caching 방식 적용: +- Anthropic Claude: Explicit (cache_control 필요), Write 1.25x, Read 0.1x +- OpenAI: Automatic (자동), Read 0.5x +- Google Gemini: Implicit + Explicit 지원, Read 0.25x +- DeepSeek: Automatic (자동), Write 1x, Read 0.1x +- Groq (Kimi K2): Automatic (자동), Read 0.25x +- Grok (xAI): Automatic (자동), Read 0.25x + +DeepAgents의 AnthropicPromptCachingMiddleware와 함께 사용 권장. +""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from enum import Enum +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + ModelResponse, +) +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import BaseMessage, SystemMessage + + +class ProviderType(Enum): + """LLM Provider 유형.""" + + ANTHROPIC = "anthropic" + OPENAI = "openai" + GEMINI = "gemini" + GEMINI_3 = "gemini_3" + OPENROUTER = "openrouter" + DEEPSEEK = "deepseek" + GROQ = "groq" + GROK = "grok" + UNKNOWN = "unknown" + + +class OpenRouterSubProvider(Enum): + """OpenRouter를 통해 접근하는 기반 모델 Provider.""" + + ANTHROPIC = "anthropic" + OPENAI = "openai" + GEMINI = "gemini" + DEEPSEEK = "deepseek" + GROQ = "groq" + GROK = "grok" + META_LLAMA = "meta-llama" + MISTRAL = "mistral" + UNKNOWN = "unknown" + + +PROVIDERS_REQUIRING_CACHE_CONTROL = { + ProviderType.ANTHROPIC, + OpenRouterSubProvider.ANTHROPIC, +} + +PROVIDERS_WITH_AUTOMATIC_CACHING = { + ProviderType.OPENAI, + ProviderType.GEMINI, + ProviderType.GEMINI_3, + ProviderType.DEEPSEEK, + ProviderType.GROQ, + ProviderType.GROK, + OpenRouterSubProvider.OPENAI, + OpenRouterSubProvider.DEEPSEEK, + OpenRouterSubProvider.GROQ, + OpenRouterSubProvider.GROK, +} + + +def detect_provider(model: BaseChatModel | None) -> ProviderType: + """모델 객체에서 Provider 유형을 감지합니다.""" + if model is None: + return ProviderType.UNKNOWN + + class_name = model.__class__.__name__.lower() + module_name = model.__class__.__module__.lower() + + if "anthropic" in class_name or "anthropic" in module_name: + return ProviderType.ANTHROPIC + + if "openai" in class_name or "openai" in module_name: + base_url = _get_base_url(model) + if "openrouter" in base_url: + return ProviderType.OPENROUTER + return ProviderType.OPENAI + + if "google" in class_name or "gemini" in class_name or "google" in module_name: + model_name = _get_model_name(model) + if "gemini-3" in model_name or "gemini/3" in model_name: + return ProviderType.GEMINI_3 + return ProviderType.GEMINI + + if "deepseek" in class_name or "deepseek" in module_name: + return ProviderType.DEEPSEEK + + if "groq" in class_name or "groq" in module_name: + return ProviderType.GROQ + + return ProviderType.UNKNOWN + + +def _get_base_url(model: BaseChatModel) -> str: + for attr in ("openai_api_base", "base_url", "api_base"): + if hasattr(model, attr): + url = getattr(model, attr, "") or "" + if url: + return url.lower() + return "" + + +def _get_model_name(model: BaseChatModel) -> str: + for attr in ("model_name", "model", "model_id"): + if hasattr(model, attr): + name = getattr(model, attr, "") or "" + if name: + return name.lower() + return "" + + +def detect_openrouter_sub_provider(model_name: str) -> OpenRouterSubProvider: + """OpenRouter 모델명에서 기반 Provider를 감지합니다. + + OpenRouter 모델명 패턴: "provider/model-name" (예: "anthropic/claude-3-sonnet") + """ + name_lower = model_name.lower() + + if "anthropic" in name_lower or "claude" in name_lower: + return OpenRouterSubProvider.ANTHROPIC + if "openai" in name_lower or "gpt" in name_lower or name_lower.startswith("o1"): + return OpenRouterSubProvider.OPENAI + if "google" in name_lower or "gemini" in name_lower: + return OpenRouterSubProvider.GEMINI + if "deepseek" in name_lower: + return OpenRouterSubProvider.DEEPSEEK + if "groq" in name_lower or "kimi" in name_lower: + return OpenRouterSubProvider.GROQ + if "grok" in name_lower or "xai" in name_lower: + return OpenRouterSubProvider.GROK + if "meta" in name_lower or "llama" in name_lower: + return OpenRouterSubProvider.META_LLAMA + if "mistral" in name_lower: + return OpenRouterSubProvider.MISTRAL + + return OpenRouterSubProvider.UNKNOWN + + +def requires_cache_control_marker( + provider: ProviderType, + sub_provider: OpenRouterSubProvider | None = None, +) -> bool: + """해당 Provider가 cache_control 마커를 필요로 하는지 확인합니다. + + Anthropic (직접 또는 OpenRouter 경유) 만 True 반환. + """ + if provider == ProviderType.ANTHROPIC: + return True + if ( + provider == ProviderType.OPENROUTER + and sub_provider == OpenRouterSubProvider.ANTHROPIC + ): + return True + return False + + +@dataclass +class CachingConfig: + """Context Caching 설정. + + Attributes: + min_cacheable_tokens: 캐싱할 최소 토큰 수 (기본값: 1024) + cache_control_type: 캐시 제어 유형 (기본값: "ephemeral") + enable_for_system_prompt: 시스템 프롬프트 캐싱 활성화 (기본값: True) + enable_for_tools: 도구 정의 캐싱 활성화 (기본값: True) + """ + + min_cacheable_tokens: int = 1024 + cache_control_type: str = "ephemeral" + enable_for_system_prompt: bool = True + enable_for_tools: bool = True + + +@dataclass +class CachingResult: + """Context Caching 결과. + + Attributes: + was_cached: 캐싱이 적용되었는지 여부 + cached_content_type: 캐싱된 컨텐츠 유형 (예: "system_prompt") + estimated_tokens_cached: 캐싱된 추정 토큰 수 + """ + + was_cached: bool + cached_content_type: str | None = None + estimated_tokens_cached: int = 0 + + +class ContextCachingStrategy(AgentMiddleware): + """Multi-Provider Prompt Caching 전략. + + Anthropic (직접 또는 OpenRouter 경유)만 cache_control 마커를 적용하고, + OpenAI/Gemini/DeepSeek/Groq/Grok은 자동 캐싱이므로 pass-through합니다. + """ + + def __init__( + self, + config: CachingConfig | None = None, + model: BaseChatModel | None = None, + openrouter_model_name: str | None = None, + ) -> None: + self.config = config or CachingConfig() + self._model = model + self._provider: ProviderType | None = None + self._sub_provider: OpenRouterSubProvider | None = None + self._openrouter_model_name = openrouter_model_name + + def set_model( + self, + model: BaseChatModel, + openrouter_model_name: str | None = None, + ) -> None: + """런타임에 모델을 설정합니다.""" + self._model = model + self._provider = None + self._sub_provider = None + if openrouter_model_name: + self._openrouter_model_name = openrouter_model_name + + @property + def provider(self) -> ProviderType: + if self._provider is None: + self._provider = detect_provider(self._model) + return self._provider + + @property + def sub_provider(self) -> OpenRouterSubProvider | None: + if self.provider != ProviderType.OPENROUTER: + return None + if self._sub_provider is None and self._openrouter_model_name: + self._sub_provider = detect_openrouter_sub_provider( + self._openrouter_model_name + ) + return self._sub_provider + + @property + def should_apply_cache_markers(self) -> bool: + return requires_cache_control_marker(self.provider, self.sub_provider) + + def _add_cache_control(self, content: Any) -> Any: + if isinstance(content, str): + return { + "type": "text", + "text": content, + "cache_control": {"type": self.config.cache_control_type}, + } + elif isinstance(content, dict) and content.get("type") == "text": + return { + **content, + "cache_control": {"type": self.config.cache_control_type}, + } + elif isinstance(content, list): + if not content: + return content + result = list(content) + last_item = result[-1] + if isinstance(last_item, dict): + result[-1] = { + **last_item, + "cache_control": {"type": self.config.cache_control_type}, + } + return result + return content + + def _process_system_message(self, message: SystemMessage) -> SystemMessage: + cached_content = self._add_cache_control(message.content) + # Ensure cached_content is a list of dicts for SystemMessage compatibility + if isinstance(cached_content, str): + cached_content = [{"type": "text", "text": cached_content}] + elif isinstance(cached_content, dict): + cached_content = [cached_content] + # Type ignore is needed due to complex SystemMessage content type requirements + return SystemMessage(content=cached_content) # type: ignore[arg-type] + + def _estimate_tokens(self, content: Any) -> int: + if isinstance(content, str): + return len(content) // 4 + elif isinstance(content, list): + return sum(self._estimate_tokens(item) for item in content) + elif isinstance(content, dict): + return self._estimate_tokens(content.get("text", "")) + return 0 + + def _should_cache(self, content: Any) -> bool: + estimated_tokens = self._estimate_tokens(content) + return estimated_tokens >= self.config.min_cacheable_tokens + + def apply_caching( + self, + messages: list[BaseMessage], + model: BaseChatModel | None = None, + openrouter_model_name: str | None = None, + ) -> tuple[list[BaseMessage], CachingResult]: + """메시지 리스트에 캐싱을 적용합니다. + + Anthropic (직접 또는 OpenRouter 경유)만 cache_control 마커를 적용합니다. + 다른 provider는 자동 캐싱이므로 메시지를 변형하지 않습니다. + + Args: + messages: 처리할 메시지 리스트 + model: LLM 모델 (Provider 감지용) + openrouter_model_name: OpenRouter 사용 시 모델명 (예: "anthropic/claude-3-sonnet") + + Returns: + 캐싱이 적용된 메시지 리스트와 캐싱 결과 튜플 + """ + if model is not None: + self.set_model(model, openrouter_model_name) + + if not messages: + return messages, CachingResult(was_cached=False) + + if not self.should_apply_cache_markers: + provider_info = self.provider.value + if self.provider == ProviderType.OPENROUTER and self.sub_provider: + provider_info = f"openrouter/{self.sub_provider.value}" + return messages, CachingResult( + was_cached=False, + cached_content_type=f"auto_cached_by_{provider_info}", + ) + + result_messages = list(messages) + cached = False + cached_type = None + tokens_cached = 0 + + for i, msg in enumerate(result_messages): + if isinstance(msg, SystemMessage): + if self.config.enable_for_system_prompt and self._should_cache( + msg.content + ): + result_messages[i] = self._process_system_message(msg) + cached = True + cached_type = "system_prompt" + tokens_cached = self._estimate_tokens(msg.content) + break + + return result_messages, CachingResult( + was_cached=cached, + cached_content_type=cached_type, + estimated_tokens_cached=tokens_cached, + ) + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + """동기 모델 호출을 래핑합니다 (기본 동작). + + Args: + request: 모델 요청 + handler: 다음 핸들러 함수 + + Returns: + 모델 응답 + """ + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + """비동기 모델 호출을 래핑합니다 (기본 동작). + + Args: + request: 모델 요청 + handler: 다음 핸들러 함수 + + Returns: + 모델 응답 + """ + return await handler(request) + + +CACHING_SYSTEM_PROMPT = """## Context Caching + +시스템 프롬프트와 도구 정의가 자동으로 캐싱됩니다. + +이점: +- API 호출 비용 절감 +- 응답 속도 향상 +- 동일 세션 내 반복 호출 최적화 +""" diff --git a/context_engineering_research_agent/context_strategies/caching_telemetry.py b/context_engineering_research_agent/context_strategies/caching_telemetry.py new file mode 100644 index 0000000..e6b9ad4 --- /dev/null +++ b/context_engineering_research_agent/context_strategies/caching_telemetry.py @@ -0,0 +1,244 @@ +"""Prompt Caching Telemetry Middleware. + +모든 Provider의 cache 사용량을 모니터링하고 로깅합니다. +요청 변형 없이 응답의 cache 관련 메타데이터만 수집합니다. + +Provider별 캐시 메타데이터 위치: +- Anthropic: cache_read_input_tokens, cache_creation_input_tokens +- OpenAI: cached_tokens (usage.prompt_tokens_details) +- Gemini 2.5/3: cached_content_token_count +- DeepSeek: cache_hit_tokens, cache_miss_tokens +- OpenRouter: 기반 모델의 메타데이터 형식 따름 +""" + +import logging +from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + ModelResponse, +) + +from context_engineering_research_agent.context_strategies.caching import ( + ProviderType, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class CacheTelemetry: + """Provider별 캐시 사용량 데이터.""" + + provider: ProviderType + cache_read_tokens: int = 0 + cache_write_tokens: int = 0 + total_input_tokens: int = 0 + cache_hit_ratio: float = 0.0 + raw_metadata: dict[str, Any] = field(default_factory=dict) + + +def extract_anthropic_cache_metrics(response: ModelResponse) -> CacheTelemetry: + """Anthropic 응답에서 캐시 메트릭을 추출합니다.""" + usage = getattr(response, "usage_metadata", {}) or {} + response_meta = getattr(response, "response_metadata", {}) or {} + usage_from_meta = response_meta.get("usage", {}) + + cache_read = usage_from_meta.get("cache_read_input_tokens", 0) + cache_creation = usage_from_meta.get("cache_creation_input_tokens", 0) + input_tokens = usage.get("input_tokens", 0) or usage_from_meta.get( + "input_tokens", 0 + ) + + hit_ratio = cache_read / input_tokens if input_tokens > 0 else 0.0 + + return CacheTelemetry( + provider=ProviderType.ANTHROPIC, + cache_read_tokens=cache_read, + cache_write_tokens=cache_creation, + total_input_tokens=input_tokens, + cache_hit_ratio=hit_ratio, + raw_metadata={"usage": usage, "response_metadata": response_meta}, + ) + + +def extract_openai_cache_metrics(response: ModelResponse) -> CacheTelemetry: + """OpenAI 응답에서 캐시 메트릭을 추출합니다.""" + usage = getattr(response, "usage_metadata", {}) or {} + response_meta = getattr(response, "response_metadata", {}) or {} + + token_usage = response_meta.get("token_usage", {}) + prompt_details = token_usage.get("prompt_tokens_details", {}) + cached_tokens = prompt_details.get("cached_tokens", 0) + input_tokens = usage.get("input_tokens", 0) or token_usage.get("prompt_tokens", 0) + + hit_ratio = cached_tokens / input_tokens if input_tokens > 0 else 0.0 + + return CacheTelemetry( + provider=ProviderType.OPENAI, + cache_read_tokens=cached_tokens, + cache_write_tokens=0, + total_input_tokens=input_tokens, + cache_hit_ratio=hit_ratio, + raw_metadata={"usage": usage, "token_usage": token_usage}, + ) + + +def extract_gemini_cache_metrics( + response: ModelResponse, provider: ProviderType = ProviderType.GEMINI +) -> CacheTelemetry: + """Gemini 2.5/3 응답에서 캐시 메트릭을 추출합니다.""" + usage = getattr(response, "usage_metadata", {}) or {} + response_meta = getattr(response, "response_metadata", {}) or {} + + cached_tokens = response_meta.get("cached_content_token_count", 0) + input_tokens = usage.get("input_tokens", 0) or response_meta.get( + "prompt_token_count", 0 + ) + + hit_ratio = cached_tokens / input_tokens if input_tokens > 0 else 0.0 + + return CacheTelemetry( + provider=provider, + cache_read_tokens=cached_tokens, + cache_write_tokens=0, + total_input_tokens=input_tokens, + cache_hit_ratio=hit_ratio, + raw_metadata={"usage": usage, "response_metadata": response_meta}, + ) + + +def extract_deepseek_cache_metrics(response: ModelResponse) -> CacheTelemetry: + """DeepSeek 응답에서 캐시 메트릭을 추출합니다.""" + usage = getattr(response, "usage_metadata", {}) or {} + response_meta = getattr(response, "response_metadata", {}) or {} + + cache_hit = response_meta.get("cache_hit_tokens", 0) + cache_miss = response_meta.get("cache_miss_tokens", 0) + input_tokens = usage.get("input_tokens", 0) or (cache_hit + cache_miss) + + hit_ratio = cache_hit / input_tokens if input_tokens > 0 else 0.0 + + return CacheTelemetry( + provider=ProviderType.DEEPSEEK, + cache_read_tokens=cache_hit, + cache_write_tokens=cache_miss, + total_input_tokens=input_tokens, + cache_hit_ratio=hit_ratio, + raw_metadata={"usage": usage, "response_metadata": response_meta}, + ) + + +def extract_cache_telemetry( + response: ModelResponse, provider: ProviderType +) -> CacheTelemetry: + """응답에서 Provider별 캐시 텔레메트리를 추출합니다.""" + extractors: dict[ProviderType, Callable[[ModelResponse], CacheTelemetry]] = { + ProviderType.ANTHROPIC: extract_anthropic_cache_metrics, + ProviderType.OPENAI: extract_openai_cache_metrics, + ProviderType.GEMINI: extract_gemini_cache_metrics, + ProviderType.GEMINI_3: lambda r: extract_gemini_cache_metrics( + r, ProviderType.GEMINI_3 + ), + ProviderType.DEEPSEEK: extract_deepseek_cache_metrics, + } + + extractor = extractors.get(provider) + if extractor: + return extractor(response) + + return CacheTelemetry( + provider=provider, + raw_metadata={ + "usage": getattr(response, "usage_metadata", {}), + "response_metadata": getattr(response, "response_metadata", {}), + }, + ) + + +class PromptCachingTelemetryMiddleware(AgentMiddleware): + """모든 Provider의 캐시 사용량을 모니터링하는 Middleware. + + 요청을 변형하지 않고, 응답의 cache 관련 메타데이터만 수집/로깅합니다. + """ + + def __init__(self, log_level: int = logging.DEBUG) -> None: + self._log_level = log_level + self._telemetry_history: list[CacheTelemetry] = [] + + @property + def telemetry_history(self) -> list[CacheTelemetry]: + return self._telemetry_history + + def get_aggregate_stats(self) -> dict[str, Any]: + if not self._telemetry_history: + return {"total_calls": 0, "total_cache_read_tokens": 0} + + total_read = sum(t.cache_read_tokens for t in self._telemetry_history) + total_write = sum(t.cache_write_tokens for t in self._telemetry_history) + total_input = sum(t.total_input_tokens for t in self._telemetry_history) + + return { + "total_calls": len(self._telemetry_history), + "total_cache_read_tokens": total_read, + "total_cache_write_tokens": total_write, + "total_input_tokens": total_input, + "overall_cache_hit_ratio": total_read / total_input if total_input else 0.0, + } + + def _log_telemetry(self, telemetry: CacheTelemetry) -> None: + if telemetry.cache_read_tokens > 0 or telemetry.cache_write_tokens > 0: + logger.log( + self._log_level, + f"[CacheTelemetry] {telemetry.provider.value}: " + f"read={telemetry.cache_read_tokens}, " + f"write={telemetry.cache_write_tokens}, " + f"hit_ratio={telemetry.cache_hit_ratio:.2%}", + ) + + def _process_response(self, response: ModelResponse) -> ModelResponse: + model = getattr(response, "response_metadata", {}).get("model", "") + provider = self._detect_provider_from_response(response, model) + telemetry = extract_cache_telemetry(response, provider) + self._telemetry_history.append(telemetry) + self._log_telemetry(telemetry) + return response + + def _detect_provider_from_response( + self, response: ModelResponse, model_name: str + ) -> ProviderType: + model_lower = model_name.lower() + if "claude" in model_lower or "anthropic" in model_lower: + return ProviderType.ANTHROPIC + if "gpt" in model_lower or "o1" in model_lower or "o3" in model_lower: + return ProviderType.OPENAI + if "gemini-3" in model_lower or "gemini/3" in model_lower: + return ProviderType.GEMINI_3 + if "gemini" in model_lower: + return ProviderType.GEMINI + if "deepseek" in model_lower: + return ProviderType.DEEPSEEK + if "groq" in model_lower or "kimi" in model_lower: + return ProviderType.GROQ + if "grok" in model_lower: + return ProviderType.GROK + return ProviderType.UNKNOWN + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + response = handler(request) + return self._process_response(response) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + response = await handler(request) + return self._process_response(response) diff --git a/context_engineering_research_agent/context_strategies/isolation.py b/context_engineering_research_agent/context_strategies/isolation.py new file mode 100644 index 0000000..e000694 --- /dev/null +++ b/context_engineering_research_agent/context_strategies/isolation.py @@ -0,0 +1,234 @@ +"""Context Isolation 전략 구현. + +SubAgent를 통해 독립된 컨텍스트 윈도우에서 작업을 수행하여 +메인 에이전트의 컨텍스트를 오염시키지 않는 전략입니다. + +DeepAgents의 SubAgentMiddleware에서 task() 도구로 구현되어 있습니다. +""" + +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from typing import Any, NotRequired, TypedDict + +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + ModelResponse, +) +from langchain.tools import BaseTool, ToolRuntime +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import HumanMessage, ToolMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import StructuredTool +from langgraph.types import Command + + +class SubAgentSpec(TypedDict): + name: str + description: str + system_prompt: str + tools: Sequence[BaseTool | Callable | dict[str, Any]] + model: NotRequired[str | BaseChatModel] + middleware: NotRequired[list[AgentMiddleware]] + + +class CompiledSubAgentSpec(TypedDict): + name: str + description: str + runnable: Runnable + + +@dataclass +class IsolationConfig: + default_model: str | BaseChatModel = "gpt-4.1" + include_general_purpose_agent: bool = True + excluded_state_keys: tuple[str, ...] = ("messages", "todos", "structured_response") + + +@dataclass +class IsolationResult: + subagent_name: str + was_successful: bool + result_length: int + error: str | None = None + + +class ContextIsolationStrategy(AgentMiddleware): + """SubAgent를 통한 Context Isolation 구현. + + Args: + config: Isolation 설정 + subagents: SubAgent 명세 목록 + agent_factory: 에이전트 생성 팩토리 함수 + """ + + def __init__( + self, + config: IsolationConfig | None = None, + subagents: list[SubAgentSpec | CompiledSubAgentSpec] | None = None, + agent_factory: Callable[..., Runnable] | None = None, + ) -> None: + self.config = config or IsolationConfig() + self._subagents = subagents or [] + self._agent_factory = agent_factory + self._compiled_agents: dict[str, Runnable] = {} + self.tools = [self._create_task_tool()] + + def _compile_subagents(self) -> dict[str, Runnable]: + if self._compiled_agents: + return self._compiled_agents + + agents: dict[str, Runnable] = {} + + for spec in self._subagents: + if "runnable" in spec: + compiled = spec # type: ignore + agents[compiled["name"]] = compiled["runnable"] + elif self._agent_factory: + simple = spec # type: ignore + agents[simple["name"]] = self._agent_factory( + model=simple.get("model", self.config.default_model), + system_prompt=simple["system_prompt"], + tools=simple["tools"], + middleware=simple.get("middleware", []), + ) + + self._compiled_agents = agents + return agents + + def _get_subagent_descriptions(self) -> str: + descriptions = [] + for spec in self._subagents: + descriptions.append(f"- {spec['name']}: {spec['description']}") + return "\n".join(descriptions) + + def _prepare_subagent_state( + self, state: dict[str, Any], task_description: str + ) -> dict[str, Any]: + filtered = { + k: v for k, v in state.items() if k not in self.config.excluded_state_keys + } + filtered["messages"] = [HumanMessage(content=task_description)] + return filtered + + def _create_task_tool(self) -> BaseTool: + strategy = self + + def task( + description: str, + subagent_type: str, + runtime: ToolRuntime, + ) -> str | Command: + agents = strategy._compile_subagents() + + if subagent_type not in agents: + allowed = ", ".join(f"`{k}`" for k in agents) + return f"SubAgent '{subagent_type}'가 존재하지 않습니다. 사용 가능: {allowed}" + + subagent = agents[subagent_type] + subagent_state = strategy._prepare_subagent_state( + runtime.state, description + ) + + result = subagent.invoke(subagent_state, runtime.config) + + final_message = result["messages"][-1].text.rstrip() + + state_update = { + k: v + for k, v in result.items() + if k not in strategy.config.excluded_state_keys + } + + return Command( + update={ + **state_update, + "messages": [ + ToolMessage(final_message, tool_call_id=runtime.tool_call_id) + ], + } + ) + + async def atask( + description: str, + subagent_type: str, + runtime: ToolRuntime, + ) -> str | Command: + agents = strategy._compile_subagents() + + if subagent_type not in agents: + allowed = ", ".join(f"`{k}`" for k in agents) + return f"SubAgent '{subagent_type}'가 존재하지 않습니다. 사용 가능: {allowed}" + + subagent = agents[subagent_type] + subagent_state = strategy._prepare_subagent_state( + runtime.state, description + ) + + result = await subagent.ainvoke(subagent_state, runtime.config) + + final_message = result["messages"][-1].text.rstrip() + + state_update = { + k: v + for k, v in result.items() + if k not in strategy.config.excluded_state_keys + } + + return Command( + update={ + **state_update, + "messages": [ + ToolMessage(final_message, tool_call_id=runtime.tool_call_id) + ], + } + ) + + subagent_list = self._get_subagent_descriptions() + + return StructuredTool.from_function( + name="task", + func=task, + coroutine=atask, + description=f"""SubAgent에게 작업을 위임합니다. + +사용 가능한 SubAgent: +{subagent_list} + +사용법: +- description: 위임할 작업 상세 설명 +- subagent_type: 사용할 SubAgent 이름 + +SubAgent는 독립된 컨텍스트에서 실행되어 메인 에이전트의 컨텍스트를 오염시키지 않습니다. +복잡하고 다단계 작업에 적합합니다.""", + ) + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + return await handler(request) + + +ISOLATION_SYSTEM_PROMPT = """## Context Isolation (task 도구) + +task 도구로 SubAgent에게 작업을 위임할 수 있습니다. + +장점: +- 독립된 컨텍스트 윈도우 +- 메인 컨텍스트 오염 방지 +- 복잡한 작업의 격리 처리 + +사용 시점: +- 다단계 복잡한 작업 +- 대량의 컨텍스트가 필요한 연구 +- 병렬 처리가 가능한 독립 작업 +""" diff --git a/context_engineering_research_agent/context_strategies/offloading.py b/context_engineering_research_agent/context_strategies/offloading.py new file mode 100644 index 0000000..74dbfa5 --- /dev/null +++ b/context_engineering_research_agent/context_strategies/offloading.py @@ -0,0 +1,260 @@ +"""Context Offloading 전략 구현. + +## 개요 + +Context Offloading은 대용량 도구 결과를 파일시스템으로 축출하여 +컨텍스트 윈도우 오버플로우를 방지하는 전략입니다. + +## 핵심 원리 + +1. 도구 실행 결과가 특정 토큰 임계값을 초과하면 자동으로 파일로 저장 +2. 원본 메시지는 파일 경로 참조로 대체 +3. 에이전트가 필요할 때 read_file로 데이터 로드 + +## DeepAgents 구현 + +FilesystemMiddleware의 `_intercept_large_tool_result` 메서드에서 구현: +- `tool_token_limit_before_evict`: 축출 임계값 (기본 20,000 토큰) +- `/large_tool_results/{tool_call_id}` 경로에 저장 +- 처음 10줄 미리보기 제공 + +## 장점 + +- 컨텍스트 윈도우 절약 +- 대용량 데이터 처리 가능 +- 선택적 로딩으로 효율성 증가 + +## 사용 예시 + +```python +from deepagents.middleware.filesystem import FilesystemMiddleware + +middleware = FilesystemMiddleware( + tool_token_limit_before_evict=15000 # 15,000 토큰 초과 시 축출 +) +``` +""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any + +from langchain.agents.middleware.types import ( + AgentMiddleware, +) +from langchain.tools import ToolRuntime +from langchain.tools.tool_node import ToolCallRequest +from langchain_core.messages import ToolMessage +from langgraph.types import Command + + +@dataclass +class OffloadingConfig: + """Context Offloading 설정.""" + + token_limit_before_evict: int = 20000 + """도구 결과를 파일로 축출하기 전 토큰 임계값. 기본값 20,000.""" + + eviction_path_prefix: str = "/large_tool_results" + """축출된 파일이 저장될 경로 접두사.""" + + preview_lines: int = 10 + """축출 시 포함할 미리보기 줄 수.""" + + chars_per_token: int = 4 + """토큰당 문자 수 근사값 (보수적 추정).""" + + +@dataclass +class OffloadingResult: + """Offloading 처리 결과.""" + + was_offloaded: bool + """실제로 축출이 발생했는지 여부.""" + + original_size: int + """원본 콘텐츠 크기 (문자 수).""" + + file_path: str | None = None + """축출된 파일 경로 (축출된 경우).""" + + preview: str | None = None + """축출 시 제공되는 미리보기.""" + + +class ContextOffloadingStrategy(AgentMiddleware): + """Context Offloading 전략 구현. + + 대용량 도구 결과를 파일시스템으로 자동 축출하여 + 컨텍스트 윈도우 오버플로우를 방지합니다. + + ## 동작 원리 + + 1. wrap_tool_call에서 도구 실행 결과를 가로챔 + 2. 결과 크기가 임계값을 초과하면 파일로 저장 + 3. 원본 메시지를 파일 경로 참조로 대체 + 4. 에이전트는 필요시 read_file로 데이터 로드 + + ## DeepAgents FilesystemMiddleware와의 관계 + + 이 클래스는 FilesystemMiddleware의 offloading 로직을 + 명시적으로 분리하여 전략 패턴으로 구현한 것입니다. + + Args: + config: Offloading 설정. None이면 기본값 사용. + backend_factory: 파일 저장용 백엔드 팩토리 함수. + """ + + def __init__( + self, + config: OffloadingConfig | None = None, + backend_factory: Callable[[ToolRuntime], Any] | None = None, + ) -> None: + self.config = config or OffloadingConfig() + self._backend_factory = backend_factory + + def _estimate_tokens(self, content: str) -> int: + """콘텐츠의 토큰 수를 추정합니다. + + 보수적인 추정값을 사용하여 조기 축출을 방지합니다. + 실제 토큰 수는 모델과 콘텐츠에 따라 다릅니다. + """ + return len(content) // self.config.chars_per_token + + def _should_offload(self, content: str) -> bool: + """주어진 콘텐츠가 축출 대상인지 판단합니다.""" + estimated_tokens = self._estimate_tokens(content) + return estimated_tokens > self.config.token_limit_before_evict + + def _create_preview(self, content: str) -> str: + """축출될 콘텐츠의 미리보기를 생성합니다.""" + lines = content.splitlines()[: self.config.preview_lines] + truncated_lines = [line[:1000] for line in lines] + return "\n".join(f"{i + 1:5}\t{line}" for i, line in enumerate(truncated_lines)) + + def _create_offload_message( + self, tool_call_id: str, file_path: str, preview: str + ) -> str: + """축출 후 대체 메시지를 생성합니다.""" + return f"""도구 결과가 너무 커서 파일시스템에 저장되었습니다. + +경로: {file_path} + +read_file 도구로 결과를 읽을 수 있습니다. +대용량 결과의 경우 offset과 limit 파라미터로 부분 읽기를 권장합니다. + +처음 {self.config.preview_lines}줄 미리보기: +{preview} +""" + + def process_tool_result( + self, + tool_result: ToolMessage, + runtime: ToolRuntime, + ) -> tuple[ToolMessage | Command, OffloadingResult]: + """도구 결과를 처리하고 필요시 축출합니다. + + Args: + tool_result: 원본 도구 실행 결과. + runtime: 도구 런타임 컨텍스트. + + Returns: + 처리된 메시지와 Offloading 결과 튜플. + """ + content = ( + tool_result.content + if isinstance(tool_result.content, str) + else str(tool_result.content) + ) + + result = OffloadingResult( + was_offloaded=False, + original_size=len(content), + ) + + if not self._should_offload(content): + return tool_result, result + + if self._backend_factory is None: + return tool_result, result + + backend = self._backend_factory(runtime) + + sanitized_id = self._sanitize_tool_call_id(tool_result.tool_call_id) + file_path = f"{self.config.eviction_path_prefix}/{sanitized_id}" + + write_result = backend.write(file_path, content) + if write_result.error: + return tool_result, result + + preview = self._create_preview(content) + replacement_text = self._create_offload_message( + tool_result.tool_call_id, file_path, preview + ) + + result.was_offloaded = True + result.file_path = file_path + result.preview = preview + + if write_result.files_update is not None: + return Command( + update={ + "files": write_result.files_update, + "messages": [ + ToolMessage( + content=replacement_text, + tool_call_id=tool_result.tool_call_id, + ) + ], + } + ), result + + return ToolMessage( + content=replacement_text, + tool_call_id=tool_result.tool_call_id, + ), result + + def _sanitize_tool_call_id(self, tool_call_id: str) -> str: + """파일명에 안전한 tool_call_id로 변환합니다.""" + return "".join(c if c.isalnum() or c in "-_" else "_" for c in tool_call_id) + + def wrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command], + ) -> ToolMessage | Command: + """도구 호출을 래핑하여 결과를 가로채고 필요시 축출합니다.""" + tool_result = handler(request) + + if isinstance(tool_result, ToolMessage): + processed, _ = self.process_tool_result(tool_result, request.runtime) + return processed + + return tool_result + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]], + ) -> ToolMessage | Command: + """비동기 도구 호출을 래핑합니다.""" + tool_result = await handler(request) + + if isinstance(tool_result, ToolMessage): + processed, _ = self.process_tool_result(tool_result, request.runtime) + return processed + + return tool_result + + +OFFLOADING_SYSTEM_PROMPT = """## Context Offloading 안내 + +대용량 도구 결과는 자동으로 파일시스템에 저장됩니다. + +결과가 축출된 경우: +1. 파일 경로와 미리보기가 제공됩니다 +2. read_file(path, offset=0, limit=100)으로 부분 읽기하세요 +3. 전체 내용이 필요한 경우에만 전체 파일을 읽으세요 + +이 방식으로 컨텍스트 윈도우를 효율적으로 관리할 수 있습니다. +""" diff --git a/context_engineering_research_agent/context_strategies/reduction.py b/context_engineering_research_agent/context_strategies/reduction.py new file mode 100644 index 0000000..97f9211 --- /dev/null +++ b/context_engineering_research_agent/context_strategies/reduction.py @@ -0,0 +1,318 @@ +"""Context Reduction 전략 구현. + +## 개요 + +Context Reduction은 컨텍스트 윈도우 사용량이 임계값을 초과할 때 +자동으로 대화 내용을 압축하는 전략입니다. + +## 두 가지 기법 + +### 1. Compaction (압축) +- 오래된 메시지에서 도구 호출(tool_calls)과 도구 결과(ToolMessage) 제거 +- 텍스트 응답만 유지 +- 세부 실행 이력은 제거하고 결론만 보존 + +### 2. Summarization (요약) +- 컨텍스트가 임계값(기본 85%) 초과 시 트리거 +- LLM을 사용하여 대화 내용 요약 +- 핵심 정보만 유지하고 세부사항 압축 + +## DeepAgents 구현 + +SummarizationMiddleware에서 구현: +- `context_threshold`: 요약 트리거 임계값 (기본 0.85 = 85%) +- `model_context_window`: 모델의 컨텍스트 윈도우 크기 +- 자동으로 토큰 사용량 추적 및 요약 트리거 + +## 사용 예시 + +```python +from deepagents.middleware.summarization import SummarizationMiddleware + +middleware = SummarizationMiddleware( + context_threshold=0.85, # 85% 사용 시 요약 + model_context_window=200000, # Claude의 컨텍스트 윈도우 +) +``` +""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + ModelResponse, +) +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, +) + + +@dataclass +class ReductionConfig: + """Context Reduction 설정.""" + + context_threshold: float = 0.85 + """컨텍스트 사용률 임계값. 기본 0.85 (85%).""" + + model_context_window: int = 200000 + """모델의 전체 컨텍스트 윈도우 크기.""" + + compaction_age_threshold: int = 10 + """Compaction 대상이 되는 메시지 나이 (메시지 수 기준).""" + + min_messages_to_keep: int = 5 + """요약 후 유지할 최소 메시지 수.""" + + chars_per_token: int = 4 + """토큰당 문자 수 근사값.""" + + +@dataclass +class ReductionResult: + """Reduction 처리 결과.""" + + was_reduced: bool + """실제로 축소가 발생했는지 여부.""" + + technique_used: str | None = None + """사용된 기법 ('compaction' 또는 'summarization').""" + + original_message_count: int = 0 + """원본 메시지 수.""" + + reduced_message_count: int = 0 + """축소 후 메시지 수.""" + + estimated_tokens_saved: int = 0 + """절약된 추정 토큰 수.""" + + +class ContextReductionStrategy(AgentMiddleware): + """Context Reduction 전략 구현. + + 컨텍스트 윈도우 사용량이 임계값을 초과할 때 자동으로 + 대화 내용을 압축하는 전략입니다. + + ## 동작 원리 + + 1. before_model_call에서 현재 토큰 사용량 추정 + 2. 임계값 초과 시 먼저 Compaction 시도 + 3. 여전히 초과하면 Summarization 실행 + 4. 축소된 메시지로 요청 수정 + + ## Compaction vs Summarization + + - **Compaction**: 빠르고 저렴함. 도구 호출/결과만 제거. + - **Summarization**: 느리고 비용 발생. LLM이 내용 요약. + + 우선순위: Compaction → Summarization + + Args: + config: Reduction 설정. None이면 기본값 사용. + summarization_model: 요약에 사용할 LLM. None이면 요약 비활성화. + """ + + def __init__( + self, + config: ReductionConfig | None = None, + summarization_model: BaseChatModel | None = None, + ) -> None: + self.config = config or ReductionConfig() + self._summarization_model = summarization_model + + def _estimate_tokens(self, messages: list[BaseMessage]) -> int: + """메시지 목록의 총 토큰 수를 추정합니다.""" + total_chars = sum(len(str(msg.content)) for msg in messages) + return total_chars // self.config.chars_per_token + + def _get_context_usage_ratio(self, messages: list[BaseMessage]) -> float: + """현재 컨텍스트 사용률을 계산합니다.""" + estimated_tokens = self._estimate_tokens(messages) + return estimated_tokens / self.config.model_context_window + + def _should_reduce(self, messages: list[BaseMessage]) -> bool: + """축소가 필요한지 판단합니다.""" + usage_ratio = self._get_context_usage_ratio(messages) + return usage_ratio > self.config.context_threshold + + def apply_compaction( + self, + messages: list[BaseMessage], + ) -> tuple[list[BaseMessage], ReductionResult]: + """Compaction을 적용합니다. + + 오래된 메시지에서 도구 호출과 도구 결과를 제거합니다. + """ + original_count = len(messages) + compacted: list[BaseMessage] = [] + + for i, msg in enumerate(messages): + age = len(messages) - i + + if age <= self.config.compaction_age_threshold: + compacted.append(msg) + continue + + if isinstance(msg, AIMessage): + if msg.tool_calls: + text_content = ( + msg.text if hasattr(msg, "text") else str(msg.content) + ) + if text_content.strip(): + compacted.append(AIMessage(content=text_content)) + else: + compacted.append(msg) + elif isinstance(msg, (HumanMessage, SystemMessage)): + compacted.append(msg) + + result = ReductionResult( + was_reduced=len(compacted) < original_count, + technique_used="compaction", + original_message_count=original_count, + reduced_message_count=len(compacted), + estimated_tokens_saved=( + self._estimate_tokens(messages) - self._estimate_tokens(compacted) + ), + ) + + return compacted, result + + def apply_summarization( + self, + messages: list[BaseMessage], + ) -> tuple[list[BaseMessage], ReductionResult]: + """Summarization을 적용합니다. + + LLM을 사용하여 대화 내용을 요약합니다. + """ + if self._summarization_model is None: + return messages, ReductionResult(was_reduced=False) + + original_count = len(messages) + + keep_count = self.config.min_messages_to_keep + messages_to_summarize = ( + messages[:-keep_count] if len(messages) > keep_count else [] + ) + recent_messages = ( + messages[-keep_count:] if len(messages) > keep_count else messages + ) + + if not messages_to_summarize: + return messages, ReductionResult(was_reduced=False) + + summary_prompt = self._create_summary_prompt(messages_to_summarize) + + summary_response = self._summarization_model.invoke( + [ + SystemMessage( + content="당신은 대화 요약 전문가입니다. 핵심 정보만 간결하게 요약하세요." + ), + HumanMessage(content=summary_prompt), + ] + ) + + summary_message = SystemMessage( + content=f"[이전 대화 요약]\n{summary_response.content}" + ) + + summarized = [summary_message] + list(recent_messages) + + result = ReductionResult( + was_reduced=True, + technique_used="summarization", + original_message_count=original_count, + reduced_message_count=len(summarized), + estimated_tokens_saved=( + self._estimate_tokens(messages) - self._estimate_tokens(summarized) + ), + ) + + return summarized, result + + def _create_summary_prompt(self, messages: list[BaseMessage]) -> str: + """요약을 위한 프롬프트를 생성합니다.""" + conversation_text = [] + for msg in messages: + role = msg.__class__.__name__.replace("Message", "") + content = str(msg.content)[:500] + conversation_text.append(f"[{role}]: {content}") + + return f"""다음 대화를 요약해주세요. 핵심 정보, 결정사항, 중요한 컨텍스트만 포함하세요. + +대화 내용: +{chr(10).join(conversation_text)} + +요약 (한국어로, 500자 이내):""" + + def reduce_context( + self, + messages: list[BaseMessage], + ) -> tuple[list[BaseMessage], ReductionResult]: + """컨텍스트를 축소합니다. + + 먼저 Compaction을 시도하고, 여전히 임계값을 초과하면 + Summarization을 적용합니다. + """ + if not self._should_reduce(messages): + return messages, ReductionResult(was_reduced=False) + + compacted, compaction_result = self.apply_compaction(messages) + + if not self._should_reduce(compacted): + return compacted, compaction_result + + summarized, summarization_result = self.apply_summarization(compacted) + + return summarized, summarization_result + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + """모델 호출을 래핑하여 필요시 컨텍스트를 축소합니다.""" + messages = list(request.state.get("messages", [])) + + reduced_messages, result = self.reduce_context(messages) + + if result.was_reduced: + modified_state = {**request.state, "messages": reduced_messages} + request = request.override(state=modified_state) + + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + """비동기 모델 호출을 래핑합니다.""" + messages = list(request.state.get("messages", [])) + + reduced_messages, result = self.reduce_context(messages) + + if result.was_reduced: + modified_state = {**request.state, "messages": reduced_messages} + request = request.override(state=modified_state) + + return await handler(request) + + +REDUCTION_SYSTEM_PROMPT = """## Context Reduction 안내 + +대화가 길어지면 자동으로 컨텍스트가 압축됩니다. + +압축 방식: +1. **Compaction**: 오래된 도구 호출/결과 제거 +2. **Summarization**: LLM이 이전 대화 요약 + +중요한 정보는 파일시스템에 저장하는 것을 권장합니다. +요약으로 인해 세부사항이 손실될 수 있습니다. +""" diff --git a/context_engineering_research_agent/context_strategies/retrieval.py b/context_engineering_research_agent/context_strategies/retrieval.py new file mode 100644 index 0000000..fb3f844 --- /dev/null +++ b/context_engineering_research_agent/context_strategies/retrieval.py @@ -0,0 +1,319 @@ +"""Context Retrieval 전략 구현. + +## 개요 + +Context Retrieval은 필요한 정보를 선택적으로 로드하여 +컨텍스트 윈도우를 효율적으로 사용하는 전략입니다. + +## 핵심 원리 + +1. 벡터 DB나 복잡한 인덱싱 없이 직접 파일 검색 +2. grep/glob 기반의 단순하고 빠른 패턴 매칭 +3. 필요한 파일/내용만 선택적으로 로드 + +## DeepAgents 구현 + +FilesystemMiddleware에서 제공하는 도구들: +- `read_file`: 파일 내용 읽기 (offset/limit으로 부분 읽기 지원) +- `grep`: 텍스트 패턴 검색 +- `glob`: 파일명 패턴 매칭 +- `ls`: 디렉토리 목록 조회 + +## 벡터 검색을 사용하지 않는 이유 + +1. **단순성**: 추가 인프라 불필요 +2. **결정성**: 정확한 매칭, 모호함 없음 +3. **속도**: 인덱싱 오버헤드 없음 +4. **디버깅 용이**: 검색 결과 예측 가능 + +## 사용 예시 + +```python +# 파일 검색 +grep(pattern="TODO", glob="*.py") + +# 파일 목록 +glob(pattern="**/*.md") + +# 부분 읽기 +read_file(path="/data.txt", offset=100, limit=50) +``` +""" + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Any, Literal + +from langchain.agents.middleware.types import ( + AgentMiddleware, + ModelRequest, + ModelResponse, +) +from langchain.tools import ToolRuntime +from langchain_core.tools import BaseTool, StructuredTool + + +@dataclass +class RetrievalConfig: + """Context Retrieval 설정.""" + + default_read_limit: int = 500 + """read_file의 기본 줄 수 제한.""" + + max_grep_results: int = 100 + """grep 결과 최대 개수.""" + + max_glob_results: int = 100 + """glob 결과 최대 개수.""" + + truncate_line_length: int = 2000 + """줄 길이 제한 (초과 시 자름).""" + + +@dataclass +class RetrievalResult: + """검색 결과 메타데이터.""" + + tool_used: str + """사용된 도구 이름.""" + + query: str + """검색 쿼리/패턴.""" + + result_count: int + """결과 개수.""" + + was_truncated: bool = False + """결과가 잘렸는지 여부.""" + + +class ContextRetrievalStrategy(AgentMiddleware): + """Context Retrieval 전략 구현. + + grep/glob 기반의 단순하고 빠른 검색으로 + 필요한 정보만 선택적으로 로드합니다. + + ## 동작 원리 + + 1. 파일시스템에서 직접 패턴 매칭 + 2. 결과 개수 제한으로 컨텍스트 오버플로우 방지 + 3. 부분 읽기로 대용량 파일 효율적 처리 + + ## 제공 도구 + + - read_file: 파일 읽기 (offset/limit 지원) + - grep: 텍스트 패턴 검색 + - glob: 파일명 패턴 매칭 + + ## 벡터 DB를 사용하지 않는 이유 + + DeepAgents는 의도적으로 벡터 검색 대신 직접 파일 검색을 선택했습니다: + - 결정적이고 예측 가능한 결과 + - 추가 인프라 불필요 + - 디버깅 용이 + + Args: + config: Retrieval 설정. None이면 기본값 사용. + backend_factory: 파일 작업용 백엔드 팩토리 함수. + """ + + def __init__( + self, + config: RetrievalConfig | None = None, + backend_factory: Callable[[ToolRuntime], Any] | None = None, + ) -> None: + self.config = config or RetrievalConfig() + self._backend_factory = backend_factory + self.tools = self._create_tools() + + def _create_tools(self) -> list[BaseTool]: + """검색 도구들을 생성합니다.""" + return [ + self._create_read_file_tool(), + self._create_grep_tool(), + self._create_glob_tool(), + ] + + def _create_read_file_tool(self) -> BaseTool: + """read_file 도구를 생성합니다.""" + config = self.config + backend_factory = self._backend_factory + + def read_file( + file_path: str, + runtime: ToolRuntime, + offset: int = 0, + limit: int | None = None, + ) -> str: + """파일을 읽습니다. + + Args: + file_path: 읽을 파일의 절대 경로. + offset: 시작 줄 번호 (0부터 시작). + limit: 읽을 최대 줄 수. 기본값은 설정에 따름. + + Returns: + 줄 번호가 포함된 파일 내용. + """ + if backend_factory is None: + return "백엔드가 설정되지 않았습니다." + + backend = backend_factory(runtime) + actual_limit = limit or config.default_read_limit + return backend.read(file_path, offset=offset, limit=actual_limit) + + return StructuredTool.from_function( + name="read_file", + description=f"""파일을 읽습니다. + +사용법: +- file_path: 절대 경로 필수 +- offset: 시작 줄 (기본 0) +- limit: 읽을 줄 수 (기본 {config.default_read_limit}) + +대용량 파일은 offset/limit으로 부분 읽기를 권장합니다.""", + func=read_file, + ) + + def _create_grep_tool(self) -> BaseTool: + """Grep 도구를 생성합니다.""" + config = self.config + backend_factory = self._backend_factory + + def grep( + pattern: str, + runtime: ToolRuntime, + path: str | None = None, + glob_pattern: str | None = None, + output_mode: Literal[ + "files_with_matches", "content", "count" + ] = "files_with_matches", + ) -> str: + """텍스트 패턴을 검색합니다. + + Args: + pattern: 검색할 텍스트 (정규식 아님). + path: 검색 시작 디렉토리. + glob_pattern: 파일 필터 (예: "*.py"). + output_mode: 출력 형식. + + Returns: + 검색 결과. + """ + if backend_factory is None: + return "백엔드가 설정되지 않았습니다." + + backend = backend_factory(runtime) + raw_results = backend.grep_raw(pattern, path=path, glob=glob_pattern) + + if isinstance(raw_results, str): + return raw_results + + truncated = raw_results[: config.max_grep_results] + + if output_mode == "files_with_matches": + files = list(set(r.get("path", "") for r in truncated)) + return "\n".join(files) + elif output_mode == "count": + from collections import Counter + + counts = Counter(r.get("path", "") for r in truncated) + return "\n".join(f"{path}: {count}" for path, count in counts.items()) + else: + lines = [] + for r in truncated: + path = r.get("path", "") + line_num = r.get("line_number", 0) + content = r.get("content", "")[: config.truncate_line_length] + lines.append(f"{path}:{line_num}: {content}") + return "\n".join(lines) + + return StructuredTool.from_function( + name="grep", + description=f"""텍스트 패턴을 검색합니다. + +사용법: +- pattern: 검색할 텍스트 (리터럴 문자열) +- path: 검색 디렉토리 (선택) +- glob_pattern: 파일 필터 예: "*.py" (선택) +- output_mode: files_with_matches | content | count + +최대 {config.max_grep_results}개 결과를 반환합니다.""", + func=grep, + ) + + def _create_glob_tool(self) -> BaseTool: + """Glob 도구를 생성합니다.""" + config = self.config + backend_factory = self._backend_factory + + def glob( + pattern: str, + runtime: ToolRuntime, + path: str = "/", + ) -> str: + """파일명 패턴으로 파일을 찾습니다. + + Args: + pattern: glob 패턴 (예: "**/*.py"). + path: 검색 시작 경로. + + Returns: + 매칭된 파일 경로 목록. + """ + if backend_factory is None: + return "백엔드가 설정되지 않았습니다." + + backend = backend_factory(runtime) + infos = backend.glob_info(pattern, path=path) + + paths = [fi.get("path", "") for fi in infos[: config.max_glob_results]] + return "\n".join(paths) + + return StructuredTool.from_function( + name="glob", + description=f"""파일명 패턴으로 파일을 찾습니다. + +사용법: +- pattern: glob 패턴 (*, **, ? 지원) +- path: 검색 시작 경로 (기본 "/") + +예시: +- "**/*.py": 모든 Python 파일 +- "src/**/*.ts": src 아래 모든 TypeScript 파일 + +최대 {config.max_glob_results}개 결과를 반환합니다.""", + func=glob, + ) + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + """모델 호출을 래핑합니다 (기본 동작).""" + return handler(request) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + """비동기 모델 호출을 래핑합니다 (기본 동작).""" + return await handler(request) + + +RETRIEVAL_SYSTEM_PROMPT = """## Context Retrieval 도구 + +파일시스템에서 정보를 검색할 수 있습니다. + +도구: +- read_file: 파일 읽기 (offset/limit으로 부분 읽기) +- grep: 텍스트 패턴 검색 +- glob: 파일명 패턴 매칭 + +사용 팁: +1. 대용량 파일은 먼저 구조 파악 (limit=100) +2. grep으로 관련 파일 찾은 후 read_file +3. glob으로 파일 위치 확인 후 탐색 +""" diff --git a/context_engineering_research_agent/research/__init__.py b/context_engineering_research_agent/research/__init__.py new file mode 100644 index 0000000..8be3a91 --- /dev/null +++ b/context_engineering_research_agent/research/__init__.py @@ -0,0 +1,11 @@ +"""연구 에이전트 모듈.""" + +from context_engineering_more_deep_research_agent.research.agent import ( + create_researcher_agent, + get_researcher_subagent, +) + +__all__ = [ + "create_researcher_agent", + "get_researcher_subagent", +] diff --git a/context_engineering_research_agent/research/agent.py b/context_engineering_research_agent/research/agent.py new file mode 100644 index 0000000..18aff0c --- /dev/null +++ b/context_engineering_research_agent/research/agent.py @@ -0,0 +1,96 @@ +"""자율적 연구 에이전트.""" + +from datetime import datetime +from typing import Any + +from deepagents import create_deep_agent +from deepagents.backends.protocol import BackendFactory, BackendProtocol +from langchain_core.language_models import BaseChatModel +from langchain_openai import ChatOpenAI +from langgraph.graph.state import CompiledStateGraph + +AUTONOMOUS_RESEARCHER_INSTRUCTIONS = """당신은 자율적 연구 에이전트입니다. +"넓게 탐색 → 깊게 파기" 방법론을 따라 주제를 철저히 연구합니다. + +오늘 날짜: {date} + +## 연구 워크플로우 + +### Phase 1: 탐색적 검색 (1-2회) +- 넓은 검색으로 분야 전체 파악 +- 핵심 개념, 주요 플레이어, 최근 트렌드 확인 +- think_tool로 유망한 방향 2-3개 식별 + +### Phase 2: 심층 연구 (방향당 1-2회) +- 식별된 방향별 집중 검색 +- 각 검색 후 think_tool로 평가 +- 가치 있는 정보 획득 여부 판단 + +### Phase 3: 종합 +- 모든 발견 사항 검토 +- 패턴과 연결점 식별 +- 출처 일치/불일치 기록 + +## 도구 제한 +- 탐색: 최대 2회 검색 +- 심층: 최대 3-4회 검색 +- **총합: 5-6회** + +## 종료 조건 +- 포괄적 답변 가능 +- 최근 2회 검색이 유사 정보 반환 +- 최대 검색 횟수 도달 + +## 응답 형식 + +```markdown +## 핵심 발견 + +### 발견 1: [제목] +[인용 포함 상세 설명 [1], [2]] + +### 발견 2: [제목] +[인용 포함 상세 설명] + +## 출처 일치 분석 +- **높은 일치**: [출처들이 동의하는 주제] +- **불일치/불확실**: [충돌 정보] + +## 출처 +[1] 출처 제목: URL +[2] 출처 제목: URL +``` +""" + + +def create_researcher_agent( + model: str | BaseChatModel | None = None, + backend: BackendProtocol | BackendFactory | None = None, +) -> CompiledStateGraph: + if model is None: + model = ChatOpenAI(model="gpt-4.1", temperature=0.0) + + current_date = datetime.now().strftime("%Y-%m-%d") + formatted_prompt = AUTONOMOUS_RESEARCHER_INSTRUCTIONS.format(date=current_date) + + return create_deep_agent( + model=model, + system_prompt=formatted_prompt, + backend=backend, + ) + + +def get_researcher_subagent( + model: str | BaseChatModel | None = None, + backend: BackendProtocol | BackendFactory | None = None, +) -> dict[str, Any]: + researcher = create_researcher_agent(model=model, backend=backend) + + return { + "name": "researcher", + "description": ( + "자율적 심층 연구 에이전트. '넓게 탐색 → 깊게 파기' 방법론 사용. " + "복잡한 주제 연구, 다각적 질문, 트렌드 분석에 적합." + ), + "runnable": researcher, + } diff --git a/context_engineering_research_agent/skills/__init__.py b/context_engineering_research_agent/skills/__init__.py new file mode 100644 index 0000000..a76a5ba --- /dev/null +++ b/context_engineering_research_agent/skills/__init__.py @@ -0,0 +1,36 @@ +"""Context Engineering 연구 에이전트용 스킬 시스템. + +이 모듈은 Anthropic의 Agent Skills 패턴을 구현하여 +에이전트에게 도메인별 전문 지식과 워크플로우를 제공합니다. + +## Progressive Disclosure (점진적 공개) + +스킬은 점진적 공개 패턴을 따릅니다: +1. 세션 시작 시 스킬 메타데이터(이름 + 설명)만 로드 +2. 에이전트가 필요할 때 전체 SKILL.md 내용 읽기 +3. 컨텍스트 윈도우 효율적 사용 + +## Context Engineering 관점에서의 스킬 + +스킬 시스템은 Context Engineering의 핵심 전략들을 활용합니다: + +1. **Context Retrieval**: read_file로 필요할 때만 스킬 내용 로드 +2. **Context Offloading**: 전체 스킬 내용 대신 메타데이터만 시스템 프롬프트에 포함 +3. **Context Isolation**: 각 스킬은 독립적인 도메인 지식 캡슐화 +""" + +from context_engineering_research_agent.skills.load import ( + SkillMetadata, + list_skills, +) +from context_engineering_research_agent.skills.middleware import ( + SkillsMiddleware, + SkillsState, +) + +__all__ = [ + "SkillMetadata", + "list_skills", + "SkillsMiddleware", + "SkillsState", +] diff --git a/context_engineering_research_agent/skills/load.py b/context_engineering_research_agent/skills/load.py new file mode 100644 index 0000000..52789d9 --- /dev/null +++ b/context_engineering_research_agent/skills/load.py @@ -0,0 +1,197 @@ +"""SKILL.md 파일에서 에이전트 스킬을 파싱하고 로드하는 스킬 로더. + +YAML 프론트매터 파싱을 통해 Anthropic Agent Skills 패턴을 구현합니다. +""" + +from __future__ import annotations + +import logging +import re +from pathlib import Path +from typing import NotRequired, TypedDict + +import yaml + +logger = logging.getLogger(__name__) + +MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024 # 10MB - DoS 방지 +MAX_SKILL_NAME_LENGTH = 64 +MAX_SKILL_DESCRIPTION_LENGTH = 1024 + + +class SkillMetadata(TypedDict): + """Agent Skills 명세를 따르는 스킬 메타데이터.""" + + name: str + description: str + path: str + source: str + license: NotRequired[str | None] + compatibility: NotRequired[str | None] + metadata: NotRequired[dict[str, str] | None] + allowed_tools: NotRequired[str | None] + + +def _is_safe_path(path: Path, base_dir: Path) -> bool: + """경로가 base_dir 내에 안전하게 포함되어 있는지 확인합니다.""" + try: + resolved_path = path.resolve() + resolved_base = base_dir.resolve() + resolved_path.relative_to(resolved_base) + return True + except ValueError: + return False + except (OSError, RuntimeError): + return False + + +def _validate_skill_name(name: str, directory_name: str) -> tuple[bool, str]: + """Agent Skills 명세에 따라 스킬 이름을 검증합니다.""" + if not name: + return False, "이름은 필수입니다" + if len(name) > MAX_SKILL_NAME_LENGTH: + return False, "이름이 64자를 초과합니다" + if not re.match(r"^[a-z0-9]+(-[a-z0-9]+)*$", name): + return False, "이름은 소문자 영숫자와 단일 하이픈만 사용해야 합니다" + if name != directory_name: + return ( + False, + f"이름 '{name}'은 디렉토리 이름 '{directory_name}'과 일치해야 합니다", + ) + return True, "" + + +def _parse_skill_metadata(skill_md_path: Path, source: str) -> SkillMetadata | None: + """SKILL.md 파일에서 YAML 프론트매터를 파싱합니다.""" + try: + file_size = skill_md_path.stat().st_size + if file_size > MAX_SKILL_FILE_SIZE: + logger.warning( + "%s 건너뜀: 파일이 너무 큼 (%d 바이트)", skill_md_path, file_size + ) + return None + + content = skill_md_path.read_text(encoding="utf-8") + + frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n" + match = re.match(frontmatter_pattern, content, re.DOTALL) + + if not match: + logger.warning( + "%s 건너뜀: 유효한 YAML 프론트매터를 찾을 수 없음", skill_md_path + ) + return None + + frontmatter_str = match.group(1) + + try: + frontmatter_data = yaml.safe_load(frontmatter_str) + except yaml.YAMLError as e: + logger.warning("%s의 YAML이 유효하지 않음: %s", skill_md_path, e) + return None + + if not isinstance(frontmatter_data, dict): + logger.warning("%s 건너뜀: 프론트매터가 매핑이 아님", skill_md_path) + return None + + name = frontmatter_data.get("name") + description = frontmatter_data.get("description") + + if not name or not description: + logger.warning( + "%s 건너뜀: 필수 'name' 또는 'description' 누락", skill_md_path + ) + return None + + directory_name = skill_md_path.parent.name + is_valid, error = _validate_skill_name(str(name), directory_name) + if not is_valid: + logger.warning( + "'%s' 스킬 (%s)이 Agent Skills 명세를 따르지 않음: %s. " + "명세 준수를 위해 이름 변경을 고려하세요.", + name, + skill_md_path, + error, + ) + + description_str = str(description) + if len(description_str) > MAX_SKILL_DESCRIPTION_LENGTH: + logger.warning( + "%s의 설명이 %d자를 초과하여 잘림", + skill_md_path, + MAX_SKILL_DESCRIPTION_LENGTH, + ) + description_str = description_str[:MAX_SKILL_DESCRIPTION_LENGTH] + + return SkillMetadata( + name=str(name), + description=description_str, + path=str(skill_md_path), + source=source, + license=frontmatter_data.get("license"), + compatibility=frontmatter_data.get("compatibility"), + metadata=frontmatter_data.get("metadata"), + allowed_tools=frontmatter_data.get("allowed-tools"), + ) + + except (OSError, UnicodeDecodeError) as e: + logger.warning("%s 읽기 오류: %s", skill_md_path, e) + return None + + +def _list_skills_from_dir(skills_dir: Path, source: str) -> list[SkillMetadata]: + """단일 스킬 디렉토리에서 모든 스킬을 나열합니다.""" + skills_dir = skills_dir.expanduser() + if not skills_dir.exists(): + return [] + + try: + resolved_base = skills_dir.resolve() + except (OSError, RuntimeError): + return [] + + skills: list[SkillMetadata] = [] + + for skill_dir in skills_dir.iterdir(): + if not _is_safe_path(skill_dir, resolved_base): + continue + + if not skill_dir.is_dir(): + continue + + skill_md_path = skill_dir / "SKILL.md" + if not skill_md_path.exists(): + continue + + if not _is_safe_path(skill_md_path, resolved_base): + continue + + metadata = _parse_skill_metadata(skill_md_path, source=source) + if metadata: + skills.append(metadata) + + return skills + + +def list_skills( + *, + user_skills_dir: Path | None = None, + project_skills_dir: Path | None = None, +) -> list[SkillMetadata]: + """사용자 및/또는 프로젝트 디렉토리에서 스킬을 나열합니다. + + 프로젝트 스킬이 같은 이름의 사용자 스킬을 오버라이드합니다. + """ + all_skills: dict[str, SkillMetadata] = {} + + if user_skills_dir: + user_skills = _list_skills_from_dir(user_skills_dir, source="user") + for skill in user_skills: + all_skills[skill["name"]] = skill + + if project_skills_dir: + project_skills = _list_skills_from_dir(project_skills_dir, source="project") + for skill in project_skills: + all_skills[skill["name"]] = skill + + return list(all_skills.values()) diff --git a/context_engineering_research_agent/skills/middleware.py b/context_engineering_research_agent/skills/middleware.py new file mode 100644 index 0000000..1d63461 --- /dev/null +++ b/context_engineering_research_agent/skills/middleware.py @@ -0,0 +1,165 @@ +"""스킬 시스템 미들웨어. + +Progressive Disclosure 패턴으로 스킬 메타데이터를 시스템 프롬프트에 주입합니다. +""" + +from collections.abc import Awaitable, Callable +from pathlib import Path +from typing import NotRequired, TypedDict, cast + +from langchain.agents.middleware.types import ( + AgentMiddleware, + AgentState, + ModelRequest, + ModelResponse, +) +from langgraph.runtime import Runtime + +from context_engineering_research_agent.skills.load import SkillMetadata, list_skills + + +class SkillsState(AgentState): + skills_metadata: NotRequired[list[SkillMetadata]] + + +class SkillsStateUpdate(TypedDict): + skills_metadata: list[SkillMetadata] + + +SKILLS_SYSTEM_PROMPT = """ + +## 스킬 시스템 + +스킬 라이브러리를 통해 전문화된 기능과 도메인 지식을 사용할 수 있습니다. + +{skills_locations} + +**사용 가능한 스킬:** + +{skills_list} + +**스킬 사용법 (Progressive Disclosure):** + +스킬은 점진적 공개 패턴을 따릅니다. 위에서 스킬의 존재(이름 + 설명)를 알 수 있지만, +필요할 때만 전체 지침을 읽습니다: + +1. 스킬 적용 여부 판단: 사용자 요청이 스킬 설명과 일치하는지 확인 +2. 전체 지침 읽기: read_file로 SKILL.md 경로 읽기 +3. 지침 따르기: SKILL.md에는 단계별 워크플로우와 예시 포함 +4. 지원 파일 활용: 스킬에 Python 스크립트나 설정 파일 포함 가능 + +**스킬 사용 시점:** +- 사용자 요청이 스킬 도메인과 일치할 때 +- 전문 지식이나 구조화된 워크플로우가 도움될 때 +- 복잡한 작업에 검증된 패턴이 필요할 때 +""" + + +class SkillsMiddleware(AgentMiddleware): + """Progressive Disclosure 패턴으로 스킬을 노출하는 미들웨어.""" + + state_schema = SkillsState + + def __init__( + self, + *, + skills_dir: str | Path, + assistant_id: str, + project_skills_dir: str | Path | None = None, + ) -> None: + self.skills_dir = Path(skills_dir).expanduser() + self.assistant_id = assistant_id + self.project_skills_dir = ( + Path(project_skills_dir).expanduser() if project_skills_dir else None + ) + self.user_skills_display = f"~/.deepagents/{assistant_id}/skills" + self.system_prompt_template = SKILLS_SYSTEM_PROMPT + + def _format_skills_locations(self) -> str: + locations = [f"**사용자 스킬**: `{self.user_skills_display}`"] + if self.project_skills_dir: + locations.append( + f"**프로젝트 스킬**: `{self.project_skills_dir}` (사용자 스킬 오버라이드)" + ) + return "\n".join(locations) + + def _format_skills_list(self, skills: list[SkillMetadata]) -> str: + if not skills: + locations = [f"{self.user_skills_display}/"] + if self.project_skills_dir: + locations.append(f"{self.project_skills_dir}/") + return f"(사용 가능한 스킬 없음. {' 또는 '.join(locations)}에서 스킬 생성 가능)" + + user_skills = [s for s in skills if s["source"] == "user"] + project_skills = [s for s in skills if s["source"] == "project"] + + lines = [] + + if user_skills: + lines.append("**사용자 스킬:**") + for skill in user_skills: + lines.append(f"- **{skill['name']}**: {skill['description']}") + lines.append(f" → 전체 지침: `{skill['path']}`") + lines.append("") + + if project_skills: + lines.append("**프로젝트 스킬:**") + for skill in project_skills: + lines.append(f"- **{skill['name']}**: {skill['description']}") + lines.append(f" → 전체 지침: `{skill['path']}`") + + return "\n".join(lines) + + def before_agent( + self, state: SkillsState, runtime: Runtime + ) -> SkillsStateUpdate | None: + skills = list_skills( + user_skills_dir=self.skills_dir, + project_skills_dir=self.project_skills_dir, + ) + return SkillsStateUpdate(skills_metadata=skills) + + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelResponse: + skills_metadata = request.state.get("skills_metadata", []) + + skills_locations = self._format_skills_locations() + skills_list = self._format_skills_list(skills_metadata) + + skills_section = self.system_prompt_template.format( + skills_locations=skills_locations, + skills_list=skills_list, + ) + + if request.system_prompt: + system_prompt = request.system_prompt + "\n\n" + skills_section + else: + system_prompt = skills_section + + return handler(request.override(system_prompt=system_prompt)) + + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelResponse: + state = cast("SkillsState", request.state) + skills_metadata = state.get("skills_metadata", []) + + skills_locations = self._format_skills_locations() + skills_list = self._format_skills_list(skills_metadata) + + skills_section = self.system_prompt_template.format( + skills_locations=skills_locations, + skills_list=skills_list, + ) + + if request.system_prompt: + system_prompt = request.system_prompt + "\n\n" + skills_section + else: + system_prompt = skills_section + + return await handler(request.override(system_prompt=system_prompt)) diff --git a/pyproject.toml b/pyproject.toml index 5685d8d..d16a6c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,10 +21,12 @@ dependencies = [ [dependency-groups] dev = [ + "docker>=7.1.0", "ipykernel>=7.1.0", "ipywidgets>=8.1.8", "langgraph-cli[inmem]>=0.4.11", "mypy>=1.19.1", + "pytest>=9.0.2", "ruff>=0.14.10", ] @@ -48,14 +50,19 @@ lint.select = [ "T201", "UP", ] -lint.ignore = ["UP006", "UP007", "UP035", "D417", "E501"] +lint.ignore = ["UP006", "UP007", "UP035", "D417", "E501", "D101", "D102", "D103", "D105", "D107"] [tool.ruff.lint.per-file-ignores] -"tests/*" = ["D", "UP"] +"tests/*" = ["D", "UP", "T201"] [tool.ruff.lint.pydocstyle] convention = "google" +[tool.pytest.ini_options] +markers = [ + "integration: 실제 외부 서비스(Docker, API 등)를 사용하는 통합 테스트", +] + [tool.mypy] python_version = "3.13" exclude = [".venv"] diff --git a/rust-research-agent/rig-deepagents/Cargo.toml b/rust-research-agent/rig-deepagents/Cargo.toml index a27c85c..2d1b152 100644 --- a/rust-research-agent/rig-deepagents/Cargo.toml +++ b/rust-research-agent/rig-deepagents/Cargo.toml @@ -10,6 +10,7 @@ default = [] checkpointer-sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"] checkpointer-redis = ["dep:redis"] checkpointer-postgres = ["dep:sqlx"] +tokenizer-tiktoken = ["dep:tiktoken-rs"] [dependencies] rig-core = { version = "0.27", features = ["derive"] } @@ -31,6 +32,7 @@ num_cpus = "1" # For default parallelism configuration humantime-serde = "1" # For Duration serialization in configs zstd = "0.13" # For checkpoint compression regex = "1" +tiktoken-rs = { version = "0.5", optional = true } # HTTP client for external API tools (Tavily, etc.) reqwest = { version = "0.12", features = ["json"] } diff --git a/rust-research-agent/rig-deepagents/src/backends/protocol.rs b/rust-research-agent/rig-deepagents/src/backends/protocol.rs index 43e7b85..121e50c 100644 --- a/rust-research-agent/rig-deepagents/src/backends/protocol.rs +++ b/rust-research-agent/rig-deepagents/src/backends/protocol.rs @@ -70,6 +70,11 @@ pub trait Backend: Send + Sync { /// Returns: 라인 번호 포함된 포맷 (cat -n 스타일) async fn read(&self, path: &str, offset: usize, limit: usize) -> Result; + async fn read_plain(&self, path: &str) -> Result { + let formatted = self.read(path, 0, 50_000).await?; + Ok(strip_cat_n(&formatted)) + } + /// 파일 쓰기 (새 파일 생성) /// Python: write(file_path: str, content: str) -> WriteResult async fn write(&self, path: &str, content: &str) -> Result; @@ -122,3 +127,11 @@ pub trait Backend: Send + Sync { /// 파일 삭제 async fn delete(&self, path: &str) -> Result<(), BackendError>; } + +fn strip_cat_n(formatted: &str) -> String { + formatted + .lines() + .map(|line| line.split_once('\t').map(|(_, s)| s).unwrap_or(line)) + .collect::>() + .join("\n") +} diff --git a/rust-research-agent/rig-deepagents/src/executor.rs b/rust-research-agent/rig-deepagents/src/executor.rs index 9b9d622..4de20ae 100644 --- a/rust-research-agent/rig-deepagents/src/executor.rs +++ b/rust-research-agent/rig-deepagents/src/executor.rs @@ -167,7 +167,7 @@ impl AgentExecutor { model_request = model_request.with_config(config.clone()); } - let before_control = self.middleware.before_model(&mut model_request, &state, &runtime).await + let before_control = self.middleware.before_model(&mut model_request, &mut state, &runtime).await .map_err(DeepAgentError::Middleware)?; // before_model 제어 흐름 처리 @@ -235,7 +235,22 @@ impl AgentExecutor { // 도구 호출 처리 if let Some(tool_calls) = &response.tool_calls { + let write_todos_count = tool_calls + .iter() + .filter(|call| call.name == "write_todos") + .count(); + let has_duplicate_write_todos = write_todos_count > 1; + for call in tool_calls { + if has_duplicate_write_todos && call.name == "write_todos" { + let result = ToolResult::new( + "Error: multiple write_todos calls in a single response are not allowed", + ); + let tool_message = Message::tool_with_status(&result.message, &call.id, "error"); + state.add_message(tool_message); + continue; + } + let result = self .execute_tool_call(call, &tools, &state, runtime.config()) .await; @@ -455,6 +470,54 @@ mod tests { assert_eq!(result.todos[0].content, "Test todo"); } + #[tokio::test] + async fn test_executor_rejects_duplicate_write_todos() { + let tool_calls = vec![ + ToolCall { + id: "call_1".to_string(), + name: "write_todos".to_string(), + arguments: serde_json::json!({"todos": []}), + }, + ToolCall { + id: "call_2".to_string(), + name: "write_todos".to_string(), + arguments: serde_json::json!({"todos": []}), + }, + ]; + + let responses = vec![ + Message::assistant_with_tool_calls("", tool_calls), + Message::assistant("Done."), + ]; + + let llm = Arc::new(MockLLM::new(responses)); + let backend = Arc::new(MemoryBackend::new()); + let middleware = MiddlewareStack::new(); + + let executor = AgentExecutor::new(llm, middleware, backend) + .with_tools(vec![Arc::new(crate::tools::WriteTodosTool)]); + + let initial_state = AgentState::with_messages(vec![ + Message::user("Update todos"), + ]); + + let result = executor.run(initial_state).await.unwrap(); + + assert!(result.todos.is_empty()); + + let tool_messages: Vec<&Message> = result + .messages + .iter() + .filter(|message| message.role == Role::Tool) + .collect(); + + assert_eq!(tool_messages.len(), 2); + for message in tool_messages { + assert_eq!(message.status.as_deref(), Some("error")); + assert!(message.content.contains("multiple write_todos")); + } + } + #[tokio::test] async fn test_executor_max_iterations() { // Create LLM that always returns tool calls diff --git a/rust-research-agent/rig-deepagents/src/lib.rs b/rust-research-agent/rig-deepagents/src/lib.rs index 37187af..650185b 100644 --- a/rust-research-agent/rig-deepagents/src/lib.rs +++ b/rust-research-agent/rig-deepagents/src/lib.rs @@ -41,6 +41,7 @@ pub mod skills; pub mod research; pub mod config; pub mod compat; +pub mod tokenization; mod tool_result_eviction; // Re-exports for convenience diff --git a/rust-research-agent/rig-deepagents/src/middleware/stack.rs b/rust-research-agent/rig-deepagents/src/middleware/stack.rs index 983e459..e3a3cea 100644 --- a/rust-research-agent/rig-deepagents/src/middleware/stack.rs +++ b/rust-research-agent/rig-deepagents/src/middleware/stack.rs @@ -110,7 +110,7 @@ impl MiddlewareStack { pub async fn before_model( &self, request: &mut ModelRequest, - state: &AgentState, + state: &mut AgentState, runtime: &ToolRuntime, ) -> Result { for middleware in &self.middlewares { diff --git a/rust-research-agent/rig-deepagents/src/middleware/summarization/mod.rs b/rust-research-agent/rig-deepagents/src/middleware/summarization/mod.rs index 76f0937..f61270a 100644 --- a/rust-research-agent/rig-deepagents/src/middleware/summarization/mod.rs +++ b/rust-research-agent/rig-deepagents/src/middleware/summarization/mod.rs @@ -53,9 +53,10 @@ use tracing::{debug, info, warn}; use crate::error::MiddlewareError; use crate::llm::LLMProvider; -use crate::middleware::traits::{AgentMiddleware, DynTool, StateUpdate}; +use crate::middleware::traits::{AgentMiddleware, DynTool, ModelControl, ModelRequest}; use crate::runtime::ToolRuntime; use crate::state::{AgentState, Message, Role}; +use crate::tokenization::{ApproxTokenCounter, TokenCounter}; /// Summarization Middleware for token budget management. /// @@ -67,6 +68,7 @@ pub struct SummarizationMiddleware { llm_provider: Arc, /// Configuration config: SummarizationConfig, + token_counter: Arc, } impl SummarizationMiddleware { @@ -77,7 +79,27 @@ impl SummarizationMiddleware { /// * `llm_provider` - LLM provider for generating summaries /// * `config` - Configuration for triggers, keep size, and prompts pub fn new(llm_provider: Arc, config: SummarizationConfig) -> Self { - Self { llm_provider, config } + let token_counter = Arc::new(ApproxTokenCounter::new( + config.chars_per_token, + config.overhead_per_message as usize, + )); + Self { + llm_provider, + config, + token_counter, + } + } + + pub fn with_token_counter( + llm_provider: Arc, + config: SummarizationConfig, + token_counter: Arc, + ) -> Self { + Self { + llm_provider, + config, + token_counter, + } } /// Create with default configuration. @@ -92,11 +114,7 @@ impl SummarizationMiddleware { /// Count tokens in the current messages. fn count_tokens(&self, messages: &[Message]) -> usize { - count_tokens_approximately( - messages, - self.config.chars_per_token, - self.config.overhead_per_message, - ) + self.token_counter.count_messages(messages) } /// Check if summarization should be triggered. @@ -152,11 +170,7 @@ impl SummarizationMiddleware { let mut count = 0; for msg in messages.iter().rev() { - let msg_tokens = count_tokens_approximately( - std::slice::from_ref(msg), - self.config.chars_per_token, - self.config.overhead_per_message, - ); + let msg_tokens = self.token_counter.count_message(msg); if total_tokens + msg_tokens > token_budget { break; @@ -180,9 +194,8 @@ impl SummarizationMiddleware { let mut cutoff = initial_cutoff; - // Advance past Tool messages (they should stay with their AI message) - while cutoff < messages.len() && messages[cutoff].role == Role::Tool { - cutoff += 1; + while cutoff > 0 && messages[cutoff].role == Role::Tool { + cutoff -= 1; } cutoff @@ -233,11 +246,7 @@ impl SummarizationMiddleware { // Take messages from the end (most recent first), respecting token budget for msg in messages.iter().rev() { - let msg_tokens = count_tokens_approximately( - std::slice::from_ref(msg), - self.config.chars_per_token, - self.config.overhead_per_message, - ); + let msg_tokens = self.token_counter.count_message(msg); if total_tokens + msg_tokens > max_tokens { break; @@ -297,11 +306,12 @@ impl AgentMiddleware for SummarizationMiddleware { prompt } - async fn after_agent( + async fn before_model( &self, + request: &mut ModelRequest, state: &mut AgentState, _runtime: &ToolRuntime, - ) -> Result, MiddlewareError> { + ) -> Result { let token_count = self.count_tokens(&state.messages); let message_count = state.messages.len(); @@ -314,7 +324,7 @@ impl AgentMiddleware for SummarizationMiddleware { // Check if we should summarize if !self.should_summarize(token_count, message_count) { - return Ok(None); + return Ok(ModelControl::Continue); } info!( @@ -328,7 +338,7 @@ impl AgentMiddleware for SummarizationMiddleware { if to_summarize.is_empty() { debug!("No messages to summarize"); - return Ok(None); + return Ok(ModelControl::Continue); } debug!( @@ -342,7 +352,7 @@ impl AgentMiddleware for SummarizationMiddleware { Ok(s) => s, Err(e) => { warn!(error = %e, "Failed to generate summary, keeping original messages"); - return Ok(None); + return Ok(ModelControl::Continue); } }; @@ -362,7 +372,10 @@ impl AgentMiddleware for SummarizationMiddleware { "Summarization complete" ); - Ok(Some(StateUpdate::SetMessages(new_messages))) + state.messages = new_messages.clone(); + request.messages = new_messages.clone(); + + Ok(ModelControl::ModifyRequest(request.clone())) } } @@ -379,6 +392,8 @@ impl std::fmt::Debug for SummarizationMiddleware { mod tests { use super::*; use crate::llm::{LLMConfig, LLMResponse}; + use crate::middleware::{ModelControl, ModelRequest}; + use crate::runtime::ToolRuntime; /// Mock LLM provider for testing struct MockProvider { @@ -449,10 +464,10 @@ mod tests { } #[test] - fn test_safe_cutoff_skips_tool_messages() { + fn test_safe_cutoff_moves_backward_for_tool_messages() { let provider = Arc::new(MockProvider::new("Summary")); let config = SummarizationConfig::builder() - .keep(KeepSize::Messages(2)) + .keep(KeepSize::Messages(3)) .build(); let middleware = SummarizationMiddleware::new(provider, config); @@ -465,26 +480,16 @@ mod tests { arguments: serde_json::json!({"path": "/test"}), } ]), - Message::tool("File contents", "call_1"), // Tool result + Message::tool("File contents", "call_1"), Message::assistant("Here's what I found"), Message::user("Thanks"), ]; let (to_summarize, preserved) = middleware.partition_messages(&messages); - // Should not split in the middle of AI/Tool pair - // If cutoff lands on Tool message, it advances past it - for msg in &to_summarize { - if msg.role == Role::Tool { - // Check there's an AI message before it in to_summarize - let has_ai_before = to_summarize.iter().any(|m| m.role == Role::Assistant); - assert!(has_ai_before, "Tool message should have AI message before it"); - } - } - - // Preserved messages should include the recent context - assert!(!preserved.is_empty(), "Should have preserved messages"); - // Combined should equal original + assert_eq!(to_summarize.len(), 1); + assert_eq!(preserved.len(), 4); + assert!(preserved[0].tool_calls.is_some()); assert_eq!( to_summarize.len() + preserved.len(), messages.len(), @@ -492,6 +497,39 @@ mod tests { ); } + #[tokio::test] + async fn test_before_model_summarizes_request_messages() { + let provider = Arc::new(MockProvider::new("Summary text")); + let config = SummarizationConfig::builder() + .trigger(TriggerCondition::Messages(2)) + .keep(KeepSize::Messages(1)) + .build(); + let middleware = SummarizationMiddleware::new(provider, config); + + let mut state = AgentState::with_messages(vec![ + Message::user("First"), + Message::assistant("Second"), + Message::user("Third"), + ]); + + let mut request = ModelRequest::new(state.messages.clone(), vec![]); + let backend = Arc::new(crate::backends::MemoryBackend::new()); + let runtime = ToolRuntime::new(state.clone(), backend); + + let control = middleware + .before_model(&mut request, &mut state, &runtime) + .await + .unwrap(); + + assert!(matches!(control, ModelControl::ModifyRequest(_))); + assert_eq!(request.messages.len(), state.messages.len()); + assert_eq!(request.messages[0].role, state.messages[0].role); + assert_eq!(request.messages[0].content, state.messages[0].content); + assert_eq!(request.messages[1].content, state.messages[1].content); + assert_eq!(state.messages.len(), 2); + assert!(state.messages[0].content.contains("Summary text")); + } + #[test] fn test_format_messages() { let provider = Arc::new(MockProvider::new("Summary")); diff --git a/rust-research-agent/rig-deepagents/src/middleware/todo_list.rs b/rust-research-agent/rig-deepagents/src/middleware/todo_list.rs index 420ebcf..522f661 100644 --- a/rust-research-agent/rig-deepagents/src/middleware/todo_list.rs +++ b/rust-research-agent/rig-deepagents/src/middleware/todo_list.rs @@ -7,7 +7,7 @@ use std::sync::Arc; use async_trait::async_trait; use crate::middleware::{AgentMiddleware, DynTool}; -use crate::tools::WriteTodosTool; +use crate::tools::{ReadTodosTool, WriteTodosTool}; /// Default system prompt for todo planning. pub const TODO_SYSTEM_PROMPT: &str = "## Planning with `write_todos`\n\ @@ -17,7 +17,7 @@ Update the list as you work: mark items in_progress before starting and complete /// Middleware that injects the write_todos tool and planning guidance. pub struct TodoListMiddleware { - tool: DynTool, + tools: Vec, system_prompt: String, } @@ -30,7 +30,7 @@ impl TodoListMiddleware { /// Create a TodoListMiddleware with a custom system prompt. pub fn with_system_prompt(prompt: impl Into) -> Self { Self { - tool: Arc::new(WriteTodosTool), + tools: vec![Arc::new(ReadTodosTool), Arc::new(WriteTodosTool)], system_prompt: prompt.into(), } } @@ -43,7 +43,7 @@ impl AgentMiddleware for TodoListMiddleware { } fn tools(&self) -> Vec { - vec![self.tool.clone()] + self.tools.clone() } fn modify_system_prompt(&self, prompt: String) -> String { @@ -63,8 +63,9 @@ mod tests { fn test_todo_list_injects_tool() { let middleware = TodoListMiddleware::new(); let tools = middleware.tools(); - assert_eq!(tools.len(), 1); - assert_eq!(tools[0].definition().name, "write_todos"); + assert_eq!(tools.len(), 2); + assert_eq!(tools[0].definition().name, "read_todos"); + assert_eq!(tools[1].definition().name, "write_todos"); } #[test] diff --git a/rust-research-agent/rig-deepagents/src/middleware/traits.rs b/rust-research-agent/rig-deepagents/src/middleware/traits.rs index 144cb3f..01f5d97 100644 --- a/rust-research-agent/rig-deepagents/src/middleware/traits.rs +++ b/rust-research-agent/rig-deepagents/src/middleware/traits.rs @@ -466,7 +466,7 @@ pub trait AgentMiddleware: Send + Sync { async fn before_model( &self, _request: &mut ModelRequest, - _state: &AgentState, + _state: &mut AgentState, _runtime: &ToolRuntime, ) -> Result { Ok(ModelControl::Continue) diff --git a/rust-research-agent/rig-deepagents/src/skills/loader.rs b/rust-research-agent/rig-deepagents/src/skills/loader.rs index 350de7e..a6bf4d3 100644 --- a/rust-research-agent/rig-deepagents/src/skills/loader.rs +++ b/rust-research-agent/rig-deepagents/src/skills/loader.rs @@ -15,32 +15,40 @@ use tokio::sync::RwLock; use tracing::{debug, warn}; use super::types::{SkillContent, SkillMetadata, SkillSource}; +use crate::backends::Backend; use crate::error::MiddlewareError; -/// Type alias for metadata cache entry (metadata, file path, source) type MetadataCacheEntry = (SkillMetadata, PathBuf, SkillSource); -/// Skill loader with caching support +pub enum SkillStorage { + Filesystem { + user_dir: Option, + project_dir: Option, + }, + Backend { + backend: Arc, + sources: Vec, + }, +} + pub struct SkillLoader { - /// User skills directory (e.g., ~/.claude/skills) - user_dir: Option, - - /// Project skills directory (e.g., ./skills) - project_dir: Option, - - /// Cached metadata (loaded eagerly on init) + storage: SkillStorage, metadata_cache: Arc>>, - - /// Cached full content (loaded lazily on demand) content_cache: Arc>>, } impl SkillLoader { - /// Create a new skill loader with specified directories pub fn new(user_dir: Option, project_dir: Option) -> Self { Self { - user_dir, - project_dir, + storage: SkillStorage::Filesystem { user_dir, project_dir }, + metadata_cache: Arc::new(RwLock::new(HashMap::new())), + content_cache: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub fn from_backend(backend: Arc, sources: Vec) -> Self { + Self { + storage: SkillStorage::Backend { backend, sources }, metadata_cache: Arc::new(RwLock::new(HashMap::new())), content_cache: Arc::new(RwLock::new(HashMap::new())), } @@ -70,18 +78,24 @@ impl SkillLoader { let mut cache = self.metadata_cache.write().await; cache.clear(); - // Scan user skills first (lower priority) - if let Some(user_dir) = &self.user_dir { - if user_dir.exists() { - self.scan_directory(user_dir, SkillSource::User, &mut cache) - .await?; - } - } + match &self.storage { + SkillStorage::Filesystem { user_dir, project_dir } => { + if let Some(user_dir) = user_dir { + if user_dir.exists() { + self.scan_directory(user_dir, SkillSource::User, &mut cache) + .await?; + } + } - // Scan project skills (higher priority, can override user skills) - if let Some(project_dir) = &self.project_dir { - if project_dir.exists() { - self.scan_directory(project_dir, SkillSource::Project, &mut cache) + if let Some(project_dir) = project_dir { + if project_dir.exists() { + self.scan_directory(project_dir, SkillSource::Project, &mut cache) + .await?; + } + } + } + SkillStorage::Backend { backend, sources } => { + self.scan_backend_sources(backend, sources, &mut cache) .await?; } } @@ -136,6 +150,55 @@ impl SkillLoader { Ok(()) } + async fn scan_backend_sources( + &self, + backend: &Arc, + sources: &[String], + cache: &mut HashMap, + ) -> Result<(), MiddlewareError> { + for source in sources { + let entries = match backend.ls(source).await { + Ok(entries) => entries, + Err(e) => { + warn!("Failed to list backend source {}: {}", source, e); + continue; + } + }; + + for entry in entries { + if !entry.is_dir { + continue; + } + + let skill_file = format!("{}/SKILL.md", entry.path.trim_end_matches('/')); + match backend.read_plain(&skill_file).await { + Ok(content) => match parse_frontmatter(&content) { + Ok(skill_meta) => { + debug!( + "Loaded skill metadata: {} from {} ({})", + skill_meta.name, + skill_file, + SkillSource::Backend.as_str() + ); + cache.insert( + skill_meta.name.clone(), + (skill_meta, PathBuf::from(&skill_file), SkillSource::Backend), + ); + } + Err(e) => { + warn!("Failed to parse skill {}: {}", skill_file, e); + } + }, + Err(e) => { + warn!("Failed to read skill {}: {}", skill_file, e); + } + } + } + } + + Ok(()) + } + /// Parse only metadata from YAML frontmatter (fast) async fn parse_metadata(&self, path: &Path) -> Result { let content = tokio::fs::read_to_string(path) @@ -191,10 +254,18 @@ impl SkillLoader { } }; - // Load and parse full content - let raw_content = tokio::fs::read_to_string(&path) - .await - .map_err(|e| MiddlewareError::ToolExecution(format!("Failed to read skill: {}", e)))?; + let raw_content = match &self.storage { + SkillStorage::Filesystem { .. } => tokio::fs::read_to_string(&path) + .await + .map_err(|e| MiddlewareError::ToolExecution(format!("Failed to read skill: {}", e)))?, + SkillStorage::Backend { backend, .. } => { + let path_str = path.to_string_lossy(); + backend + .read_plain(&path_str) + .await + .map_err(|e| MiddlewareError::ToolExecution(format!("Failed to read skill: {}", e)))? + } + }; let body = parse_body(&raw_content); let content = SkillContent::new(metadata, body, path.to_string_lossy().to_string()); @@ -304,6 +375,7 @@ fn parse_body(content: &str) -> String { #[cfg(test)] mod tests { use super::*; + use crate::backends::{Backend, MemoryBackend}; #[test] fn test_parse_frontmatter_valid() { @@ -467,4 +539,69 @@ This is the test skill body. let result = loader.load_skill("nonexistent").await; assert!(result.is_err()); } + + #[tokio::test] + async fn test_backend_loader_layering() { + let backend: Arc = Arc::new(MemoryBackend::new()); + + backend + .write( + "/source-a/shared/SKILL.md", + r#"--- +name: shared +description: First description +--- +# Shared Skill + +First body. +"#, + ) + .await + .unwrap(); + + backend + .write( + "/source-b/shared/SKILL.md", + r#"--- +name: shared +description: Second description +--- +# Shared Skill + +Second body. +"#, + ) + .await + .unwrap(); + + backend + .write( + "/source-b/unique/SKILL.md", + r#"--- +name: unique +description: Unique description +--- +# Unique Skill + +Unique body. +"#, + ) + .await + .unwrap(); + + let loader = SkillLoader::from_backend( + Arc::clone(&backend), + vec!["/source-a".to_string(), "/source-b".to_string()], + ); + loader.initialize().await.unwrap(); + + let metadata = loader.get_metadata("shared").await.unwrap(); + assert_eq!(metadata.description, "Second description"); + + let content = loader.load_skill("shared").await.unwrap(); + assert!(content.body.contains("Second body")); + + let unique = loader.get_metadata("unique").await.unwrap(); + assert_eq!(unique.description, "Unique description"); + } } diff --git a/rust-research-agent/rig-deepagents/src/skills/middleware.rs b/rust-research-agent/rig-deepagents/src/skills/middleware.rs index 2b5b525..0a3d952 100644 --- a/rust-research-agent/rig-deepagents/src/skills/middleware.rs +++ b/rust-research-agent/rig-deepagents/src/skills/middleware.rs @@ -11,8 +11,9 @@ use tokio::sync::RwLock; use super::loader::SkillLoader; use super::types::{SkillMetadata, SkillSource}; use crate::error::MiddlewareError; -use crate::middleware::{AgentMiddleware, DynTool, Tool, ToolDefinition, ToolResult}; +use crate::middleware::{AgentMiddleware, DynTool, Tool, ToolDefinition, ToolResult, StateUpdate}; use crate::runtime::ToolRuntime; +use crate::state::AgentState; /// Skills middleware for progressive skill disclosure /// @@ -126,6 +127,16 @@ impl AgentMiddleware for SkillsMiddleware { None => prompt, } } + + async fn before_agent( + &self, + _state: &mut AgentState, + _runtime: &ToolRuntime, + ) -> Result, MiddlewareError> { + self.loader.initialize().await?; + self.refresh_cache().await; + Ok(None) + } } /// Tool for loading skill content on-demand diff --git a/rust-research-agent/rig-deepagents/src/skills/types.rs b/rust-research-agent/rig-deepagents/src/skills/types.rs index d276bb4..2897c44 100644 --- a/rust-research-agent/rig-deepagents/src/skills/types.rs +++ b/rust-research-agent/rig-deepagents/src/skills/types.rs @@ -93,6 +93,7 @@ pub enum SkillSource { User, /// Project-level skills (PROJECT_ROOT/skills/) Project, + Backend, } impl SkillSource { @@ -101,6 +102,7 @@ impl SkillSource { match self { Self::User => "user", Self::Project => "project", + Self::Backend => "backend", } } } @@ -186,5 +188,6 @@ description: Minimal skill fn test_skill_source() { assert_eq!(SkillSource::User.as_str(), "user"); assert_eq!(SkillSource::Project.as_str(), "project"); + assert_eq!(SkillSource::Backend.as_str(), "backend"); } } diff --git a/rust-research-agent/rig-deepagents/src/state.rs b/rust-research-agent/rig-deepagents/src/state.rs index 5fde8cd..a1f4789 100644 --- a/rust-research-agent/rig-deepagents/src/state.rs +++ b/rust-research-agent/rig-deepagents/src/state.rs @@ -105,15 +105,29 @@ pub struct Message { pub tool_call_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option, } impl Message { pub fn user(content: &str) -> Self { - Self { role: Role::User, content: content.to_string(), tool_call_id: None, tool_calls: None } + Self { + role: Role::User, + content: content.to_string(), + tool_call_id: None, + tool_calls: None, + status: None, + } } pub fn assistant(content: &str) -> Self { - Self { role: Role::Assistant, content: content.to_string(), tool_call_id: None, tool_calls: None } + Self { + role: Role::Assistant, + content: content.to_string(), + tool_call_id: None, + tool_calls: None, + status: None, + } } pub fn assistant_with_tool_calls(content: &str, tool_calls: Vec) -> Self { @@ -121,12 +135,19 @@ impl Message { role: Role::Assistant, content: content.to_string(), tool_call_id: None, - tool_calls: Some(tool_calls) + tool_calls: Some(tool_calls), + status: None, } } pub fn system(content: &str) -> Self { - Self { role: Role::System, content: content.to_string(), tool_call_id: None, tool_calls: None } + Self { + role: Role::System, + content: content.to_string(), + tool_call_id: None, + tool_calls: None, + status: None, + } } pub fn tool(content: &str, tool_call_id: &str) -> Self { @@ -135,6 +156,17 @@ impl Message { content: content.to_string(), tool_call_id: Some(tool_call_id.to_string()), tool_calls: None, + status: None, + } + } + + pub fn tool_with_status(content: &str, tool_call_id: &str, status: &str) -> Self { + Self { + role: Role::Tool, + content: content.to_string(), + tool_call_id: Some(tool_call_id.to_string()), + tool_calls: None, + status: Some(status.to_string()), } } diff --git a/rust-research-agent/rig-deepagents/src/tokenization/mod.rs b/rust-research-agent/rig-deepagents/src/tokenization/mod.rs new file mode 100644 index 0000000..dfcfd7b --- /dev/null +++ b/rust-research-agent/rig-deepagents/src/tokenization/mod.rs @@ -0,0 +1,139 @@ +use crate::middleware::summarization::token_counter::{ + count_tokens_approximately, DEFAULT_CHARS_PER_TOKEN, DEFAULT_OVERHEAD_PER_MESSAGE, +}; +use crate::state::Message; +#[cfg(feature = "tokenizer-tiktoken")] +use crate::state::Role; + +pub trait TokenCounter: Send + Sync { + fn count_text(&self, text: &str) -> usize; + fn count_message(&self, message: &Message) -> usize; + fn count_messages(&self, messages: &[Message]) -> usize { + messages.iter().map(|msg| self.count_message(msg)).sum() + } +} + +#[derive(Debug, Clone)] +pub struct ApproxTokenCounter { + pub chars_per_token: f32, + pub overhead_per_message: usize, +} + +impl ApproxTokenCounter { + pub fn new(chars_per_token: f32, overhead_per_message: usize) -> Self { + Self { + chars_per_token, + overhead_per_message, + } + } +} + +impl Default for ApproxTokenCounter { + fn default() -> Self { + Self { + chars_per_token: DEFAULT_CHARS_PER_TOKEN, + overhead_per_message: DEFAULT_OVERHEAD_PER_MESSAGE as usize, + } + } +} + +impl TokenCounter for ApproxTokenCounter { + fn count_text(&self, text: &str) -> usize { + (text.len() as f32 / self.chars_per_token).ceil() as usize + } + + fn count_message(&self, message: &Message) -> usize { + count_tokens_approximately( + std::slice::from_ref(message), + self.chars_per_token, + self.overhead_per_message as f32, + ) + } +} + +#[cfg(feature = "tokenizer-tiktoken")] +#[derive(Debug, Clone)] +pub struct TiktokenTokenCounter { + encoder: tiktoken_rs::CoreBPE, +} + +#[cfg(feature = "tokenizer-tiktoken")] +impl TiktokenTokenCounter { + pub fn new(encoder: tiktoken_rs::CoreBPE) -> Self { + Self { encoder } + } + + pub fn cl100k_base() -> Result { + Ok(Self { + encoder: tiktoken_rs::cl100k_base()?, + }) + } +} + +#[cfg(feature = "tokenizer-tiktoken")] +impl TokenCounter for TiktokenTokenCounter { + fn count_text(&self, text: &str) -> usize { + self.encoder.encode_with_special_tokens(text).len() + } + + fn count_message(&self, message: &Message) -> usize { + let message_text = build_message_text(message); + self.count_text(&message_text) + } +} + +#[cfg(feature = "tokenizer-tiktoken")] +fn role_name(role: &Role) -> &'static str { + match role { + Role::User => "user", + Role::Assistant => "assistant", + Role::System => "system", + Role::Tool => "tool", + } +} + +#[cfg(feature = "tokenizer-tiktoken")] +fn build_message_text(message: &Message) -> String { + let mut text = String::new(); + text.push_str(&message.content); + text.push_str(role_name(&message.role)); + + if let Some(ref tool_call_id) = message.tool_call_id { + text.push_str(tool_call_id); + } + + if let Some(ref tool_calls) = message.tool_calls { + for tc in tool_calls { + text.push_str(&tc.id); + text.push_str(&tc.name); + if let Ok(args_str) = serde_json::to_string(&tc.arguments) { + text.push_str(&args_str); + } + } + } + + text +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::state::Message; + + #[test] + fn test_approx_counter_counts_non_zero() { + let counter = ApproxTokenCounter::new(4.0, 3); + let messages = vec![Message::user("Hello there")]; + assert!(counter.count_messages(&messages) > 0); + assert!(counter.count_text("Hello there") > 0); + } + + #[cfg(feature = "tokenizer-tiktoken")] + #[test] + fn test_tiktoken_counter_counts_non_zero() { + let counter = TiktokenTokenCounter::cl100k_base().unwrap(); + let messages = vec![Message::assistant("Hello there")]; + assert!(counter.count_messages(&messages) > 0); + assert!(counter.count_text("Hello there") > 0); + } +} diff --git a/rust-research-agent/rig-deepagents/src/tools/mod.rs b/rust-research-agent/rig-deepagents/src/tools/mod.rs index d01e7a3..749fab0 100644 --- a/rust-research-agent/rig-deepagents/src/tools/mod.rs +++ b/rust-research-agent/rig-deepagents/src/tools/mod.rs @@ -17,6 +17,7 @@ mod edit_file; mod ls; mod glob; mod grep; +mod read_todos; mod write_todos; mod task; @@ -30,6 +31,7 @@ pub use edit_file::EditFileTool; pub use ls::LsTool; pub use glob::GlobTool; pub use grep::GrepTool; +pub use read_todos::ReadTodosTool; pub use write_todos::WriteTodosTool; pub use task::TaskTool; @@ -49,6 +51,7 @@ pub fn default_tools() -> Vec { Arc::new(LsTool), Arc::new(GlobTool), Arc::new(GrepTool), + Arc::new(ReadTodosTool), Arc::new(WriteTodosTool), ] } diff --git a/rust-research-agent/rig-deepagents/src/tools/read_todos.rs b/rust-research-agent/rig-deepagents/src/tools/read_todos.rs new file mode 100644 index 0000000..e7547fd --- /dev/null +++ b/rust-research-agent/rig-deepagents/src/tools/read_todos.rs @@ -0,0 +1,61 @@ +use async_trait::async_trait; + +use crate::error::MiddlewareError; +use crate::middleware::{Tool, ToolDefinition, ToolResult}; +use crate::runtime::ToolRuntime; + +pub struct ReadTodosTool; + +#[async_trait] +impl Tool for ReadTodosTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "read_todos".to_string(), + description: "Read the current todo list state.".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": {}, + }), + } + } + + async fn execute( + &self, + _args: serde_json::Value, + runtime: &ToolRuntime, + ) -> Result { + let todos = &runtime.state().todos; + let json = serde_json::to_string(todos) + .map_err(|e| MiddlewareError::ToolExecution(format!("Failed to serialize todos: {e}")))?; + Ok(ToolResult::new(json)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::backends::MemoryBackend; + use crate::state::{AgentState, Todo, TodoStatus}; + use std::sync::Arc; + + #[tokio::test] + async fn test_read_todos_returns_state() { + let tool = ReadTodosTool; + let backend = Arc::new(MemoryBackend::new()); + let mut state = AgentState::new(); + state.todos = vec![ + Todo::with_status("First", TodoStatus::Pending), + Todo::with_status("Second", TodoStatus::Completed), + ]; + let runtime = ToolRuntime::new(state, backend); + + let result = tool.execute(serde_json::json!({}), &runtime).await.unwrap(); + let todos: Vec = serde_json::from_str(&result.message).unwrap(); + + assert_eq!(todos.len(), 2); + assert_eq!(todos[0].content, "First"); + assert_eq!(todos[0].status, TodoStatus::Pending); + assert_eq!(todos[1].content, "Second"); + assert_eq!(todos[1].status, TodoStatus::Completed); + } +} diff --git a/rust-research-agent/rig-deepagents/src/workflow/vertices/agent.rs b/rust-research-agent/rig-deepagents/src/workflow/vertices/agent.rs index 73c4c3d..bc9e8bb 100644 --- a/rust-research-agent/rig-deepagents/src/workflow/vertices/agent.rs +++ b/rust-research-agent/rig-deepagents/src/workflow/vertices/agent.rs @@ -228,6 +228,7 @@ impl Vertex for AgentVe content: self.config.system_prompt.clone(), tool_calls: None, tool_call_id: None, + status: None, }]; // Add any incoming workflow messages as user messages @@ -238,6 +239,7 @@ impl Vertex for AgentVe content: value.to_string(), tool_calls: None, tool_call_id: None, + status: None, }); } } @@ -249,6 +251,7 @@ impl Vertex for AgentVe content: "Begin processing.".to_string(), tool_calls: None, tool_call_id: None, + status: None, }); } @@ -354,6 +357,7 @@ mod tests { content: content.into(), tool_calls: None, tool_call_id: None, + status: None, }; self.responses.lock().unwrap().push(message); self @@ -369,6 +373,7 @@ mod tests { arguments: serde_json::json!({}), }]), tool_call_id: None, + status: None, }; self.responses.lock().unwrap().push(message); self @@ -563,6 +568,7 @@ mod tests { content: "Done".to_string(), tool_calls: Some(vec![]), tool_call_id: None, + status: None, }; // State with non-matching phase diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/backends/__init__.py b/tests/backends/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/backends/conftest.py b/tests/backends/conftest.py new file mode 100644 index 0000000..e284ef3 --- /dev/null +++ b/tests/backends/conftest.py @@ -0,0 +1,14 @@ +import sys +from pathlib import Path + +try: + import docker # noqa: F401 +except ImportError: + pass + +_REPO_ROOT = Path(__file__).resolve().parents[2] +_VENDORED_DEEPAGENTS = _REPO_ROOT / "deepagents_sourcecode" / "libs" / "deepagents" +if _VENDORED_DEEPAGENTS.exists() and str(_VENDORED_DEEPAGENTS) not in sys.path: + sys.path.insert(0, str(_VENDORED_DEEPAGENTS)) +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) diff --git a/tests/backends/test_docker_sandbox_integration.py b/tests/backends/test_docker_sandbox_integration.py new file mode 100644 index 0000000..f6286ac --- /dev/null +++ b/tests/backends/test_docker_sandbox_integration.py @@ -0,0 +1,496 @@ +"""DockerSandboxBackend/DockerSandboxSession 실환경 통합 테스트. + +주의: 이 테스트는 실제 Docker 데몬과 컨테이너 실행이 필요합니다. +""" + +from __future__ import annotations + +import time +from collections.abc import Iterator + +import pytest +from langchain_core.messages import ToolMessage + +from context_engineering_research_agent.backends.docker_sandbox import ( + DockerSandboxBackend, +) +from context_engineering_research_agent.backends.docker_session import ( + DockerSandboxSession, +) +from context_engineering_research_agent.context_strategies.offloading import ( + ContextOffloadingStrategy, + OffloadingConfig, +) + + +def _docker_available() -> bool: + """Docker 사용 가능 여부를 확인합니다.""" + try: + import docker + + client = docker.from_env() + # ping은 Docker 데몬 연결 여부를 가장 빠르게 확인합니다. + client.ping() + return True + except Exception as e: + print(f"DEBUG: Docker not available: {type(e).__name__}: {e}") + return False + + +DOCKER_AVAILABLE = _docker_available() + +pytestmark = [ + pytest.mark.integration, + pytest.mark.skipif( + not DOCKER_AVAILABLE, + reason="Docker 데몬 또는 python docker SDK를 사용할 수 없습니다.", + ), +] + + +@pytest.fixture(scope="module") +def docker_backend() -> Iterator[DockerSandboxBackend]: + """테스트용 Docker 샌드박스 백엔드를 제공합니다.""" + session = DockerSandboxSession() + try: + import asyncio + + asyncio.run(session.start()) + backend = session.get_backend() + yield backend + finally: + import asyncio + + asyncio.run(session.stop()) + + +@pytest.fixture(autouse=True) +def _reset_workspace(docker_backend: DockerSandboxBackend) -> None: + """각 테스트마다 워크스페이스 내 테스트 디렉토리를 초기화합니다.""" + rm_result = docker_backend.execute("rm -rf test_docker_sandbox") + assert rm_result.exit_code == 0, f"rm failed: {rm_result.output}" + mkdir_result = docker_backend.execute("mkdir -p test_docker_sandbox") + assert mkdir_result.exit_code == 0, f"mkdir failed: {mkdir_result.output}" + + +# --------------------------------------------------------------------------- +# 1) Code Execution Tests +# --------------------------------------------------------------------------- + + +def test_execute_basic_commands(docker_backend: DockerSandboxBackend) -> None: + """기본 명령( echo/ls/pwd )이 컨테이너 내부에서 정상 동작하는지 확인합니다.""" + echo = docker_backend.execute("echo 'hello'") + assert echo.exit_code == 0 + assert echo.output.strip() == "hello" + + pwd = docker_backend.execute("pwd") + assert pwd.exit_code == 0 + # DockerSandboxBackend는 workdir=/workspace로 실행합니다. + assert pwd.output.strip() == "/workspace" + + docker_backend.execute("echo 'x' > test_docker_sandbox/file.txt") + ls = docker_backend.execute("ls -la test_docker_sandbox") + assert ls.exit_code == 0 + assert "file.txt" in ls.output + + +def test_execute_python_and_exit_codes(docker_backend: DockerSandboxBackend) -> None: + """파이썬 실행 및 exit code(성공/실패)가 정확히 전달되는지 확인합니다.""" + py = docker_backend.execute('python3 -c "print(2 + 2)"') + assert py.exit_code == 0 + assert py.output.strip() == "4" + + fail = docker_backend.execute('python3 -c "import sys; sys.exit(42)"') + assert fail.exit_code == 42 + + +def test_execute_truncates_large_output(docker_backend: DockerSandboxBackend) -> None: + """대용량 출력이 100,000자 기준으로 잘리는지(truncated) 확인합니다.""" + # 110k 이상 출력 생성 + big = docker_backend.execute("python3 -c \"print('x' * 110500)\"") + assert big.exit_code == 0 + assert big.truncated is True + assert "[출력이 잘렸습니다" in big.output + assert len(big.output) <= 100000 + 200 # 안내 문구 포함 여유 + + +def test_execute_timeout_handling_via_alarm( + docker_backend: DockerSandboxBackend, +) -> None: + """장시간 작업이 자체 타임아웃(알람)으로 빠르게 종료되는지 확인합니다.""" + start = time.monotonic() + res = docker_backend.execute( + 'python3 -c "\n' + "import signal, time\n" + "def _handler(signum, frame):\n" + " raise TimeoutError('alarm')\n" + "signal.signal(signal.SIGALRM, _handler)\n" + "signal.alarm(1)\n" + "time.sleep(10)\n" + '"' + ) + elapsed = time.monotonic() - start + + assert elapsed < 5, f"예상보다 오래 걸렸습니다: {elapsed:.2f}s" + assert res.exit_code != 0 + assert "TimeoutError" in res.output or "alarm" in res.output + + +# --------------------------------------------------------------------------- +# 2) File Operations Tests +# --------------------------------------------------------------------------- + + +def test_upload_files_single_and_nested(docker_backend: DockerSandboxBackend) -> None: + """upload_files가 단일 파일 및 중첩 디렉토리 업로드를 지원하는지 확인합니다.""" + files = [ + ("test_docker_sandbox/one.txt", b"one"), + ("test_docker_sandbox/nested/dir/two.bin", b"\x00\x01\x02"), + ] + responses = docker_backend.upload_files(files) + + assert [r.path for r in responses] == [p for p, _ in files] + assert all(r.error is None for r in responses) + + # 컨테이너 내 파일 존재/내용 확인 + cat_one = docker_backend.execute("cat test_docker_sandbox/one.txt") + assert cat_one.exit_code == 0 + assert cat_one.output.strip() == "one" + + # 이진 파일은 다운로드로 검증 + dl = docker_backend.download_files(["test_docker_sandbox/nested/dir/two.bin"]) + assert dl[0].error is None + assert dl[0].content == b"\x00\x01\x02" + + +def test_upload_download_multiple_roundtrip( + docker_backend: DockerSandboxBackend, +) -> None: + """여러 파일을 업로드한 뒤 다운로드하여 내용이 동일한지 확인합니다.""" + files = [ + ("test_docker_sandbox/a.txt", b"A"), + ("test_docker_sandbox/b.txt", b"B"), + ("test_docker_sandbox/sub/c.txt", b"C"), + ] + + up = docker_backend.upload_files(files) + assert len(up) == 3 + assert all(r.error is None for r in up) + + paths = [p for p, _ in files] + dl = docker_backend.download_files(paths) + assert [r.path for r in dl] == paths + assert all(r.error is None for r in dl) + + got = {r.path: r.content for r in dl} + expected = {p: c for p, c in files} + assert got == expected + + +def test_download_files_nonexistent_and_directory( + docker_backend: DockerSandboxBackend, +) -> None: + """download_files가 없는 파일/디렉토리 대상에서 올바른 에러를 반환하는지 확인합니다.""" + docker_backend.execute("mkdir -p test_docker_sandbox/dir_only") + + responses = docker_backend.download_files( + [ + "test_docker_sandbox/does_not_exist.txt", + "test_docker_sandbox/dir_only", + ] + ) + + assert responses[0].error == "file_not_found" + assert responses[0].content is None + + assert responses[1].error == "is_directory" + assert responses[1].content is None + + +# --------------------------------------------------------------------------- +# 3) Context Offloading Tests (WITHOUT Agent) +# --------------------------------------------------------------------------- + + +def test_context_offloading_writes_large_tool_result_to_docker_filesystem( + docker_backend: DockerSandboxBackend, +) -> None: + """ContextOffloadingStrategy가 대용량 결과를 Docker 파일시스템에 저장하는지 확인합니다.""" + + # backend_factory는 runtime을 무시하고 현재 DockerSandboxBackend를 반환합니다. + strategy = ContextOffloadingStrategy( + config=OffloadingConfig(token_limit_before_evict=10, chars_per_token=1), + backend_factory=lambda _runtime: docker_backend, + ) + + content_lines = [f"line_{i:03d}: {'x' * 20}" for i in range(50)] + large_content = "\n".join(content_lines) + + tool_result = ToolMessage( + content=large_content, + tool_call_id="call/with:special@chars!", + ) + + class MinimalRuntime: + """ToolRuntime 대체용 최소 객체입니다(backend_factory 호출을 위해서만 사용).""" + + state: dict = {} + config: dict = {} + + processed, offload = strategy.process_tool_result(tool_result, MinimalRuntime()) # type: ignore[arg-type] + + assert offload.was_offloaded is True + assert offload.file_path is not None + assert offload.file_path.startswith("/large_tool_results/") + + # 반환 메시지는 원문 전체가 아니라 경로 참조를 포함해야 합니다. + if hasattr(processed, "content"): + replacement_text = processed.content # ToolMessage + else: + # Command(update={messages:[ToolMessage...]}) 형태 + update = processed.update # type: ignore[attr-defined] + replacement_text = update["messages"][0].content + + assert offload.file_path in replacement_text + assert "read_file" in replacement_text + assert len(replacement_text) < len(large_content) + + # 실제 파일이 컨테이너에 저장되었는지 다운로드로 검증합니다. + downloaded = docker_backend.download_files([offload.file_path]) + assert downloaded[0].error is None + assert downloaded[0].content is not None + assert downloaded[0].content.decode("utf-8") == large_content + + +# --------------------------------------------------------------------------- +# 4) Session Lifecycle Tests +# --------------------------------------------------------------------------- + + +def test_session_initializes_workspace_dirs() -> None: + """세션 시작 시 /workspace/_meta 및 /workspace/shared 디렉토리가 생성되는지 확인합니다.""" + session = DockerSandboxSession() + try: + import asyncio + + asyncio.run(session.start()) + backend = session.get_backend() + + meta = backend.execute("test -d /workspace/_meta && echo ok") + shared = backend.execute("test -d /workspace/shared && echo ok") + assert meta.exit_code == 0 + assert shared.exit_code == 0 + assert meta.output.strip() == "ok" + assert shared.output.strip() == "ok" + finally: + import asyncio + + asyncio.run(session.stop()) + + +def test_multiple_backends_share_same_container_workspace() -> None: + """동일 컨테이너 ID를 사용하는 여러 백엔드가 파일을 공유하는지 확인합니다.""" + try: + import docker + except Exception: + pytest.skip("python docker SDK가 필요합니다") + + session = DockerSandboxSession() + try: + import asyncio + + asyncio.run(session.start()) + backend1 = session.get_backend() + backend2 = DockerSandboxBackend( + container_id=backend1.id, + docker_client=docker.from_env(), + ) + + backend1.execute("mkdir -p test_docker_sandbox") + backend1.write("/workspace/test_docker_sandbox/shared.txt", "shared") + + read_back = backend2.execute("cat /workspace/test_docker_sandbox/shared.txt") + assert read_back.exit_code == 0 + assert read_back.output.strip() == "shared" + finally: + import asyncio + + asyncio.run(session.stop()) + + +def test_session_stop_removes_container() -> None: + """세션 종료 시 컨테이너가 중지/삭제되는지 확인합니다.""" + try: + import docker + except Exception: + pytest.skip("python docker SDK가 필요합니다") + + client = docker.from_env() + session = DockerSandboxSession() + + import asyncio + + asyncio.run(session.start()) + backend = session.get_backend() + container_id = backend.id + + # 실제로 컨테이너가 존재하는지 확인 + client.containers.get(container_id) + + asyncio.run(session.stop()) + + # stop()이 swallow하므로, 실제 삭제 여부는 inspect로 확인 + with pytest.raises(Exception): + client.containers.get(container_id) + + +# --------------------------------------------------------------------------- +# 5) Security Verification Tests +# --------------------------------------------------------------------------- + + +def test_container_security_options_applied( + docker_backend: DockerSandboxBackend, +) -> None: + """컨테이너 생성 시 네트워크/권한/메모리 제한 옵션이 적용되는지 확인합니다.""" + try: + import docker + except Exception: + pytest.skip("python docker SDK가 필요합니다") + + client = docker.from_env() + container = client.containers.get(docker_backend.id) + container.reload() + host_cfg = container.attrs.get("HostConfig", {}) + + assert host_cfg.get("NetworkMode") == "none" + + cap_drop = host_cfg.get("CapDrop") or [] + assert "ALL" in cap_drop + + security_opt = host_cfg.get("SecurityOpt") or [] + assert any("no-new-privileges" in opt for opt in security_opt) + + # Docker가 바이트 단위로 변환합니다(512m ≈ 536,870,912 bytes) + memory = host_cfg.get("Memory") + assert memory is not None + assert memory >= 512 * 1024 * 1024 + + pids_limit = host_cfg.get("PidsLimit") + assert pids_limit == 128 + + +def test_network_isolation_blocks_outbound( + docker_backend: DockerSandboxBackend, +) -> None: + """network_mode='none' 설정으로 외부 네트워크 연결이 차단되는지 확인합니다.""" + res = docker_backend.execute( + 'python3 -c "\n' + "import socket\n" + "s = socket.socket()\n" + "s.settimeout(1.0)\n" + "try:\n" + " s.connect(('1.1.1.1', 53))\n" + " print('UNEXPECTED_CONNECTED')\n" + " raise SystemExit(1)\n" + "except OSError as e:\n" + " print('blocked', type(e).__name__)\n" + " raise SystemExit(0)\n" + "finally:\n" + " s.close()\n" + '"' + ) + + assert res.exit_code == 0 + assert "blocked" in res.output + assert "UNEXPECTED_CONNECTED" not in res.output + + +# --------------------------------------------------------------------------- +# 6) LLM Output Formatting Tests (코드 실행 결과가 LLM에 전달되는지 검증) +# --------------------------------------------------------------------------- + + +def _format_execute_result_for_llm(result) -> str: + """DeepAgents _execute_tool_generator와 동일한 포맷팅 로직. + + FilesystemMiddleware의 execute tool이 LLM에 반환하는 형식을 재현합니다. + """ + parts = [result.output] + + if result.exit_code is not None: + status = "succeeded" if result.exit_code == 0 else "failed" + parts.append(f"\n[Command {status} with exit code {result.exit_code}]") + + if result.truncated: + parts.append("\n[Output was truncated due to size limits]") + + return "".join(parts) + + +def test_execute_result_formatted_for_llm_success( + docker_backend: DockerSandboxBackend, +) -> None: + """성공한 코드 실행 결과가 LLM이 인지할 수 있는 형태로 포맷팅되는지 확인합니다.""" + result = docker_backend.execute('python3 -c "print(42 * 2)"') + + llm_output = _format_execute_result_for_llm(result) + + assert "84" in llm_output + assert "[Command succeeded with exit code 0]" in llm_output + assert "truncated" not in llm_output.lower() + + +def test_execute_result_formatted_for_llm_failure( + docker_backend: DockerSandboxBackend, +) -> None: + """실패한 코드 실행 결과가 LLM이 인지할 수 있는 형태로 포맷팅되는지 확인합니다.""" + result = docker_backend.execute("python3 -c \"raise ValueError('test error')\"") + + llm_output = _format_execute_result_for_llm(result) + + assert "ValueError" in llm_output + assert "test error" in llm_output + assert "[Command failed with exit code 1]" in llm_output + + +def test_execute_result_formatted_for_llm_multiline_output( + docker_backend: DockerSandboxBackend, +) -> None: + """여러 줄 출력이 LLM에 그대로 전달되는지 확인합니다.""" + result = docker_backend.execute( + "python3 -c \"for i in range(5): print(f'line {i}')\"" + ) + + llm_output = _format_execute_result_for_llm(result) + + for i in range(5): + assert f"line {i}" in llm_output + assert "[Command succeeded with exit code 0]" in llm_output + + +def test_execute_result_formatted_for_llm_truncation_notice( + docker_backend: DockerSandboxBackend, +) -> None: + """대용량 출력이 잘릴 때 LLM에 truncation 알림이 포함되는지 확인합니다.""" + result = docker_backend.execute("python3 -c \"print('x' * 110500)\"") + + llm_output = _format_execute_result_for_llm(result) + + assert result.truncated is True + assert "[Output was truncated due to size limits]" in llm_output + assert "[Command succeeded with exit code 0]" in llm_output + + +def test_execute_result_contains_stderr_for_llm( + docker_backend: DockerSandboxBackend, +) -> None: + """stderr 출력이 LLM에 전달되는지 확인합니다.""" + result = docker_backend.execute( + "python3 -c \"import sys; sys.stderr.write('error message\\n')\"" + ) + + llm_output = _format_execute_result_for_llm(result) + + assert "error message" in llm_output diff --git a/tests/context_engineering/__init__.py b/tests/context_engineering/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/context_engineering/test_caching.py b/tests/context_engineering/test_caching.py new file mode 100644 index 0000000..7b8767b --- /dev/null +++ b/tests/context_engineering/test_caching.py @@ -0,0 +1,665 @@ +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from context_engineering_research_agent.context_strategies.caching import ( + CachingConfig, + CachingResult, + ContextCachingStrategy, + OpenRouterSubProvider, + ProviderType, + detect_openrouter_sub_provider, + detect_provider, + requires_cache_control_marker, +) +from context_engineering_research_agent.context_strategies.caching_telemetry import ( + CacheTelemetry, + PromptCachingTelemetryMiddleware, + extract_anthropic_cache_metrics, + extract_cache_telemetry, + extract_deepseek_cache_metrics, + extract_gemini_cache_metrics, + extract_openai_cache_metrics, +) + + +class TestCachingConfig: + def test_default_values(self): + config = CachingConfig() + + assert config.min_cacheable_tokens == 1024 + assert config.cache_control_type == "ephemeral" + assert config.enable_for_system_prompt is True + assert config.enable_for_tools is True + + def test_custom_values(self): + config = CachingConfig( + min_cacheable_tokens=2048, + cache_control_type="permanent", + enable_for_system_prompt=False, + enable_for_tools=False, + ) + + assert config.min_cacheable_tokens == 2048 + assert config.cache_control_type == "permanent" + assert config.enable_for_system_prompt is False + assert config.enable_for_tools is False + + +class TestCachingResult: + def test_not_cached(self): + result = CachingResult(was_cached=False) + + assert result.was_cached is False + assert result.cached_content_type is None + assert result.estimated_tokens_cached == 0 + + def test_cached(self): + result = CachingResult( + was_cached=True, + cached_content_type="system_prompt", + estimated_tokens_cached=5000, + ) + + assert result.was_cached is True + assert result.cached_content_type == "system_prompt" + assert result.estimated_tokens_cached == 5000 + + +class TestContextCachingStrategy: + @pytest.fixture + def mock_anthropic_model(self) -> MagicMock: + mock = MagicMock() + mock.__class__.__name__ = "ChatAnthropic" + mock.__class__.__module__ = "langchain_anthropic" + return mock + + @pytest.fixture + def strategy(self, mock_anthropic_model: MagicMock) -> ContextCachingStrategy: + return ContextCachingStrategy(model=mock_anthropic_model) + + @pytest.fixture + def strategy_low_threshold( + self, mock_anthropic_model: MagicMock + ) -> ContextCachingStrategy: + return ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=10), + model=mock_anthropic_model, + ) + + def test_estimate_tokens_string(self, strategy: ContextCachingStrategy): + content = "a" * 400 + estimated = strategy._estimate_tokens(content) + + assert estimated == 100 + + def test_estimate_tokens_list(self, strategy: ContextCachingStrategy): + content = [ + {"type": "text", "text": "a" * 200}, + {"type": "text", "text": "b" * 200}, + ] + estimated = strategy._estimate_tokens(content) + + assert estimated == 100 + + def test_estimate_tokens_dict(self, strategy: ContextCachingStrategy): + content = {"type": "text", "text": "a" * 400} + estimated = strategy._estimate_tokens(content) + + assert estimated == 100 + + def test_should_cache_small_content(self, strategy: ContextCachingStrategy): + small_content = "short text" + + assert strategy._should_cache(small_content) is False + + def test_should_cache_large_content(self, strategy: ContextCachingStrategy): + large_content = "x" * 5000 + + assert strategy._should_cache(large_content) is True + + def test_add_cache_control_string(self, strategy: ContextCachingStrategy): + content = "test content" + cached = strategy._add_cache_control(content) + + assert cached["type"] == "text" + assert cached["text"] == "test content" + assert cached["cache_control"]["type"] == "ephemeral" + + def test_add_cache_control_dict(self, strategy: ContextCachingStrategy): + content = {"type": "text", "text": "test content"} + cached = strategy._add_cache_control(content) + + assert cached["cache_control"]["type"] == "ephemeral" + + def test_add_cache_control_list(self, strategy: ContextCachingStrategy): + content = [ + {"type": "text", "text": "first"}, + {"type": "text", "text": "second"}, + ] + cached = strategy._add_cache_control(content) + + assert "cache_control" not in cached[0] + assert cached[1]["cache_control"]["type"] == "ephemeral" + + def test_add_cache_control_empty_list(self, strategy: ContextCachingStrategy): + content: list = [] + cached = strategy._add_cache_control(content) + + assert cached == [] + + def test_process_system_message(self, strategy: ContextCachingStrategy): + message = SystemMessage(content="You are a helpful assistant") + processed = strategy._process_system_message(message) + + assert isinstance(processed, SystemMessage) + assert isinstance(processed.content, list) + assert processed.content[0]["cache_control"]["type"] == "ephemeral" + + def test_apply_caching_empty_messages(self, strategy: ContextCachingStrategy): + messages: list = [] + cached, result = strategy.apply_caching(messages) + + assert result.was_cached is False + assert cached == [] + + def test_apply_caching_no_system_message(self, strategy: ContextCachingStrategy): + messages = [ + HumanMessage(content="Hello"), + AIMessage(content="Hi there!"), + ] + cached, result = strategy.apply_caching(messages) + + assert result.was_cached is False + + def test_apply_caching_small_system_message(self, strategy: ContextCachingStrategy): + messages = [ + SystemMessage(content="Be helpful"), + HumanMessage(content="Hello"), + ] + cached, result = strategy.apply_caching(messages) + + assert result.was_cached is False + + def test_apply_caching_large_system_message( + self, strategy_low_threshold: ContextCachingStrategy + ): + large_prompt = "You are a helpful assistant. " * 50 + messages = [ + SystemMessage(content=large_prompt), + HumanMessage(content="Hello"), + ] + cached, result = strategy_low_threshold.apply_caching(messages) + + assert result.was_cached is True + assert result.cached_content_type == "system_prompt" + assert result.estimated_tokens_cached > 0 + + def test_apply_caching_preserves_message_order( + self, strategy_low_threshold: ContextCachingStrategy + ): + large_prompt = "System prompt " * 100 + messages = [ + SystemMessage(content=large_prompt), + HumanMessage(content="Question 1"), + AIMessage(content="Answer 1"), + HumanMessage(content="Question 2"), + ] + cached, _ = strategy_low_threshold.apply_caching(messages) + + assert len(cached) == 4 + assert isinstance(cached[0], SystemMessage) + assert isinstance(cached[1], HumanMessage) + assert isinstance(cached[2], AIMessage) + assert isinstance(cached[3], HumanMessage) + + def test_custom_cache_control_type(self): + strategy = ContextCachingStrategy( + config=CachingConfig( + cache_control_type="permanent", + min_cacheable_tokens=10, + ) + ) + content = "test content" + cached = strategy._add_cache_control(content) + + assert cached["cache_control"]["type"] == "permanent" + + def test_disabled_system_prompt_caching(self): + strategy = ContextCachingStrategy( + config=CachingConfig( + enable_for_system_prompt=False, + min_cacheable_tokens=10, + ) + ) + large_prompt = "System prompt " * 100 + messages = [ + SystemMessage(content=large_prompt), + HumanMessage(content="Hello"), + ] + cached, result = strategy.apply_caching(messages) + + assert result.was_cached is False + + +class TestProviderDetection: + def test_detect_anthropic_from_class_name(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatAnthropic" + mock_model.__class__.__module__ = "langchain_anthropic.chat_models" + + assert detect_provider(mock_model) == ProviderType.ANTHROPIC + + def test_detect_openai_from_class_name(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatOpenAI" + mock_model.__class__.__module__ = "langchain_openai.chat_models" + mock_model.openai_api_base = None + + assert detect_provider(mock_model) == ProviderType.OPENAI + + def test_detect_gemini_from_class_name(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatGoogleGenerativeAI" + mock_model.__class__.__module__ = "langchain_google_genai" + + assert detect_provider(mock_model) == ProviderType.GEMINI + + def test_detect_openrouter_from_base_url(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatOpenAI" + mock_model.__class__.__module__ = "langchain_openai" + mock_model.openai_api_base = "https://openrouter.ai/api/v1" + + assert detect_provider(mock_model) == ProviderType.OPENROUTER + + def test_detect_unknown_provider(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "CustomModel" + mock_model.__class__.__module__ = "custom_module" + + assert detect_provider(mock_model) == ProviderType.UNKNOWN + + def test_detect_none_model(self): + assert detect_provider(None) == ProviderType.UNKNOWN + + def test_requires_cache_control_marker_anthropic(self): + assert requires_cache_control_marker(ProviderType.ANTHROPIC) is True + + def test_requires_cache_control_marker_openai(self): + assert requires_cache_control_marker(ProviderType.OPENAI) is False + + def test_requires_cache_control_marker_gemini(self): + assert requires_cache_control_marker(ProviderType.GEMINI) is False + + +class TestContextCachingStrategyMultiProvider: + def test_anthropic_applies_cache_markers(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatAnthropic" + mock_model.__class__.__module__ = "langchain_anthropic" + + strategy = ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=10), + model=mock_model, + ) + large_prompt = "System prompt " * 100 + messages = [SystemMessage(content=large_prompt)] + + cached, result = strategy.apply_caching(messages) + + assert result.was_cached is True + assert result.cached_content_type == "system_prompt" + + def test_openai_skips_cache_markers(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatOpenAI" + mock_model.__class__.__module__ = "langchain_openai" + mock_model.openai_api_base = None + + strategy = ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=10), + model=mock_model, + ) + large_prompt = "System prompt " * 100 + messages = [SystemMessage(content=large_prompt)] + + cached, result = strategy.apply_caching(messages) + + assert result.was_cached is False + assert result.cached_content_type == "auto_cached_by_openai" + + def test_gemini_skips_cache_markers(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatGoogleGenerativeAI" + mock_model.__class__.__module__ = "langchain_google_genai" + + strategy = ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=10), + model=mock_model, + ) + large_prompt = "System prompt " * 100 + messages = [SystemMessage(content=large_prompt)] + + cached, result = strategy.apply_caching(messages) + + assert result.was_cached is False + assert result.cached_content_type == "auto_cached_by_gemini" + + def test_set_model_updates_provider(self): + strategy = ContextCachingStrategy() + + mock_anthropic = MagicMock() + mock_anthropic.__class__.__name__ = "ChatAnthropic" + mock_anthropic.__class__.__module__ = "langchain_anthropic" + + strategy.set_model(mock_anthropic) + assert strategy.provider == ProviderType.ANTHROPIC + + mock_openai = MagicMock() + mock_openai.__class__.__name__ = "ChatOpenAI" + mock_openai.__class__.__module__ = "langchain_openai" + mock_openai.openai_api_base = None + + strategy.set_model(mock_openai) + assert strategy.provider == ProviderType.OPENAI + + +class TestCacheTelemetry: + def test_default_values(self): + telemetry = CacheTelemetry(provider=ProviderType.OPENAI) + + assert telemetry.cache_read_tokens == 0 + assert telemetry.cache_write_tokens == 0 + assert telemetry.cache_hit_ratio == 0.0 + + def test_with_values(self): + telemetry = CacheTelemetry( + provider=ProviderType.ANTHROPIC, + cache_read_tokens=1000, + cache_write_tokens=500, + total_input_tokens=2000, + cache_hit_ratio=0.5, + ) + + assert telemetry.cache_read_tokens == 1000 + assert telemetry.cache_write_tokens == 500 + assert telemetry.cache_hit_ratio == 0.5 + + +class TestCacheMetricsExtraction: + def test_extract_anthropic_metrics(self): + mock_response = MagicMock() + mock_response.usage_metadata = {"input_tokens": 1000} + mock_response.response_metadata = { + "usage": { + "input_tokens": 1000, + "cache_read_input_tokens": 800, + "cache_creation_input_tokens": 200, + } + } + + telemetry = extract_anthropic_cache_metrics(mock_response) + + assert telemetry.provider == ProviderType.ANTHROPIC + assert telemetry.cache_read_tokens == 800 + assert telemetry.cache_write_tokens == 200 + assert telemetry.cache_hit_ratio == 0.8 + + def test_extract_openai_metrics(self): + mock_response = MagicMock() + mock_response.usage_metadata = {"input_tokens": 1000} + mock_response.response_metadata = { + "token_usage": { + "prompt_tokens": 1000, + "prompt_tokens_details": {"cached_tokens": 500}, + } + } + + telemetry = extract_openai_cache_metrics(mock_response) + + assert telemetry.provider == ProviderType.OPENAI + assert telemetry.cache_read_tokens == 500 + assert telemetry.cache_hit_ratio == 0.5 + + def test_extract_gemini_metrics(self): + mock_response = MagicMock() + mock_response.usage_metadata = {"input_tokens": 1000} + mock_response.response_metadata = { + "prompt_token_count": 1000, + "cached_content_token_count": 750, + } + + telemetry = extract_gemini_cache_metrics(mock_response) + + assert telemetry.provider == ProviderType.GEMINI + assert telemetry.cache_read_tokens == 750 + assert telemetry.cache_hit_ratio == 0.75 + + def test_extract_cache_telemetry_unknown_provider(self): + mock_response = MagicMock() + mock_response.usage_metadata = {} + mock_response.response_metadata = {} + + telemetry = extract_cache_telemetry(mock_response, ProviderType.UNKNOWN) + + assert telemetry.provider == ProviderType.UNKNOWN + assert telemetry.cache_read_tokens == 0 + + +class TestPromptCachingTelemetryMiddleware: + def test_initialization(self): + middleware = PromptCachingTelemetryMiddleware() + + assert middleware.telemetry_history == [] + + def test_get_aggregate_stats_empty(self): + middleware = PromptCachingTelemetryMiddleware() + stats = middleware.get_aggregate_stats() + + assert stats["total_calls"] == 0 + assert stats["total_cache_read_tokens"] == 0 + + def test_get_aggregate_stats_with_data(self): + middleware = PromptCachingTelemetryMiddleware() + middleware._telemetry_history = [ + CacheTelemetry( + provider=ProviderType.ANTHROPIC, + cache_read_tokens=800, + cache_write_tokens=200, + total_input_tokens=1000, + ), + CacheTelemetry( + provider=ProviderType.ANTHROPIC, + cache_read_tokens=900, + cache_write_tokens=100, + total_input_tokens=1000, + ), + ] + + stats = middleware.get_aggregate_stats() + + assert stats["total_calls"] == 2 + assert stats["total_cache_read_tokens"] == 1700 + assert stats["total_cache_write_tokens"] == 300 + assert stats["total_input_tokens"] == 2000 + assert stats["overall_cache_hit_ratio"] == 0.85 + + def test_wrap_model_call_collects_telemetry(self): + middleware = PromptCachingTelemetryMiddleware() + + mock_response = MagicMock() + mock_response.response_metadata = { + "model": "claude-3-sonnet", + "usage": {"cache_read_input_tokens": 500}, + } + mock_response.usage_metadata = {"input_tokens": 1000} + + def mock_handler(request): + return mock_response + + mock_request = MagicMock() + result = middleware.wrap_model_call(mock_request, mock_handler) + + assert result == mock_response + assert len(middleware.telemetry_history) == 1 + + +class TestOpenRouterSubProvider: + def test_detect_anthropic_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("anthropic/claude-3-sonnet") + == OpenRouterSubProvider.ANTHROPIC + ) + assert ( + detect_openrouter_sub_provider("anthropic/claude-3.5-sonnet") + == OpenRouterSubProvider.ANTHROPIC + ) + + def test_detect_openai_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("openai/gpt-4o") + == OpenRouterSubProvider.OPENAI + ) + assert ( + detect_openrouter_sub_provider("openai/o1-preview") + == OpenRouterSubProvider.OPENAI + ) + + def test_detect_gemini_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("google/gemini-2.5-pro") + == OpenRouterSubProvider.GEMINI + ) + assert ( + detect_openrouter_sub_provider("google/gemini-3-flash") + == OpenRouterSubProvider.GEMINI + ) + + def test_detect_deepseek_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("deepseek/deepseek-chat") + == OpenRouterSubProvider.DEEPSEEK + ) + + def test_detect_groq_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("groq/kimi-k2") == OpenRouterSubProvider.GROQ + ) + + def test_detect_grok_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("xai/grok-2") == OpenRouterSubProvider.GROK + ) + + def test_detect_llama_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("meta-llama/llama-3.3-70b") + == OpenRouterSubProvider.META_LLAMA + ) + + def test_detect_mistral_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("mistral/mistral-large") + == OpenRouterSubProvider.MISTRAL + ) + + def test_detect_unknown_via_openrouter(self): + assert ( + detect_openrouter_sub_provider("some-provider/some-model") + == OpenRouterSubProvider.UNKNOWN + ) + + +class TestOpenRouterCachingStrategy: + def test_openrouter_anthropic_applies_cache_markers(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatOpenAI" + mock_model.__class__.__module__ = "langchain_openai" + mock_model.openai_api_base = "https://openrouter.ai/api/v1" + + strategy = ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=10), + model=mock_model, + openrouter_model_name="anthropic/claude-3-sonnet", + ) + + assert strategy.provider == ProviderType.OPENROUTER + assert strategy.sub_provider == OpenRouterSubProvider.ANTHROPIC + assert strategy.should_apply_cache_markers is True + + def test_openrouter_openai_skips_cache_markers(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatOpenAI" + mock_model.__class__.__module__ = "langchain_openai" + mock_model.openai_api_base = "https://openrouter.ai/api/v1" + + strategy = ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=10), + model=mock_model, + openrouter_model_name="openai/gpt-4o", + ) + + assert strategy.provider == ProviderType.OPENROUTER + assert strategy.sub_provider == OpenRouterSubProvider.OPENAI + assert strategy.should_apply_cache_markers is False + + def test_openrouter_deepseek_skips_cache_markers(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatOpenAI" + mock_model.__class__.__module__ = "langchain_openai" + mock_model.openai_api_base = "https://openrouter.ai/api/v1" + + strategy = ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=10), + model=mock_model, + openrouter_model_name="deepseek/deepseek-chat", + ) + + assert strategy.provider == ProviderType.OPENROUTER + assert strategy.sub_provider == OpenRouterSubProvider.DEEPSEEK + assert strategy.should_apply_cache_markers is False + + +class TestGemini3Detection: + def test_detect_gemini_3_from_model_name(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatGoogleGenerativeAI" + mock_model.__class__.__module__ = "langchain_google_genai" + mock_model.model_name = "gemini-3-pro-preview" + + assert detect_provider(mock_model) == ProviderType.GEMINI_3 + + def test_detect_gemini_3_flash(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatGoogleGenerativeAI" + mock_model.__class__.__module__ = "langchain_google_genai" + mock_model.model_name = "gemini-3-flash-preview" + + assert detect_provider(mock_model) == ProviderType.GEMINI_3 + + def test_detect_gemini_25_not_gemini_3(self): + mock_model = MagicMock() + mock_model.__class__.__name__ = "ChatGoogleGenerativeAI" + mock_model.__class__.__module__ = "langchain_google_genai" + mock_model.model_name = "gemini-2.5-pro" + + assert detect_provider(mock_model) == ProviderType.GEMINI + + +class TestDeepSeekMetrics: + def test_extract_deepseek_metrics(self): + mock_response = MagicMock() + mock_response.usage_metadata = {"input_tokens": 1000} + mock_response.response_metadata = { + "cache_hit_tokens": 700, + "cache_miss_tokens": 300, + } + + telemetry = extract_deepseek_cache_metrics(mock_response) + + assert telemetry.provider == ProviderType.DEEPSEEK + assert telemetry.cache_read_tokens == 700 + assert telemetry.cache_write_tokens == 300 + assert telemetry.cache_hit_ratio == 0.7 diff --git a/tests/context_engineering/test_integration.py b/tests/context_engineering/test_integration.py new file mode 100644 index 0000000..8b4a539 --- /dev/null +++ b/tests/context_engineering/test_integration.py @@ -0,0 +1,156 @@ +import os + +import pytest + +from context_engineering_research_agent import ( + ContextCachingStrategy, + ContextIsolationStrategy, + ContextOffloadingStrategy, + ContextReductionStrategy, + ContextRetrievalStrategy, + create_context_aware_agent, +) +from context_engineering_research_agent.context_strategies.caching import CachingConfig +from context_engineering_research_agent.context_strategies.offloading import ( + OffloadingConfig, +) +from context_engineering_research_agent.context_strategies.reduction import ( + ReductionConfig, +) + +SKIP_OPENAI = not os.environ.get("OPENAI_API_KEY") + + +class TestModuleExports: + def test_exports_all_strategies(self): + assert ContextOffloadingStrategy is not None + assert ContextReductionStrategy is not None + assert ContextRetrievalStrategy is not None + assert ContextIsolationStrategy is not None + assert ContextCachingStrategy is not None + + def test_exports_create_context_aware_agent(self): + assert create_context_aware_agent is not None + assert callable(create_context_aware_agent) + + +class TestStrategyInstantiation: + def test_offloading_strategy_instantiation(self): + strategy = ContextOffloadingStrategy() + assert strategy.config.token_limit_before_evict == 20000 + + custom_strategy = ContextOffloadingStrategy( + config=OffloadingConfig(token_limit_before_evict=10000) + ) + assert custom_strategy.config.token_limit_before_evict == 10000 + + def test_reduction_strategy_instantiation(self): + strategy = ContextReductionStrategy() + assert strategy.config.context_threshold == 0.85 + + custom_strategy = ContextReductionStrategy( + config=ReductionConfig(context_threshold=0.9) + ) + assert custom_strategy.config.context_threshold == 0.9 + + def test_retrieval_strategy_instantiation(self): + strategy = ContextRetrievalStrategy() + assert len(strategy.tools) == 3 + + def test_isolation_strategy_instantiation(self): + strategy = ContextIsolationStrategy() + assert len(strategy.tools) == 1 + + def test_caching_strategy_instantiation(self): + strategy = ContextCachingStrategy() + assert strategy.config.min_cacheable_tokens == 1024 + + custom_strategy = ContextCachingStrategy( + config=CachingConfig(min_cacheable_tokens=2048) + ) + assert custom_strategy.config.min_cacheable_tokens == 2048 + + +@pytest.mark.skipif(SKIP_OPENAI, reason="OPENAI_API_KEY not set") +class TestCreateContextAwareAgent: + def test_create_agent_default_settings(self): + agent = create_context_aware_agent(model_name="gpt-4.1") + + assert agent is not None + + def test_create_agent_all_disabled(self): + agent = create_context_aware_agent( + model_name="gpt-4.1", + enable_offloading=False, + enable_reduction=False, + enable_caching=False, + ) + + assert agent is not None + + def test_create_agent_all_enabled(self): + agent = create_context_aware_agent( + model_name="gpt-4.1", + enable_offloading=True, + enable_reduction=True, + enable_caching=True, + ) + + assert agent is not None + + def test_create_agent_custom_thresholds(self): + agent = create_context_aware_agent( + model_name="gpt-4.1", + enable_offloading=True, + enable_reduction=True, + offloading_token_limit=10000, + reduction_threshold=0.9, + ) + + assert agent is not None + + def test_create_agent_offloading_only(self): + agent = create_context_aware_agent( + model_name="gpt-4.1", + enable_offloading=True, + enable_reduction=False, + enable_caching=False, + ) + + assert agent is not None + + def test_create_agent_reduction_only(self): + agent = create_context_aware_agent( + model_name="gpt-4.1", + enable_offloading=False, + enable_reduction=True, + enable_caching=False, + ) + + assert agent is not None + + +class TestStrategyCombinations: + def test_offloading_with_reduction(self): + offloading = ContextOffloadingStrategy( + config=OffloadingConfig(token_limit_before_evict=15000) + ) + reduction = ContextReductionStrategy( + config=ReductionConfig(context_threshold=0.8) + ) + + assert offloading.config.token_limit_before_evict == 15000 + assert reduction.config.context_threshold == 0.8 + + def test_all_strategies_together(self): + offloading = ContextOffloadingStrategy() + reduction = ContextReductionStrategy() + retrieval = ContextRetrievalStrategy() + isolation = ContextIsolationStrategy() + caching = ContextCachingStrategy() + + strategies = [offloading, reduction, retrieval, isolation, caching] + + assert len(strategies) == 5 + for strategy in strategies: + assert hasattr(strategy, "config") diff --git a/tests/context_engineering/test_isolation.py b/tests/context_engineering/test_isolation.py new file mode 100644 index 0000000..178c202 --- /dev/null +++ b/tests/context_engineering/test_isolation.py @@ -0,0 +1,167 @@ +import pytest + +from context_engineering_research_agent.context_strategies.isolation import ( + ContextIsolationStrategy, + IsolationConfig, + IsolationResult, +) + + +class TestIsolationConfig: + def test_default_values(self): + config = IsolationConfig() + + assert config.default_model == "gpt-4.1" + assert config.include_general_purpose_agent is True + assert config.excluded_state_keys == ( + "messages", + "todos", + "structured_response", + ) + + def test_custom_values(self): + config = IsolationConfig( + default_model="claude-3-5-sonnet", + include_general_purpose_agent=False, + excluded_state_keys=("messages", "custom_key"), + ) + + assert config.default_model == "claude-3-5-sonnet" + assert config.include_general_purpose_agent is False + assert config.excluded_state_keys == ("messages", "custom_key") + + +class TestIsolationResult: + def test_successful_result(self): + result = IsolationResult( + subagent_name="researcher", + was_successful=True, + result_length=500, + ) + + assert result.subagent_name == "researcher" + assert result.was_successful is True + assert result.result_length == 500 + assert result.error is None + + def test_failed_result(self): + result = IsolationResult( + subagent_name="researcher", + was_successful=False, + result_length=0, + error="SubAgent not found", + ) + + assert result.was_successful is False + assert result.error == "SubAgent not found" + + +class TestContextIsolationStrategy: + @pytest.fixture + def strategy(self) -> ContextIsolationStrategy: + return ContextIsolationStrategy() + + @pytest.fixture + def strategy_with_subagents(self) -> ContextIsolationStrategy: + subagents = [ + { + "name": "researcher", + "description": "Research agent", + "system_prompt": "You are a researcher", + "tools": [], + }, + { + "name": "coder", + "description": "Coding agent", + "system_prompt": "You are a coder", + "tools": [], + }, + ] + return ContextIsolationStrategy(subagents=subagents) # type: ignore + + def test_initialization(self, strategy: ContextIsolationStrategy): + assert strategy.config is not None + assert len(strategy.tools) == 1 + + def test_creates_task_tool(self, strategy: ContextIsolationStrategy): + assert strategy.tools[0].name == "task" + + def test_task_tool_description( + self, strategy_with_subagents: ContextIsolationStrategy + ): + task_tool = strategy_with_subagents.tools[0] + + assert "researcher" in task_tool.description + assert "coder" in task_tool.description + assert "SubAgent" in task_tool.description + + def test_get_subagent_descriptions( + self, strategy_with_subagents: ContextIsolationStrategy + ): + descriptions = strategy_with_subagents._get_subagent_descriptions() + + assert "researcher" in descriptions + assert "Research agent" in descriptions + assert "coder" in descriptions + assert "Coding agent" in descriptions + + def test_prepare_subagent_state(self, strategy: ContextIsolationStrategy): + state = { + "messages": [{"role": "user", "content": "old message"}], + "todos": ["task1", "task2"], + "structured_response": {"key": "value"}, + "files": {"path": "/test"}, + } + + prepared = strategy._prepare_subagent_state(state, "New task description") + + assert "todos" not in prepared + assert "structured_response" not in prepared + assert "files" in prepared + assert len(prepared["messages"]) == 1 + assert prepared["messages"][0].content == "New task description" + + def test_prepare_subagent_state_custom_excluded_keys(self): + strategy = ContextIsolationStrategy( + config=IsolationConfig(excluded_state_keys=("messages", "custom_exclude")) + ) + state = { + "messages": [{"role": "user", "content": "old"}], + "custom_exclude": "should be excluded", + "keep_this": "should be kept", + } + + prepared = strategy._prepare_subagent_state(state, "New task") + + assert "custom_exclude" not in prepared + assert "keep_this" in prepared + + def test_compile_subagents_empty(self, strategy: ContextIsolationStrategy): + agents = strategy._compile_subagents() + + assert agents == {} + + def test_compile_subagents_caches_result( + self, strategy_with_subagents: ContextIsolationStrategy + ): + strategy_with_subagents._compiled_agents = {"cached": "value"} # type: ignore + + agents = strategy_with_subagents._compile_subagents() + + assert agents == {"cached": "value"} + + def test_task_without_subagents(self, strategy: ContextIsolationStrategy): + task_tool = strategy.tools[0] + + class MockRuntime: + state = {} + config = {} + tool_call_id = "call_123" + + result = task_tool.func( # type: ignore + description="Test task", + subagent_type="researcher", + runtime=MockRuntime(), + ) + + assert "존재하지 않습니다" in str(result) diff --git a/tests/context_engineering/test_offloading.py b/tests/context_engineering/test_offloading.py new file mode 100644 index 0000000..249e220 --- /dev/null +++ b/tests/context_engineering/test_offloading.py @@ -0,0 +1,184 @@ +import pytest +from langchain_core.messages import ToolMessage + +from context_engineering_research_agent.context_strategies.offloading import ( + ContextOffloadingStrategy, + OffloadingConfig, + OffloadingResult, +) + + +class TestOffloadingConfig: + def test_default_values(self): + config = OffloadingConfig() + + assert config.token_limit_before_evict == 20000 + assert config.eviction_path_prefix == "/large_tool_results" + assert config.preview_lines == 10 + assert config.chars_per_token == 4 + + def test_custom_values(self): + config = OffloadingConfig( + token_limit_before_evict=15000, + eviction_path_prefix="/custom_path", + preview_lines=5, + chars_per_token=3, + ) + + assert config.token_limit_before_evict == 15000 + assert config.eviction_path_prefix == "/custom_path" + assert config.preview_lines == 5 + assert config.chars_per_token == 3 + + +class TestOffloadingResult: + def test_not_offloaded(self): + result = OffloadingResult(was_offloaded=False, original_size=100) + + assert result.was_offloaded is False + assert result.original_size == 100 + assert result.file_path is None + assert result.preview is None + + def test_offloaded(self): + result = OffloadingResult( + was_offloaded=True, + original_size=100000, + file_path="/large_tool_results/call_123", + preview="first 10 lines...", + ) + + assert result.was_offloaded is True + assert result.original_size == 100000 + assert result.file_path == "/large_tool_results/call_123" + assert result.preview == "first 10 lines..." + + +class TestContextOffloadingStrategy: + @pytest.fixture + def strategy(self) -> ContextOffloadingStrategy: + return ContextOffloadingStrategy() + + @pytest.fixture + def strategy_low_threshold(self) -> ContextOffloadingStrategy: + return ContextOffloadingStrategy( + config=OffloadingConfig(token_limit_before_evict=100) + ) + + def test_estimate_tokens(self, strategy: ContextOffloadingStrategy): + content = "a" * 400 + estimated = strategy._estimate_tokens(content) + + assert estimated == 100 + + def test_estimate_tokens_with_custom_chars_per_token(self): + strategy = ContextOffloadingStrategy(config=OffloadingConfig(chars_per_token=2)) + content = "a" * 400 + estimated = strategy._estimate_tokens(content) + + assert estimated == 200 + + def test_should_offload_small_content(self, strategy: ContextOffloadingStrategy): + small_content = "short text" * 100 + + assert strategy._should_offload(small_content) is False + + def test_should_offload_large_content(self, strategy: ContextOffloadingStrategy): + large_content = "x" * 100000 + + assert strategy._should_offload(large_content) is True + + def test_should_offload_boundary(self): + config = OffloadingConfig(token_limit_before_evict=100, chars_per_token=4) + strategy = ContextOffloadingStrategy(config=config) + + exactly_at_limit = "x" * 400 + just_over_limit = "x" * 404 + + assert strategy._should_offload(exactly_at_limit) is False + assert strategy._should_offload(just_over_limit) is True + + def test_create_preview_short_content(self, strategy: ContextOffloadingStrategy): + content = "line1\nline2\nline3" + preview = strategy._create_preview(content) + + assert "line1" in preview + assert "line2" in preview + assert "line3" in preview + + def test_create_preview_long_content(self, strategy: ContextOffloadingStrategy): + lines = [f"line_{i}" for i in range(100)] + content = "\n".join(lines) + preview = strategy._create_preview(content) + + assert "line_0" in preview + assert "line_9" in preview + assert "line_10" not in preview + + def test_create_preview_with_custom_lines(self): + strategy = ContextOffloadingStrategy(config=OffloadingConfig(preview_lines=3)) + lines = [f"line_{i}" for i in range(10)] + content = "\n".join(lines) + preview = strategy._create_preview(content) + + assert "line_0" in preview + assert "line_2" in preview + assert "line_3" not in preview + + def test_create_preview_truncates_long_lines( + self, strategy: ContextOffloadingStrategy + ): + long_line = "x" * 2000 + preview = strategy._create_preview(long_line) + + assert len(preview.split("\t")[1]) == 1000 + + def test_create_offload_message(self, strategy: ContextOffloadingStrategy): + message = strategy._create_offload_message( + tool_call_id="call_123", + file_path="/large_tool_results/call_123", + preview="preview content", + ) + + assert "/large_tool_results/call_123" in message + assert "preview content" in message + assert "read_file" in message + + def test_sanitize_tool_call_id(self, strategy: ContextOffloadingStrategy): + normal_id = "call_abc123" + special_id = "call/with:special@chars!" + + assert strategy._sanitize_tool_call_id(normal_id) == "call_abc123" + assert strategy._sanitize_tool_call_id(special_id) == "call_with_special_chars_" + + def test_process_tool_result_small_content( + self, strategy: ContextOffloadingStrategy + ): + tool_result = ToolMessage(content="small content", tool_call_id="call_123") + + class MockRuntime: + state = {} + config = {} + + processed, result = strategy.process_tool_result(tool_result, MockRuntime()) # type: ignore + + assert result.was_offloaded is False + assert processed.content == "small content" + + def test_process_tool_result_no_backend( + self, strategy_low_threshold: ContextOffloadingStrategy + ): + large_content = "x" * 1000 + tool_result = ToolMessage(content=large_content, tool_call_id="call_123") + + class MockRuntime: + state = {} + config = {} + + processed, result = strategy_low_threshold.process_tool_result( + tool_result, + MockRuntime(), # type: ignore + ) + + assert result.was_offloaded is False + assert processed.content == large_content diff --git a/tests/context_engineering/test_openrouter_models.py b/tests/context_engineering/test_openrouter_models.py new file mode 100644 index 0000000..817dce1 --- /dev/null +++ b/tests/context_engineering/test_openrouter_models.py @@ -0,0 +1,187 @@ +"""OpenRouter 15개 모델 통합 테스트. + +실제 OpenRouter API를 통해 다양한 모델의 캐싱 전략과 provider 감지를 검증합니다. +OPENROUTER_API_KEY 환경 변수가 필요합니다. + +테스트 대상 모델 (Anthropic/OpenAI/Google 제외): +- deepseek/deepseek-v3.2 +- x-ai/grok-* +- xiaomi/mimo-v2-flash +- minimax/minimax-m2.1 +- bytedance-seed/seed-1.6 +- z-ai/glm-4.7 +- allenai/olmo-3.1-32b-instruct +- mistralai/mistral-small-creative +- nvidia/nemotron-3-nano-30b-a3b +- qwen/qwen3-max, qwen3-coder-plus, qwen3-coder-flash, qwen3-vl-32b-instruct, qwen3-next-80b-a3b-thinking +""" + +from __future__ import annotations + +import os + +import pytest +from langchain_openai import ChatOpenAI + +from context_engineering_research_agent.context_strategies.caching import ( + CachingConfig, + ContextCachingStrategy, + OpenRouterSubProvider, + ProviderType, + detect_openrouter_sub_provider, + detect_provider, +) + + +def _openrouter_available() -> bool: + """OpenRouter API 키가 설정되어 있는지 확인합니다.""" + return bool(os.environ.get("OPENROUTER_API_KEY")) + + +OPENROUTER_AVAILABLE = _openrouter_available() +OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1" + +pytestmark = [ + pytest.mark.integration, + pytest.mark.skipif( + not OPENROUTER_AVAILABLE, + reason="OPENROUTER_API_KEY 환경 변수가 설정되지 않았습니다.", + ), +] + + +OPENROUTER_MODELS = [ + ("deepseek/deepseek-chat-v3-0324", OpenRouterSubProvider.DEEPSEEK), + ("x-ai/grok-3-mini-beta", OpenRouterSubProvider.GROK), + ("qwen/qwen3-235b-a22b", OpenRouterSubProvider.UNKNOWN), + ("qwen/qwen3-32b", OpenRouterSubProvider.UNKNOWN), + ("mistralai/mistral-small-3.1-24b-instruct", OpenRouterSubProvider.MISTRAL), + ("meta-llama/llama-4-maverick", OpenRouterSubProvider.META_LLAMA), + ("meta-llama/llama-4-scout", OpenRouterSubProvider.META_LLAMA), + ("nvidia/llama-3.1-nemotron-70b-instruct", OpenRouterSubProvider.UNKNOWN), + ("microsoft/phi-4", OpenRouterSubProvider.UNKNOWN), + ("google/gemma-3-27b-it", OpenRouterSubProvider.UNKNOWN), + ("cohere/command-a", OpenRouterSubProvider.UNKNOWN), + ("perplexity/sonar-pro", OpenRouterSubProvider.UNKNOWN), + ("ai21/jamba-1.6-large", OpenRouterSubProvider.UNKNOWN), + ("inflection/inflection-3-pi", OpenRouterSubProvider.UNKNOWN), + ("amazon/nova-pro-v1", OpenRouterSubProvider.UNKNOWN), +] + + +def _create_openrouter_model(model_name: str) -> ChatOpenAI: + """OpenRouter 모델 인스턴스를 생성합니다.""" + return ChatOpenAI( + model=model_name, + openai_api_key=os.environ.get("OPENROUTER_API_KEY"), + openai_api_base=OPENROUTER_BASE_URL, + temperature=0.0, + max_tokens=100, + ) + + +class TestOpenRouterProviderDetection: + """OpenRouter Provider 감지 테스트.""" + + @pytest.mark.parametrize("model_name,expected_sub", OPENROUTER_MODELS) + def test_detect_provider_openrouter( + self, model_name: str, expected_sub: OpenRouterSubProvider + ) -> None: + """OpenRouter 모델이 ProviderType.OPENROUTER로 감지되는지 확인합니다.""" + model = _create_openrouter_model(model_name) + provider = detect_provider(model) + assert provider == ProviderType.OPENROUTER + + @pytest.mark.parametrize("model_name,expected_sub", OPENROUTER_MODELS) + def test_detect_openrouter_sub_provider( + self, model_name: str, expected_sub: OpenRouterSubProvider + ) -> None: + """OpenRouter 모델명에서 sub-provider를 올바르게 감지하는지 확인합니다.""" + sub_provider = detect_openrouter_sub_provider(model_name) + assert sub_provider == expected_sub + + +class TestOpenRouterCachingStrategy: + """OpenRouter 캐싱 전략 테스트.""" + + @pytest.fixture + def low_threshold_config(self) -> CachingConfig: + return CachingConfig(min_cacheable_tokens=10) + + @pytest.mark.parametrize("model_name,expected_sub", OPENROUTER_MODELS[:5]) + def test_caching_strategy_initialization( + self, + model_name: str, + expected_sub: OpenRouterSubProvider, + ) -> None: + """ContextCachingStrategy가 OpenRouter 모델로 올바르게 초기화되는지 확인합니다.""" + model = _create_openrouter_model(model_name) + strategy = ContextCachingStrategy(model=model, openrouter_model_name=model_name) + + assert strategy._provider == ProviderType.OPENROUTER + assert strategy._openrouter_sub_provider == expected_sub + + +class TestOpenRouterModelInvocation: + """OpenRouter 모델 실제 호출 테스트. + + 실제 API 비용이 발생하므로 소수의 모델만 테스트합니다. + """ + + @pytest.mark.parametrize( + "model_name", + [ + "deepseek/deepseek-chat-v3-0324", + "qwen/qwen3-32b", + "mistralai/mistral-small-3.1-24b-instruct", + ], + ) + def test_simple_invocation(self, model_name: str) -> None: + """모델이 간단한 프롬프트에 응답하는지 확인합니다.""" + model = _create_openrouter_model(model_name) + response = model.invoke("Say 'hello' in one word.") + assert response.content + assert len(response.content) > 0 + + @pytest.mark.parametrize( + "model_name", + [ + "deepseek/deepseek-chat-v3-0324", + "x-ai/grok-3-mini-beta", + ], + ) + def test_caching_strategy_apply_does_not_error(self, model_name: str) -> None: + """ContextCachingStrategy.apply()가 에러 없이 동작하는지 확인합니다.""" + from langchain_core.messages import HumanMessage, SystemMessage + + model = _create_openrouter_model(model_name) + strategy = ContextCachingStrategy( + model=model, + openrouter_model_name=model_name, + config=CachingConfig(min_cacheable_tokens=10), + ) + + messages = [ + SystemMessage(content="You are a helpful assistant. " * 100), + HumanMessage(content="Hello"), + ] + + result = strategy.apply(messages) + assert result.was_cached is not None + + +class TestOpenRouterModelNameExtraction: + """OpenRouter 모델명 추출 테스트.""" + + def test_model_name_extraction_from_model_attribute(self) -> None: + """model 속성에서 모델명이 올바르게 추출되는지 확인합니다.""" + model = _create_openrouter_model("deepseek/deepseek-chat-v3-0324") + name = getattr(model, "model_name", None) or getattr(model, "model", None) + assert name == "deepseek/deepseek-chat-v3-0324" + + def test_openrouter_base_url_detection(self) -> None: + """OpenRouter base URL이 올바르게 감지되는지 확인합니다.""" + model = _create_openrouter_model("qwen/qwen3-32b") + base_url = getattr(model, "openai_api_base", None) + assert base_url is not None + assert "openrouter" in base_url.lower() diff --git a/tests/context_engineering/test_reduction.py b/tests/context_engineering/test_reduction.py new file mode 100644 index 0000000..68f8c68 --- /dev/null +++ b/tests/context_engineering/test_reduction.py @@ -0,0 +1,255 @@ +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage + +from context_engineering_research_agent.context_strategies.reduction import ( + ContextReductionStrategy, + ReductionConfig, + ReductionResult, +) + + +class TestReductionConfig: + def test_default_values(self): + config = ReductionConfig() + + assert config.context_threshold == 0.85 + assert config.model_context_window == 200000 + assert config.compaction_age_threshold == 10 + assert config.min_messages_to_keep == 5 + assert config.chars_per_token == 4 + + def test_custom_values(self): + config = ReductionConfig( + context_threshold=0.90, + model_context_window=100000, + compaction_age_threshold=5, + min_messages_to_keep=3, + ) + + assert config.context_threshold == 0.90 + assert config.model_context_window == 100000 + assert config.compaction_age_threshold == 5 + assert config.min_messages_to_keep == 3 + + +class TestReductionResult: + def test_not_reduced(self): + result = ReductionResult(was_reduced=False) + + assert result.was_reduced is False + assert result.technique_used is None + assert result.original_message_count == 0 + assert result.reduced_message_count == 0 + assert result.estimated_tokens_saved == 0 + + def test_reduced_with_compaction(self): + result = ReductionResult( + was_reduced=True, + technique_used="compaction", + original_message_count=50, + reduced_message_count=30, + estimated_tokens_saved=5000, + ) + + assert result.was_reduced is True + assert result.technique_used == "compaction" + assert result.original_message_count == 50 + assert result.reduced_message_count == 30 + assert result.estimated_tokens_saved == 5000 + + +class TestContextReductionStrategy: + @pytest.fixture + def strategy(self) -> ContextReductionStrategy: + return ContextReductionStrategy() + + @pytest.fixture + def strategy_low_threshold(self) -> ContextReductionStrategy: + return ContextReductionStrategy( + config=ReductionConfig( + context_threshold=0.1, + model_context_window=1000, + ) + ) + + def test_estimate_tokens(self, strategy: ContextReductionStrategy): + messages = [ + HumanMessage(content="a" * 400), + AIMessage(content="b" * 400), + ] + estimated = strategy._estimate_tokens(messages) + + assert estimated == 200 + + def test_get_context_usage_ratio(self, strategy: ContextReductionStrategy): + messages = [ + HumanMessage(content="x" * 40000), + ] + ratio = strategy._get_context_usage_ratio(messages) + + assert ratio == pytest.approx(0.05, rel=0.01) + + def test_should_reduce_below_threshold(self, strategy: ContextReductionStrategy): + messages = [HumanMessage(content="short message")] + + assert strategy._should_reduce(messages) is False + + def test_should_reduce_above_threshold( + self, strategy_low_threshold: ContextReductionStrategy + ): + messages = [HumanMessage(content="x" * 1000)] + + assert strategy_low_threshold._should_reduce(messages) is True + + def test_apply_compaction_removes_old_tool_calls( + self, strategy: ContextReductionStrategy + ): + messages = [] + for i in range(20): + messages.append(HumanMessage(content=f"question {i}")) + ai_msg = AIMessage( + content=f"answer {i}", + tool_calls=( + [{"id": f"call_{i}", "name": "search", "args": {"q": "test"}}] + if i < 15 + else [] + ), + ) + messages.append(ai_msg) + if i < 15: + messages.append( + ToolMessage(content=f"result {i}", tool_call_id=f"call_{i}") + ) + + compacted, result = strategy.apply_compaction(messages) + + assert result.was_reduced is True + assert result.technique_used == "compaction" + assert len(compacted) < len(messages) + + def test_apply_compaction_keeps_recent_messages( + self, strategy: ContextReductionStrategy + ): + messages = [] + for i in range(5): + messages.append(HumanMessage(content=f"recent question {i}")) + messages.append(AIMessage(content=f"recent answer {i}")) + + compacted, result = strategy.apply_compaction(messages) + + assert len(compacted) == len(messages) + + def test_apply_compaction_preserves_text_content(self): + strategy = ContextReductionStrategy( + config=ReductionConfig(compaction_age_threshold=2) + ) + messages = [ + HumanMessage(content="old question"), + AIMessage( + content="old answer with important info", + tool_calls=[{"id": "call_old", "name": "search", "args": {}}], + ), + ToolMessage(content="old result", tool_call_id="call_old"), + HumanMessage(content="recent question"), + AIMessage(content="recent answer"), + ] + + compacted, _ = strategy.apply_compaction(messages) + + text_contents = [str(m.content) for m in compacted] + has_important_info = any("important info" in c for c in text_contents) + + assert has_important_info is True + + def test_apply_compaction_removes_tool_messages(self): + strategy = ContextReductionStrategy( + config=ReductionConfig(compaction_age_threshold=2) + ) + messages = [ + HumanMessage(content="old question"), + AIMessage( + content="", + tool_calls=[{"id": "call_old", "name": "search", "args": {}}], + ), + ToolMessage(content="old tool result", tool_call_id="call_old"), + HumanMessage(content="recent question"), + AIMessage(content="recent answer"), + ] + + compacted, _ = strategy.apply_compaction(messages) + + tool_message_count = sum(1 for m in compacted if isinstance(m, ToolMessage)) + + assert tool_message_count == 0 + + def test_reduce_context_no_reduction_needed( + self, strategy: ContextReductionStrategy + ): + messages = [HumanMessage(content="short")] + + reduced, result = strategy.reduce_context(messages) + + assert result.was_reduced is False + assert reduced == messages + + def test_reduce_context_with_very_large_messages(self): + strategy = ContextReductionStrategy( + config=ReductionConfig( + context_threshold=0.01, + model_context_window=100, + compaction_age_threshold=5, + ) + ) + messages = [] + for i in range(20): + messages.append(HumanMessage(content=f"question {i} " * 100)) + messages.append( + AIMessage( + content=f"answer {i} " * 100, + tool_calls=[{"id": f"c{i}", "name": "s", "args": {}}], + ) + ) + messages.append(ToolMessage(content="result " * 100, tool_call_id=f"c{i}")) + + compacted, result = strategy.apply_compaction(messages) + + assert result.was_reduced is True + assert len(compacted) < len(messages) + + def test_create_summary_prompt(self, strategy: ContextReductionStrategy): + messages = [ + HumanMessage(content="What is Python?"), + AIMessage(content="Python is a programming language."), + ] + + prompt = strategy._create_summary_prompt(messages) + + assert "Python" in prompt + assert "Human" in prompt + assert "AI" in prompt + + def test_apply_summarization_no_model(self, strategy: ContextReductionStrategy): + messages = [HumanMessage(content="test")] + + summarized, result = strategy.apply_summarization(messages) + + assert result.was_reduced is False + assert summarized == messages + + def test_compaction_preserves_system_messages(self): + strategy = ContextReductionStrategy( + config=ReductionConfig(compaction_age_threshold=2) + ) + messages = [ + SystemMessage(content="You are a helpful assistant"), + HumanMessage(content="old question"), + AIMessage(content="old answer"), + HumanMessage(content="recent question"), + AIMessage(content="recent answer"), + ] + + compacted, _ = strategy.apply_compaction(messages) + + system_messages = [m for m in compacted if isinstance(m, SystemMessage)] + assert len(system_messages) == 1 + assert "helpful assistant" in str(system_messages[0].content) diff --git a/tests/context_engineering/test_retrieval.py b/tests/context_engineering/test_retrieval.py new file mode 100644 index 0000000..46a8b21 --- /dev/null +++ b/tests/context_engineering/test_retrieval.py @@ -0,0 +1,163 @@ +import pytest + +from context_engineering_research_agent.context_strategies.retrieval import ( + ContextRetrievalStrategy, + RetrievalConfig, + RetrievalResult, +) + + +class TestRetrievalConfig: + def test_default_values(self): + config = RetrievalConfig() + + assert config.default_read_limit == 500 + assert config.max_grep_results == 100 + assert config.max_glob_results == 100 + assert config.truncate_line_length == 2000 + + def test_custom_values(self): + config = RetrievalConfig( + default_read_limit=1000, + max_grep_results=50, + max_glob_results=200, + truncate_line_length=3000, + ) + + assert config.default_read_limit == 1000 + assert config.max_grep_results == 50 + assert config.max_glob_results == 200 + assert config.truncate_line_length == 3000 + + +class TestRetrievalResult: + def test_basic_result(self): + result = RetrievalResult( + tool_used="grep", + query="TODO", + result_count=10, + ) + + assert result.tool_used == "grep" + assert result.query == "TODO" + assert result.result_count == 10 + assert result.was_truncated is False + + def test_truncated_result(self): + result = RetrievalResult( + tool_used="glob", + query="**/*.py", + result_count=100, + was_truncated=True, + ) + + assert result.was_truncated is True + + +class TestContextRetrievalStrategy: + @pytest.fixture + def strategy(self) -> ContextRetrievalStrategy: + return ContextRetrievalStrategy() + + @pytest.fixture + def strategy_with_custom_config(self) -> ContextRetrievalStrategy: + return ContextRetrievalStrategy( + config=RetrievalConfig( + default_read_limit=100, + max_grep_results=10, + max_glob_results=10, + ) + ) + + def test_initialization(self, strategy: ContextRetrievalStrategy): + assert strategy.config is not None + assert len(strategy.tools) == 3 + + def test_creates_read_file_tool(self, strategy: ContextRetrievalStrategy): + tool_names = [t.name for t in strategy.tools] + + assert "read_file" in tool_names + + def test_creates_grep_tool(self, strategy: ContextRetrievalStrategy): + tool_names = [t.name for t in strategy.tools] + + assert "grep" in tool_names + + def test_creates_glob_tool(self, strategy: ContextRetrievalStrategy): + tool_names = [t.name for t in strategy.tools] + + assert "glob" in tool_names + + def test_read_file_tool_description(self, strategy: ContextRetrievalStrategy): + read_file_tool = next(t for t in strategy.tools if t.name == "read_file") + + assert "500" in read_file_tool.description + assert "offset" in read_file_tool.description.lower() + assert "limit" in read_file_tool.description.lower() + + def test_grep_tool_description(self, strategy: ContextRetrievalStrategy): + grep_tool = next(t for t in strategy.tools if t.name == "grep") + + assert "100" in grep_tool.description + assert "pattern" in grep_tool.description.lower() + + def test_glob_tool_description(self, strategy: ContextRetrievalStrategy): + glob_tool = next(t for t in strategy.tools if t.name == "glob") + + assert "100" in glob_tool.description + assert "**/*.py" in glob_tool.description + + def test_custom_config_affects_tool_descriptions( + self, strategy_with_custom_config: ContextRetrievalStrategy + ): + read_file_tool = next( + t for t in strategy_with_custom_config.tools if t.name == "read_file" + ) + grep_tool = next( + t for t in strategy_with_custom_config.tools if t.name == "grep" + ) + + assert "100" in read_file_tool.description + assert "10" in grep_tool.description + + def test_no_backend_read_file(self, strategy: ContextRetrievalStrategy): + read_file_tool = next(t for t in strategy.tools if t.name == "read_file") + + class MockRuntime: + state = {} + config = {} + + result = read_file_tool.func( # type: ignore + file_path="/test.txt", + runtime=MockRuntime(), + ) + + assert "백엔드가 설정되지 않았습니다" in result + + def test_no_backend_grep(self, strategy: ContextRetrievalStrategy): + grep_tool = next(t for t in strategy.tools if t.name == "grep") + + class MockRuntime: + state = {} + config = {} + + result = grep_tool.func( # type: ignore + pattern="TODO", + runtime=MockRuntime(), + ) + + assert "백엔드가 설정되지 않았습니다" in result + + def test_no_backend_glob(self, strategy: ContextRetrievalStrategy): + glob_tool = next(t for t in strategy.tools if t.name == "glob") + + class MockRuntime: + state = {} + config = {} + + result = glob_tool.func( # type: ignore + pattern="**/*.py", + runtime=MockRuntime(), + ) + + assert "백엔드가 설정되지 않았습니다" in result diff --git a/uv.lock b/uv.lock index 33d32ca..97226a3 100644 --- a/uv.lock +++ b/uv.lock @@ -340,10 +340,12 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "docker" }, { name = "ipykernel" }, { name = "ipywidgets" }, { name = "langgraph-cli", extra = ["inmem"] }, { name = "mypy" }, + { name = "pytest" }, { name = "ruff" }, ] @@ -361,10 +363,12 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "docker", specifier = ">=7.1.0" }, { name = "ipykernel", specifier = ">=7.1.0" }, { name = "ipywidgets", specifier = ">=8.1.8" }, { name = "langgraph-cli", extras = ["inmem"], specifier = ">=0.4.11" }, { name = "mypy", specifier = ">=1.19.1" }, + { name = "pytest", specifier = ">=9.0.2" }, { name = "ruff", specifier = ">=0.14.10" }, ] @@ -393,6 +397,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" }, ] +[[package]] +name = "docker" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, + { name = "requests" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" }, +] + [[package]] name = "docstring-parser" version = "0.17.0" @@ -624,6 +642,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "ipykernel" version = "7.1.0" @@ -1408,6 +1435,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731, upload-time = "2025-12-05T13:52:56.823Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.52" @@ -1597,6 +1633,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" }, ] +[[package]] +name = "pytest" +version = "9.0.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1627,6 +1679,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, ] +[[package]] +name = "pywin32" +version = "311" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" }, + { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" }, + { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" }, + { url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" }, +] + [[package]] name = "pyyaml" version = "6.0.3"