init Context Engineering of DeepAgents
This commit is contained in:
677
Context_Engineering_Research.ipynb
Normal file
677
Context_Engineering_Research.ipynb
Normal file
@@ -0,0 +1,677 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4c897bc9",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Context Engineering 연구 노트북\n",
|
||||
"\n",
|
||||
"DeepAgents 라이브러리에서 사용되는 5가지 Context Engineering 전략을 분석하고 실험합니다.\n",
|
||||
"\n",
|
||||
"## Context Engineering 5가지 핵심 전략\n",
|
||||
"\n",
|
||||
"| 전략 | 설명 | DeepAgents 구현 |\n",
|
||||
"|------|------|----------------|\n",
|
||||
"| **1. Offloading** | 대용량 결과를 파일로 축출 | FilesystemMiddleware |\n",
|
||||
"| **2. Reduction** | Compaction + Summarization | SummarizationMiddleware |\n",
|
||||
"| **3. Retrieval** | grep/glob 기반 검색 | FilesystemMiddleware |\n",
|
||||
"| **4. Isolation** | SubAgent로 컨텍스트 격리 | SubAgentMiddleware |\n",
|
||||
"| **5. Caching** | Prompt Caching | AnthropicPromptCachingMiddleware |\n",
|
||||
"\n",
|
||||
"## 아키텍처 개요\n",
|
||||
"\n",
|
||||
"```\n",
|
||||
"┌─────────────────────────────────────────────────────────────────┐\n",
|
||||
"│ Context Engineering │\n",
|
||||
"├─────────────────────────────────────────────────────────────────┤\n",
|
||||
"│ │\n",
|
||||
"│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │\n",
|
||||
"│ │ Offloading │ │ Reduction │ │ Caching │ │\n",
|
||||
"│ │ (20k 토큰) │ │ (85% 임계) │ │ (Anthropic)│ │\n",
|
||||
"│ └─────┬──────┘ └─────┬──────┘ └─────┬──────┘ │\n",
|
||||
"│ │ │ │ │\n",
|
||||
"│ ▼ ▼ ▼ │\n",
|
||||
"│ ┌─────────────────────────────────────────────────────┐ │\n",
|
||||
"│ │ Middleware Stack │ │\n",
|
||||
"│ └─────────────────────────────────────────────────────┘ │\n",
|
||||
"│ │ │\n",
|
||||
"│ ┌────────────────┼────────────────┐ │\n",
|
||||
"│ ▼ ▼ ▼ │\n",
|
||||
"│ ┌────────────┐ ┌────────────┐ ┌────────────┐ │\n",
|
||||
"│ │ Retrieval │ │ Isolation │ │ Backend │ │\n",
|
||||
"│ │(grep/glob) │ │ (SubAgent) │ │ (FileSystem│ │\n",
|
||||
"│ └────────────┘ └────────────┘ └────────────┘ │\n",
|
||||
"│ │\n",
|
||||
"└─────────────────────────────────────────────────────────────────┘\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "fecc3e39",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"from dotenv import load_dotenv\n",
|
||||
"\n",
|
||||
"load_dotenv(\".env\", override=True)\n",
|
||||
"\n",
|
||||
"PROJECT_ROOT = Path.cwd()\n",
|
||||
"if str(PROJECT_ROOT) not in sys.path:\n",
|
||||
" sys.path.insert(0, str(PROJECT_ROOT))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "strategy1",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 전략 1: Context Offloading\n",
|
||||
"\n",
|
||||
"대용량 도구 결과를 파일시스템으로 축출하여 컨텍스트 윈도우 오버플로우를 방지합니다.\n",
|
||||
"\n",
|
||||
"### 핵심 원리\n",
|
||||
"- 도구 결과가 `tool_token_limit_before_evict` (기본 20,000 토큰) 초과 시 자동 축출\n",
|
||||
"- `/large_tool_results/{tool_call_id}` 경로에 저장\n",
|
||||
"- 처음 10줄 미리보기 제공\n",
|
||||
"- 에이전트가 `read_file`로 필요할 때 로드"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "offloading_demo",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent.context_strategies.offloading import (\n",
|
||||
" ContextOffloadingStrategy,\n",
|
||||
" OffloadingConfig,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"config = OffloadingConfig(\n",
|
||||
" token_limit_before_evict=20000,\n",
|
||||
" eviction_path_prefix=\"/large_tool_results\",\n",
|
||||
" preview_lines=10,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"토큰 임계값: {config.token_limit_before_evict:,}\")\n",
|
||||
"print(f\"축출 경로: {config.eviction_path_prefix}\")\n",
|
||||
"print(f\"미리보기 줄 수: {config.preview_lines}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "offloading_test",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"strategy = ContextOffloadingStrategy(config=config)\n",
|
||||
"\n",
|
||||
"small_content = \"짧은 텍스트\" * 100\n",
|
||||
"large_content = \"대용량 텍스트\" * 30000\n",
|
||||
"\n",
|
||||
"print(f\"짧은 콘텐츠: {len(small_content)} 자 → 축출 대상: {strategy._should_offload(small_content)}\")\n",
|
||||
"print(f\"대용량 콘텐츠: {len(large_content):,} 자 → 축출 대상: {strategy._should_offload(large_content)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "strategy2",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 전략 2: Context Reduction\n",
|
||||
"\n",
|
||||
"컨텍스트 윈도우 사용량이 임계값을 초과할 때 자동으로 대화 내용을 압축합니다.\n",
|
||||
"\n",
|
||||
"### 두 가지 기법\n",
|
||||
"\n",
|
||||
"| 기법 | 설명 | 비용 |\n",
|
||||
"|------|------|------|\n",
|
||||
"| **Compaction** | 오래된 도구 호출/결과 제거 | 무료 |\n",
|
||||
"| **Summarization** | LLM이 대화 요약 | API 비용 발생 |\n",
|
||||
"\n",
|
||||
"우선순위: Compaction → Summarization"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "reduction_demo",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent.context_strategies.reduction import (\n",
|
||||
" ContextReductionStrategy,\n",
|
||||
" ReductionConfig,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"config = ReductionConfig(\n",
|
||||
" context_threshold=0.85,\n",
|
||||
" model_context_window=200000,\n",
|
||||
" compaction_age_threshold=10,\n",
|
||||
" min_messages_to_keep=5,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"임계값: {config.context_threshold * 100}%\")\n",
|
||||
"print(f\"컨텍스트 윈도우: {config.model_context_window:,} 토큰\")\n",
|
||||
"print(f\"Compaction 대상 나이: {config.compaction_age_threshold} 메시지\")\n",
|
||||
"print(f\"최소 유지 메시지: {config.min_messages_to_keep}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "reduction_test",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_core.messages import AIMessage, HumanMessage\n",
|
||||
"\n",
|
||||
"strategy = ContextReductionStrategy(config=config)\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" HumanMessage(content=\"안녕하세요\" * 1000),\n",
|
||||
" AIMessage(content=\"안녕하세요\" * 1000),\n",
|
||||
"] * 20\n",
|
||||
"\n",
|
||||
"usage_ratio = strategy._get_context_usage_ratio(messages)\n",
|
||||
"print(f\"컨텍스트 사용률: {usage_ratio * 100:.1f}%\")\n",
|
||||
"print(f\"축소 필요: {strategy._should_reduce(messages)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "strategy3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 전략 3: Context Retrieval\n",
|
||||
"\n",
|
||||
"grep/glob 기반의 단순하고 빠른 검색으로 필요한 정보만 선택적으로 로드합니다.\n",
|
||||
"\n",
|
||||
"### 벡터 검색을 사용하지 않는 이유\n",
|
||||
"\n",
|
||||
"| 특성 | 직접 검색 | 벡터 검색 |\n",
|
||||
"|------|----------|----------|\n",
|
||||
"| 결정성 | ✅ 정확한 매칭 | ❌ 확률적 |\n",
|
||||
"| 인프라 | ✅ 불필요 | ❌ 벡터 DB 필요 |\n",
|
||||
"| 속도 | ✅ 빠름 | ❌ 인덱싱 오버헤드 |\n",
|
||||
"| 디버깅 | ✅ 예측 가능 | ❌ 블랙박스 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "retrieval_demo",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent.context_strategies.retrieval import (\n",
|
||||
" ContextRetrievalStrategy,\n",
|
||||
" RetrievalConfig,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"config = RetrievalConfig(\n",
|
||||
" default_read_limit=500,\n",
|
||||
" max_grep_results=100,\n",
|
||||
" max_glob_results=100,\n",
|
||||
" truncate_line_length=2000,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"기본 읽기 제한: {config.default_read_limit} 줄\")\n",
|
||||
"print(f\"grep 최대 결과: {config.max_grep_results}\")\n",
|
||||
"print(f\"glob 최대 결과: {config.max_glob_results}\")\n",
|
||||
"print(f\"줄 길이 제한: {config.truncate_line_length} 자\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "strategy4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 전략 4: Context Isolation\n",
|
||||
"\n",
|
||||
"SubAgent를 통해 독립된 컨텍스트 윈도우에서 작업을 수행합니다.\n",
|
||||
"\n",
|
||||
"### 장점\n",
|
||||
"- 메인 에이전트 컨텍스트 오염 방지\n",
|
||||
"- 복잡한 작업의 격리 처리\n",
|
||||
"- 병렬 처리 가능\n",
|
||||
"\n",
|
||||
"### SubAgent 유형\n",
|
||||
"\n",
|
||||
"| 유형 | 구조 | 특징 |\n",
|
||||
"|------|------|------|\n",
|
||||
"| Simple | `{name, system_prompt, tools}` | 단일 응답 |\n",
|
||||
"| Compiled | `{name, runnable}` | 자체 DeepAgent, 다중 턴 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "isolation_demo",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent.context_strategies.isolation import (\n",
|
||||
" ContextIsolationStrategy,\n",
|
||||
" IsolationConfig,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"config = IsolationConfig(\n",
|
||||
" default_model=\"gpt-4.1\",\n",
|
||||
" include_general_purpose_agent=True,\n",
|
||||
" excluded_state_keys=(\"messages\", \"todos\", \"structured_response\"),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"기본 모델: {config.default_model}\")\n",
|
||||
"print(f\"범용 에이전트 포함: {config.include_general_purpose_agent}\")\n",
|
||||
"print(f\"제외 상태 키: {config.excluded_state_keys}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "strategy5",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 전략 5: Context Caching\n",
|
||||
"\n",
|
||||
"Anthropic Prompt Caching을 활용하여 시스템 프롬프트와 반복 컨텍스트를 캐싱합니다.\n",
|
||||
"\n",
|
||||
"### 이점\n",
|
||||
"- API 호출 비용 절감\n",
|
||||
"- 응답 속도 향상\n",
|
||||
"- 동일 세션 내 반복 호출 최적화\n",
|
||||
"\n",
|
||||
"### 캐싱 조건\n",
|
||||
"- 최소 1,024 토큰 이상\n",
|
||||
"- `cache_control: {\"type\": \"ephemeral\"}` 마커 추가"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "caching_demo",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent.context_strategies.caching import (\n",
|
||||
" ContextCachingStrategy,\n",
|
||||
" CachingConfig,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"config = CachingConfig(\n",
|
||||
" min_cacheable_tokens=1024,\n",
|
||||
" cache_control_type=\"ephemeral\",\n",
|
||||
" enable_for_system_prompt=True,\n",
|
||||
" enable_for_tools=True,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"최소 캐싱 토큰: {config.min_cacheable_tokens:,}\")\n",
|
||||
"print(f\"캐시 컨트롤 타입: {config.cache_control_type}\")\n",
|
||||
"print(f\"시스템 프롬프트 캐싱: {config.enable_for_system_prompt}\")\n",
|
||||
"print(f\"도구 캐싱: {config.enable_for_tools}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "caching_test",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"strategy = ContextCachingStrategy(config=config)\n",
|
||||
"\n",
|
||||
"short_content = \"짧은 시스템 프롬프트\"\n",
|
||||
"long_content = \"긴 시스템 프롬프트 \" * 500\n",
|
||||
"\n",
|
||||
"print(f\"짧은 콘텐츠: {len(short_content)} 자 → 캐싱 대상: {strategy._should_cache(short_content)}\")\n",
|
||||
"print(f\"긴 콘텐츠: {len(long_content):,} 자 → 캐싱 대상: {strategy._should_cache(long_content)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "integration",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 통합 에이전트 실행\n",
|
||||
"\n",
|
||||
"5가지 전략이 모두 적용된 에이전트를 실행합니다."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "agent_create",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent import create_context_aware_agent\n",
|
||||
"\n",
|
||||
"agent = create_context_aware_agent(\n",
|
||||
" model_name=\"gpt-4.1\",\n",
|
||||
" enable_offloading=True,\n",
|
||||
" enable_reduction=True,\n",
|
||||
" enable_caching=True,\n",
|
||||
" offloading_token_limit=20000,\n",
|
||||
" reduction_threshold=0.85,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"에이전트 타입: {type(agent).__name__}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "comparison_intro",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 전략 활성화/비활성화 비교 실험\n",
|
||||
"\n",
|
||||
"각 전략을 활성화/비활성화했을 때의 차이점을 실험합니다.\n",
|
||||
"\n",
|
||||
"### 실험 설계\n",
|
||||
"\n",
|
||||
"| 실험 | Offloading | Reduction | Caching | 목적 |\n",
|
||||
"|------|------------|-----------|---------|------|\n",
|
||||
"| 1. 기본 | ❌ | ❌ | ❌ | 베이스라인 |\n",
|
||||
"| 2. Offloading만 | ✅ | ❌ | ❌ | 대용량 결과 축출 효과 |\n",
|
||||
"| 3. Reduction만 | ❌ | ✅ | ❌ | 컨텍스트 압축 효과 |\n",
|
||||
"| 4. 모두 활성화 | ✅ | ✅ | ✅ | 전체 효과 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "comparison_setup",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent.context_strategies.offloading import (\n",
|
||||
" ContextOffloadingStrategy, OffloadingConfig\n",
|
||||
")\n",
|
||||
"from context_engineering_research_agent.context_strategies.reduction import (\n",
|
||||
" ContextReductionStrategy, ReductionConfig\n",
|
||||
")\n",
|
||||
"from langchain_core.messages import AIMessage, HumanMessage, ToolMessage\n",
|
||||
"\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(\"전략 비교를 위한 테스트 데이터 생성\")\n",
|
||||
"print(\"=\" * 60)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exp1_offloading",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 실험 1: Offloading 전략 효과\n",
|
||||
"\n",
|
||||
"대용량 도구 결과가 있을 때 Offloading 활성화/비활성화 비교"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "exp1_code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"small_result = \"검색 결과: 항목 1, 항목 2, 항목 3\"\n",
|
||||
"large_result = \"\\n\".join([f\"검색 결과 {i}: \" + \"상세 내용 \" * 100 for i in range(500)])\n",
|
||||
"\n",
|
||||
"print(f\"작은 결과 크기: {len(small_result):,} 자\")\n",
|
||||
"print(f\"대용량 결과 크기: {len(large_result):,} 자\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"offloading_disabled = ContextOffloadingStrategy(\n",
|
||||
" config=OffloadingConfig(token_limit_before_evict=999999999)\n",
|
||||
")\n",
|
||||
"offloading_enabled = ContextOffloadingStrategy(\n",
|
||||
" config=OffloadingConfig(token_limit_before_evict=20000)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"[Offloading 비활성화 시]\")\n",
|
||||
"print(f\" 작은 결과 축출: {offloading_disabled._should_offload(small_result)}\")\n",
|
||||
"print(f\" 대용량 결과 축출: {offloading_disabled._should_offload(large_result)}\")\n",
|
||||
"print(f\" → 대용량 결과가 컨텍스트에 그대로 포함됨\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"print(\"[Offloading 활성화 시]\")\n",
|
||||
"print(f\" 작은 결과 축출: {offloading_enabled._should_offload(small_result)}\")\n",
|
||||
"print(f\" 대용량 결과 축출: {offloading_enabled._should_offload(large_result)}\")\n",
|
||||
"print(f\" → 대용량 결과는 파일로 저장, 미리보기만 컨텍스트에 포함\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"preview = offloading_enabled._create_preview(large_result)\n",
|
||||
"print(f\"미리보기 크기: {len(preview):,} 자 (원본의 {len(preview)/len(large_result)*100:.1f}%)\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exp2_reduction",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 실험 2: Reduction 전략 효과\n",
|
||||
"\n",
|
||||
"긴 대화에서 Compaction 적용 전/후 비교"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "exp2_code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"messages_with_tools = []\n",
|
||||
"for i in range(30):\n",
|
||||
" messages_with_tools.append(HumanMessage(content=f\"질문 {i}: \" + \"내용 \" * 50))\n",
|
||||
" ai_msg = AIMessage(\n",
|
||||
" content=f\"답변 {i}: \" + \"응답 \" * 50,\n",
|
||||
" tool_calls=[{'id': f'call_{i}', 'name': 'search', 'args': {'q': 'test'}}] if i < 25 else []\n",
|
||||
" )\n",
|
||||
" messages_with_tools.append(ai_msg)\n",
|
||||
" if i < 25:\n",
|
||||
" messages_with_tools.append(ToolMessage(content=f\"도구 결과 {i}: \" + \"결과 \" * 30, tool_call_id=f'call_{i}'))\n",
|
||||
"\n",
|
||||
"reduction = ContextReductionStrategy(\n",
|
||||
" config=ReductionConfig(compaction_age_threshold=10)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"original_tokens = reduction._estimate_tokens(messages_with_tools)\n",
|
||||
"print(f\"[Reduction 비활성화 시]\")\n",
|
||||
"print(f\" 메시지 수: {len(messages_with_tools)}\")\n",
|
||||
"print(f\" 추정 토큰: {original_tokens:,}\")\n",
|
||||
"print(f\" → 모든 도구 호출/결과가 컨텍스트에 유지됨\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"compacted, result = reduction.apply_compaction(messages_with_tools)\n",
|
||||
"compacted_tokens = reduction._estimate_tokens(compacted)\n",
|
||||
"\n",
|
||||
"print(f\"[Reduction 활성화 시 - Compaction]\")\n",
|
||||
"print(f\" 메시지 수: {len(messages_with_tools)} → {len(compacted)}\")\n",
|
||||
"print(f\" 추정 토큰: {original_tokens:,} → {compacted_tokens:,}\")\n",
|
||||
"print(f\" 절약된 토큰: {result.estimated_tokens_saved:,} ({result.estimated_tokens_saved/original_tokens*100:.1f}%)\")\n",
|
||||
"print(f\" → 오래된 도구 호출/결과가 제거되어 컨텍스트 효율화\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exp3_combined",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 실험 3: 전략 조합 효과 시뮬레이션\n",
|
||||
"\n",
|
||||
"모든 전략을 함께 적용했을 때의 시너지 효과"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "exp3_code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"=\" * 60)\n",
|
||||
"print(\"시나리오: 복잡한 연구 작업 수행\")\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"scenario = {\n",
|
||||
" \"대화 턴 수\": 50,\n",
|
||||
" \"도구 호출 수\": 40,\n",
|
||||
" \"대용량 결과 수\": 5,\n",
|
||||
" \"평균 결과 크기\": \"100k 자\",\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"print(\"[시나리오 설정]\")\n",
|
||||
"for k, v in scenario.items():\n",
|
||||
" print(f\" {k}: {v}\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"baseline_context = 50 * 500 + 40 * 300 + 5 * 100000\n",
|
||||
"print(\"[모든 전략 비활성화 시]\")\n",
|
||||
"print(f\" 예상 컨텍스트 크기: {baseline_context:,} 자 (~{baseline_context//4:,} 토큰)\")\n",
|
||||
"print(f\" 문제: 컨텍스트 윈도우 초과 가능성 높음\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"with_offloading = 50 * 500 + 40 * 300 + 5 * 1000\n",
|
||||
"print(\"[Offloading만 활성화 시]\")\n",
|
||||
"print(f\" 예상 컨텍스트 크기: {with_offloading:,} 자 (~{with_offloading//4:,} 토큰)\")\n",
|
||||
"print(f\" 절약: {(baseline_context - with_offloading):,} 자 ({(baseline_context - with_offloading)/baseline_context*100:.1f}%)\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"with_reduction = with_offloading * 0.6\n",
|
||||
"print(\"[Offloading + Reduction 활성화 시]\")\n",
|
||||
"print(f\" 예상 컨텍스트 크기: {int(with_reduction):,} 자 (~{int(with_reduction)//4:,} 토큰)\")\n",
|
||||
"print(f\" 총 절약: {int(baseline_context - with_reduction):,} 자 ({(baseline_context - with_reduction)/baseline_context*100:.1f}%)\")\n",
|
||||
"print()\n",
|
||||
"\n",
|
||||
"print(\"[+ Caching 활성화 시 추가 효과]\")\n",
|
||||
"print(f\" 시스템 프롬프트 캐싱으로 반복 호출 비용 90% 절감\")\n",
|
||||
"print(f\" 응답 속도 향상\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exp4_live",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 실험 4: 실제 에이전트 실행 비교\n",
|
||||
"\n",
|
||||
"실제 에이전트를 다른 설정으로 생성하여 비교합니다."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "exp4_code",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from context_engineering_research_agent import create_context_aware_agent\n",
|
||||
"\n",
|
||||
"print(\"에이전트 생성 비교\")\n",
|
||||
"print(\"=\" * 60)\n",
|
||||
"\n",
|
||||
"configs = [\n",
|
||||
" {\"name\": \"기본 (모두 비활성화)\", \"offloading\": False, \"reduction\": False, \"caching\": False},\n",
|
||||
" {\"name\": \"Offloading만\", \"offloading\": True, \"reduction\": False, \"caching\": False},\n",
|
||||
" {\"name\": \"Reduction만\", \"offloading\": False, \"reduction\": True, \"caching\": False},\n",
|
||||
" {\"name\": \"모두 활성화\", \"offloading\": True, \"reduction\": True, \"caching\": True},\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"for cfg in configs:\n",
|
||||
" agent = create_context_aware_agent(\n",
|
||||
" model_name=\"gpt-4.1\",\n",
|
||||
" enable_offloading=cfg[\"offloading\"],\n",
|
||||
" enable_reduction=cfg[\"reduction\"],\n",
|
||||
" enable_caching=cfg[\"caching\"],\n",
|
||||
" )\n",
|
||||
" print(f\"\\n[{cfg['name']}]\")\n",
|
||||
" print(f\" Offloading: {'✅' if cfg['offloading'] else '❌'}\")\n",
|
||||
" print(f\" Reduction: {'✅' if cfg['reduction'] else '❌'}\")\n",
|
||||
" print(f\" Caching: {'✅' if cfg['caching'] else '❌'}\")\n",
|
||||
" print(f\" 에이전트 타입: {type(agent).__name__}\")\n",
|
||||
"\n",
|
||||
"print(\"\\n\" + \"=\" * 60)\n",
|
||||
"print(\"모든 에이전트가 성공적으로 생성되었습니다.\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "exp5_recommendation",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 권장 설정\n",
|
||||
"\n",
|
||||
"| 사용 사례 | Offloading | Reduction | Caching | 이유 |\n",
|
||||
"|----------|------------|-----------|---------|------|\n",
|
||||
"| **짧은 대화** | ❌ | ❌ | ✅ | 오버헤드 최소화 |\n",
|
||||
"| **일반 작업** | ✅ | ❌ | ✅ | 대용량 결과 대비 |\n",
|
||||
"| **장시간 연구** | ✅ | ✅ | ✅ | 모든 최적화 활용 |\n",
|
||||
"| **디버깅** | ❌ | ❌ | ❌ | 전체 컨텍스트 확인 |"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "summary",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"---\n",
|
||||
"\n",
|
||||
"## 요약\n",
|
||||
"\n",
|
||||
"### Context Engineering 5가지 전략 요약\n",
|
||||
"\n",
|
||||
"| 전략 | 트리거 조건 | 효과 |\n",
|
||||
"|------|------------|------|\n",
|
||||
"| **Offloading** | 20k 토큰 초과 | 파일로 축출 |\n",
|
||||
"| **Reduction** | 85% 사용량 초과 | Compaction/Summarization |\n",
|
||||
"| **Retrieval** | 파일 접근 필요 | grep/glob 검색 |\n",
|
||||
"| **Isolation** | 복잡한 작업 | SubAgent 위임 |\n",
|
||||
"| **Caching** | 1k+ 토큰 시스템 프롬프트 | Prompt Caching |\n",
|
||||
"\n",
|
||||
"### 핵심 인사이트\n",
|
||||
"\n",
|
||||
"1. **파일시스템 = 외부 메모리**: 컨텍스트 윈도우는 제한되어 있지만, 파일시스템은 무한\n",
|
||||
"2. **점진적 공개**: 모든 정보를 한 번에 로드하지 않고 필요할 때만 로드\n",
|
||||
"3. **격리된 실행**: SubAgent로 컨텍스트 오염 방지\n",
|
||||
"4. **자동화된 관리**: 에이전트가 직접 컨텍스트를 관리하도록 미들웨어 설계"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"name": "python",
|
||||
"version": "3.12.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -37,7 +37,7 @@
|
||||
## 모듈 구조
|
||||
|
||||
```
|
||||
context-engineering-more-deep_research_agent/
|
||||
context_engineering_research_agent/
|
||||
├── __init__.py # 이 파일
|
||||
├── agent.py # 메인 에이전트 (5가지 전략 통합)
|
||||
├── prompts.py # 시스템 프롬프트
|
||||
@@ -65,9 +65,10 @@ context-engineering-more-deep_research_agent/
|
||||
## 사용 예시
|
||||
|
||||
```python
|
||||
from context_engineering_research_agent import agent
|
||||
from context_engineering_research_agent import get_agent
|
||||
|
||||
# 에이전트 실행
|
||||
# 에이전트 실행 (API key 필요)
|
||||
agent = get_agent()
|
||||
result = agent.invoke({
|
||||
"messages": [{"role": "user", "content": "Context Engineering 전략 연구"}]
|
||||
})
|
||||
@@ -80,16 +81,14 @@ result = agent.invoke({
|
||||
- LangGraph: https://docs.langchain.com/oss/python/langgraph/overview
|
||||
"""
|
||||
|
||||
# 버전 정보
|
||||
__version__ = "0.1.0"
|
||||
__author__ = "Context Engineering Research Team"
|
||||
|
||||
# 주요 컴포넌트 export
|
||||
from context_engineering_more_deep_research_agent.agent import (
|
||||
agent,
|
||||
from context_engineering_research_agent.agent import (
|
||||
create_context_aware_agent,
|
||||
get_agent,
|
||||
)
|
||||
from context_engineering_more_deep_research_agent.context_strategies import (
|
||||
from context_engineering_research_agent.context_strategies import (
|
||||
ContextCachingStrategy,
|
||||
ContextIsolationStrategy,
|
||||
ContextOffloadingStrategy,
|
||||
@@ -98,10 +97,8 @@ from context_engineering_more_deep_research_agent.context_strategies import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 에이전트
|
||||
"agent",
|
||||
"get_agent",
|
||||
"create_context_aware_agent",
|
||||
# Context Engineering 전략
|
||||
"ContextOffloadingStrategy",
|
||||
"ContextReductionStrategy",
|
||||
"ContextRetrievalStrategy",
|
||||
254
context_engineering_research_agent/agent.py
Normal file
254
context_engineering_research_agent/agent.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""Context Engineering 연구용 메인 에이전트.
|
||||
|
||||
5가지 Context Engineering 전략을 명시적으로 통합한 에이전트입니다.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from deepagents import create_deep_agent
|
||||
from deepagents.backends import CompositeBackend, FilesystemBackend, StateBackend
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from context_engineering_research_agent.context_strategies import (
|
||||
ContextCachingStrategy,
|
||||
ContextOffloadingStrategy,
|
||||
ContextReductionStrategy,
|
||||
PromptCachingTelemetryMiddleware,
|
||||
ProviderType,
|
||||
detect_provider,
|
||||
)
|
||||
|
||||
CONTEXT_ENGINEERING_SYSTEM_PROMPT = """# Context Engineering 연구 에이전트
|
||||
|
||||
당신은 Context Engineering 전략을 연구하고 실험하는 에이전트입니다.
|
||||
|
||||
## Context Engineering 5가지 핵심 전략
|
||||
|
||||
### 1. Context Offloading (컨텍스트 오프로딩)
|
||||
대용량 도구 결과를 파일시스템으로 축출합니다.
|
||||
- 도구 결과가 20,000 토큰 초과 시 자동 축출
|
||||
- /large_tool_results/ 경로에 저장
|
||||
- read_file로 필요할 때 로드
|
||||
|
||||
### 2. Context Reduction (컨텍스트 축소)
|
||||
컨텍스트 윈도우 사용량이 임계값 초과 시 압축합니다.
|
||||
- Compaction: 오래된 도구 호출/결과 제거
|
||||
- Summarization: LLM이 대화 요약 (85% 초과 시)
|
||||
|
||||
### 3. Context Retrieval (컨텍스트 검색)
|
||||
필요한 정보만 선택적으로 로드합니다.
|
||||
- grep: 텍스트 패턴 검색
|
||||
- glob: 파일명 패턴 매칭
|
||||
- read_file: 부분 읽기 (offset/limit)
|
||||
|
||||
### 4. Context Isolation (컨텍스트 격리)
|
||||
SubAgent를 통해 독립된 컨텍스트에서 작업합니다.
|
||||
- task() 도구로 작업 위임
|
||||
- 메인 컨텍스트 오염 방지
|
||||
- 복잡한 작업의 격리 처리
|
||||
|
||||
### 5. Context Caching (컨텍스트 캐싱)
|
||||
시스템 프롬프트와 반복 컨텍스트를 캐싱합니다.
|
||||
- Anthropic Prompt Caching 활용
|
||||
- API 비용 절감
|
||||
- 응답 속도 향상
|
||||
|
||||
## 연구 워크플로우
|
||||
|
||||
1. 연구 요청을 분석하고 TODO 목록 작성
|
||||
2. 필요시 SubAgent에게 작업 위임
|
||||
3. 중간 결과를 파일시스템에 저장
|
||||
4. 최종 보고서 작성 (/final_report.md)
|
||||
|
||||
## 중요 원칙
|
||||
|
||||
- 대용량 결과는 파일로 저장하고 참조
|
||||
- 복잡한 작업은 SubAgent에게 위임
|
||||
- 진행 상황을 TODO로 추적
|
||||
- 인용 형식: [1], [2], [3]
|
||||
"""
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parent.parent
|
||||
RESEARCH_WORKSPACE_DIR = BASE_DIR / "research_workspace"
|
||||
|
||||
_cached_agent = None
|
||||
_cached_model = None
|
||||
|
||||
|
||||
def _infer_openrouter_model_name(model: BaseChatModel) -> str | None:
|
||||
"""OpenRouter 모델에서 모델명을 추출합니다.
|
||||
|
||||
Args:
|
||||
model: LangChain 모델 인스턴스
|
||||
|
||||
Returns:
|
||||
OpenRouter 모델명 (예: "anthropic/claude-3-sonnet") 또는 None
|
||||
"""
|
||||
if detect_provider(model) != ProviderType.OPENROUTER:
|
||||
return None
|
||||
for attr in ("model_name", "model", "model_id"):
|
||||
name = getattr(model, attr, None)
|
||||
if isinstance(name, str) and name.strip():
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def _get_fs_backend() -> FilesystemBackend:
|
||||
return FilesystemBackend(
|
||||
root_dir=RESEARCH_WORKSPACE_DIR,
|
||||
virtual_mode=True,
|
||||
max_file_size_mb=20,
|
||||
)
|
||||
|
||||
|
||||
def _get_backend_factory():
|
||||
fs_backend = _get_fs_backend()
|
||||
|
||||
def backend_factory(rt: ToolRuntime) -> CompositeBackend:
|
||||
return CompositeBackend(
|
||||
default=StateBackend(rt),
|
||||
routes={"/": fs_backend},
|
||||
)
|
||||
|
||||
return backend_factory
|
||||
|
||||
|
||||
def get_model():
|
||||
global _cached_model
|
||||
if _cached_model is None:
|
||||
_cached_model = ChatOpenAI(model="gpt-4.1", temperature=0.0)
|
||||
return _cached_model
|
||||
|
||||
|
||||
def get_agent():
|
||||
global _cached_agent
|
||||
if _cached_agent is None:
|
||||
model = get_model()
|
||||
backend_factory = _get_backend_factory()
|
||||
|
||||
offloading_strategy = ContextOffloadingStrategy(backend_factory=backend_factory)
|
||||
reduction_strategy = ContextReductionStrategy(summarization_model=model)
|
||||
openrouter_model_name = _infer_openrouter_model_name(model)
|
||||
caching_strategy = ContextCachingStrategy(
|
||||
model=model,
|
||||
openrouter_model_name=openrouter_model_name,
|
||||
)
|
||||
telemetry_middleware = PromptCachingTelemetryMiddleware()
|
||||
|
||||
_cached_agent = create_deep_agent(
|
||||
model=model,
|
||||
system_prompt=CONTEXT_ENGINEERING_SYSTEM_PROMPT,
|
||||
backend=backend_factory,
|
||||
middleware=[
|
||||
offloading_strategy,
|
||||
reduction_strategy,
|
||||
caching_strategy,
|
||||
telemetry_middleware,
|
||||
],
|
||||
)
|
||||
return _cached_agent
|
||||
|
||||
|
||||
def create_context_aware_agent(
|
||||
model: BaseChatModel | str = "gpt-4.1",
|
||||
workspace_dir: Path | str | None = None,
|
||||
enable_offloading: bool = True,
|
||||
enable_reduction: bool = True,
|
||||
enable_caching: bool = True,
|
||||
enable_cache_telemetry: bool = True,
|
||||
offloading_token_limit: int = 20000,
|
||||
reduction_threshold: float = 0.85,
|
||||
openrouter_model_name: str | None = None,
|
||||
) -> Any:
|
||||
"""Context Engineering 전략이 적용된 에이전트를 생성합니다.
|
||||
|
||||
Multi-Provider 지원: Anthropic, OpenAI, Gemini, OpenRouter 모델 사용 가능.
|
||||
Provider는 자동 감지되며, Anthropic만 cache_control 마커가 적용됩니다.
|
||||
|
||||
Args:
|
||||
model: LLM 모델 객체 또는 모델명 (기본: gpt-4.1)
|
||||
workspace_dir: 작업 디렉토리
|
||||
enable_offloading: Context Offloading 활성화
|
||||
enable_reduction: Context Reduction 활성화
|
||||
enable_caching: Context Caching 활성화
|
||||
enable_cache_telemetry: Cache 텔레메트리 수집 활성화
|
||||
offloading_token_limit: Offloading 토큰 임계값
|
||||
reduction_threshold: Reduction 트리거 임계값
|
||||
openrouter_model_name: OpenRouter 모델명 강제 지정
|
||||
|
||||
Returns:
|
||||
구성된 DeepAgent
|
||||
"""
|
||||
from context_engineering_research_agent.context_strategies.offloading import (
|
||||
OffloadingConfig,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.reduction import (
|
||||
ReductionConfig,
|
||||
)
|
||||
|
||||
if isinstance(model, str):
|
||||
llm: BaseChatModel = ChatOpenAI(model=model, temperature=0.0)
|
||||
else:
|
||||
llm = model
|
||||
|
||||
workspace = Path(workspace_dir) if workspace_dir else RESEARCH_WORKSPACE_DIR
|
||||
workspace.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
local_fs_backend = FilesystemBackend(
|
||||
root_dir=workspace,
|
||||
virtual_mode=True,
|
||||
max_file_size_mb=20,
|
||||
)
|
||||
|
||||
def local_backend_factory(rt: ToolRuntime) -> CompositeBackend:
|
||||
return CompositeBackend(
|
||||
default=StateBackend(rt),
|
||||
routes={"/": local_fs_backend},
|
||||
)
|
||||
|
||||
middlewares = []
|
||||
|
||||
if enable_offloading:
|
||||
offload_config = OffloadingConfig(
|
||||
token_limit_before_evict=offloading_token_limit
|
||||
)
|
||||
middlewares.append(
|
||||
ContextOffloadingStrategy(
|
||||
config=offload_config, backend_factory=local_backend_factory
|
||||
)
|
||||
)
|
||||
|
||||
if enable_reduction:
|
||||
reduce_config = ReductionConfig(context_threshold=reduction_threshold)
|
||||
middlewares.append(
|
||||
ContextReductionStrategy(config=reduce_config, summarization_model=llm)
|
||||
)
|
||||
|
||||
if enable_caching:
|
||||
inferred_openrouter_model_name = (
|
||||
openrouter_model_name
|
||||
if openrouter_model_name is not None
|
||||
else _infer_openrouter_model_name(llm)
|
||||
)
|
||||
middlewares.append(
|
||||
ContextCachingStrategy(
|
||||
model=llm,
|
||||
openrouter_model_name=inferred_openrouter_model_name,
|
||||
)
|
||||
)
|
||||
|
||||
if enable_cache_telemetry:
|
||||
middlewares.append(PromptCachingTelemetryMiddleware())
|
||||
|
||||
return create_deep_agent(
|
||||
model=llm,
|
||||
system_prompt=CONTEXT_ENGINEERING_SYSTEM_PROMPT,
|
||||
backend=local_backend_factory,
|
||||
middleware=middlewares,
|
||||
)
|
||||
24
context_engineering_research_agent/backends/__init__.py
Normal file
24
context_engineering_research_agent/backends/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""백엔드 모듈.
|
||||
|
||||
안전한 코드 실행을 위한 백엔드 구현체들입니다.
|
||||
"""
|
||||
|
||||
from context_engineering_research_agent.backends.docker_sandbox import (
|
||||
DockerSandboxBackend,
|
||||
)
|
||||
from context_engineering_research_agent.backends.docker_session import (
|
||||
DockerSandboxSession,
|
||||
)
|
||||
from context_engineering_research_agent.backends.docker_shared import (
|
||||
SharedDockerBackend,
|
||||
)
|
||||
from context_engineering_research_agent.backends.pyodide_sandbox import (
|
||||
PyodideSandboxBackend,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PyodideSandboxBackend",
|
||||
"SharedDockerBackend",
|
||||
"DockerSandboxBackend",
|
||||
"DockerSandboxSession",
|
||||
]
|
||||
211
context_engineering_research_agent/backends/docker_sandbox.py
Normal file
211
context_engineering_research_agent/backends/docker_sandbox.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Docker 샌드박스 백엔드 구현.
|
||||
|
||||
BaseSandbox를 상속하여 Docker 컨테이너 내에서 파일 작업 및 코드 실행을 수행합니다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import posixpath
|
||||
import tarfile
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from deepagents.backends.protocol import (
|
||||
ExecuteResponse,
|
||||
FileDownloadResponse,
|
||||
FileOperationError,
|
||||
FileUploadResponse,
|
||||
)
|
||||
from deepagents.backends.sandbox import BaseSandbox
|
||||
|
||||
from context_engineering_research_agent.backends.workspace_protocol import (
|
||||
WORKSPACE_ROOT,
|
||||
)
|
||||
|
||||
|
||||
class DockerSandboxBackend(BaseSandbox):
|
||||
"""Docker 컨테이너 기반 샌드박스 백엔드.
|
||||
|
||||
실행과 파일 작업을 모두 컨테이너 내부에서 수행하여
|
||||
SubAgent 간 공유 작업공간을 제공합니다.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
container_id: str,
|
||||
*,
|
||||
workspace_root: str = WORKSPACE_ROOT,
|
||||
docker_client: Any | None = None,
|
||||
) -> None:
|
||||
self._container_id = container_id
|
||||
self._workspace_root = workspace_root
|
||||
self._docker_client = docker_client
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""컨테이너 식별자를 반환합니다."""
|
||||
return self._container_id
|
||||
|
||||
def _get_docker_client(self) -> Any:
|
||||
if self._docker_client is None:
|
||||
try:
|
||||
import docker
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"docker 패키지가 설치되지 않았습니다: pip install docker"
|
||||
) from exc
|
||||
try:
|
||||
self._docker_client = docker.from_env()
|
||||
except Exception as exc:
|
||||
docker_exception = getattr(
|
||||
getattr(docker, "errors", None), "DockerException", None
|
||||
)
|
||||
if docker_exception and isinstance(exc, docker_exception):
|
||||
raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc
|
||||
raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc
|
||||
return self._docker_client
|
||||
|
||||
def _get_container(self) -> Any:
|
||||
client = self._get_docker_client()
|
||||
try:
|
||||
return client.containers.get(self._container_id)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"컨테이너 조회 실패: {exc}") from exc
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
if path.startswith("/"):
|
||||
return path
|
||||
return posixpath.join(self._workspace_root, path)
|
||||
|
||||
def _ensure_parent_dir(self, parent_dir: str) -> None:
|
||||
if not parent_dir:
|
||||
return
|
||||
result = self.execute(f"mkdir -p {parent_dir}")
|
||||
if result.exit_code not in (0, None):
|
||||
raise RuntimeError(result.output or "워크스페이스 디렉토리 생성 실패")
|
||||
|
||||
def _truncate_output(self, output: str) -> tuple[str, bool]:
|
||||
if len(output) <= 100000:
|
||||
return output, False
|
||||
return output[:100000] + "\n[출력이 잘렸습니다...]", True
|
||||
|
||||
def execute(self, command: str) -> ExecuteResponse:
|
||||
"""컨테이너 내부에서 명령을 실행합니다.
|
||||
|
||||
shell을 통해 실행하므로 리다이렉션(>), 파이프(|), &&, || 등을 사용할 수 있습니다.
|
||||
"""
|
||||
try:
|
||||
container = self._get_container()
|
||||
exec_result = container.exec_run(
|
||||
["sh", "-c", command],
|
||||
workdir=self._workspace_root,
|
||||
)
|
||||
raw_output = exec_result.output
|
||||
if isinstance(raw_output, bytes):
|
||||
output = raw_output.decode("utf-8", errors="replace")
|
||||
else:
|
||||
output = str(raw_output)
|
||||
output, truncated = self._truncate_output(output)
|
||||
return ExecuteResponse(
|
||||
output=output,
|
||||
exit_code=exec_result.exit_code,
|
||||
truncated=truncated,
|
||||
)
|
||||
except Exception as exc:
|
||||
return ExecuteResponse(output=f"Docker 실행 오류: {exc}", exit_code=1)
|
||||
|
||||
async def aexecute(self, command: str) -> ExecuteResponse:
|
||||
"""비동기 실행 래퍼."""
|
||||
return await asyncio.to_thread(self.execute, command)
|
||||
|
||||
def upload_files(self, files: list[tuple[str, bytes]]) -> list[FileUploadResponse]:
|
||||
"""파일을 컨테이너로 업로드합니다."""
|
||||
responses: list[FileUploadResponse] = []
|
||||
for path, content in files:
|
||||
try:
|
||||
full_path = self._resolve_path(path)
|
||||
parent_dir = posixpath.dirname(full_path)
|
||||
file_name = posixpath.basename(full_path)
|
||||
if not file_name:
|
||||
raise ValueError("업로드 경로에 파일명이 필요합니다")
|
||||
self._ensure_parent_dir(parent_dir)
|
||||
|
||||
tar_stream = io.BytesIO()
|
||||
with tarfile.open(fileobj=tar_stream, mode="w") as tar:
|
||||
info = tarfile.TarInfo(name=file_name)
|
||||
info.size = len(content)
|
||||
info.mtime = time.time()
|
||||
tar.addfile(info, io.BytesIO(content))
|
||||
tar_stream.seek(0)
|
||||
|
||||
container = self._get_container()
|
||||
container.put_archive(parent_dir or "/", tar_stream.getvalue())
|
||||
responses.append(FileUploadResponse(path=path))
|
||||
except Exception as exc:
|
||||
responses.append(
|
||||
FileUploadResponse(path=path, error=self._map_upload_error(exc))
|
||||
)
|
||||
return responses
|
||||
|
||||
async def aupload_files(
|
||||
self, files: list[tuple[str, bytes]]
|
||||
) -> list[FileUploadResponse]:
|
||||
"""비동기 업로드 래퍼."""
|
||||
return await asyncio.to_thread(self.upload_files, files)
|
||||
|
||||
def download_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
"""컨테이너에서 파일을 다운로드합니다."""
|
||||
responses: list[FileDownloadResponse] = []
|
||||
for path in paths:
|
||||
try:
|
||||
full_path = self._resolve_path(path)
|
||||
container = self._get_container()
|
||||
stream, _ = container.get_archive(full_path)
|
||||
raw = b"".join(chunk for chunk in stream)
|
||||
content = self._extract_tar_content(raw)
|
||||
if content is None:
|
||||
responses.append(
|
||||
FileDownloadResponse(path=path, error="is_directory")
|
||||
)
|
||||
continue
|
||||
responses.append(FileDownloadResponse(path=path, content=content))
|
||||
except Exception as exc:
|
||||
responses.append(
|
||||
FileDownloadResponse(path=path, error=self._map_download_error(exc))
|
||||
)
|
||||
return responses
|
||||
|
||||
async def adownload_files(self, paths: list[str]) -> list[FileDownloadResponse]:
|
||||
"""비동기 다운로드 래퍼."""
|
||||
return await asyncio.to_thread(self.download_files, paths)
|
||||
|
||||
def _extract_tar_content(self, raw: bytes) -> bytes | None:
|
||||
with tarfile.open(fileobj=io.BytesIO(raw)) as tar:
|
||||
members = [member for member in tar.getmembers() if member.isfile()]
|
||||
if not members:
|
||||
return None
|
||||
target = members[0]
|
||||
file_obj = tar.extractfile(target)
|
||||
if file_obj is None:
|
||||
return None
|
||||
return file_obj.read()
|
||||
|
||||
def _map_upload_error(self, exc: Exception) -> FileOperationError:
|
||||
message = str(exc).lower()
|
||||
if "permission" in message:
|
||||
return "permission_denied"
|
||||
if "is a directory" in message:
|
||||
return "is_directory"
|
||||
return "invalid_path"
|
||||
|
||||
def _map_download_error(self, exc: Exception) -> FileOperationError:
|
||||
message = str(exc).lower()
|
||||
if "permission" in message:
|
||||
return "permission_denied"
|
||||
if "no such file" in message or "not found" in message:
|
||||
return "file_not_found"
|
||||
if "is a directory" in message:
|
||||
return "is_directory"
|
||||
return "invalid_path"
|
||||
118
context_engineering_research_agent/backends/docker_session.py
Normal file
118
context_engineering_research_agent/backends/docker_session.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Docker 샌드박스 세션 관리 모듈.
|
||||
|
||||
요청 단위로 Docker 컨테이너를 생성하고, 모든 subagent가 동일한 /workspace를 공유합니다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from context_engineering_research_agent.backends.docker_sandbox import (
|
||||
DockerSandboxBackend,
|
||||
)
|
||||
from context_engineering_research_agent.backends.workspace_protocol import (
|
||||
META_DIR,
|
||||
SHARED_DIR,
|
||||
WORKSPACE_ROOT,
|
||||
)
|
||||
|
||||
|
||||
class DockerSandboxSession:
|
||||
"""Docker 컨테이너 라이프사이클을 관리하는 세션.
|
||||
|
||||
컨테이너는 요청 단위로 생성되며, SubAgent는 동일한 컨테이너를 공유합니다.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image: str = "python:3.11-slim",
|
||||
workspace_root: str = WORKSPACE_ROOT,
|
||||
) -> None:
|
||||
self.image = image
|
||||
self.workspace_root = workspace_root
|
||||
self._docker_client: Any | None = None
|
||||
self._container: Any | None = None
|
||||
self._backend: DockerSandboxBackend | None = None
|
||||
|
||||
def _get_docker_client(self) -> Any:
|
||||
if self._docker_client is None:
|
||||
try:
|
||||
import docker
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
"docker 패키지가 설치되지 않았습니다: pip install docker"
|
||||
) from exc
|
||||
try:
|
||||
self._docker_client = docker.from_env()
|
||||
except Exception as exc:
|
||||
docker_exception = getattr(
|
||||
getattr(docker, "errors", None), "DockerException", None
|
||||
)
|
||||
if docker_exception and isinstance(exc, docker_exception):
|
||||
raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc
|
||||
raise RuntimeError(f"Docker 클라이언트 초기화 실패: {exc}") from exc
|
||||
return self._docker_client
|
||||
|
||||
async def start(self) -> None:
|
||||
"""보안 옵션이 적용된 컨테이너를 생성합니다."""
|
||||
if self._container is not None:
|
||||
return
|
||||
|
||||
client = self._get_docker_client()
|
||||
try:
|
||||
self._container = await asyncio.to_thread(
|
||||
client.containers.run,
|
||||
self.image,
|
||||
command="tail -f /dev/null",
|
||||
detach=True,
|
||||
network_mode="none",
|
||||
cap_drop=["ALL"],
|
||||
security_opt=["no-new-privileges=true"],
|
||||
mem_limit="512m",
|
||||
pids_limit=128,
|
||||
working_dir=self.workspace_root,
|
||||
)
|
||||
await asyncio.to_thread(
|
||||
self._container.exec_run,
|
||||
f"mkdir -p {self.workspace_root}/{META_DIR} {self.workspace_root}/{SHARED_DIR}",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"Docker 컨테이너 생성 실패: {exc}") from exc
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""컨테이너를 중지하고 제거합니다."""
|
||||
container = self._container
|
||||
self._container = None
|
||||
self._backend = None
|
||||
if container is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(container.stop)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(container.remove)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_backend(self) -> DockerSandboxBackend:
|
||||
"""현재 세션용 백엔드를 반환합니다."""
|
||||
if self._container is None:
|
||||
raise RuntimeError("DockerSandboxSession이 시작되지 않았습니다")
|
||||
if self._backend is None:
|
||||
self._backend = DockerSandboxBackend(
|
||||
container_id=self._container.id,
|
||||
workspace_root=self.workspace_root,
|
||||
docker_client=self._get_docker_client(),
|
||||
)
|
||||
return self._backend
|
||||
|
||||
async def __aenter__(self) -> DockerSandboxSession:
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
|
||||
await self.stop()
|
||||
264
context_engineering_research_agent/backends/docker_shared.py
Normal file
264
context_engineering_research_agent/backends/docker_shared.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Docker 공유 작업공간 백엔드.
|
||||
|
||||
여러 SubAgent가 동일한 Docker 컨테이너 작업공간을 공유하는 백엔드입니다.
|
||||
|
||||
## 설계 배경
|
||||
|
||||
DeepAgents의 기본 구조에서 SubAgent들은 독립된 컨텍스트를 가지지만,
|
||||
파일시스템을 공유해야 하는 경우가 있습니다:
|
||||
- 연구 결과물 공유
|
||||
- 중간 생성물 활용
|
||||
- 협업 워크플로우
|
||||
|
||||
## 아키텍처
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Main Agent (Orchestrator) │
|
||||
│ │
|
||||
│ ┌───────────┐ ┌───────────┐ ┌───────────┐ │
|
||||
│ │SubAgent A │ │SubAgent B │ │SubAgent C │ │
|
||||
│ │(Research) │ │(Analysis) │ │(Synthesis)│ │
|
||||
│ └─────┬─────┘ └─────┬─────┘ └─────┬─────┘ │
|
||||
│ │ │ │ │
|
||||
│ └──────────────┼──────────────┘ │
|
||||
│ ▼ │
|
||||
│ SharedDockerBackend │
|
||||
│ │ │
|
||||
└───────────────────────┼─────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Docker Container │
|
||||
│ ┌─────────────────────────────────────────────────────────│
|
||||
│ │ /workspace (공유 작업 디렉토리) │
|
||||
│ │ ├── research/ (SubAgent A 출력) │
|
||||
│ │ ├── analysis/ (SubAgent B 출력) │
|
||||
│ │ └── synthesis/ (SubAgent C 출력) │
|
||||
│ └─────────────────────────────────────────────────────────│
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## 보안 고려사항
|
||||
|
||||
1. **컨테이너 격리**: 호스트 시스템과 격리
|
||||
2. **볼륨 마운트**: 필요한 디렉토리만 마운트
|
||||
3. **네트워크 정책**: 필요시 네트워크 격리
|
||||
4. **리소스 제한**: CPU/메모리 제한 설정
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class DockerConfig:
|
||||
image: str = "python:3.11-slim"
|
||||
workspace_path: str = "/workspace"
|
||||
memory_limit: str = "2g"
|
||||
cpu_limit: float = 2.0
|
||||
network_mode: str = "none"
|
||||
auto_remove: bool = True
|
||||
timeout_seconds: int = 300
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteResponse:
|
||||
output: str
|
||||
exit_code: int | None = None
|
||||
truncated: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class WriteResult:
|
||||
path: str
|
||||
error: str | None = None
|
||||
files_update: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditResult:
|
||||
path: str
|
||||
occurrences: int = 0
|
||||
error: str | None = None
|
||||
files_update: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SharedDockerBackend:
|
||||
"""여러 SubAgent가 공유하는 Docker 작업공간 백엔드.
|
||||
|
||||
Args:
|
||||
config: Docker 설정
|
||||
container_id: 기존 컨테이너 ID (재사용 시)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DockerConfig | None = None,
|
||||
container_id: str | None = None,
|
||||
) -> None:
|
||||
self.config = config or DockerConfig()
|
||||
self._container_id = container_id
|
||||
self._docker_client: Any = None
|
||||
|
||||
def _get_docker_client(self) -> Any:
|
||||
if self._docker_client is None:
|
||||
try:
|
||||
import docker
|
||||
|
||||
self._docker_client = docker.from_env()
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"docker 패키지가 설치되지 않았습니다: pip install docker"
|
||||
)
|
||||
return self._docker_client
|
||||
|
||||
def _ensure_container(self) -> str:
|
||||
if self._container_id:
|
||||
return self._container_id
|
||||
|
||||
client = self._get_docker_client()
|
||||
container = client.containers.run(
|
||||
self.config.image,
|
||||
command="tail -f /dev/null",
|
||||
detach=True,
|
||||
mem_limit=self.config.memory_limit,
|
||||
nano_cpus=int(self.config.cpu_limit * 1e9),
|
||||
network_mode=self.config.network_mode,
|
||||
auto_remove=self.config.auto_remove,
|
||||
)
|
||||
self._container_id = container.id
|
||||
return container.id
|
||||
|
||||
def execute(self, command: str) -> ExecuteResponse:
|
||||
try:
|
||||
container_id = self._ensure_container()
|
||||
client = self._get_docker_client()
|
||||
container = client.containers.get(container_id)
|
||||
|
||||
exec_result = container.exec_run(
|
||||
command,
|
||||
workdir=self.config.workspace_path,
|
||||
)
|
||||
|
||||
output = exec_result.output.decode("utf-8", errors="replace")
|
||||
truncated = len(output) > 100000
|
||||
if truncated:
|
||||
output = output[:100000] + "\n[출력이 잘렸습니다...]"
|
||||
|
||||
return ExecuteResponse(
|
||||
output=output,
|
||||
exit_code=exec_result.exit_code,
|
||||
truncated=truncated,
|
||||
)
|
||||
except Exception as e:
|
||||
return ExecuteResponse(
|
||||
output="",
|
||||
exit_code=1,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
async def aexecute(self, command: str) -> ExecuteResponse:
|
||||
return self.execute(command)
|
||||
|
||||
def read(self, path: str, offset: int = 0, limit: int = 500) -> str:
|
||||
full_path = f"{self.config.workspace_path}{path}"
|
||||
result = self.execute(f"sed -n '{offset + 1},{offset + limit}p' {full_path}")
|
||||
|
||||
if result.error:
|
||||
return f"파일 읽기 오류: {result.error}"
|
||||
|
||||
return result.output
|
||||
|
||||
async def aread(self, path: str, offset: int = 0, limit: int = 500) -> str:
|
||||
return self.read(path, offset, limit)
|
||||
|
||||
def write(self, path: str, content: str) -> WriteResult:
|
||||
full_path = f"{self.config.workspace_path}{path}"
|
||||
|
||||
dir_path = "/".join(full_path.split("/")[:-1])
|
||||
self.execute(f"mkdir -p {dir_path}")
|
||||
|
||||
escaped_content = content.replace("'", "'\\''")
|
||||
result = self.execute(f"echo '{escaped_content}' > {full_path}")
|
||||
|
||||
if result.exit_code != 0:
|
||||
return WriteResult(path=path, error=result.output or result.error)
|
||||
|
||||
return WriteResult(path=path)
|
||||
|
||||
async def awrite(self, path: str, content: str) -> WriteResult:
|
||||
return self.write(path, content)
|
||||
|
||||
def ls_info(self, path: str) -> list[dict[str, Any]]:
|
||||
full_path = f"{self.config.workspace_path}{path}"
|
||||
result = self.execute(f"ls -la {full_path}")
|
||||
|
||||
if result.error:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for line in result.output.strip().split("\n")[1:]:
|
||||
parts = line.split()
|
||||
if len(parts) >= 9:
|
||||
name = " ".join(parts[8:])
|
||||
files.append(
|
||||
{
|
||||
"path": f"{path}/{name}".replace("//", "/"),
|
||||
"is_dir": line.startswith("d"),
|
||||
}
|
||||
)
|
||||
|
||||
return files
|
||||
|
||||
async def als_info(self, path: str) -> list[dict[str, Any]]:
|
||||
return self.ls_info(path)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self._container_id and self._docker_client:
|
||||
try:
|
||||
container = self._docker_client.containers.get(self._container_id)
|
||||
container.stop()
|
||||
except Exception:
|
||||
pass
|
||||
self._container_id = None
|
||||
|
||||
def __enter__(self) -> "SharedDockerBackend":
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
self.cleanup()
|
||||
|
||||
|
||||
SHARED_WORKSPACE_DESIGN_DOC = """
|
||||
# 공유 작업공간 설계
|
||||
|
||||
## 문제점
|
||||
|
||||
DeepAgents의 SubAgent들은 독립된 컨텍스트를 가지지만,
|
||||
연구 워크플로우에서는 파일시스템 공유가 필요한 경우가 많습니다.
|
||||
|
||||
예시:
|
||||
1. Research Agent가 수집한 데이터를 Analysis Agent가 처리
|
||||
2. 여러 Agent가 생성한 결과물을 Synthesis Agent가 통합
|
||||
3. 중간 체크포인트를 다른 Agent가 이어서 작업
|
||||
|
||||
## 해결책: SharedDockerBackend
|
||||
|
||||
1. **단일 컨테이너**: 모든 SubAgent가 동일한 Docker 컨테이너 사용
|
||||
2. **격리된 디렉토리**: 각 SubAgent는 자신의 디렉토리에서 작업
|
||||
3. **공유 영역**: /workspace/shared 같은 공유 디렉토리 운영
|
||||
|
||||
## 장점
|
||||
|
||||
- 파일 복사 오버헤드 없음
|
||||
- 실시간 결과 공유
|
||||
- 일관된 실행 환경
|
||||
|
||||
## 단점
|
||||
|
||||
- 컨테이너 수명 관리 필요
|
||||
- 동시 접근 충돌 가능성
|
||||
- 보안 경계가 SubAgent 간에는 없음
|
||||
"""
|
||||
189
context_engineering_research_agent/backends/pyodide_sandbox.py
Normal file
189
context_engineering_research_agent/backends/pyodide_sandbox.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Pyodide 기반 WASM 샌드박스 백엔드.
|
||||
|
||||
WebAssembly 환경에서 Python 코드를 안전하게 실행하기 위한 백엔드입니다.
|
||||
|
||||
## Pyodide란?
|
||||
|
||||
Pyodide는 CPython을 WebAssembly로 컴파일한 프로젝트입니다.
|
||||
브라우저나 Node.js 환경에서 Python 코드를 실행할 수 있습니다.
|
||||
|
||||
## 보안 모델
|
||||
|
||||
WASM 샌드박스는 다음과 같은 격리를 제공합니다:
|
||||
- 호스트 파일시스템 접근 불가
|
||||
- 네트워크 접근 제한 (JavaScript API 통해서만)
|
||||
- 메모리 격리
|
||||
|
||||
## 한계
|
||||
|
||||
1. 네이티브 C 확장 라이브러리 제한적 지원
|
||||
2. 성능 오버헤드 (네이티브 대비 ~3-10x 느림)
|
||||
3. 초기 로딩 시간 (Pyodide 런타임 + 패키지)
|
||||
|
||||
## 권장 사용 사례
|
||||
|
||||
- 신뢰할 수 없는 사용자 코드 실행
|
||||
- 브라우저 기반 Python 환경
|
||||
- 격리된 데이터 분석 작업
|
||||
|
||||
## 사용 예시 (JavaScript 환경)
|
||||
|
||||
```javascript
|
||||
const pyodide = await loadPyodide();
|
||||
await pyodide.loadPackagesFromImports(pythonCode);
|
||||
const result = await pyodide.runPythonAsync(pythonCode);
|
||||
```
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class PyodideRuntime(Protocol):
|
||||
def runPython(self, code: str) -> str: ...
|
||||
async def runPythonAsync(self, code: str) -> str: ...
|
||||
def loadPackagesFromImports(self, code: str) -> None: ...
|
||||
|
||||
|
||||
@dataclass
|
||||
class PyodideConfig:
|
||||
timeout_seconds: int = 30
|
||||
max_memory_mb: int = 512
|
||||
allowed_packages: list[str] = field(
|
||||
default_factory=lambda: [
|
||||
"numpy",
|
||||
"pandas",
|
||||
"scipy",
|
||||
"matplotlib",
|
||||
"scikit-learn",
|
||||
]
|
||||
)
|
||||
enable_network: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteResponse:
|
||||
output: str
|
||||
exit_code: int | None = None
|
||||
truncated: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class PyodideSandboxBackend:
|
||||
"""Pyodide WASM 샌드박스 백엔드.
|
||||
|
||||
Note: 이 클래스는 설계 문서입니다.
|
||||
실제 구현은 JavaScript/TypeScript 환경에서 이루어집니다.
|
||||
|
||||
Python에서 직접 Pyodide를 실행하려면 별도의 subprocess나
|
||||
JS 런타임(Node.js) 연동이 필요합니다.
|
||||
"""
|
||||
|
||||
def __init__(self, config: PyodideConfig | None = None) -> None:
|
||||
self.config = config or PyodideConfig()
|
||||
self._runtime: PyodideRuntime | None = None
|
||||
|
||||
def execute(self, code: str) -> ExecuteResponse:
|
||||
"""Python 코드를 WASM 샌드박스에서 실행합니다.
|
||||
|
||||
실제 구현에서는 Node.js subprocess나
|
||||
WebWorker를 통해 Pyodide를 실행합니다.
|
||||
"""
|
||||
return ExecuteResponse(
|
||||
output="Pyodide 실행은 JavaScript 환경에서만 지원됩니다.",
|
||||
exit_code=1,
|
||||
error="NotImplemented: Python에서 직접 Pyodide 실행 불가",
|
||||
)
|
||||
|
||||
async def aexecute(self, code: str) -> ExecuteResponse:
|
||||
return self.execute(code)
|
||||
|
||||
def get_pyodide_js_code(self, python_code: str) -> str:
|
||||
"""주어진 Python 코드를 실행하는 JavaScript 코드를 생성합니다.
|
||||
|
||||
이 JavaScript 코드를 브라우저나 Node.js에서 실행하면
|
||||
Pyodide 환경에서 Python 코드가 실행됩니다.
|
||||
"""
|
||||
escaped_code = python_code.replace("`", "\\`").replace("$", "\\$")
|
||||
|
||||
return f"""
|
||||
import {{ loadPyodide }} from "pyodide";
|
||||
|
||||
async function runPythonInSandbox() {{
|
||||
const pyodide = await loadPyodide();
|
||||
|
||||
const pythonCode = `{escaped_code}`;
|
||||
|
||||
await pyodide.loadPackagesFromImports(pythonCode);
|
||||
|
||||
try {{
|
||||
const result = await pyodide.runPythonAsync(pythonCode);
|
||||
return {{ success: true, result }};
|
||||
}} catch (error) {{
|
||||
return {{ success: false, error: error.message }};
|
||||
}}
|
||||
}}
|
||||
|
||||
runPythonInSandbox().then(console.log);
|
||||
"""
|
||||
|
||||
|
||||
PYODIDE_DESIGN_DOC = """
|
||||
# Pyodide 기반 안전한 코드 실행 설계
|
||||
|
||||
## 아키텍처
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Python Agent │
|
||||
│ (LangGraph/DeepAgents) │
|
||||
└─────────────────┬───────────────────────────────────────────┘
|
||||
│ execute() 호출
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ PyodideSandboxBackend │
|
||||
│ - 코드 전처리 │
|
||||
│ - 보안 검증 │
|
||||
│ - JS 코드 생성 │
|
||||
└─────────────────┬───────────────────────────────────────────┘
|
||||
│ subprocess 또는 IPC
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Node.js Runtime │
|
||||
│ ┌───────────────────────────────────────────────────────┐ │
|
||||
│ │ WebWorker │ │
|
||||
│ │ ┌─────────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Pyodide (WASM) │ │ │
|
||||
│ │ │ - Python 인터프리터 │ │ │
|
||||
│ │ │ - 제한된 패키지 │ │ │
|
||||
│ │ │ - 격리된 메모리 │ │ │
|
||||
│ │ └─────────────────────────────────────────────────┘ │ │
|
||||
│ └───────────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Docker vs Pyodide 비교
|
||||
|
||||
| 측면 | Docker | Pyodide (WASM) |
|
||||
|------|--------|----------------|
|
||||
| 격리 수준 | 컨테이너 (OS 레벨) | WASM 샌드박스 |
|
||||
| 시작 시간 | ~1-2초 | ~3-5초 (최초), 이후 빠름 |
|
||||
| 메모리 | 높음 | 낮음 |
|
||||
| 패키지 지원 | 완전 | 제한적 |
|
||||
| 보안 | 높음 | 매우 높음 |
|
||||
| 호스트 접근 | 마운트 통해 가능 | 불가 |
|
||||
|
||||
## 권장 사용 시나리오
|
||||
|
||||
### Docker 사용
|
||||
- 복잡한 라이브러리 필요
|
||||
- 파일 I/O 필요
|
||||
- 장시간 실행 작업
|
||||
|
||||
### Pyodide 사용
|
||||
- 간단한 계산/분석
|
||||
- 신뢰할 수 없는 코드
|
||||
- 브라우저 환경
|
||||
- 빠른 피드백 필요
|
||||
"""
|
||||
@@ -0,0 +1,25 @@
|
||||
"""워크스페이스 경로 및 파일 통신 프로토콜 유틸리티."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import posixpath
|
||||
|
||||
WORKSPACE_ROOT = "/workspace"
|
||||
META_DIR = "_meta"
|
||||
SHARED_DIR = "shared"
|
||||
|
||||
|
||||
def _sanitize_segment(segment: str) -> str:
|
||||
return segment.strip().strip("/")
|
||||
|
||||
|
||||
def get_subagent_dir(subagent_type: str) -> str:
|
||||
"""SubAgent별 전용 작업 디렉토리를 반환합니다."""
|
||||
safe_segment = _sanitize_segment(subagent_type)
|
||||
return posixpath.join(WORKSPACE_ROOT, safe_segment)
|
||||
|
||||
|
||||
def get_result_path(subagent_type: str) -> str:
|
||||
"""SubAgent 결과 파일 경로를 반환합니다."""
|
||||
safe_segment = _sanitize_segment(subagent_type)
|
||||
return posixpath.join(WORKSPACE_ROOT, META_DIR, safe_segment, "result.json")
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Context Engineering 전략 모듈.
|
||||
|
||||
DeepAgents에서 구현된 5가지 Context Engineering 핵심 전략을
|
||||
명시적으로 분리하고 문서화한 모듈입니다.
|
||||
|
||||
각 전략은 독립적으로 사용하거나 조합하여 사용할 수 있습니다.
|
||||
"""
|
||||
|
||||
from context_engineering_research_agent.context_strategies.caching import (
|
||||
ContextCachingStrategy,
|
||||
OpenRouterSubProvider,
|
||||
ProviderType,
|
||||
detect_openrouter_sub_provider,
|
||||
detect_provider,
|
||||
requires_cache_control_marker,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.caching_telemetry import (
|
||||
CacheTelemetry,
|
||||
PromptCachingTelemetryMiddleware,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.isolation import (
|
||||
ContextIsolationStrategy,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.offloading import (
|
||||
ContextOffloadingStrategy,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.reduction import (
|
||||
ContextReductionStrategy,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.retrieval import (
|
||||
ContextRetrievalStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContextOffloadingStrategy",
|
||||
"ContextReductionStrategy",
|
||||
"ContextRetrievalStrategy",
|
||||
"ContextIsolationStrategy",
|
||||
"ContextCachingStrategy",
|
||||
"ProviderType",
|
||||
"OpenRouterSubProvider",
|
||||
"detect_provider",
|
||||
"detect_openrouter_sub_provider",
|
||||
"requires_cache_control_marker",
|
||||
"CacheTelemetry",
|
||||
"PromptCachingTelemetryMiddleware",
|
||||
]
|
||||
408
context_engineering_research_agent/context_strategies/caching.py
Normal file
408
context_engineering_research_agent/context_strategies/caching.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""Context Caching 전략 구현.
|
||||
|
||||
Multi-Provider Prompt Caching 전략을 구현합니다.
|
||||
각 Provider별 특성에 맞게 캐싱을 적용합니다:
|
||||
|
||||
## Direct Provider Access
|
||||
- Anthropic: Explicit caching (cache_control 마커 필요), Write 1.25x, Read 0.1x
|
||||
- OpenAI: Automatic caching (자동), Read 0.5x, 1024+ tokens
|
||||
- Gemini 2.5: Implicit caching (자동), Read 0.25x, 1028-2048+ tokens
|
||||
- Gemini 3: Implicit caching (자동), Read 0.1x (90% 할인), 1024-4096+ tokens
|
||||
|
||||
## OpenRouter (Multi-Model Gateway)
|
||||
OpenRouter는 기반 모델에 따라 다른 caching 방식 적용:
|
||||
- Anthropic Claude: Explicit (cache_control 필요), Write 1.25x, Read 0.1x
|
||||
- OpenAI: Automatic (자동), Read 0.5x
|
||||
- Google Gemini: Implicit + Explicit 지원, Read 0.25x
|
||||
- DeepSeek: Automatic (자동), Write 1x, Read 0.1x
|
||||
- Groq (Kimi K2): Automatic (자동), Read 0.25x
|
||||
- Grok (xAI): Automatic (자동), Read 0.25x
|
||||
|
||||
DeepAgents의 AnthropicPromptCachingMiddleware와 함께 사용 권장.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage, SystemMessage
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
"""LLM Provider 유형."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
GEMINI = "gemini"
|
||||
GEMINI_3 = "gemini_3"
|
||||
OPENROUTER = "openrouter"
|
||||
DEEPSEEK = "deepseek"
|
||||
GROQ = "groq"
|
||||
GROK = "grok"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class OpenRouterSubProvider(Enum):
|
||||
"""OpenRouter를 통해 접근하는 기반 모델 Provider."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
GEMINI = "gemini"
|
||||
DEEPSEEK = "deepseek"
|
||||
GROQ = "groq"
|
||||
GROK = "grok"
|
||||
META_LLAMA = "meta-llama"
|
||||
MISTRAL = "mistral"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
PROVIDERS_REQUIRING_CACHE_CONTROL = {
|
||||
ProviderType.ANTHROPIC,
|
||||
OpenRouterSubProvider.ANTHROPIC,
|
||||
}
|
||||
|
||||
PROVIDERS_WITH_AUTOMATIC_CACHING = {
|
||||
ProviderType.OPENAI,
|
||||
ProviderType.GEMINI,
|
||||
ProviderType.GEMINI_3,
|
||||
ProviderType.DEEPSEEK,
|
||||
ProviderType.GROQ,
|
||||
ProviderType.GROK,
|
||||
OpenRouterSubProvider.OPENAI,
|
||||
OpenRouterSubProvider.DEEPSEEK,
|
||||
OpenRouterSubProvider.GROQ,
|
||||
OpenRouterSubProvider.GROK,
|
||||
}
|
||||
|
||||
|
||||
def detect_provider(model: BaseChatModel | None) -> ProviderType:
|
||||
"""모델 객체에서 Provider 유형을 감지합니다."""
|
||||
if model is None:
|
||||
return ProviderType.UNKNOWN
|
||||
|
||||
class_name = model.__class__.__name__.lower()
|
||||
module_name = model.__class__.__module__.lower()
|
||||
|
||||
if "anthropic" in class_name or "anthropic" in module_name:
|
||||
return ProviderType.ANTHROPIC
|
||||
|
||||
if "openai" in class_name or "openai" in module_name:
|
||||
base_url = _get_base_url(model)
|
||||
if "openrouter" in base_url:
|
||||
return ProviderType.OPENROUTER
|
||||
return ProviderType.OPENAI
|
||||
|
||||
if "google" in class_name or "gemini" in class_name or "google" in module_name:
|
||||
model_name = _get_model_name(model)
|
||||
if "gemini-3" in model_name or "gemini/3" in model_name:
|
||||
return ProviderType.GEMINI_3
|
||||
return ProviderType.GEMINI
|
||||
|
||||
if "deepseek" in class_name or "deepseek" in module_name:
|
||||
return ProviderType.DEEPSEEK
|
||||
|
||||
if "groq" in class_name or "groq" in module_name:
|
||||
return ProviderType.GROQ
|
||||
|
||||
return ProviderType.UNKNOWN
|
||||
|
||||
|
||||
def _get_base_url(model: BaseChatModel) -> str:
|
||||
for attr in ("openai_api_base", "base_url", "api_base"):
|
||||
if hasattr(model, attr):
|
||||
url = getattr(model, attr, "") or ""
|
||||
if url:
|
||||
return url.lower()
|
||||
return ""
|
||||
|
||||
|
||||
def _get_model_name(model: BaseChatModel) -> str:
|
||||
for attr in ("model_name", "model", "model_id"):
|
||||
if hasattr(model, attr):
|
||||
name = getattr(model, attr, "") or ""
|
||||
if name:
|
||||
return name.lower()
|
||||
return ""
|
||||
|
||||
|
||||
def detect_openrouter_sub_provider(model_name: str) -> OpenRouterSubProvider:
|
||||
"""OpenRouter 모델명에서 기반 Provider를 감지합니다.
|
||||
|
||||
OpenRouter 모델명 패턴: "provider/model-name" (예: "anthropic/claude-3-sonnet")
|
||||
"""
|
||||
name_lower = model_name.lower()
|
||||
|
||||
if "anthropic" in name_lower or "claude" in name_lower:
|
||||
return OpenRouterSubProvider.ANTHROPIC
|
||||
if "openai" in name_lower or "gpt" in name_lower or name_lower.startswith("o1"):
|
||||
return OpenRouterSubProvider.OPENAI
|
||||
if "google" in name_lower or "gemini" in name_lower:
|
||||
return OpenRouterSubProvider.GEMINI
|
||||
if "deepseek" in name_lower:
|
||||
return OpenRouterSubProvider.DEEPSEEK
|
||||
if "groq" in name_lower or "kimi" in name_lower:
|
||||
return OpenRouterSubProvider.GROQ
|
||||
if "grok" in name_lower or "xai" in name_lower:
|
||||
return OpenRouterSubProvider.GROK
|
||||
if "meta" in name_lower or "llama" in name_lower:
|
||||
return OpenRouterSubProvider.META_LLAMA
|
||||
if "mistral" in name_lower:
|
||||
return OpenRouterSubProvider.MISTRAL
|
||||
|
||||
return OpenRouterSubProvider.UNKNOWN
|
||||
|
||||
|
||||
def requires_cache_control_marker(
|
||||
provider: ProviderType,
|
||||
sub_provider: OpenRouterSubProvider | None = None,
|
||||
) -> bool:
|
||||
"""해당 Provider가 cache_control 마커를 필요로 하는지 확인합니다.
|
||||
|
||||
Anthropic (직접 또는 OpenRouter 경유) 만 True 반환.
|
||||
"""
|
||||
if provider == ProviderType.ANTHROPIC:
|
||||
return True
|
||||
if (
|
||||
provider == ProviderType.OPENROUTER
|
||||
and sub_provider == OpenRouterSubProvider.ANTHROPIC
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachingConfig:
|
||||
"""Context Caching 설정.
|
||||
|
||||
Attributes:
|
||||
min_cacheable_tokens: 캐싱할 최소 토큰 수 (기본값: 1024)
|
||||
cache_control_type: 캐시 제어 유형 (기본값: "ephemeral")
|
||||
enable_for_system_prompt: 시스템 프롬프트 캐싱 활성화 (기본값: True)
|
||||
enable_for_tools: 도구 정의 캐싱 활성화 (기본값: True)
|
||||
"""
|
||||
|
||||
min_cacheable_tokens: int = 1024
|
||||
cache_control_type: str = "ephemeral"
|
||||
enable_for_system_prompt: bool = True
|
||||
enable_for_tools: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachingResult:
|
||||
"""Context Caching 결과.
|
||||
|
||||
Attributes:
|
||||
was_cached: 캐싱이 적용되었는지 여부
|
||||
cached_content_type: 캐싱된 컨텐츠 유형 (예: "system_prompt")
|
||||
estimated_tokens_cached: 캐싱된 추정 토큰 수
|
||||
"""
|
||||
|
||||
was_cached: bool
|
||||
cached_content_type: str | None = None
|
||||
estimated_tokens_cached: int = 0
|
||||
|
||||
|
||||
class ContextCachingStrategy(AgentMiddleware):
|
||||
"""Multi-Provider Prompt Caching 전략.
|
||||
|
||||
Anthropic (직접 또는 OpenRouter 경유)만 cache_control 마커를 적용하고,
|
||||
OpenAI/Gemini/DeepSeek/Groq/Grok은 자동 캐싱이므로 pass-through합니다.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: CachingConfig | None = None,
|
||||
model: BaseChatModel | None = None,
|
||||
openrouter_model_name: str | None = None,
|
||||
) -> None:
|
||||
self.config = config or CachingConfig()
|
||||
self._model = model
|
||||
self._provider: ProviderType | None = None
|
||||
self._sub_provider: OpenRouterSubProvider | None = None
|
||||
self._openrouter_model_name = openrouter_model_name
|
||||
|
||||
def set_model(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
openrouter_model_name: str | None = None,
|
||||
) -> None:
|
||||
"""런타임에 모델을 설정합니다."""
|
||||
self._model = model
|
||||
self._provider = None
|
||||
self._sub_provider = None
|
||||
if openrouter_model_name:
|
||||
self._openrouter_model_name = openrouter_model_name
|
||||
|
||||
@property
|
||||
def provider(self) -> ProviderType:
|
||||
if self._provider is None:
|
||||
self._provider = detect_provider(self._model)
|
||||
return self._provider
|
||||
|
||||
@property
|
||||
def sub_provider(self) -> OpenRouterSubProvider | None:
|
||||
if self.provider != ProviderType.OPENROUTER:
|
||||
return None
|
||||
if self._sub_provider is None and self._openrouter_model_name:
|
||||
self._sub_provider = detect_openrouter_sub_provider(
|
||||
self._openrouter_model_name
|
||||
)
|
||||
return self._sub_provider
|
||||
|
||||
@property
|
||||
def should_apply_cache_markers(self) -> bool:
|
||||
return requires_cache_control_marker(self.provider, self.sub_provider)
|
||||
|
||||
def _add_cache_control(self, content: Any) -> Any:
|
||||
if isinstance(content, str):
|
||||
return {
|
||||
"type": "text",
|
||||
"text": content,
|
||||
"cache_control": {"type": self.config.cache_control_type},
|
||||
}
|
||||
elif isinstance(content, dict) and content.get("type") == "text":
|
||||
return {
|
||||
**content,
|
||||
"cache_control": {"type": self.config.cache_control_type},
|
||||
}
|
||||
elif isinstance(content, list):
|
||||
if not content:
|
||||
return content
|
||||
result = list(content)
|
||||
last_item = result[-1]
|
||||
if isinstance(last_item, dict):
|
||||
result[-1] = {
|
||||
**last_item,
|
||||
"cache_control": {"type": self.config.cache_control_type},
|
||||
}
|
||||
return result
|
||||
return content
|
||||
|
||||
def _process_system_message(self, message: SystemMessage) -> SystemMessage:
|
||||
cached_content = self._add_cache_control(message.content)
|
||||
# Ensure cached_content is a list of dicts for SystemMessage compatibility
|
||||
if isinstance(cached_content, str):
|
||||
cached_content = [{"type": "text", "text": cached_content}]
|
||||
elif isinstance(cached_content, dict):
|
||||
cached_content = [cached_content]
|
||||
# Type ignore is needed due to complex SystemMessage content type requirements
|
||||
return SystemMessage(content=cached_content) # type: ignore[arg-type]
|
||||
|
||||
def _estimate_tokens(self, content: Any) -> int:
|
||||
if isinstance(content, str):
|
||||
return len(content) // 4
|
||||
elif isinstance(content, list):
|
||||
return sum(self._estimate_tokens(item) for item in content)
|
||||
elif isinstance(content, dict):
|
||||
return self._estimate_tokens(content.get("text", ""))
|
||||
return 0
|
||||
|
||||
def _should_cache(self, content: Any) -> bool:
|
||||
estimated_tokens = self._estimate_tokens(content)
|
||||
return estimated_tokens >= self.config.min_cacheable_tokens
|
||||
|
||||
def apply_caching(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
model: BaseChatModel | None = None,
|
||||
openrouter_model_name: str | None = None,
|
||||
) -> tuple[list[BaseMessage], CachingResult]:
|
||||
"""메시지 리스트에 캐싱을 적용합니다.
|
||||
|
||||
Anthropic (직접 또는 OpenRouter 경유)만 cache_control 마커를 적용합니다.
|
||||
다른 provider는 자동 캐싱이므로 메시지를 변형하지 않습니다.
|
||||
|
||||
Args:
|
||||
messages: 처리할 메시지 리스트
|
||||
model: LLM 모델 (Provider 감지용)
|
||||
openrouter_model_name: OpenRouter 사용 시 모델명 (예: "anthropic/claude-3-sonnet")
|
||||
|
||||
Returns:
|
||||
캐싱이 적용된 메시지 리스트와 캐싱 결과 튜플
|
||||
"""
|
||||
if model is not None:
|
||||
self.set_model(model, openrouter_model_name)
|
||||
|
||||
if not messages:
|
||||
return messages, CachingResult(was_cached=False)
|
||||
|
||||
if not self.should_apply_cache_markers:
|
||||
provider_info = self.provider.value
|
||||
if self.provider == ProviderType.OPENROUTER and self.sub_provider:
|
||||
provider_info = f"openrouter/{self.sub_provider.value}"
|
||||
return messages, CachingResult(
|
||||
was_cached=False,
|
||||
cached_content_type=f"auto_cached_by_{provider_info}",
|
||||
)
|
||||
|
||||
result_messages = list(messages)
|
||||
cached = False
|
||||
cached_type = None
|
||||
tokens_cached = 0
|
||||
|
||||
for i, msg in enumerate(result_messages):
|
||||
if isinstance(msg, SystemMessage):
|
||||
if self.config.enable_for_system_prompt and self._should_cache(
|
||||
msg.content
|
||||
):
|
||||
result_messages[i] = self._process_system_message(msg)
|
||||
cached = True
|
||||
cached_type = "system_prompt"
|
||||
tokens_cached = self._estimate_tokens(msg.content)
|
||||
break
|
||||
|
||||
return result_messages, CachingResult(
|
||||
was_cached=cached,
|
||||
cached_content_type=cached_type,
|
||||
estimated_tokens_cached=tokens_cached,
|
||||
)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""동기 모델 호출을 래핑합니다 (기본 동작).
|
||||
|
||||
Args:
|
||||
request: 모델 요청
|
||||
handler: 다음 핸들러 함수
|
||||
|
||||
Returns:
|
||||
모델 응답
|
||||
"""
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""비동기 모델 호출을 래핑합니다 (기본 동작).
|
||||
|
||||
Args:
|
||||
request: 모델 요청
|
||||
handler: 다음 핸들러 함수
|
||||
|
||||
Returns:
|
||||
모델 응답
|
||||
"""
|
||||
return await handler(request)
|
||||
|
||||
|
||||
CACHING_SYSTEM_PROMPT = """## Context Caching
|
||||
|
||||
시스템 프롬프트와 도구 정의가 자동으로 캐싱됩니다.
|
||||
|
||||
이점:
|
||||
- API 호출 비용 절감
|
||||
- 응답 속도 향상
|
||||
- 동일 세션 내 반복 호출 최적화
|
||||
"""
|
||||
@@ -0,0 +1,244 @@
|
||||
"""Prompt Caching Telemetry Middleware.
|
||||
|
||||
모든 Provider의 cache 사용량을 모니터링하고 로깅합니다.
|
||||
요청 변형 없이 응답의 cache 관련 메타데이터만 수집합니다.
|
||||
|
||||
Provider별 캐시 메타데이터 위치:
|
||||
- Anthropic: cache_read_input_tokens, cache_creation_input_tokens
|
||||
- OpenAI: cached_tokens (usage.prompt_tokens_details)
|
||||
- Gemini 2.5/3: cached_content_token_count
|
||||
- DeepSeek: cache_hit_tokens, cache_miss_tokens
|
||||
- OpenRouter: 기반 모델의 메타데이터 형식 따름
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
from context_engineering_research_agent.context_strategies.caching import (
|
||||
ProviderType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheTelemetry:
|
||||
"""Provider별 캐시 사용량 데이터."""
|
||||
|
||||
provider: ProviderType
|
||||
cache_read_tokens: int = 0
|
||||
cache_write_tokens: int = 0
|
||||
total_input_tokens: int = 0
|
||||
cache_hit_ratio: float = 0.0
|
||||
raw_metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
def extract_anthropic_cache_metrics(response: ModelResponse) -> CacheTelemetry:
|
||||
"""Anthropic 응답에서 캐시 메트릭을 추출합니다."""
|
||||
usage = getattr(response, "usage_metadata", {}) or {}
|
||||
response_meta = getattr(response, "response_metadata", {}) or {}
|
||||
usage_from_meta = response_meta.get("usage", {})
|
||||
|
||||
cache_read = usage_from_meta.get("cache_read_input_tokens", 0)
|
||||
cache_creation = usage_from_meta.get("cache_creation_input_tokens", 0)
|
||||
input_tokens = usage.get("input_tokens", 0) or usage_from_meta.get(
|
||||
"input_tokens", 0
|
||||
)
|
||||
|
||||
hit_ratio = cache_read / input_tokens if input_tokens > 0 else 0.0
|
||||
|
||||
return CacheTelemetry(
|
||||
provider=ProviderType.ANTHROPIC,
|
||||
cache_read_tokens=cache_read,
|
||||
cache_write_tokens=cache_creation,
|
||||
total_input_tokens=input_tokens,
|
||||
cache_hit_ratio=hit_ratio,
|
||||
raw_metadata={"usage": usage, "response_metadata": response_meta},
|
||||
)
|
||||
|
||||
|
||||
def extract_openai_cache_metrics(response: ModelResponse) -> CacheTelemetry:
|
||||
"""OpenAI 응답에서 캐시 메트릭을 추출합니다."""
|
||||
usage = getattr(response, "usage_metadata", {}) or {}
|
||||
response_meta = getattr(response, "response_metadata", {}) or {}
|
||||
|
||||
token_usage = response_meta.get("token_usage", {})
|
||||
prompt_details = token_usage.get("prompt_tokens_details", {})
|
||||
cached_tokens = prompt_details.get("cached_tokens", 0)
|
||||
input_tokens = usage.get("input_tokens", 0) or token_usage.get("prompt_tokens", 0)
|
||||
|
||||
hit_ratio = cached_tokens / input_tokens if input_tokens > 0 else 0.0
|
||||
|
||||
return CacheTelemetry(
|
||||
provider=ProviderType.OPENAI,
|
||||
cache_read_tokens=cached_tokens,
|
||||
cache_write_tokens=0,
|
||||
total_input_tokens=input_tokens,
|
||||
cache_hit_ratio=hit_ratio,
|
||||
raw_metadata={"usage": usage, "token_usage": token_usage},
|
||||
)
|
||||
|
||||
|
||||
def extract_gemini_cache_metrics(
|
||||
response: ModelResponse, provider: ProviderType = ProviderType.GEMINI
|
||||
) -> CacheTelemetry:
|
||||
"""Gemini 2.5/3 응답에서 캐시 메트릭을 추출합니다."""
|
||||
usage = getattr(response, "usage_metadata", {}) or {}
|
||||
response_meta = getattr(response, "response_metadata", {}) or {}
|
||||
|
||||
cached_tokens = response_meta.get("cached_content_token_count", 0)
|
||||
input_tokens = usage.get("input_tokens", 0) or response_meta.get(
|
||||
"prompt_token_count", 0
|
||||
)
|
||||
|
||||
hit_ratio = cached_tokens / input_tokens if input_tokens > 0 else 0.0
|
||||
|
||||
return CacheTelemetry(
|
||||
provider=provider,
|
||||
cache_read_tokens=cached_tokens,
|
||||
cache_write_tokens=0,
|
||||
total_input_tokens=input_tokens,
|
||||
cache_hit_ratio=hit_ratio,
|
||||
raw_metadata={"usage": usage, "response_metadata": response_meta},
|
||||
)
|
||||
|
||||
|
||||
def extract_deepseek_cache_metrics(response: ModelResponse) -> CacheTelemetry:
|
||||
"""DeepSeek 응답에서 캐시 메트릭을 추출합니다."""
|
||||
usage = getattr(response, "usage_metadata", {}) or {}
|
||||
response_meta = getattr(response, "response_metadata", {}) or {}
|
||||
|
||||
cache_hit = response_meta.get("cache_hit_tokens", 0)
|
||||
cache_miss = response_meta.get("cache_miss_tokens", 0)
|
||||
input_tokens = usage.get("input_tokens", 0) or (cache_hit + cache_miss)
|
||||
|
||||
hit_ratio = cache_hit / input_tokens if input_tokens > 0 else 0.0
|
||||
|
||||
return CacheTelemetry(
|
||||
provider=ProviderType.DEEPSEEK,
|
||||
cache_read_tokens=cache_hit,
|
||||
cache_write_tokens=cache_miss,
|
||||
total_input_tokens=input_tokens,
|
||||
cache_hit_ratio=hit_ratio,
|
||||
raw_metadata={"usage": usage, "response_metadata": response_meta},
|
||||
)
|
||||
|
||||
|
||||
def extract_cache_telemetry(
|
||||
response: ModelResponse, provider: ProviderType
|
||||
) -> CacheTelemetry:
|
||||
"""응답에서 Provider별 캐시 텔레메트리를 추출합니다."""
|
||||
extractors: dict[ProviderType, Callable[[ModelResponse], CacheTelemetry]] = {
|
||||
ProviderType.ANTHROPIC: extract_anthropic_cache_metrics,
|
||||
ProviderType.OPENAI: extract_openai_cache_metrics,
|
||||
ProviderType.GEMINI: extract_gemini_cache_metrics,
|
||||
ProviderType.GEMINI_3: lambda r: extract_gemini_cache_metrics(
|
||||
r, ProviderType.GEMINI_3
|
||||
),
|
||||
ProviderType.DEEPSEEK: extract_deepseek_cache_metrics,
|
||||
}
|
||||
|
||||
extractor = extractors.get(provider)
|
||||
if extractor:
|
||||
return extractor(response)
|
||||
|
||||
return CacheTelemetry(
|
||||
provider=provider,
|
||||
raw_metadata={
|
||||
"usage": getattr(response, "usage_metadata", {}),
|
||||
"response_metadata": getattr(response, "response_metadata", {}),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class PromptCachingTelemetryMiddleware(AgentMiddleware):
|
||||
"""모든 Provider의 캐시 사용량을 모니터링하는 Middleware.
|
||||
|
||||
요청을 변형하지 않고, 응답의 cache 관련 메타데이터만 수집/로깅합니다.
|
||||
"""
|
||||
|
||||
def __init__(self, log_level: int = logging.DEBUG) -> None:
|
||||
self._log_level = log_level
|
||||
self._telemetry_history: list[CacheTelemetry] = []
|
||||
|
||||
@property
|
||||
def telemetry_history(self) -> list[CacheTelemetry]:
|
||||
return self._telemetry_history
|
||||
|
||||
def get_aggregate_stats(self) -> dict[str, Any]:
|
||||
if not self._telemetry_history:
|
||||
return {"total_calls": 0, "total_cache_read_tokens": 0}
|
||||
|
||||
total_read = sum(t.cache_read_tokens for t in self._telemetry_history)
|
||||
total_write = sum(t.cache_write_tokens for t in self._telemetry_history)
|
||||
total_input = sum(t.total_input_tokens for t in self._telemetry_history)
|
||||
|
||||
return {
|
||||
"total_calls": len(self._telemetry_history),
|
||||
"total_cache_read_tokens": total_read,
|
||||
"total_cache_write_tokens": total_write,
|
||||
"total_input_tokens": total_input,
|
||||
"overall_cache_hit_ratio": total_read / total_input if total_input else 0.0,
|
||||
}
|
||||
|
||||
def _log_telemetry(self, telemetry: CacheTelemetry) -> None:
|
||||
if telemetry.cache_read_tokens > 0 or telemetry.cache_write_tokens > 0:
|
||||
logger.log(
|
||||
self._log_level,
|
||||
f"[CacheTelemetry] {telemetry.provider.value}: "
|
||||
f"read={telemetry.cache_read_tokens}, "
|
||||
f"write={telemetry.cache_write_tokens}, "
|
||||
f"hit_ratio={telemetry.cache_hit_ratio:.2%}",
|
||||
)
|
||||
|
||||
def _process_response(self, response: ModelResponse) -> ModelResponse:
|
||||
model = getattr(response, "response_metadata", {}).get("model", "")
|
||||
provider = self._detect_provider_from_response(response, model)
|
||||
telemetry = extract_cache_telemetry(response, provider)
|
||||
self._telemetry_history.append(telemetry)
|
||||
self._log_telemetry(telemetry)
|
||||
return response
|
||||
|
||||
def _detect_provider_from_response(
|
||||
self, response: ModelResponse, model_name: str
|
||||
) -> ProviderType:
|
||||
model_lower = model_name.lower()
|
||||
if "claude" in model_lower or "anthropic" in model_lower:
|
||||
return ProviderType.ANTHROPIC
|
||||
if "gpt" in model_lower or "o1" in model_lower or "o3" in model_lower:
|
||||
return ProviderType.OPENAI
|
||||
if "gemini-3" in model_lower or "gemini/3" in model_lower:
|
||||
return ProviderType.GEMINI_3
|
||||
if "gemini" in model_lower:
|
||||
return ProviderType.GEMINI
|
||||
if "deepseek" in model_lower:
|
||||
return ProviderType.DEEPSEEK
|
||||
if "groq" in model_lower or "kimi" in model_lower:
|
||||
return ProviderType.GROQ
|
||||
if "grok" in model_lower:
|
||||
return ProviderType.GROK
|
||||
return ProviderType.UNKNOWN
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
response = handler(request)
|
||||
return self._process_response(response)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
response = await handler(request)
|
||||
return self._process_response(response)
|
||||
@@ -0,0 +1,234 @@
|
||||
"""Context Isolation 전략 구현.
|
||||
|
||||
SubAgent를 통해 독립된 컨텍스트 윈도우에서 작업을 수행하여
|
||||
메인 에이전트의 컨텍스트를 오염시키지 않는 전략입니다.
|
||||
|
||||
DeepAgents의 SubAgentMiddleware에서 task() 도구로 구현되어 있습니다.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.tools import BaseTool, ToolRuntime
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
from langchain_core.runnables import Runnable
|
||||
from langchain_core.tools import StructuredTool
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
class SubAgentSpec(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
tools: Sequence[BaseTool | Callable | dict[str, Any]]
|
||||
model: NotRequired[str | BaseChatModel]
|
||||
middleware: NotRequired[list[AgentMiddleware]]
|
||||
|
||||
|
||||
class CompiledSubAgentSpec(TypedDict):
|
||||
name: str
|
||||
description: str
|
||||
runnable: Runnable
|
||||
|
||||
|
||||
@dataclass
|
||||
class IsolationConfig:
|
||||
default_model: str | BaseChatModel = "gpt-4.1"
|
||||
include_general_purpose_agent: bool = True
|
||||
excluded_state_keys: tuple[str, ...] = ("messages", "todos", "structured_response")
|
||||
|
||||
|
||||
@dataclass
|
||||
class IsolationResult:
|
||||
subagent_name: str
|
||||
was_successful: bool
|
||||
result_length: int
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ContextIsolationStrategy(AgentMiddleware):
|
||||
"""SubAgent를 통한 Context Isolation 구현.
|
||||
|
||||
Args:
|
||||
config: Isolation 설정
|
||||
subagents: SubAgent 명세 목록
|
||||
agent_factory: 에이전트 생성 팩토리 함수
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: IsolationConfig | None = None,
|
||||
subagents: list[SubAgentSpec | CompiledSubAgentSpec] | None = None,
|
||||
agent_factory: Callable[..., Runnable] | None = None,
|
||||
) -> None:
|
||||
self.config = config or IsolationConfig()
|
||||
self._subagents = subagents or []
|
||||
self._agent_factory = agent_factory
|
||||
self._compiled_agents: dict[str, Runnable] = {}
|
||||
self.tools = [self._create_task_tool()]
|
||||
|
||||
def _compile_subagents(self) -> dict[str, Runnable]:
|
||||
if self._compiled_agents:
|
||||
return self._compiled_agents
|
||||
|
||||
agents: dict[str, Runnable] = {}
|
||||
|
||||
for spec in self._subagents:
|
||||
if "runnable" in spec:
|
||||
compiled = spec # type: ignore
|
||||
agents[compiled["name"]] = compiled["runnable"]
|
||||
elif self._agent_factory:
|
||||
simple = spec # type: ignore
|
||||
agents[simple["name"]] = self._agent_factory(
|
||||
model=simple.get("model", self.config.default_model),
|
||||
system_prompt=simple["system_prompt"],
|
||||
tools=simple["tools"],
|
||||
middleware=simple.get("middleware", []),
|
||||
)
|
||||
|
||||
self._compiled_agents = agents
|
||||
return agents
|
||||
|
||||
def _get_subagent_descriptions(self) -> str:
|
||||
descriptions = []
|
||||
for spec in self._subagents:
|
||||
descriptions.append(f"- {spec['name']}: {spec['description']}")
|
||||
return "\n".join(descriptions)
|
||||
|
||||
def _prepare_subagent_state(
|
||||
self, state: dict[str, Any], task_description: str
|
||||
) -> dict[str, Any]:
|
||||
filtered = {
|
||||
k: v for k, v in state.items() if k not in self.config.excluded_state_keys
|
||||
}
|
||||
filtered["messages"] = [HumanMessage(content=task_description)]
|
||||
return filtered
|
||||
|
||||
def _create_task_tool(self) -> BaseTool:
|
||||
strategy = self
|
||||
|
||||
def task(
|
||||
description: str,
|
||||
subagent_type: str,
|
||||
runtime: ToolRuntime,
|
||||
) -> str | Command:
|
||||
agents = strategy._compile_subagents()
|
||||
|
||||
if subagent_type not in agents:
|
||||
allowed = ", ".join(f"`{k}`" for k in agents)
|
||||
return f"SubAgent '{subagent_type}'가 존재하지 않습니다. 사용 가능: {allowed}"
|
||||
|
||||
subagent = agents[subagent_type]
|
||||
subagent_state = strategy._prepare_subagent_state(
|
||||
runtime.state, description
|
||||
)
|
||||
|
||||
result = subagent.invoke(subagent_state, runtime.config)
|
||||
|
||||
final_message = result["messages"][-1].text.rstrip()
|
||||
|
||||
state_update = {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k not in strategy.config.excluded_state_keys
|
||||
}
|
||||
|
||||
return Command(
|
||||
update={
|
||||
**state_update,
|
||||
"messages": [
|
||||
ToolMessage(final_message, tool_call_id=runtime.tool_call_id)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
async def atask(
|
||||
description: str,
|
||||
subagent_type: str,
|
||||
runtime: ToolRuntime,
|
||||
) -> str | Command:
|
||||
agents = strategy._compile_subagents()
|
||||
|
||||
if subagent_type not in agents:
|
||||
allowed = ", ".join(f"`{k}`" for k in agents)
|
||||
return f"SubAgent '{subagent_type}'가 존재하지 않습니다. 사용 가능: {allowed}"
|
||||
|
||||
subagent = agents[subagent_type]
|
||||
subagent_state = strategy._prepare_subagent_state(
|
||||
runtime.state, description
|
||||
)
|
||||
|
||||
result = await subagent.ainvoke(subagent_state, runtime.config)
|
||||
|
||||
final_message = result["messages"][-1].text.rstrip()
|
||||
|
||||
state_update = {
|
||||
k: v
|
||||
for k, v in result.items()
|
||||
if k not in strategy.config.excluded_state_keys
|
||||
}
|
||||
|
||||
return Command(
|
||||
update={
|
||||
**state_update,
|
||||
"messages": [
|
||||
ToolMessage(final_message, tool_call_id=runtime.tool_call_id)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
subagent_list = self._get_subagent_descriptions()
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="task",
|
||||
func=task,
|
||||
coroutine=atask,
|
||||
description=f"""SubAgent에게 작업을 위임합니다.
|
||||
|
||||
사용 가능한 SubAgent:
|
||||
{subagent_list}
|
||||
|
||||
사용법:
|
||||
- description: 위임할 작업 상세 설명
|
||||
- subagent_type: 사용할 SubAgent 이름
|
||||
|
||||
SubAgent는 독립된 컨텍스트에서 실행되어 메인 에이전트의 컨텍스트를 오염시키지 않습니다.
|
||||
복잡하고 다단계 작업에 적합합니다.""",
|
||||
)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
return await handler(request)
|
||||
|
||||
|
||||
ISOLATION_SYSTEM_PROMPT = """## Context Isolation (task 도구)
|
||||
|
||||
task 도구로 SubAgent에게 작업을 위임할 수 있습니다.
|
||||
|
||||
장점:
|
||||
- 독립된 컨텍스트 윈도우
|
||||
- 메인 컨텍스트 오염 방지
|
||||
- 복잡한 작업의 격리 처리
|
||||
|
||||
사용 시점:
|
||||
- 다단계 복잡한 작업
|
||||
- 대량의 컨텍스트가 필요한 연구
|
||||
- 병렬 처리가 가능한 독립 작업
|
||||
"""
|
||||
@@ -0,0 +1,260 @@
|
||||
"""Context Offloading 전략 구현.
|
||||
|
||||
## 개요
|
||||
|
||||
Context Offloading은 대용량 도구 결과를 파일시스템으로 축출하여
|
||||
컨텍스트 윈도우 오버플로우를 방지하는 전략입니다.
|
||||
|
||||
## 핵심 원리
|
||||
|
||||
1. 도구 실행 결과가 특정 토큰 임계값을 초과하면 자동으로 파일로 저장
|
||||
2. 원본 메시지는 파일 경로 참조로 대체
|
||||
3. 에이전트가 필요할 때 read_file로 데이터 로드
|
||||
|
||||
## DeepAgents 구현
|
||||
|
||||
FilesystemMiddleware의 `_intercept_large_tool_result` 메서드에서 구현:
|
||||
- `tool_token_limit_before_evict`: 축출 임계값 (기본 20,000 토큰)
|
||||
- `/large_tool_results/{tool_call_id}` 경로에 저장
|
||||
- 처음 10줄 미리보기 제공
|
||||
|
||||
## 장점
|
||||
|
||||
- 컨텍스트 윈도우 절약
|
||||
- 대용량 데이터 처리 가능
|
||||
- 선택적 로딩으로 효율성 증가
|
||||
|
||||
## 사용 예시
|
||||
|
||||
```python
|
||||
from deepagents.middleware.filesystem import FilesystemMiddleware
|
||||
|
||||
middleware = FilesystemMiddleware(
|
||||
tool_token_limit_before_evict=15000 # 15,000 토큰 초과 시 축출
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
)
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain.tools.tool_node import ToolCallRequest
|
||||
from langchain_core.messages import ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConfig:
|
||||
"""Context Offloading 설정."""
|
||||
|
||||
token_limit_before_evict: int = 20000
|
||||
"""도구 결과를 파일로 축출하기 전 토큰 임계값. 기본값 20,000."""
|
||||
|
||||
eviction_path_prefix: str = "/large_tool_results"
|
||||
"""축출된 파일이 저장될 경로 접두사."""
|
||||
|
||||
preview_lines: int = 10
|
||||
"""축출 시 포함할 미리보기 줄 수."""
|
||||
|
||||
chars_per_token: int = 4
|
||||
"""토큰당 문자 수 근사값 (보수적 추정)."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingResult:
|
||||
"""Offloading 처리 결과."""
|
||||
|
||||
was_offloaded: bool
|
||||
"""실제로 축출이 발생했는지 여부."""
|
||||
|
||||
original_size: int
|
||||
"""원본 콘텐츠 크기 (문자 수)."""
|
||||
|
||||
file_path: str | None = None
|
||||
"""축출된 파일 경로 (축출된 경우)."""
|
||||
|
||||
preview: str | None = None
|
||||
"""축출 시 제공되는 미리보기."""
|
||||
|
||||
|
||||
class ContextOffloadingStrategy(AgentMiddleware):
|
||||
"""Context Offloading 전략 구현.
|
||||
|
||||
대용량 도구 결과를 파일시스템으로 자동 축출하여
|
||||
컨텍스트 윈도우 오버플로우를 방지합니다.
|
||||
|
||||
## 동작 원리
|
||||
|
||||
1. wrap_tool_call에서 도구 실행 결과를 가로챔
|
||||
2. 결과 크기가 임계값을 초과하면 파일로 저장
|
||||
3. 원본 메시지를 파일 경로 참조로 대체
|
||||
4. 에이전트는 필요시 read_file로 데이터 로드
|
||||
|
||||
## DeepAgents FilesystemMiddleware와의 관계
|
||||
|
||||
이 클래스는 FilesystemMiddleware의 offloading 로직을
|
||||
명시적으로 분리하여 전략 패턴으로 구현한 것입니다.
|
||||
|
||||
Args:
|
||||
config: Offloading 설정. None이면 기본값 사용.
|
||||
backend_factory: 파일 저장용 백엔드 팩토리 함수.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OffloadingConfig | None = None,
|
||||
backend_factory: Callable[[ToolRuntime], Any] | None = None,
|
||||
) -> None:
|
||||
self.config = config or OffloadingConfig()
|
||||
self._backend_factory = backend_factory
|
||||
|
||||
def _estimate_tokens(self, content: str) -> int:
|
||||
"""콘텐츠의 토큰 수를 추정합니다.
|
||||
|
||||
보수적인 추정값을 사용하여 조기 축출을 방지합니다.
|
||||
실제 토큰 수는 모델과 콘텐츠에 따라 다릅니다.
|
||||
"""
|
||||
return len(content) // self.config.chars_per_token
|
||||
|
||||
def _should_offload(self, content: str) -> bool:
|
||||
"""주어진 콘텐츠가 축출 대상인지 판단합니다."""
|
||||
estimated_tokens = self._estimate_tokens(content)
|
||||
return estimated_tokens > self.config.token_limit_before_evict
|
||||
|
||||
def _create_preview(self, content: str) -> str:
|
||||
"""축출될 콘텐츠의 미리보기를 생성합니다."""
|
||||
lines = content.splitlines()[: self.config.preview_lines]
|
||||
truncated_lines = [line[:1000] for line in lines]
|
||||
return "\n".join(f"{i + 1:5}\t{line}" for i, line in enumerate(truncated_lines))
|
||||
|
||||
def _create_offload_message(
|
||||
self, tool_call_id: str, file_path: str, preview: str
|
||||
) -> str:
|
||||
"""축출 후 대체 메시지를 생성합니다."""
|
||||
return f"""도구 결과가 너무 커서 파일시스템에 저장되었습니다.
|
||||
|
||||
경로: {file_path}
|
||||
|
||||
read_file 도구로 결과를 읽을 수 있습니다.
|
||||
대용량 결과의 경우 offset과 limit 파라미터로 부분 읽기를 권장합니다.
|
||||
|
||||
처음 {self.config.preview_lines}줄 미리보기:
|
||||
{preview}
|
||||
"""
|
||||
|
||||
def process_tool_result(
|
||||
self,
|
||||
tool_result: ToolMessage,
|
||||
runtime: ToolRuntime,
|
||||
) -> tuple[ToolMessage | Command, OffloadingResult]:
|
||||
"""도구 결과를 처리하고 필요시 축출합니다.
|
||||
|
||||
Args:
|
||||
tool_result: 원본 도구 실행 결과.
|
||||
runtime: 도구 런타임 컨텍스트.
|
||||
|
||||
Returns:
|
||||
처리된 메시지와 Offloading 결과 튜플.
|
||||
"""
|
||||
content = (
|
||||
tool_result.content
|
||||
if isinstance(tool_result.content, str)
|
||||
else str(tool_result.content)
|
||||
)
|
||||
|
||||
result = OffloadingResult(
|
||||
was_offloaded=False,
|
||||
original_size=len(content),
|
||||
)
|
||||
|
||||
if not self._should_offload(content):
|
||||
return tool_result, result
|
||||
|
||||
if self._backend_factory is None:
|
||||
return tool_result, result
|
||||
|
||||
backend = self._backend_factory(runtime)
|
||||
|
||||
sanitized_id = self._sanitize_tool_call_id(tool_result.tool_call_id)
|
||||
file_path = f"{self.config.eviction_path_prefix}/{sanitized_id}"
|
||||
|
||||
write_result = backend.write(file_path, content)
|
||||
if write_result.error:
|
||||
return tool_result, result
|
||||
|
||||
preview = self._create_preview(content)
|
||||
replacement_text = self._create_offload_message(
|
||||
tool_result.tool_call_id, file_path, preview
|
||||
)
|
||||
|
||||
result.was_offloaded = True
|
||||
result.file_path = file_path
|
||||
result.preview = preview
|
||||
|
||||
if write_result.files_update is not None:
|
||||
return Command(
|
||||
update={
|
||||
"files": write_result.files_update,
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=replacement_text,
|
||||
tool_call_id=tool_result.tool_call_id,
|
||||
)
|
||||
],
|
||||
}
|
||||
), result
|
||||
|
||||
return ToolMessage(
|
||||
content=replacement_text,
|
||||
tool_call_id=tool_result.tool_call_id,
|
||||
), result
|
||||
|
||||
def _sanitize_tool_call_id(self, tool_call_id: str) -> str:
|
||||
"""파일명에 안전한 tool_call_id로 변환합니다."""
|
||||
return "".join(c if c.isalnum() or c in "-_" else "_" for c in tool_call_id)
|
||||
|
||||
def wrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], ToolMessage | Command],
|
||||
) -> ToolMessage | Command:
|
||||
"""도구 호출을 래핑하여 결과를 가로채고 필요시 축출합니다."""
|
||||
tool_result = handler(request)
|
||||
|
||||
if isinstance(tool_result, ToolMessage):
|
||||
processed, _ = self.process_tool_result(tool_result, request.runtime)
|
||||
return processed
|
||||
|
||||
return tool_result
|
||||
|
||||
async def awrap_tool_call(
|
||||
self,
|
||||
request: ToolCallRequest,
|
||||
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
|
||||
) -> ToolMessage | Command:
|
||||
"""비동기 도구 호출을 래핑합니다."""
|
||||
tool_result = await handler(request)
|
||||
|
||||
if isinstance(tool_result, ToolMessage):
|
||||
processed, _ = self.process_tool_result(tool_result, request.runtime)
|
||||
return processed
|
||||
|
||||
return tool_result
|
||||
|
||||
|
||||
OFFLOADING_SYSTEM_PROMPT = """## Context Offloading 안내
|
||||
|
||||
대용량 도구 결과는 자동으로 파일시스템에 저장됩니다.
|
||||
|
||||
결과가 축출된 경우:
|
||||
1. 파일 경로와 미리보기가 제공됩니다
|
||||
2. read_file(path, offset=0, limit=100)으로 부분 읽기하세요
|
||||
3. 전체 내용이 필요한 경우에만 전체 파일을 읽으세요
|
||||
|
||||
이 방식으로 컨텍스트 윈도우를 효율적으로 관리할 수 있습니다.
|
||||
"""
|
||||
@@ -0,0 +1,318 @@
|
||||
"""Context Reduction 전략 구현.
|
||||
|
||||
## 개요
|
||||
|
||||
Context Reduction은 컨텍스트 윈도우 사용량이 임계값을 초과할 때
|
||||
자동으로 대화 내용을 압축하는 전략입니다.
|
||||
|
||||
## 두 가지 기법
|
||||
|
||||
### 1. Compaction (압축)
|
||||
- 오래된 메시지에서 도구 호출(tool_calls)과 도구 결과(ToolMessage) 제거
|
||||
- 텍스트 응답만 유지
|
||||
- 세부 실행 이력은 제거하고 결론만 보존
|
||||
|
||||
### 2. Summarization (요약)
|
||||
- 컨텍스트가 임계값(기본 85%) 초과 시 트리거
|
||||
- LLM을 사용하여 대화 내용 요약
|
||||
- 핵심 정보만 유지하고 세부사항 압축
|
||||
|
||||
## DeepAgents 구현
|
||||
|
||||
SummarizationMiddleware에서 구현:
|
||||
- `context_threshold`: 요약 트리거 임계값 (기본 0.85 = 85%)
|
||||
- `model_context_window`: 모델의 컨텍스트 윈도우 크기
|
||||
- 자동으로 토큰 사용량 추적 및 요약 트리거
|
||||
|
||||
## 사용 예시
|
||||
|
||||
```python
|
||||
from deepagents.middleware.summarization import SummarizationMiddleware
|
||||
|
||||
middleware = SummarizationMiddleware(
|
||||
context_threshold=0.85, # 85% 사용 시 요약
|
||||
model_context_window=200000, # Claude의 컨텍스트 윈도우
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReductionConfig:
|
||||
"""Context Reduction 설정."""
|
||||
|
||||
context_threshold: float = 0.85
|
||||
"""컨텍스트 사용률 임계값. 기본 0.85 (85%)."""
|
||||
|
||||
model_context_window: int = 200000
|
||||
"""모델의 전체 컨텍스트 윈도우 크기."""
|
||||
|
||||
compaction_age_threshold: int = 10
|
||||
"""Compaction 대상이 되는 메시지 나이 (메시지 수 기준)."""
|
||||
|
||||
min_messages_to_keep: int = 5
|
||||
"""요약 후 유지할 최소 메시지 수."""
|
||||
|
||||
chars_per_token: int = 4
|
||||
"""토큰당 문자 수 근사값."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReductionResult:
|
||||
"""Reduction 처리 결과."""
|
||||
|
||||
was_reduced: bool
|
||||
"""실제로 축소가 발생했는지 여부."""
|
||||
|
||||
technique_used: str | None = None
|
||||
"""사용된 기법 ('compaction' 또는 'summarization')."""
|
||||
|
||||
original_message_count: int = 0
|
||||
"""원본 메시지 수."""
|
||||
|
||||
reduced_message_count: int = 0
|
||||
"""축소 후 메시지 수."""
|
||||
|
||||
estimated_tokens_saved: int = 0
|
||||
"""절약된 추정 토큰 수."""
|
||||
|
||||
|
||||
class ContextReductionStrategy(AgentMiddleware):
|
||||
"""Context Reduction 전략 구현.
|
||||
|
||||
컨텍스트 윈도우 사용량이 임계값을 초과할 때 자동으로
|
||||
대화 내용을 압축하는 전략입니다.
|
||||
|
||||
## 동작 원리
|
||||
|
||||
1. before_model_call에서 현재 토큰 사용량 추정
|
||||
2. 임계값 초과 시 먼저 Compaction 시도
|
||||
3. 여전히 초과하면 Summarization 실행
|
||||
4. 축소된 메시지로 요청 수정
|
||||
|
||||
## Compaction vs Summarization
|
||||
|
||||
- **Compaction**: 빠르고 저렴함. 도구 호출/결과만 제거.
|
||||
- **Summarization**: 느리고 비용 발생. LLM이 내용 요약.
|
||||
|
||||
우선순위: Compaction → Summarization
|
||||
|
||||
Args:
|
||||
config: Reduction 설정. None이면 기본값 사용.
|
||||
summarization_model: 요약에 사용할 LLM. None이면 요약 비활성화.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ReductionConfig | None = None,
|
||||
summarization_model: BaseChatModel | None = None,
|
||||
) -> None:
|
||||
self.config = config or ReductionConfig()
|
||||
self._summarization_model = summarization_model
|
||||
|
||||
def _estimate_tokens(self, messages: list[BaseMessage]) -> int:
|
||||
"""메시지 목록의 총 토큰 수를 추정합니다."""
|
||||
total_chars = sum(len(str(msg.content)) for msg in messages)
|
||||
return total_chars // self.config.chars_per_token
|
||||
|
||||
def _get_context_usage_ratio(self, messages: list[BaseMessage]) -> float:
|
||||
"""현재 컨텍스트 사용률을 계산합니다."""
|
||||
estimated_tokens = self._estimate_tokens(messages)
|
||||
return estimated_tokens / self.config.model_context_window
|
||||
|
||||
def _should_reduce(self, messages: list[BaseMessage]) -> bool:
|
||||
"""축소가 필요한지 판단합니다."""
|
||||
usage_ratio = self._get_context_usage_ratio(messages)
|
||||
return usage_ratio > self.config.context_threshold
|
||||
|
||||
def apply_compaction(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
) -> tuple[list[BaseMessage], ReductionResult]:
|
||||
"""Compaction을 적용합니다.
|
||||
|
||||
오래된 메시지에서 도구 호출과 도구 결과를 제거합니다.
|
||||
"""
|
||||
original_count = len(messages)
|
||||
compacted: list[BaseMessage] = []
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
age = len(messages) - i
|
||||
|
||||
if age <= self.config.compaction_age_threshold:
|
||||
compacted.append(msg)
|
||||
continue
|
||||
|
||||
if isinstance(msg, AIMessage):
|
||||
if msg.tool_calls:
|
||||
text_content = (
|
||||
msg.text if hasattr(msg, "text") else str(msg.content)
|
||||
)
|
||||
if text_content.strip():
|
||||
compacted.append(AIMessage(content=text_content))
|
||||
else:
|
||||
compacted.append(msg)
|
||||
elif isinstance(msg, (HumanMessage, SystemMessage)):
|
||||
compacted.append(msg)
|
||||
|
||||
result = ReductionResult(
|
||||
was_reduced=len(compacted) < original_count,
|
||||
technique_used="compaction",
|
||||
original_message_count=original_count,
|
||||
reduced_message_count=len(compacted),
|
||||
estimated_tokens_saved=(
|
||||
self._estimate_tokens(messages) - self._estimate_tokens(compacted)
|
||||
),
|
||||
)
|
||||
|
||||
return compacted, result
|
||||
|
||||
def apply_summarization(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
) -> tuple[list[BaseMessage], ReductionResult]:
|
||||
"""Summarization을 적용합니다.
|
||||
|
||||
LLM을 사용하여 대화 내용을 요약합니다.
|
||||
"""
|
||||
if self._summarization_model is None:
|
||||
return messages, ReductionResult(was_reduced=False)
|
||||
|
||||
original_count = len(messages)
|
||||
|
||||
keep_count = self.config.min_messages_to_keep
|
||||
messages_to_summarize = (
|
||||
messages[:-keep_count] if len(messages) > keep_count else []
|
||||
)
|
||||
recent_messages = (
|
||||
messages[-keep_count:] if len(messages) > keep_count else messages
|
||||
)
|
||||
|
||||
if not messages_to_summarize:
|
||||
return messages, ReductionResult(was_reduced=False)
|
||||
|
||||
summary_prompt = self._create_summary_prompt(messages_to_summarize)
|
||||
|
||||
summary_response = self._summarization_model.invoke(
|
||||
[
|
||||
SystemMessage(
|
||||
content="당신은 대화 요약 전문가입니다. 핵심 정보만 간결하게 요약하세요."
|
||||
),
|
||||
HumanMessage(content=summary_prompt),
|
||||
]
|
||||
)
|
||||
|
||||
summary_message = SystemMessage(
|
||||
content=f"[이전 대화 요약]\n{summary_response.content}"
|
||||
)
|
||||
|
||||
summarized = [summary_message] + list(recent_messages)
|
||||
|
||||
result = ReductionResult(
|
||||
was_reduced=True,
|
||||
technique_used="summarization",
|
||||
original_message_count=original_count,
|
||||
reduced_message_count=len(summarized),
|
||||
estimated_tokens_saved=(
|
||||
self._estimate_tokens(messages) - self._estimate_tokens(summarized)
|
||||
),
|
||||
)
|
||||
|
||||
return summarized, result
|
||||
|
||||
def _create_summary_prompt(self, messages: list[BaseMessage]) -> str:
|
||||
"""요약을 위한 프롬프트를 생성합니다."""
|
||||
conversation_text = []
|
||||
for msg in messages:
|
||||
role = msg.__class__.__name__.replace("Message", "")
|
||||
content = str(msg.content)[:500]
|
||||
conversation_text.append(f"[{role}]: {content}")
|
||||
|
||||
return f"""다음 대화를 요약해주세요. 핵심 정보, 결정사항, 중요한 컨텍스트만 포함하세요.
|
||||
|
||||
대화 내용:
|
||||
{chr(10).join(conversation_text)}
|
||||
|
||||
요약 (한국어로, 500자 이내):"""
|
||||
|
||||
def reduce_context(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
) -> tuple[list[BaseMessage], ReductionResult]:
|
||||
"""컨텍스트를 축소합니다.
|
||||
|
||||
먼저 Compaction을 시도하고, 여전히 임계값을 초과하면
|
||||
Summarization을 적용합니다.
|
||||
"""
|
||||
if not self._should_reduce(messages):
|
||||
return messages, ReductionResult(was_reduced=False)
|
||||
|
||||
compacted, compaction_result = self.apply_compaction(messages)
|
||||
|
||||
if not self._should_reduce(compacted):
|
||||
return compacted, compaction_result
|
||||
|
||||
summarized, summarization_result = self.apply_summarization(compacted)
|
||||
|
||||
return summarized, summarization_result
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""모델 호출을 래핑하여 필요시 컨텍스트를 축소합니다."""
|
||||
messages = list(request.state.get("messages", []))
|
||||
|
||||
reduced_messages, result = self.reduce_context(messages)
|
||||
|
||||
if result.was_reduced:
|
||||
modified_state = {**request.state, "messages": reduced_messages}
|
||||
request = request.override(state=modified_state)
|
||||
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""비동기 모델 호출을 래핑합니다."""
|
||||
messages = list(request.state.get("messages", []))
|
||||
|
||||
reduced_messages, result = self.reduce_context(messages)
|
||||
|
||||
if result.was_reduced:
|
||||
modified_state = {**request.state, "messages": reduced_messages}
|
||||
request = request.override(state=modified_state)
|
||||
|
||||
return await handler(request)
|
||||
|
||||
|
||||
REDUCTION_SYSTEM_PROMPT = """## Context Reduction 안내
|
||||
|
||||
대화가 길어지면 자동으로 컨텍스트가 압축됩니다.
|
||||
|
||||
압축 방식:
|
||||
1. **Compaction**: 오래된 도구 호출/결과 제거
|
||||
2. **Summarization**: LLM이 이전 대화 요약
|
||||
|
||||
중요한 정보는 파일시스템에 저장하는 것을 권장합니다.
|
||||
요약으로 인해 세부사항이 손실될 수 있습니다.
|
||||
"""
|
||||
@@ -0,0 +1,319 @@
|
||||
"""Context Retrieval 전략 구현.
|
||||
|
||||
## 개요
|
||||
|
||||
Context Retrieval은 필요한 정보를 선택적으로 로드하여
|
||||
컨텍스트 윈도우를 효율적으로 사용하는 전략입니다.
|
||||
|
||||
## 핵심 원리
|
||||
|
||||
1. 벡터 DB나 복잡한 인덱싱 없이 직접 파일 검색
|
||||
2. grep/glob 기반의 단순하고 빠른 패턴 매칭
|
||||
3. 필요한 파일/내용만 선택적으로 로드
|
||||
|
||||
## DeepAgents 구현
|
||||
|
||||
FilesystemMiddleware에서 제공하는 도구들:
|
||||
- `read_file`: 파일 내용 읽기 (offset/limit으로 부분 읽기 지원)
|
||||
- `grep`: 텍스트 패턴 검색
|
||||
- `glob`: 파일명 패턴 매칭
|
||||
- `ls`: 디렉토리 목록 조회
|
||||
|
||||
## 벡터 검색을 사용하지 않는 이유
|
||||
|
||||
1. **단순성**: 추가 인프라 불필요
|
||||
2. **결정성**: 정확한 매칭, 모호함 없음
|
||||
3. **속도**: 인덱싱 오버헤드 없음
|
||||
4. **디버깅 용이**: 검색 결과 예측 가능
|
||||
|
||||
## 사용 예시
|
||||
|
||||
```python
|
||||
# 파일 검색
|
||||
grep(pattern="TODO", glob="*.py")
|
||||
|
||||
# 파일 목록
|
||||
glob(pattern="**/*.md")
|
||||
|
||||
# 부분 읽기
|
||||
read_file(path="/data.txt", offset=100, limit=50)
|
||||
```
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langchain.tools import ToolRuntime
|
||||
from langchain_core.tools import BaseTool, StructuredTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalConfig:
|
||||
"""Context Retrieval 설정."""
|
||||
|
||||
default_read_limit: int = 500
|
||||
"""read_file의 기본 줄 수 제한."""
|
||||
|
||||
max_grep_results: int = 100
|
||||
"""grep 결과 최대 개수."""
|
||||
|
||||
max_glob_results: int = 100
|
||||
"""glob 결과 최대 개수."""
|
||||
|
||||
truncate_line_length: int = 2000
|
||||
"""줄 길이 제한 (초과 시 자름)."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievalResult:
|
||||
"""검색 결과 메타데이터."""
|
||||
|
||||
tool_used: str
|
||||
"""사용된 도구 이름."""
|
||||
|
||||
query: str
|
||||
"""검색 쿼리/패턴."""
|
||||
|
||||
result_count: int
|
||||
"""결과 개수."""
|
||||
|
||||
was_truncated: bool = False
|
||||
"""결과가 잘렸는지 여부."""
|
||||
|
||||
|
||||
class ContextRetrievalStrategy(AgentMiddleware):
|
||||
"""Context Retrieval 전략 구현.
|
||||
|
||||
grep/glob 기반의 단순하고 빠른 검색으로
|
||||
필요한 정보만 선택적으로 로드합니다.
|
||||
|
||||
## 동작 원리
|
||||
|
||||
1. 파일시스템에서 직접 패턴 매칭
|
||||
2. 결과 개수 제한으로 컨텍스트 오버플로우 방지
|
||||
3. 부분 읽기로 대용량 파일 효율적 처리
|
||||
|
||||
## 제공 도구
|
||||
|
||||
- read_file: 파일 읽기 (offset/limit 지원)
|
||||
- grep: 텍스트 패턴 검색
|
||||
- glob: 파일명 패턴 매칭
|
||||
|
||||
## 벡터 DB를 사용하지 않는 이유
|
||||
|
||||
DeepAgents는 의도적으로 벡터 검색 대신 직접 파일 검색을 선택했습니다:
|
||||
- 결정적이고 예측 가능한 결과
|
||||
- 추가 인프라 불필요
|
||||
- 디버깅 용이
|
||||
|
||||
Args:
|
||||
config: Retrieval 설정. None이면 기본값 사용.
|
||||
backend_factory: 파일 작업용 백엔드 팩토리 함수.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RetrievalConfig | None = None,
|
||||
backend_factory: Callable[[ToolRuntime], Any] | None = None,
|
||||
) -> None:
|
||||
self.config = config or RetrievalConfig()
|
||||
self._backend_factory = backend_factory
|
||||
self.tools = self._create_tools()
|
||||
|
||||
def _create_tools(self) -> list[BaseTool]:
|
||||
"""검색 도구들을 생성합니다."""
|
||||
return [
|
||||
self._create_read_file_tool(),
|
||||
self._create_grep_tool(),
|
||||
self._create_glob_tool(),
|
||||
]
|
||||
|
||||
def _create_read_file_tool(self) -> BaseTool:
|
||||
"""read_file 도구를 생성합니다."""
|
||||
config = self.config
|
||||
backend_factory = self._backend_factory
|
||||
|
||||
def read_file(
|
||||
file_path: str,
|
||||
runtime: ToolRuntime,
|
||||
offset: int = 0,
|
||||
limit: int | None = None,
|
||||
) -> str:
|
||||
"""파일을 읽습니다.
|
||||
|
||||
Args:
|
||||
file_path: 읽을 파일의 절대 경로.
|
||||
offset: 시작 줄 번호 (0부터 시작).
|
||||
limit: 읽을 최대 줄 수. 기본값은 설정에 따름.
|
||||
|
||||
Returns:
|
||||
줄 번호가 포함된 파일 내용.
|
||||
"""
|
||||
if backend_factory is None:
|
||||
return "백엔드가 설정되지 않았습니다."
|
||||
|
||||
backend = backend_factory(runtime)
|
||||
actual_limit = limit or config.default_read_limit
|
||||
return backend.read(file_path, offset=offset, limit=actual_limit)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="read_file",
|
||||
description=f"""파일을 읽습니다.
|
||||
|
||||
사용법:
|
||||
- file_path: 절대 경로 필수
|
||||
- offset: 시작 줄 (기본 0)
|
||||
- limit: 읽을 줄 수 (기본 {config.default_read_limit})
|
||||
|
||||
대용량 파일은 offset/limit으로 부분 읽기를 권장합니다.""",
|
||||
func=read_file,
|
||||
)
|
||||
|
||||
def _create_grep_tool(self) -> BaseTool:
|
||||
"""Grep 도구를 생성합니다."""
|
||||
config = self.config
|
||||
backend_factory = self._backend_factory
|
||||
|
||||
def grep(
|
||||
pattern: str,
|
||||
runtime: ToolRuntime,
|
||||
path: str | None = None,
|
||||
glob_pattern: str | None = None,
|
||||
output_mode: Literal[
|
||||
"files_with_matches", "content", "count"
|
||||
] = "files_with_matches",
|
||||
) -> str:
|
||||
"""텍스트 패턴을 검색합니다.
|
||||
|
||||
Args:
|
||||
pattern: 검색할 텍스트 (정규식 아님).
|
||||
path: 검색 시작 디렉토리.
|
||||
glob_pattern: 파일 필터 (예: "*.py").
|
||||
output_mode: 출력 형식.
|
||||
|
||||
Returns:
|
||||
검색 결과.
|
||||
"""
|
||||
if backend_factory is None:
|
||||
return "백엔드가 설정되지 않았습니다."
|
||||
|
||||
backend = backend_factory(runtime)
|
||||
raw_results = backend.grep_raw(pattern, path=path, glob=glob_pattern)
|
||||
|
||||
if isinstance(raw_results, str):
|
||||
return raw_results
|
||||
|
||||
truncated = raw_results[: config.max_grep_results]
|
||||
|
||||
if output_mode == "files_with_matches":
|
||||
files = list(set(r.get("path", "") for r in truncated))
|
||||
return "\n".join(files)
|
||||
elif output_mode == "count":
|
||||
from collections import Counter
|
||||
|
||||
counts = Counter(r.get("path", "") for r in truncated)
|
||||
return "\n".join(f"{path}: {count}" for path, count in counts.items())
|
||||
else:
|
||||
lines = []
|
||||
for r in truncated:
|
||||
path = r.get("path", "")
|
||||
line_num = r.get("line_number", 0)
|
||||
content = r.get("content", "")[: config.truncate_line_length]
|
||||
lines.append(f"{path}:{line_num}: {content}")
|
||||
return "\n".join(lines)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="grep",
|
||||
description=f"""텍스트 패턴을 검색합니다.
|
||||
|
||||
사용법:
|
||||
- pattern: 검색할 텍스트 (리터럴 문자열)
|
||||
- path: 검색 디렉토리 (선택)
|
||||
- glob_pattern: 파일 필터 예: "*.py" (선택)
|
||||
- output_mode: files_with_matches | content | count
|
||||
|
||||
최대 {config.max_grep_results}개 결과를 반환합니다.""",
|
||||
func=grep,
|
||||
)
|
||||
|
||||
def _create_glob_tool(self) -> BaseTool:
|
||||
"""Glob 도구를 생성합니다."""
|
||||
config = self.config
|
||||
backend_factory = self._backend_factory
|
||||
|
||||
def glob(
|
||||
pattern: str,
|
||||
runtime: ToolRuntime,
|
||||
path: str = "/",
|
||||
) -> str:
|
||||
"""파일명 패턴으로 파일을 찾습니다.
|
||||
|
||||
Args:
|
||||
pattern: glob 패턴 (예: "**/*.py").
|
||||
path: 검색 시작 경로.
|
||||
|
||||
Returns:
|
||||
매칭된 파일 경로 목록.
|
||||
"""
|
||||
if backend_factory is None:
|
||||
return "백엔드가 설정되지 않았습니다."
|
||||
|
||||
backend = backend_factory(runtime)
|
||||
infos = backend.glob_info(pattern, path=path)
|
||||
|
||||
paths = [fi.get("path", "") for fi in infos[: config.max_glob_results]]
|
||||
return "\n".join(paths)
|
||||
|
||||
return StructuredTool.from_function(
|
||||
name="glob",
|
||||
description=f"""파일명 패턴으로 파일을 찾습니다.
|
||||
|
||||
사용법:
|
||||
- pattern: glob 패턴 (*, **, ? 지원)
|
||||
- path: 검색 시작 경로 (기본 "/")
|
||||
|
||||
예시:
|
||||
- "**/*.py": 모든 Python 파일
|
||||
- "src/**/*.ts": src 아래 모든 TypeScript 파일
|
||||
|
||||
최대 {config.max_glob_results}개 결과를 반환합니다.""",
|
||||
func=glob,
|
||||
)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
"""모델 호출을 래핑합니다 (기본 동작)."""
|
||||
return handler(request)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
"""비동기 모델 호출을 래핑합니다 (기본 동작)."""
|
||||
return await handler(request)
|
||||
|
||||
|
||||
RETRIEVAL_SYSTEM_PROMPT = """## Context Retrieval 도구
|
||||
|
||||
파일시스템에서 정보를 검색할 수 있습니다.
|
||||
|
||||
도구:
|
||||
- read_file: 파일 읽기 (offset/limit으로 부분 읽기)
|
||||
- grep: 텍스트 패턴 검색
|
||||
- glob: 파일명 패턴 매칭
|
||||
|
||||
사용 팁:
|
||||
1. 대용량 파일은 먼저 구조 파악 (limit=100)
|
||||
2. grep으로 관련 파일 찾은 후 read_file
|
||||
3. glob으로 파일 위치 확인 후 탐색
|
||||
"""
|
||||
11
context_engineering_research_agent/research/__init__.py
Normal file
11
context_engineering_research_agent/research/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""연구 에이전트 모듈."""
|
||||
|
||||
from context_engineering_more_deep_research_agent.research.agent import (
|
||||
create_researcher_agent,
|
||||
get_researcher_subagent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_researcher_agent",
|
||||
"get_researcher_subagent",
|
||||
]
|
||||
96
context_engineering_research_agent/research/agent.py
Normal file
96
context_engineering_research_agent/research/agent.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""자율적 연구 에이전트."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from deepagents import create_deep_agent
|
||||
from deepagents.backends.protocol import BackendFactory, BackendProtocol
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
AUTONOMOUS_RESEARCHER_INSTRUCTIONS = """당신은 자율적 연구 에이전트입니다.
|
||||
"넓게 탐색 → 깊게 파기" 방법론을 따라 주제를 철저히 연구합니다.
|
||||
|
||||
오늘 날짜: {date}
|
||||
|
||||
## 연구 워크플로우
|
||||
|
||||
### Phase 1: 탐색적 검색 (1-2회)
|
||||
- 넓은 검색으로 분야 전체 파악
|
||||
- 핵심 개념, 주요 플레이어, 최근 트렌드 확인
|
||||
- think_tool로 유망한 방향 2-3개 식별
|
||||
|
||||
### Phase 2: 심층 연구 (방향당 1-2회)
|
||||
- 식별된 방향별 집중 검색
|
||||
- 각 검색 후 think_tool로 평가
|
||||
- 가치 있는 정보 획득 여부 판단
|
||||
|
||||
### Phase 3: 종합
|
||||
- 모든 발견 사항 검토
|
||||
- 패턴과 연결점 식별
|
||||
- 출처 일치/불일치 기록
|
||||
|
||||
## 도구 제한
|
||||
- 탐색: 최대 2회 검색
|
||||
- 심층: 최대 3-4회 검색
|
||||
- **총합: 5-6회**
|
||||
|
||||
## 종료 조건
|
||||
- 포괄적 답변 가능
|
||||
- 최근 2회 검색이 유사 정보 반환
|
||||
- 최대 검색 횟수 도달
|
||||
|
||||
## 응답 형식
|
||||
|
||||
```markdown
|
||||
## 핵심 발견
|
||||
|
||||
### 발견 1: [제목]
|
||||
[인용 포함 상세 설명 [1], [2]]
|
||||
|
||||
### 발견 2: [제목]
|
||||
[인용 포함 상세 설명]
|
||||
|
||||
## 출처 일치 분석
|
||||
- **높은 일치**: [출처들이 동의하는 주제]
|
||||
- **불일치/불확실**: [충돌 정보]
|
||||
|
||||
## 출처
|
||||
[1] 출처 제목: URL
|
||||
[2] 출처 제목: URL
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def create_researcher_agent(
|
||||
model: str | BaseChatModel | None = None,
|
||||
backend: BackendProtocol | BackendFactory | None = None,
|
||||
) -> CompiledStateGraph:
|
||||
if model is None:
|
||||
model = ChatOpenAI(model="gpt-4.1", temperature=0.0)
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
formatted_prompt = AUTONOMOUS_RESEARCHER_INSTRUCTIONS.format(date=current_date)
|
||||
|
||||
return create_deep_agent(
|
||||
model=model,
|
||||
system_prompt=formatted_prompt,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
|
||||
def get_researcher_subagent(
|
||||
model: str | BaseChatModel | None = None,
|
||||
backend: BackendProtocol | BackendFactory | None = None,
|
||||
) -> dict[str, Any]:
|
||||
researcher = create_researcher_agent(model=model, backend=backend)
|
||||
|
||||
return {
|
||||
"name": "researcher",
|
||||
"description": (
|
||||
"자율적 심층 연구 에이전트. '넓게 탐색 → 깊게 파기' 방법론 사용. "
|
||||
"복잡한 주제 연구, 다각적 질문, 트렌드 분석에 적합."
|
||||
),
|
||||
"runnable": researcher,
|
||||
}
|
||||
36
context_engineering_research_agent/skills/__init__.py
Normal file
36
context_engineering_research_agent/skills/__init__.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Context Engineering 연구 에이전트용 스킬 시스템.
|
||||
|
||||
이 모듈은 Anthropic의 Agent Skills 패턴을 구현하여
|
||||
에이전트에게 도메인별 전문 지식과 워크플로우를 제공합니다.
|
||||
|
||||
## Progressive Disclosure (점진적 공개)
|
||||
|
||||
스킬은 점진적 공개 패턴을 따릅니다:
|
||||
1. 세션 시작 시 스킬 메타데이터(이름 + 설명)만 로드
|
||||
2. 에이전트가 필요할 때 전체 SKILL.md 내용 읽기
|
||||
3. 컨텍스트 윈도우 효율적 사용
|
||||
|
||||
## Context Engineering 관점에서의 스킬
|
||||
|
||||
스킬 시스템은 Context Engineering의 핵심 전략들을 활용합니다:
|
||||
|
||||
1. **Context Retrieval**: read_file로 필요할 때만 스킬 내용 로드
|
||||
2. **Context Offloading**: 전체 스킬 내용 대신 메타데이터만 시스템 프롬프트에 포함
|
||||
3. **Context Isolation**: 각 스킬은 독립적인 도메인 지식 캡슐화
|
||||
"""
|
||||
|
||||
from context_engineering_research_agent.skills.load import (
|
||||
SkillMetadata,
|
||||
list_skills,
|
||||
)
|
||||
from context_engineering_research_agent.skills.middleware import (
|
||||
SkillsMiddleware,
|
||||
SkillsState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SkillMetadata",
|
||||
"list_skills",
|
||||
"SkillsMiddleware",
|
||||
"SkillsState",
|
||||
]
|
||||
197
context_engineering_research_agent/skills/load.py
Normal file
197
context_engineering_research_agent/skills/load.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""SKILL.md 파일에서 에이전트 스킬을 파싱하고 로드하는 스킬 로더.
|
||||
|
||||
YAML 프론트매터 파싱을 통해 Anthropic Agent Skills 패턴을 구현합니다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024 # 10MB - DoS 방지
|
||||
MAX_SKILL_NAME_LENGTH = 64
|
||||
MAX_SKILL_DESCRIPTION_LENGTH = 1024
|
||||
|
||||
|
||||
class SkillMetadata(TypedDict):
|
||||
"""Agent Skills 명세를 따르는 스킬 메타데이터."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
path: str
|
||||
source: str
|
||||
license: NotRequired[str | None]
|
||||
compatibility: NotRequired[str | None]
|
||||
metadata: NotRequired[dict[str, str] | None]
|
||||
allowed_tools: NotRequired[str | None]
|
||||
|
||||
|
||||
def _is_safe_path(path: Path, base_dir: Path) -> bool:
|
||||
"""경로가 base_dir 내에 안전하게 포함되어 있는지 확인합니다."""
|
||||
try:
|
||||
resolved_path = path.resolve()
|
||||
resolved_base = base_dir.resolve()
|
||||
resolved_path.relative_to(resolved_base)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
except (OSError, RuntimeError):
|
||||
return False
|
||||
|
||||
|
||||
def _validate_skill_name(name: str, directory_name: str) -> tuple[bool, str]:
|
||||
"""Agent Skills 명세에 따라 스킬 이름을 검증합니다."""
|
||||
if not name:
|
||||
return False, "이름은 필수입니다"
|
||||
if len(name) > MAX_SKILL_NAME_LENGTH:
|
||||
return False, "이름이 64자를 초과합니다"
|
||||
if not re.match(r"^[a-z0-9]+(-[a-z0-9]+)*$", name):
|
||||
return False, "이름은 소문자 영숫자와 단일 하이픈만 사용해야 합니다"
|
||||
if name != directory_name:
|
||||
return (
|
||||
False,
|
||||
f"이름 '{name}'은 디렉토리 이름 '{directory_name}'과 일치해야 합니다",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
|
||||
def _parse_skill_metadata(skill_md_path: Path, source: str) -> SkillMetadata | None:
|
||||
"""SKILL.md 파일에서 YAML 프론트매터를 파싱합니다."""
|
||||
try:
|
||||
file_size = skill_md_path.stat().st_size
|
||||
if file_size > MAX_SKILL_FILE_SIZE:
|
||||
logger.warning(
|
||||
"%s 건너뜀: 파일이 너무 큼 (%d 바이트)", skill_md_path, file_size
|
||||
)
|
||||
return None
|
||||
|
||||
content = skill_md_path.read_text(encoding="utf-8")
|
||||
|
||||
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n"
|
||||
match = re.match(frontmatter_pattern, content, re.DOTALL)
|
||||
|
||||
if not match:
|
||||
logger.warning(
|
||||
"%s 건너뜀: 유효한 YAML 프론트매터를 찾을 수 없음", skill_md_path
|
||||
)
|
||||
return None
|
||||
|
||||
frontmatter_str = match.group(1)
|
||||
|
||||
try:
|
||||
frontmatter_data = yaml.safe_load(frontmatter_str)
|
||||
except yaml.YAMLError as e:
|
||||
logger.warning("%s의 YAML이 유효하지 않음: %s", skill_md_path, e)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter_data, dict):
|
||||
logger.warning("%s 건너뜀: 프론트매터가 매핑이 아님", skill_md_path)
|
||||
return None
|
||||
|
||||
name = frontmatter_data.get("name")
|
||||
description = frontmatter_data.get("description")
|
||||
|
||||
if not name or not description:
|
||||
logger.warning(
|
||||
"%s 건너뜀: 필수 'name' 또는 'description' 누락", skill_md_path
|
||||
)
|
||||
return None
|
||||
|
||||
directory_name = skill_md_path.parent.name
|
||||
is_valid, error = _validate_skill_name(str(name), directory_name)
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
"'%s' 스킬 (%s)이 Agent Skills 명세를 따르지 않음: %s. "
|
||||
"명세 준수를 위해 이름 변경을 고려하세요.",
|
||||
name,
|
||||
skill_md_path,
|
||||
error,
|
||||
)
|
||||
|
||||
description_str = str(description)
|
||||
if len(description_str) > MAX_SKILL_DESCRIPTION_LENGTH:
|
||||
logger.warning(
|
||||
"%s의 설명이 %d자를 초과하여 잘림",
|
||||
skill_md_path,
|
||||
MAX_SKILL_DESCRIPTION_LENGTH,
|
||||
)
|
||||
description_str = description_str[:MAX_SKILL_DESCRIPTION_LENGTH]
|
||||
|
||||
return SkillMetadata(
|
||||
name=str(name),
|
||||
description=description_str,
|
||||
path=str(skill_md_path),
|
||||
source=source,
|
||||
license=frontmatter_data.get("license"),
|
||||
compatibility=frontmatter_data.get("compatibility"),
|
||||
metadata=frontmatter_data.get("metadata"),
|
||||
allowed_tools=frontmatter_data.get("allowed-tools"),
|
||||
)
|
||||
|
||||
except (OSError, UnicodeDecodeError) as e:
|
||||
logger.warning("%s 읽기 오류: %s", skill_md_path, e)
|
||||
return None
|
||||
|
||||
|
||||
def _list_skills_from_dir(skills_dir: Path, source: str) -> list[SkillMetadata]:
|
||||
"""단일 스킬 디렉토리에서 모든 스킬을 나열합니다."""
|
||||
skills_dir = skills_dir.expanduser()
|
||||
if not skills_dir.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
resolved_base = skills_dir.resolve()
|
||||
except (OSError, RuntimeError):
|
||||
return []
|
||||
|
||||
skills: list[SkillMetadata] = []
|
||||
|
||||
for skill_dir in skills_dir.iterdir():
|
||||
if not _is_safe_path(skill_dir, resolved_base):
|
||||
continue
|
||||
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
|
||||
skill_md_path = skill_dir / "SKILL.md"
|
||||
if not skill_md_path.exists():
|
||||
continue
|
||||
|
||||
if not _is_safe_path(skill_md_path, resolved_base):
|
||||
continue
|
||||
|
||||
metadata = _parse_skill_metadata(skill_md_path, source=source)
|
||||
if metadata:
|
||||
skills.append(metadata)
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def list_skills(
|
||||
*,
|
||||
user_skills_dir: Path | None = None,
|
||||
project_skills_dir: Path | None = None,
|
||||
) -> list[SkillMetadata]:
|
||||
"""사용자 및/또는 프로젝트 디렉토리에서 스킬을 나열합니다.
|
||||
|
||||
프로젝트 스킬이 같은 이름의 사용자 스킬을 오버라이드합니다.
|
||||
"""
|
||||
all_skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
if user_skills_dir:
|
||||
user_skills = _list_skills_from_dir(user_skills_dir, source="user")
|
||||
for skill in user_skills:
|
||||
all_skills[skill["name"]] = skill
|
||||
|
||||
if project_skills_dir:
|
||||
project_skills = _list_skills_from_dir(project_skills_dir, source="project")
|
||||
for skill in project_skills:
|
||||
all_skills[skill["name"]] = skill
|
||||
|
||||
return list(all_skills.values())
|
||||
165
context_engineering_research_agent/skills/middleware.py
Normal file
165
context_engineering_research_agent/skills/middleware.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""스킬 시스템 미들웨어.
|
||||
|
||||
Progressive Disclosure 패턴으로 스킬 메타데이터를 시스템 프롬프트에 주입합니다.
|
||||
"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import NotRequired, TypedDict, cast
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
)
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from context_engineering_research_agent.skills.load import SkillMetadata, list_skills
|
||||
|
||||
|
||||
class SkillsState(AgentState):
|
||||
skills_metadata: NotRequired[list[SkillMetadata]]
|
||||
|
||||
|
||||
class SkillsStateUpdate(TypedDict):
|
||||
skills_metadata: list[SkillMetadata]
|
||||
|
||||
|
||||
SKILLS_SYSTEM_PROMPT = """
|
||||
|
||||
## 스킬 시스템
|
||||
|
||||
스킬 라이브러리를 통해 전문화된 기능과 도메인 지식을 사용할 수 있습니다.
|
||||
|
||||
{skills_locations}
|
||||
|
||||
**사용 가능한 스킬:**
|
||||
|
||||
{skills_list}
|
||||
|
||||
**스킬 사용법 (Progressive Disclosure):**
|
||||
|
||||
스킬은 점진적 공개 패턴을 따릅니다. 위에서 스킬의 존재(이름 + 설명)를 알 수 있지만,
|
||||
필요할 때만 전체 지침을 읽습니다:
|
||||
|
||||
1. 스킬 적용 여부 판단: 사용자 요청이 스킬 설명과 일치하는지 확인
|
||||
2. 전체 지침 읽기: read_file로 SKILL.md 경로 읽기
|
||||
3. 지침 따르기: SKILL.md에는 단계별 워크플로우와 예시 포함
|
||||
4. 지원 파일 활용: 스킬에 Python 스크립트나 설정 파일 포함 가능
|
||||
|
||||
**스킬 사용 시점:**
|
||||
- 사용자 요청이 스킬 도메인과 일치할 때
|
||||
- 전문 지식이나 구조화된 워크플로우가 도움될 때
|
||||
- 복잡한 작업에 검증된 패턴이 필요할 때
|
||||
"""
|
||||
|
||||
|
||||
class SkillsMiddleware(AgentMiddleware):
|
||||
"""Progressive Disclosure 패턴으로 스킬을 노출하는 미들웨어."""
|
||||
|
||||
state_schema = SkillsState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
skills_dir: str | Path,
|
||||
assistant_id: str,
|
||||
project_skills_dir: str | Path | None = None,
|
||||
) -> None:
|
||||
self.skills_dir = Path(skills_dir).expanduser()
|
||||
self.assistant_id = assistant_id
|
||||
self.project_skills_dir = (
|
||||
Path(project_skills_dir).expanduser() if project_skills_dir else None
|
||||
)
|
||||
self.user_skills_display = f"~/.deepagents/{assistant_id}/skills"
|
||||
self.system_prompt_template = SKILLS_SYSTEM_PROMPT
|
||||
|
||||
def _format_skills_locations(self) -> str:
|
||||
locations = [f"**사용자 스킬**: `{self.user_skills_display}`"]
|
||||
if self.project_skills_dir:
|
||||
locations.append(
|
||||
f"**프로젝트 스킬**: `{self.project_skills_dir}` (사용자 스킬 오버라이드)"
|
||||
)
|
||||
return "\n".join(locations)
|
||||
|
||||
def _format_skills_list(self, skills: list[SkillMetadata]) -> str:
|
||||
if not skills:
|
||||
locations = [f"{self.user_skills_display}/"]
|
||||
if self.project_skills_dir:
|
||||
locations.append(f"{self.project_skills_dir}/")
|
||||
return f"(사용 가능한 스킬 없음. {' 또는 '.join(locations)}에서 스킬 생성 가능)"
|
||||
|
||||
user_skills = [s for s in skills if s["source"] == "user"]
|
||||
project_skills = [s for s in skills if s["source"] == "project"]
|
||||
|
||||
lines = []
|
||||
|
||||
if user_skills:
|
||||
lines.append("**사용자 스킬:**")
|
||||
for skill in user_skills:
|
||||
lines.append(f"- **{skill['name']}**: {skill['description']}")
|
||||
lines.append(f" → 전체 지침: `{skill['path']}`")
|
||||
lines.append("")
|
||||
|
||||
if project_skills:
|
||||
lines.append("**프로젝트 스킬:**")
|
||||
for skill in project_skills:
|
||||
lines.append(f"- **{skill['name']}**: {skill['description']}")
|
||||
lines.append(f" → 전체 지침: `{skill['path']}`")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def before_agent(
|
||||
self, state: SkillsState, runtime: Runtime
|
||||
) -> SkillsStateUpdate | None:
|
||||
skills = list_skills(
|
||||
user_skills_dir=self.skills_dir,
|
||||
project_skills_dir=self.project_skills_dir,
|
||||
)
|
||||
return SkillsStateUpdate(skills_metadata=skills)
|
||||
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelResponse:
|
||||
skills_metadata = request.state.get("skills_metadata", [])
|
||||
|
||||
skills_locations = self._format_skills_locations()
|
||||
skills_list = self._format_skills_list(skills_metadata)
|
||||
|
||||
skills_section = self.system_prompt_template.format(
|
||||
skills_locations=skills_locations,
|
||||
skills_list=skills_list,
|
||||
)
|
||||
|
||||
if request.system_prompt:
|
||||
system_prompt = request.system_prompt + "\n\n" + skills_section
|
||||
else:
|
||||
system_prompt = skills_section
|
||||
|
||||
return handler(request.override(system_prompt=system_prompt))
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelResponse:
|
||||
state = cast("SkillsState", request.state)
|
||||
skills_metadata = state.get("skills_metadata", [])
|
||||
|
||||
skills_locations = self._format_skills_locations()
|
||||
skills_list = self._format_skills_list(skills_metadata)
|
||||
|
||||
skills_section = self.system_prompt_template.format(
|
||||
skills_locations=skills_locations,
|
||||
skills_list=skills_list,
|
||||
)
|
||||
|
||||
if request.system_prompt:
|
||||
system_prompt = request.system_prompt + "\n\n" + skills_section
|
||||
else:
|
||||
system_prompt = skills_section
|
||||
|
||||
return await handler(request.override(system_prompt=system_prompt))
|
||||
@@ -21,10 +21,12 @@ dependencies = [
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"docker>=7.1.0",
|
||||
"ipykernel>=7.1.0",
|
||||
"ipywidgets>=8.1.8",
|
||||
"langgraph-cli[inmem]>=0.4.11",
|
||||
"mypy>=1.19.1",
|
||||
"pytest>=9.0.2",
|
||||
"ruff>=0.14.10",
|
||||
]
|
||||
|
||||
@@ -48,14 +50,19 @@ lint.select = [
|
||||
"T201",
|
||||
"UP",
|
||||
]
|
||||
lint.ignore = ["UP006", "UP007", "UP035", "D417", "E501"]
|
||||
lint.ignore = ["UP006", "UP007", "UP035", "D417", "E501", "D101", "D102", "D103", "D105", "D107"]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/*" = ["D", "UP"]
|
||||
"tests/*" = ["D", "UP", "T201"]
|
||||
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"integration: 실제 외부 서비스(Docker, API 등)를 사용하는 통합 테스트",
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.13"
|
||||
exclude = [".venv"]
|
||||
|
||||
@@ -10,6 +10,7 @@ default = []
|
||||
checkpointer-sqlite = ["dep:rusqlite", "dep:tokio-rusqlite"]
|
||||
checkpointer-redis = ["dep:redis"]
|
||||
checkpointer-postgres = ["dep:sqlx"]
|
||||
tokenizer-tiktoken = ["dep:tiktoken-rs"]
|
||||
|
||||
[dependencies]
|
||||
rig-core = { version = "0.27", features = ["derive"] }
|
||||
@@ -31,6 +32,7 @@ num_cpus = "1" # For default parallelism configuration
|
||||
humantime-serde = "1" # For Duration serialization in configs
|
||||
zstd = "0.13" # For checkpoint compression
|
||||
regex = "1"
|
||||
tiktoken-rs = { version = "0.5", optional = true }
|
||||
|
||||
# HTTP client for external API tools (Tavily, etc.)
|
||||
reqwest = { version = "0.12", features = ["json"] }
|
||||
|
||||
@@ -70,6 +70,11 @@ pub trait Backend: Send + Sync {
|
||||
/// Returns: 라인 번호 포함된 포맷 (cat -n 스타일)
|
||||
async fn read(&self, path: &str, offset: usize, limit: usize) -> Result<String, BackendError>;
|
||||
|
||||
async fn read_plain(&self, path: &str) -> Result<String, BackendError> {
|
||||
let formatted = self.read(path, 0, 50_000).await?;
|
||||
Ok(strip_cat_n(&formatted))
|
||||
}
|
||||
|
||||
/// 파일 쓰기 (새 파일 생성)
|
||||
/// Python: write(file_path: str, content: str) -> WriteResult
|
||||
async fn write(&self, path: &str, content: &str) -> Result<WriteResult, BackendError>;
|
||||
@@ -122,3 +127,11 @@ pub trait Backend: Send + Sync {
|
||||
/// 파일 삭제
|
||||
async fn delete(&self, path: &str) -> Result<(), BackendError>;
|
||||
}
|
||||
|
||||
fn strip_cat_n(formatted: &str) -> String {
|
||||
formatted
|
||||
.lines()
|
||||
.map(|line| line.split_once('\t').map(|(_, s)| s).unwrap_or(line))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
@@ -167,7 +167,7 @@ impl AgentExecutor {
|
||||
model_request = model_request.with_config(config.clone());
|
||||
}
|
||||
|
||||
let before_control = self.middleware.before_model(&mut model_request, &state, &runtime).await
|
||||
let before_control = self.middleware.before_model(&mut model_request, &mut state, &runtime).await
|
||||
.map_err(DeepAgentError::Middleware)?;
|
||||
|
||||
// before_model 제어 흐름 처리
|
||||
@@ -235,7 +235,22 @@ impl AgentExecutor {
|
||||
|
||||
// 도구 호출 처리
|
||||
if let Some(tool_calls) = &response.tool_calls {
|
||||
let write_todos_count = tool_calls
|
||||
.iter()
|
||||
.filter(|call| call.name == "write_todos")
|
||||
.count();
|
||||
let has_duplicate_write_todos = write_todos_count > 1;
|
||||
|
||||
for call in tool_calls {
|
||||
if has_duplicate_write_todos && call.name == "write_todos" {
|
||||
let result = ToolResult::new(
|
||||
"Error: multiple write_todos calls in a single response are not allowed",
|
||||
);
|
||||
let tool_message = Message::tool_with_status(&result.message, &call.id, "error");
|
||||
state.add_message(tool_message);
|
||||
continue;
|
||||
}
|
||||
|
||||
let result = self
|
||||
.execute_tool_call(call, &tools, &state, runtime.config())
|
||||
.await;
|
||||
@@ -455,6 +470,54 @@ mod tests {
|
||||
assert_eq!(result.todos[0].content, "Test todo");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_executor_rejects_duplicate_write_todos() {
|
||||
let tool_calls = vec![
|
||||
ToolCall {
|
||||
id: "call_1".to_string(),
|
||||
name: "write_todos".to_string(),
|
||||
arguments: serde_json::json!({"todos": []}),
|
||||
},
|
||||
ToolCall {
|
||||
id: "call_2".to_string(),
|
||||
name: "write_todos".to_string(),
|
||||
arguments: serde_json::json!({"todos": []}),
|
||||
},
|
||||
];
|
||||
|
||||
let responses = vec![
|
||||
Message::assistant_with_tool_calls("", tool_calls),
|
||||
Message::assistant("Done."),
|
||||
];
|
||||
|
||||
let llm = Arc::new(MockLLM::new(responses));
|
||||
let backend = Arc::new(MemoryBackend::new());
|
||||
let middleware = MiddlewareStack::new();
|
||||
|
||||
let executor = AgentExecutor::new(llm, middleware, backend)
|
||||
.with_tools(vec![Arc::new(crate::tools::WriteTodosTool)]);
|
||||
|
||||
let initial_state = AgentState::with_messages(vec![
|
||||
Message::user("Update todos"),
|
||||
]);
|
||||
|
||||
let result = executor.run(initial_state).await.unwrap();
|
||||
|
||||
assert!(result.todos.is_empty());
|
||||
|
||||
let tool_messages: Vec<&Message> = result
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|message| message.role == Role::Tool)
|
||||
.collect();
|
||||
|
||||
assert_eq!(tool_messages.len(), 2);
|
||||
for message in tool_messages {
|
||||
assert_eq!(message.status.as_deref(), Some("error"));
|
||||
assert!(message.content.contains("multiple write_todos"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_executor_max_iterations() {
|
||||
// Create LLM that always returns tool calls
|
||||
|
||||
@@ -41,6 +41,7 @@ pub mod skills;
|
||||
pub mod research;
|
||||
pub mod config;
|
||||
pub mod compat;
|
||||
pub mod tokenization;
|
||||
mod tool_result_eviction;
|
||||
|
||||
// Re-exports for convenience
|
||||
|
||||
@@ -110,7 +110,7 @@ impl MiddlewareStack {
|
||||
pub async fn before_model(
|
||||
&self,
|
||||
request: &mut ModelRequest,
|
||||
state: &AgentState,
|
||||
state: &mut AgentState,
|
||||
runtime: &ToolRuntime,
|
||||
) -> Result<ModelControl, MiddlewareError> {
|
||||
for middleware in &self.middlewares {
|
||||
|
||||
@@ -53,9 +53,10 @@ use tracing::{debug, info, warn};
|
||||
|
||||
use crate::error::MiddlewareError;
|
||||
use crate::llm::LLMProvider;
|
||||
use crate::middleware::traits::{AgentMiddleware, DynTool, StateUpdate};
|
||||
use crate::middleware::traits::{AgentMiddleware, DynTool, ModelControl, ModelRequest};
|
||||
use crate::runtime::ToolRuntime;
|
||||
use crate::state::{AgentState, Message, Role};
|
||||
use crate::tokenization::{ApproxTokenCounter, TokenCounter};
|
||||
|
||||
/// Summarization Middleware for token budget management.
|
||||
///
|
||||
@@ -67,6 +68,7 @@ pub struct SummarizationMiddleware {
|
||||
llm_provider: Arc<dyn LLMProvider>,
|
||||
/// Configuration
|
||||
config: SummarizationConfig,
|
||||
token_counter: Arc<dyn TokenCounter>,
|
||||
}
|
||||
|
||||
impl SummarizationMiddleware {
|
||||
@@ -77,7 +79,27 @@ impl SummarizationMiddleware {
|
||||
/// * `llm_provider` - LLM provider for generating summaries
|
||||
/// * `config` - Configuration for triggers, keep size, and prompts
|
||||
pub fn new(llm_provider: Arc<dyn LLMProvider>, config: SummarizationConfig) -> Self {
|
||||
Self { llm_provider, config }
|
||||
let token_counter = Arc::new(ApproxTokenCounter::new(
|
||||
config.chars_per_token,
|
||||
config.overhead_per_message as usize,
|
||||
));
|
||||
Self {
|
||||
llm_provider,
|
||||
config,
|
||||
token_counter,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_token_counter(
|
||||
llm_provider: Arc<dyn LLMProvider>,
|
||||
config: SummarizationConfig,
|
||||
token_counter: Arc<dyn TokenCounter>,
|
||||
) -> Self {
|
||||
Self {
|
||||
llm_provider,
|
||||
config,
|
||||
token_counter,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration.
|
||||
@@ -92,11 +114,7 @@ impl SummarizationMiddleware {
|
||||
|
||||
/// Count tokens in the current messages.
|
||||
fn count_tokens(&self, messages: &[Message]) -> usize {
|
||||
count_tokens_approximately(
|
||||
messages,
|
||||
self.config.chars_per_token,
|
||||
self.config.overhead_per_message,
|
||||
)
|
||||
self.token_counter.count_messages(messages)
|
||||
}
|
||||
|
||||
/// Check if summarization should be triggered.
|
||||
@@ -152,11 +170,7 @@ impl SummarizationMiddleware {
|
||||
let mut count = 0;
|
||||
|
||||
for msg in messages.iter().rev() {
|
||||
let msg_tokens = count_tokens_approximately(
|
||||
std::slice::from_ref(msg),
|
||||
self.config.chars_per_token,
|
||||
self.config.overhead_per_message,
|
||||
);
|
||||
let msg_tokens = self.token_counter.count_message(msg);
|
||||
|
||||
if total_tokens + msg_tokens > token_budget {
|
||||
break;
|
||||
@@ -180,9 +194,8 @@ impl SummarizationMiddleware {
|
||||
|
||||
let mut cutoff = initial_cutoff;
|
||||
|
||||
// Advance past Tool messages (they should stay with their AI message)
|
||||
while cutoff < messages.len() && messages[cutoff].role == Role::Tool {
|
||||
cutoff += 1;
|
||||
while cutoff > 0 && messages[cutoff].role == Role::Tool {
|
||||
cutoff -= 1;
|
||||
}
|
||||
|
||||
cutoff
|
||||
@@ -233,11 +246,7 @@ impl SummarizationMiddleware {
|
||||
|
||||
// Take messages from the end (most recent first), respecting token budget
|
||||
for msg in messages.iter().rev() {
|
||||
let msg_tokens = count_tokens_approximately(
|
||||
std::slice::from_ref(msg),
|
||||
self.config.chars_per_token,
|
||||
self.config.overhead_per_message,
|
||||
);
|
||||
let msg_tokens = self.token_counter.count_message(msg);
|
||||
|
||||
if total_tokens + msg_tokens > max_tokens {
|
||||
break;
|
||||
@@ -297,11 +306,12 @@ impl AgentMiddleware for SummarizationMiddleware {
|
||||
prompt
|
||||
}
|
||||
|
||||
async fn after_agent(
|
||||
async fn before_model(
|
||||
&self,
|
||||
request: &mut ModelRequest,
|
||||
state: &mut AgentState,
|
||||
_runtime: &ToolRuntime,
|
||||
) -> Result<Option<StateUpdate>, MiddlewareError> {
|
||||
) -> Result<ModelControl, MiddlewareError> {
|
||||
let token_count = self.count_tokens(&state.messages);
|
||||
let message_count = state.messages.len();
|
||||
|
||||
@@ -314,7 +324,7 @@ impl AgentMiddleware for SummarizationMiddleware {
|
||||
|
||||
// Check if we should summarize
|
||||
if !self.should_summarize(token_count, message_count) {
|
||||
return Ok(None);
|
||||
return Ok(ModelControl::Continue);
|
||||
}
|
||||
|
||||
info!(
|
||||
@@ -328,7 +338,7 @@ impl AgentMiddleware for SummarizationMiddleware {
|
||||
|
||||
if to_summarize.is_empty() {
|
||||
debug!("No messages to summarize");
|
||||
return Ok(None);
|
||||
return Ok(ModelControl::Continue);
|
||||
}
|
||||
|
||||
debug!(
|
||||
@@ -342,7 +352,7 @@ impl AgentMiddleware for SummarizationMiddleware {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to generate summary, keeping original messages");
|
||||
return Ok(None);
|
||||
return Ok(ModelControl::Continue);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -362,7 +372,10 @@ impl AgentMiddleware for SummarizationMiddleware {
|
||||
"Summarization complete"
|
||||
);
|
||||
|
||||
Ok(Some(StateUpdate::SetMessages(new_messages)))
|
||||
state.messages = new_messages.clone();
|
||||
request.messages = new_messages.clone();
|
||||
|
||||
Ok(ModelControl::ModifyRequest(request.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -379,6 +392,8 @@ impl std::fmt::Debug for SummarizationMiddleware {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::{LLMConfig, LLMResponse};
|
||||
use crate::middleware::{ModelControl, ModelRequest};
|
||||
use crate::runtime::ToolRuntime;
|
||||
|
||||
/// Mock LLM provider for testing
|
||||
struct MockProvider {
|
||||
@@ -449,10 +464,10 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safe_cutoff_skips_tool_messages() {
|
||||
fn test_safe_cutoff_moves_backward_for_tool_messages() {
|
||||
let provider = Arc::new(MockProvider::new("Summary"));
|
||||
let config = SummarizationConfig::builder()
|
||||
.keep(KeepSize::Messages(2))
|
||||
.keep(KeepSize::Messages(3))
|
||||
.build();
|
||||
let middleware = SummarizationMiddleware::new(provider, config);
|
||||
|
||||
@@ -465,26 +480,16 @@ mod tests {
|
||||
arguments: serde_json::json!({"path": "/test"}),
|
||||
}
|
||||
]),
|
||||
Message::tool("File contents", "call_1"), // Tool result
|
||||
Message::tool("File contents", "call_1"),
|
||||
Message::assistant("Here's what I found"),
|
||||
Message::user("Thanks"),
|
||||
];
|
||||
|
||||
let (to_summarize, preserved) = middleware.partition_messages(&messages);
|
||||
|
||||
// Should not split in the middle of AI/Tool pair
|
||||
// If cutoff lands on Tool message, it advances past it
|
||||
for msg in &to_summarize {
|
||||
if msg.role == Role::Tool {
|
||||
// Check there's an AI message before it in to_summarize
|
||||
let has_ai_before = to_summarize.iter().any(|m| m.role == Role::Assistant);
|
||||
assert!(has_ai_before, "Tool message should have AI message before it");
|
||||
}
|
||||
}
|
||||
|
||||
// Preserved messages should include the recent context
|
||||
assert!(!preserved.is_empty(), "Should have preserved messages");
|
||||
// Combined should equal original
|
||||
assert_eq!(to_summarize.len(), 1);
|
||||
assert_eq!(preserved.len(), 4);
|
||||
assert!(preserved[0].tool_calls.is_some());
|
||||
assert_eq!(
|
||||
to_summarize.len() + preserved.len(),
|
||||
messages.len(),
|
||||
@@ -492,6 +497,39 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_before_model_summarizes_request_messages() {
|
||||
let provider = Arc::new(MockProvider::new("Summary text"));
|
||||
let config = SummarizationConfig::builder()
|
||||
.trigger(TriggerCondition::Messages(2))
|
||||
.keep(KeepSize::Messages(1))
|
||||
.build();
|
||||
let middleware = SummarizationMiddleware::new(provider, config);
|
||||
|
||||
let mut state = AgentState::with_messages(vec![
|
||||
Message::user("First"),
|
||||
Message::assistant("Second"),
|
||||
Message::user("Third"),
|
||||
]);
|
||||
|
||||
let mut request = ModelRequest::new(state.messages.clone(), vec![]);
|
||||
let backend = Arc::new(crate::backends::MemoryBackend::new());
|
||||
let runtime = ToolRuntime::new(state.clone(), backend);
|
||||
|
||||
let control = middleware
|
||||
.before_model(&mut request, &mut state, &runtime)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(control, ModelControl::ModifyRequest(_)));
|
||||
assert_eq!(request.messages.len(), state.messages.len());
|
||||
assert_eq!(request.messages[0].role, state.messages[0].role);
|
||||
assert_eq!(request.messages[0].content, state.messages[0].content);
|
||||
assert_eq!(request.messages[1].content, state.messages[1].content);
|
||||
assert_eq!(state.messages.len(), 2);
|
||||
assert!(state.messages[0].content.contains("Summary text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages() {
|
||||
let provider = Arc::new(MockProvider::new("Summary"));
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::middleware::{AgentMiddleware, DynTool};
|
||||
use crate::tools::WriteTodosTool;
|
||||
use crate::tools::{ReadTodosTool, WriteTodosTool};
|
||||
|
||||
/// Default system prompt for todo planning.
|
||||
pub const TODO_SYSTEM_PROMPT: &str = "## Planning with `write_todos`\n\
|
||||
@@ -17,7 +17,7 @@ Update the list as you work: mark items in_progress before starting and complete
|
||||
|
||||
/// Middleware that injects the write_todos tool and planning guidance.
|
||||
pub struct TodoListMiddleware {
|
||||
tool: DynTool,
|
||||
tools: Vec<DynTool>,
|
||||
system_prompt: String,
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ impl TodoListMiddleware {
|
||||
/// Create a TodoListMiddleware with a custom system prompt.
|
||||
pub fn with_system_prompt(prompt: impl Into<String>) -> Self {
|
||||
Self {
|
||||
tool: Arc::new(WriteTodosTool),
|
||||
tools: vec![Arc::new(ReadTodosTool), Arc::new(WriteTodosTool)],
|
||||
system_prompt: prompt.into(),
|
||||
}
|
||||
}
|
||||
@@ -43,7 +43,7 @@ impl AgentMiddleware for TodoListMiddleware {
|
||||
}
|
||||
|
||||
fn tools(&self) -> Vec<DynTool> {
|
||||
vec![self.tool.clone()]
|
||||
self.tools.clone()
|
||||
}
|
||||
|
||||
fn modify_system_prompt(&self, prompt: String) -> String {
|
||||
@@ -63,8 +63,9 @@ mod tests {
|
||||
fn test_todo_list_injects_tool() {
|
||||
let middleware = TodoListMiddleware::new();
|
||||
let tools = middleware.tools();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].definition().name, "write_todos");
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].definition().name, "read_todos");
|
||||
assert_eq!(tools[1].definition().name, "write_todos");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -466,7 +466,7 @@ pub trait AgentMiddleware: Send + Sync {
|
||||
async fn before_model(
|
||||
&self,
|
||||
_request: &mut ModelRequest,
|
||||
_state: &AgentState,
|
||||
_state: &mut AgentState,
|
||||
_runtime: &ToolRuntime,
|
||||
) -> Result<ModelControl, MiddlewareError> {
|
||||
Ok(ModelControl::Continue)
|
||||
|
||||
@@ -15,32 +15,40 @@ use tokio::sync::RwLock;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::types::{SkillContent, SkillMetadata, SkillSource};
|
||||
use crate::backends::Backend;
|
||||
use crate::error::MiddlewareError;
|
||||
|
||||
/// Type alias for metadata cache entry (metadata, file path, source)
|
||||
type MetadataCacheEntry = (SkillMetadata, PathBuf, SkillSource);
|
||||
|
||||
/// Skill loader with caching support
|
||||
pub enum SkillStorage {
|
||||
Filesystem {
|
||||
user_dir: Option<PathBuf>,
|
||||
project_dir: Option<PathBuf>,
|
||||
},
|
||||
Backend {
|
||||
backend: Arc<dyn Backend>,
|
||||
sources: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct SkillLoader {
|
||||
/// User skills directory (e.g., ~/.claude/skills)
|
||||
user_dir: Option<PathBuf>,
|
||||
|
||||
/// Project skills directory (e.g., ./skills)
|
||||
project_dir: Option<PathBuf>,
|
||||
|
||||
/// Cached metadata (loaded eagerly on init)
|
||||
storage: SkillStorage,
|
||||
metadata_cache: Arc<RwLock<HashMap<String, MetadataCacheEntry>>>,
|
||||
|
||||
/// Cached full content (loaded lazily on demand)
|
||||
content_cache: Arc<RwLock<HashMap<String, SkillContent>>>,
|
||||
}
|
||||
|
||||
impl SkillLoader {
|
||||
/// Create a new skill loader with specified directories
|
||||
pub fn new(user_dir: Option<PathBuf>, project_dir: Option<PathBuf>) -> Self {
|
||||
Self {
|
||||
user_dir,
|
||||
project_dir,
|
||||
storage: SkillStorage::Filesystem { user_dir, project_dir },
|
||||
metadata_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
content_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_backend(backend: Arc<dyn Backend>, sources: Vec<String>) -> Self {
|
||||
Self {
|
||||
storage: SkillStorage::Backend { backend, sources },
|
||||
metadata_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
content_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
@@ -70,18 +78,24 @@ impl SkillLoader {
|
||||
let mut cache = self.metadata_cache.write().await;
|
||||
cache.clear();
|
||||
|
||||
// Scan user skills first (lower priority)
|
||||
if let Some(user_dir) = &self.user_dir {
|
||||
if user_dir.exists() {
|
||||
self.scan_directory(user_dir, SkillSource::User, &mut cache)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
match &self.storage {
|
||||
SkillStorage::Filesystem { user_dir, project_dir } => {
|
||||
if let Some(user_dir) = user_dir {
|
||||
if user_dir.exists() {
|
||||
self.scan_directory(user_dir, SkillSource::User, &mut cache)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
|
||||
// Scan project skills (higher priority, can override user skills)
|
||||
if let Some(project_dir) = &self.project_dir {
|
||||
if project_dir.exists() {
|
||||
self.scan_directory(project_dir, SkillSource::Project, &mut cache)
|
||||
if let Some(project_dir) = project_dir {
|
||||
if project_dir.exists() {
|
||||
self.scan_directory(project_dir, SkillSource::Project, &mut cache)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
SkillStorage::Backend { backend, sources } => {
|
||||
self.scan_backend_sources(backend, sources, &mut cache)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
@@ -136,6 +150,55 @@ impl SkillLoader {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn scan_backend_sources(
|
||||
&self,
|
||||
backend: &Arc<dyn Backend>,
|
||||
sources: &[String],
|
||||
cache: &mut HashMap<String, (SkillMetadata, PathBuf, SkillSource)>,
|
||||
) -> Result<(), MiddlewareError> {
|
||||
for source in sources {
|
||||
let entries = match backend.ls(source).await {
|
||||
Ok(entries) => entries,
|
||||
Err(e) => {
|
||||
warn!("Failed to list backend source {}: {}", source, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for entry in entries {
|
||||
if !entry.is_dir {
|
||||
continue;
|
||||
}
|
||||
|
||||
let skill_file = format!("{}/SKILL.md", entry.path.trim_end_matches('/'));
|
||||
match backend.read_plain(&skill_file).await {
|
||||
Ok(content) => match parse_frontmatter(&content) {
|
||||
Ok(skill_meta) => {
|
||||
debug!(
|
||||
"Loaded skill metadata: {} from {} ({})",
|
||||
skill_meta.name,
|
||||
skill_file,
|
||||
SkillSource::Backend.as_str()
|
||||
);
|
||||
cache.insert(
|
||||
skill_meta.name.clone(),
|
||||
(skill_meta, PathBuf::from(&skill_file), SkillSource::Backend),
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse skill {}: {}", skill_file, e);
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Failed to read skill {}: {}", skill_file, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse only metadata from YAML frontmatter (fast)
|
||||
async fn parse_metadata(&self, path: &Path) -> Result<SkillMetadata, MiddlewareError> {
|
||||
let content = tokio::fs::read_to_string(path)
|
||||
@@ -191,10 +254,18 @@ impl SkillLoader {
|
||||
}
|
||||
};
|
||||
|
||||
// Load and parse full content
|
||||
let raw_content = tokio::fs::read_to_string(&path)
|
||||
.await
|
||||
.map_err(|e| MiddlewareError::ToolExecution(format!("Failed to read skill: {}", e)))?;
|
||||
let raw_content = match &self.storage {
|
||||
SkillStorage::Filesystem { .. } => tokio::fs::read_to_string(&path)
|
||||
.await
|
||||
.map_err(|e| MiddlewareError::ToolExecution(format!("Failed to read skill: {}", e)))?,
|
||||
SkillStorage::Backend { backend, .. } => {
|
||||
let path_str = path.to_string_lossy();
|
||||
backend
|
||||
.read_plain(&path_str)
|
||||
.await
|
||||
.map_err(|e| MiddlewareError::ToolExecution(format!("Failed to read skill: {}", e)))?
|
||||
}
|
||||
};
|
||||
|
||||
let body = parse_body(&raw_content);
|
||||
let content = SkillContent::new(metadata, body, path.to_string_lossy().to_string());
|
||||
@@ -304,6 +375,7 @@ fn parse_body(content: &str) -> String {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::backends::{Backend, MemoryBackend};
|
||||
|
||||
#[test]
|
||||
fn test_parse_frontmatter_valid() {
|
||||
@@ -467,4 +539,69 @@ This is the test skill body.
|
||||
let result = loader.load_skill("nonexistent").await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_backend_loader_layering() {
|
||||
let backend: Arc<dyn Backend> = Arc::new(MemoryBackend::new());
|
||||
|
||||
backend
|
||||
.write(
|
||||
"/source-a/shared/SKILL.md",
|
||||
r#"---
|
||||
name: shared
|
||||
description: First description
|
||||
---
|
||||
# Shared Skill
|
||||
|
||||
First body.
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
backend
|
||||
.write(
|
||||
"/source-b/shared/SKILL.md",
|
||||
r#"---
|
||||
name: shared
|
||||
description: Second description
|
||||
---
|
||||
# Shared Skill
|
||||
|
||||
Second body.
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
backend
|
||||
.write(
|
||||
"/source-b/unique/SKILL.md",
|
||||
r#"---
|
||||
name: unique
|
||||
description: Unique description
|
||||
---
|
||||
# Unique Skill
|
||||
|
||||
Unique body.
|
||||
"#,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let loader = SkillLoader::from_backend(
|
||||
Arc::clone(&backend),
|
||||
vec!["/source-a".to_string(), "/source-b".to_string()],
|
||||
);
|
||||
loader.initialize().await.unwrap();
|
||||
|
||||
let metadata = loader.get_metadata("shared").await.unwrap();
|
||||
assert_eq!(metadata.description, "Second description");
|
||||
|
||||
let content = loader.load_skill("shared").await.unwrap();
|
||||
assert!(content.body.contains("Second body"));
|
||||
|
||||
let unique = loader.get_metadata("unique").await.unwrap();
|
||||
assert_eq!(unique.description, "Unique description");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,8 +11,9 @@ use tokio::sync::RwLock;
|
||||
use super::loader::SkillLoader;
|
||||
use super::types::{SkillMetadata, SkillSource};
|
||||
use crate::error::MiddlewareError;
|
||||
use crate::middleware::{AgentMiddleware, DynTool, Tool, ToolDefinition, ToolResult};
|
||||
use crate::middleware::{AgentMiddleware, DynTool, Tool, ToolDefinition, ToolResult, StateUpdate};
|
||||
use crate::runtime::ToolRuntime;
|
||||
use crate::state::AgentState;
|
||||
|
||||
/// Skills middleware for progressive skill disclosure
|
||||
///
|
||||
@@ -126,6 +127,16 @@ impl AgentMiddleware for SkillsMiddleware {
|
||||
None => prompt,
|
||||
}
|
||||
}
|
||||
|
||||
async fn before_agent(
|
||||
&self,
|
||||
_state: &mut AgentState,
|
||||
_runtime: &ToolRuntime,
|
||||
) -> Result<Option<StateUpdate>, MiddlewareError> {
|
||||
self.loader.initialize().await?;
|
||||
self.refresh_cache().await;
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for loading skill content on-demand
|
||||
|
||||
@@ -93,6 +93,7 @@ pub enum SkillSource {
|
||||
User,
|
||||
/// Project-level skills (PROJECT_ROOT/skills/)
|
||||
Project,
|
||||
Backend,
|
||||
}
|
||||
|
||||
impl SkillSource {
|
||||
@@ -101,6 +102,7 @@ impl SkillSource {
|
||||
match self {
|
||||
Self::User => "user",
|
||||
Self::Project => "project",
|
||||
Self::Backend => "backend",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -186,5 +188,6 @@ description: Minimal skill
|
||||
fn test_skill_source() {
|
||||
assert_eq!(SkillSource::User.as_str(), "user");
|
||||
assert_eq!(SkillSource::Project.as_str(), "project");
|
||||
assert_eq!(SkillSource::Backend.as_str(), "backend");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,15 +105,29 @@ pub struct Message {
|
||||
pub tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub status: Option<String>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn user(content: &str) -> Self {
|
||||
Self { role: Role::User, content: content.to_string(), tool_call_id: None, tool_calls: None }
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: content.to_string(),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
status: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(content: &str) -> Self {
|
||||
Self { role: Role::Assistant, content: content.to_string(), tool_call_id: None, tool_calls: None }
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: content.to_string(),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
status: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant_with_tool_calls(content: &str, tool_calls: Vec<ToolCall>) -> Self {
|
||||
@@ -121,12 +135,19 @@ impl Message {
|
||||
role: Role::Assistant,
|
||||
content: content.to_string(),
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls)
|
||||
tool_calls: Some(tool_calls),
|
||||
status: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: &str) -> Self {
|
||||
Self { role: Role::System, content: content.to_string(), tool_call_id: None, tool_calls: None }
|
||||
Self {
|
||||
role: Role::System,
|
||||
content: content.to_string(),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
status: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(content: &str, tool_call_id: &str) -> Self {
|
||||
@@ -135,6 +156,17 @@ impl Message {
|
||||
content: content.to_string(),
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
tool_calls: None,
|
||||
status: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_with_status(content: &str, tool_call_id: &str, status: &str) -> Self {
|
||||
Self {
|
||||
role: Role::Tool,
|
||||
content: content.to_string(),
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
tool_calls: None,
|
||||
status: Some(status.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
139
rust-research-agent/rig-deepagents/src/tokenization/mod.rs
Normal file
139
rust-research-agent/rig-deepagents/src/tokenization/mod.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use crate::middleware::summarization::token_counter::{
|
||||
count_tokens_approximately, DEFAULT_CHARS_PER_TOKEN, DEFAULT_OVERHEAD_PER_MESSAGE,
|
||||
};
|
||||
use crate::state::Message;
|
||||
#[cfg(feature = "tokenizer-tiktoken")]
|
||||
use crate::state::Role;
|
||||
|
||||
pub trait TokenCounter: Send + Sync {
|
||||
fn count_text(&self, text: &str) -> usize;
|
||||
fn count_message(&self, message: &Message) -> usize;
|
||||
fn count_messages(&self, messages: &[Message]) -> usize {
|
||||
messages.iter().map(|msg| self.count_message(msg)).sum()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ApproxTokenCounter {
|
||||
pub chars_per_token: f32,
|
||||
pub overhead_per_message: usize,
|
||||
}
|
||||
|
||||
impl ApproxTokenCounter {
|
||||
pub fn new(chars_per_token: f32, overhead_per_message: usize) -> Self {
|
||||
Self {
|
||||
chars_per_token,
|
||||
overhead_per_message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ApproxTokenCounter {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
chars_per_token: DEFAULT_CHARS_PER_TOKEN,
|
||||
overhead_per_message: DEFAULT_OVERHEAD_PER_MESSAGE as usize,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for ApproxTokenCounter {
|
||||
fn count_text(&self, text: &str) -> usize {
|
||||
(text.len() as f32 / self.chars_per_token).ceil() as usize
|
||||
}
|
||||
|
||||
fn count_message(&self, message: &Message) -> usize {
|
||||
count_tokens_approximately(
|
||||
std::slice::from_ref(message),
|
||||
self.chars_per_token,
|
||||
self.overhead_per_message as f32,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokenizer-tiktoken")]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TiktokenTokenCounter {
|
||||
encoder: tiktoken_rs::CoreBPE,
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokenizer-tiktoken")]
|
||||
impl TiktokenTokenCounter {
|
||||
pub fn new(encoder: tiktoken_rs::CoreBPE) -> Self {
|
||||
Self { encoder }
|
||||
}
|
||||
|
||||
pub fn cl100k_base() -> Result<Self, tiktoken_rs::Error> {
|
||||
Ok(Self {
|
||||
encoder: tiktoken_rs::cl100k_base()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokenizer-tiktoken")]
|
||||
impl TokenCounter for TiktokenTokenCounter {
|
||||
fn count_text(&self, text: &str) -> usize {
|
||||
self.encoder.encode_with_special_tokens(text).len()
|
||||
}
|
||||
|
||||
fn count_message(&self, message: &Message) -> usize {
|
||||
let message_text = build_message_text(message);
|
||||
self.count_text(&message_text)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokenizer-tiktoken")]
|
||||
fn role_name(role: &Role) -> &'static str {
|
||||
match role {
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::System => "system",
|
||||
Role::Tool => "tool",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokenizer-tiktoken")]
|
||||
fn build_message_text(message: &Message) -> String {
|
||||
let mut text = String::new();
|
||||
text.push_str(&message.content);
|
||||
text.push_str(role_name(&message.role));
|
||||
|
||||
if let Some(ref tool_call_id) = message.tool_call_id {
|
||||
text.push_str(tool_call_id);
|
||||
}
|
||||
|
||||
if let Some(ref tool_calls) = message.tool_calls {
|
||||
for tc in tool_calls {
|
||||
text.push_str(&tc.id);
|
||||
text.push_str(&tc.name);
|
||||
if let Ok(args_str) = serde_json::to_string(&tc.arguments) {
|
||||
text.push_str(&args_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
text
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::state::Message;
|
||||
|
||||
#[test]
|
||||
fn test_approx_counter_counts_non_zero() {
|
||||
let counter = ApproxTokenCounter::new(4.0, 3);
|
||||
let messages = vec![Message::user("Hello there")];
|
||||
assert!(counter.count_messages(&messages) > 0);
|
||||
assert!(counter.count_text("Hello there") > 0);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokenizer-tiktoken")]
|
||||
#[test]
|
||||
fn test_tiktoken_counter_counts_non_zero() {
|
||||
let counter = TiktokenTokenCounter::cl100k_base().unwrap();
|
||||
let messages = vec![Message::assistant("Hello there")];
|
||||
assert!(counter.count_messages(&messages) > 0);
|
||||
assert!(counter.count_text("Hello there") > 0);
|
||||
}
|
||||
}
|
||||
@@ -17,6 +17,7 @@ mod edit_file;
|
||||
mod ls;
|
||||
mod glob;
|
||||
mod grep;
|
||||
mod read_todos;
|
||||
mod write_todos;
|
||||
mod task;
|
||||
|
||||
@@ -30,6 +31,7 @@ pub use edit_file::EditFileTool;
|
||||
pub use ls::LsTool;
|
||||
pub use glob::GlobTool;
|
||||
pub use grep::GrepTool;
|
||||
pub use read_todos::ReadTodosTool;
|
||||
pub use write_todos::WriteTodosTool;
|
||||
pub use task::TaskTool;
|
||||
|
||||
@@ -49,6 +51,7 @@ pub fn default_tools() -> Vec<DynTool> {
|
||||
Arc::new(LsTool),
|
||||
Arc::new(GlobTool),
|
||||
Arc::new(GrepTool),
|
||||
Arc::new(ReadTodosTool),
|
||||
Arc::new(WriteTodosTool),
|
||||
]
|
||||
}
|
||||
|
||||
61
rust-research-agent/rig-deepagents/src/tools/read_todos.rs
Normal file
61
rust-research-agent/rig-deepagents/src/tools/read_todos.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::error::MiddlewareError;
|
||||
use crate::middleware::{Tool, ToolDefinition, ToolResult};
|
||||
use crate::runtime::ToolRuntime;
|
||||
|
||||
pub struct ReadTodosTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ReadTodosTool {
|
||||
fn definition(&self) -> ToolDefinition {
|
||||
ToolDefinition {
|
||||
name: "read_todos".to_string(),
|
||||
description: "Read the current todo list state.".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
&self,
|
||||
_args: serde_json::Value,
|
||||
runtime: &ToolRuntime,
|
||||
) -> Result<ToolResult, MiddlewareError> {
|
||||
let todos = &runtime.state().todos;
|
||||
let json = serde_json::to_string(todos)
|
||||
.map_err(|e| MiddlewareError::ToolExecution(format!("Failed to serialize todos: {e}")))?;
|
||||
Ok(ToolResult::new(json))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::backends::MemoryBackend;
|
||||
use crate::state::{AgentState, Todo, TodoStatus};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_read_todos_returns_state() {
|
||||
let tool = ReadTodosTool;
|
||||
let backend = Arc::new(MemoryBackend::new());
|
||||
let mut state = AgentState::new();
|
||||
state.todos = vec![
|
||||
Todo::with_status("First", TodoStatus::Pending),
|
||||
Todo::with_status("Second", TodoStatus::Completed),
|
||||
];
|
||||
let runtime = ToolRuntime::new(state, backend);
|
||||
|
||||
let result = tool.execute(serde_json::json!({}), &runtime).await.unwrap();
|
||||
let todos: Vec<Todo> = serde_json::from_str(&result.message).unwrap();
|
||||
|
||||
assert_eq!(todos.len(), 2);
|
||||
assert_eq!(todos[0].content, "First");
|
||||
assert_eq!(todos[0].status, TodoStatus::Pending);
|
||||
assert_eq!(todos[1].content, "Second");
|
||||
assert_eq!(todos[1].status, TodoStatus::Completed);
|
||||
}
|
||||
}
|
||||
@@ -228,6 +228,7 @@ impl<S: WorkflowState + serde::Serialize> Vertex<S, WorkflowMessage> for AgentVe
|
||||
content: self.config.system_prompt.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
status: None,
|
||||
}];
|
||||
|
||||
// Add any incoming workflow messages as user messages
|
||||
@@ -238,6 +239,7 @@ impl<S: WorkflowState + serde::Serialize> Vertex<S, WorkflowMessage> for AgentVe
|
||||
content: value.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
status: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -249,6 +251,7 @@ impl<S: WorkflowState + serde::Serialize> Vertex<S, WorkflowMessage> for AgentVe
|
||||
content: "Begin processing.".to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
status: None,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -354,6 +357,7 @@ mod tests {
|
||||
content: content.into(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
status: None,
|
||||
};
|
||||
self.responses.lock().unwrap().push(message);
|
||||
self
|
||||
@@ -369,6 +373,7 @@ mod tests {
|
||||
arguments: serde_json::json!({}),
|
||||
}]),
|
||||
tool_call_id: None,
|
||||
status: None,
|
||||
};
|
||||
self.responses.lock().unwrap().push(message);
|
||||
self
|
||||
@@ -563,6 +568,7 @@ mod tests {
|
||||
content: "Done".to_string(),
|
||||
tool_calls: Some(vec![]),
|
||||
tool_call_id: None,
|
||||
status: None,
|
||||
};
|
||||
|
||||
// State with non-matching phase
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/backends/__init__.py
Normal file
0
tests/backends/__init__.py
Normal file
14
tests/backends/conftest.py
Normal file
14
tests/backends/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import docker # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
_VENDORED_DEEPAGENTS = _REPO_ROOT / "deepagents_sourcecode" / "libs" / "deepagents"
|
||||
if _VENDORED_DEEPAGENTS.exists() and str(_VENDORED_DEEPAGENTS) not in sys.path:
|
||||
sys.path.insert(0, str(_VENDORED_DEEPAGENTS))
|
||||
if str(_REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_REPO_ROOT))
|
||||
496
tests/backends/test_docker_sandbox_integration.py
Normal file
496
tests/backends/test_docker_sandbox_integration.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""DockerSandboxBackend/DockerSandboxSession 실환경 통합 테스트.
|
||||
|
||||
주의: 이 테스트는 실제 Docker 데몬과 컨테이너 실행이 필요합니다.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from context_engineering_research_agent.backends.docker_sandbox import (
|
||||
DockerSandboxBackend,
|
||||
)
|
||||
from context_engineering_research_agent.backends.docker_session import (
|
||||
DockerSandboxSession,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.offloading import (
|
||||
ContextOffloadingStrategy,
|
||||
OffloadingConfig,
|
||||
)
|
||||
|
||||
|
||||
def _docker_available() -> bool:
|
||||
"""Docker 사용 가능 여부를 확인합니다."""
|
||||
try:
|
||||
import docker
|
||||
|
||||
client = docker.from_env()
|
||||
# ping은 Docker 데몬 연결 여부를 가장 빠르게 확인합니다.
|
||||
client.ping()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"DEBUG: Docker not available: {type(e).__name__}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
DOCKER_AVAILABLE = _docker_available()
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.integration,
|
||||
pytest.mark.skipif(
|
||||
not DOCKER_AVAILABLE,
|
||||
reason="Docker 데몬 또는 python docker SDK를 사용할 수 없습니다.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def docker_backend() -> Iterator[DockerSandboxBackend]:
|
||||
"""테스트용 Docker 샌드박스 백엔드를 제공합니다."""
|
||||
session = DockerSandboxSession()
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(session.start())
|
||||
backend = session.get_backend()
|
||||
yield backend
|
||||
finally:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(session.stop())
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_workspace(docker_backend: DockerSandboxBackend) -> None:
|
||||
"""각 테스트마다 워크스페이스 내 테스트 디렉토리를 초기화합니다."""
|
||||
rm_result = docker_backend.execute("rm -rf test_docker_sandbox")
|
||||
assert rm_result.exit_code == 0, f"rm failed: {rm_result.output}"
|
||||
mkdir_result = docker_backend.execute("mkdir -p test_docker_sandbox")
|
||||
assert mkdir_result.exit_code == 0, f"mkdir failed: {mkdir_result.output}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1) Code Execution Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_execute_basic_commands(docker_backend: DockerSandboxBackend) -> None:
|
||||
"""기본 명령( echo/ls/pwd )이 컨테이너 내부에서 정상 동작하는지 확인합니다."""
|
||||
echo = docker_backend.execute("echo 'hello'")
|
||||
assert echo.exit_code == 0
|
||||
assert echo.output.strip() == "hello"
|
||||
|
||||
pwd = docker_backend.execute("pwd")
|
||||
assert pwd.exit_code == 0
|
||||
# DockerSandboxBackend는 workdir=/workspace로 실행합니다.
|
||||
assert pwd.output.strip() == "/workspace"
|
||||
|
||||
docker_backend.execute("echo 'x' > test_docker_sandbox/file.txt")
|
||||
ls = docker_backend.execute("ls -la test_docker_sandbox")
|
||||
assert ls.exit_code == 0
|
||||
assert "file.txt" in ls.output
|
||||
|
||||
|
||||
def test_execute_python_and_exit_codes(docker_backend: DockerSandboxBackend) -> None:
|
||||
"""파이썬 실행 및 exit code(성공/실패)가 정확히 전달되는지 확인합니다."""
|
||||
py = docker_backend.execute('python3 -c "print(2 + 2)"')
|
||||
assert py.exit_code == 0
|
||||
assert py.output.strip() == "4"
|
||||
|
||||
fail = docker_backend.execute('python3 -c "import sys; sys.exit(42)"')
|
||||
assert fail.exit_code == 42
|
||||
|
||||
|
||||
def test_execute_truncates_large_output(docker_backend: DockerSandboxBackend) -> None:
|
||||
"""대용량 출력이 100,000자 기준으로 잘리는지(truncated) 확인합니다."""
|
||||
# 110k 이상 출력 생성
|
||||
big = docker_backend.execute("python3 -c \"print('x' * 110500)\"")
|
||||
assert big.exit_code == 0
|
||||
assert big.truncated is True
|
||||
assert "[출력이 잘렸습니다" in big.output
|
||||
assert len(big.output) <= 100000 + 200 # 안내 문구 포함 여유
|
||||
|
||||
|
||||
def test_execute_timeout_handling_via_alarm(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""장시간 작업이 자체 타임아웃(알람)으로 빠르게 종료되는지 확인합니다."""
|
||||
start = time.monotonic()
|
||||
res = docker_backend.execute(
|
||||
'python3 -c "\n'
|
||||
"import signal, time\n"
|
||||
"def _handler(signum, frame):\n"
|
||||
" raise TimeoutError('alarm')\n"
|
||||
"signal.signal(signal.SIGALRM, _handler)\n"
|
||||
"signal.alarm(1)\n"
|
||||
"time.sleep(10)\n"
|
||||
'"'
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert elapsed < 5, f"예상보다 오래 걸렸습니다: {elapsed:.2f}s"
|
||||
assert res.exit_code != 0
|
||||
assert "TimeoutError" in res.output or "alarm" in res.output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2) File Operations Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_upload_files_single_and_nested(docker_backend: DockerSandboxBackend) -> None:
|
||||
"""upload_files가 단일 파일 및 중첩 디렉토리 업로드를 지원하는지 확인합니다."""
|
||||
files = [
|
||||
("test_docker_sandbox/one.txt", b"one"),
|
||||
("test_docker_sandbox/nested/dir/two.bin", b"\x00\x01\x02"),
|
||||
]
|
||||
responses = docker_backend.upload_files(files)
|
||||
|
||||
assert [r.path for r in responses] == [p for p, _ in files]
|
||||
assert all(r.error is None for r in responses)
|
||||
|
||||
# 컨테이너 내 파일 존재/내용 확인
|
||||
cat_one = docker_backend.execute("cat test_docker_sandbox/one.txt")
|
||||
assert cat_one.exit_code == 0
|
||||
assert cat_one.output.strip() == "one"
|
||||
|
||||
# 이진 파일은 다운로드로 검증
|
||||
dl = docker_backend.download_files(["test_docker_sandbox/nested/dir/two.bin"])
|
||||
assert dl[0].error is None
|
||||
assert dl[0].content == b"\x00\x01\x02"
|
||||
|
||||
|
||||
def test_upload_download_multiple_roundtrip(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""여러 파일을 업로드한 뒤 다운로드하여 내용이 동일한지 확인합니다."""
|
||||
files = [
|
||||
("test_docker_sandbox/a.txt", b"A"),
|
||||
("test_docker_sandbox/b.txt", b"B"),
|
||||
("test_docker_sandbox/sub/c.txt", b"C"),
|
||||
]
|
||||
|
||||
up = docker_backend.upload_files(files)
|
||||
assert len(up) == 3
|
||||
assert all(r.error is None for r in up)
|
||||
|
||||
paths = [p for p, _ in files]
|
||||
dl = docker_backend.download_files(paths)
|
||||
assert [r.path for r in dl] == paths
|
||||
assert all(r.error is None for r in dl)
|
||||
|
||||
got = {r.path: r.content for r in dl}
|
||||
expected = {p: c for p, c in files}
|
||||
assert got == expected
|
||||
|
||||
|
||||
def test_download_files_nonexistent_and_directory(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""download_files가 없는 파일/디렉토리 대상에서 올바른 에러를 반환하는지 확인합니다."""
|
||||
docker_backend.execute("mkdir -p test_docker_sandbox/dir_only")
|
||||
|
||||
responses = docker_backend.download_files(
|
||||
[
|
||||
"test_docker_sandbox/does_not_exist.txt",
|
||||
"test_docker_sandbox/dir_only",
|
||||
]
|
||||
)
|
||||
|
||||
assert responses[0].error == "file_not_found"
|
||||
assert responses[0].content is None
|
||||
|
||||
assert responses[1].error == "is_directory"
|
||||
assert responses[1].content is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3) Context Offloading Tests (WITHOUT Agent)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_context_offloading_writes_large_tool_result_to_docker_filesystem(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""ContextOffloadingStrategy가 대용량 결과를 Docker 파일시스템에 저장하는지 확인합니다."""
|
||||
|
||||
# backend_factory는 runtime을 무시하고 현재 DockerSandboxBackend를 반환합니다.
|
||||
strategy = ContextOffloadingStrategy(
|
||||
config=OffloadingConfig(token_limit_before_evict=10, chars_per_token=1),
|
||||
backend_factory=lambda _runtime: docker_backend,
|
||||
)
|
||||
|
||||
content_lines = [f"line_{i:03d}: {'x' * 20}" for i in range(50)]
|
||||
large_content = "\n".join(content_lines)
|
||||
|
||||
tool_result = ToolMessage(
|
||||
content=large_content,
|
||||
tool_call_id="call/with:special@chars!",
|
||||
)
|
||||
|
||||
class MinimalRuntime:
|
||||
"""ToolRuntime 대체용 최소 객체입니다(backend_factory 호출을 위해서만 사용)."""
|
||||
|
||||
state: dict = {}
|
||||
config: dict = {}
|
||||
|
||||
processed, offload = strategy.process_tool_result(tool_result, MinimalRuntime()) # type: ignore[arg-type]
|
||||
|
||||
assert offload.was_offloaded is True
|
||||
assert offload.file_path is not None
|
||||
assert offload.file_path.startswith("/large_tool_results/")
|
||||
|
||||
# 반환 메시지는 원문 전체가 아니라 경로 참조를 포함해야 합니다.
|
||||
if hasattr(processed, "content"):
|
||||
replacement_text = processed.content # ToolMessage
|
||||
else:
|
||||
# Command(update={messages:[ToolMessage...]}) 형태
|
||||
update = processed.update # type: ignore[attr-defined]
|
||||
replacement_text = update["messages"][0].content
|
||||
|
||||
assert offload.file_path in replacement_text
|
||||
assert "read_file" in replacement_text
|
||||
assert len(replacement_text) < len(large_content)
|
||||
|
||||
# 실제 파일이 컨테이너에 저장되었는지 다운로드로 검증합니다.
|
||||
downloaded = docker_backend.download_files([offload.file_path])
|
||||
assert downloaded[0].error is None
|
||||
assert downloaded[0].content is not None
|
||||
assert downloaded[0].content.decode("utf-8") == large_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4) Session Lifecycle Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_initializes_workspace_dirs() -> None:
|
||||
"""세션 시작 시 /workspace/_meta 및 /workspace/shared 디렉토리가 생성되는지 확인합니다."""
|
||||
session = DockerSandboxSession()
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(session.start())
|
||||
backend = session.get_backend()
|
||||
|
||||
meta = backend.execute("test -d /workspace/_meta && echo ok")
|
||||
shared = backend.execute("test -d /workspace/shared && echo ok")
|
||||
assert meta.exit_code == 0
|
||||
assert shared.exit_code == 0
|
||||
assert meta.output.strip() == "ok"
|
||||
assert shared.output.strip() == "ok"
|
||||
finally:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(session.stop())
|
||||
|
||||
|
||||
def test_multiple_backends_share_same_container_workspace() -> None:
|
||||
"""동일 컨테이너 ID를 사용하는 여러 백엔드가 파일을 공유하는지 확인합니다."""
|
||||
try:
|
||||
import docker
|
||||
except Exception:
|
||||
pytest.skip("python docker SDK가 필요합니다")
|
||||
|
||||
session = DockerSandboxSession()
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(session.start())
|
||||
backend1 = session.get_backend()
|
||||
backend2 = DockerSandboxBackend(
|
||||
container_id=backend1.id,
|
||||
docker_client=docker.from_env(),
|
||||
)
|
||||
|
||||
backend1.execute("mkdir -p test_docker_sandbox")
|
||||
backend1.write("/workspace/test_docker_sandbox/shared.txt", "shared")
|
||||
|
||||
read_back = backend2.execute("cat /workspace/test_docker_sandbox/shared.txt")
|
||||
assert read_back.exit_code == 0
|
||||
assert read_back.output.strip() == "shared"
|
||||
finally:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(session.stop())
|
||||
|
||||
|
||||
def test_session_stop_removes_container() -> None:
|
||||
"""세션 종료 시 컨테이너가 중지/삭제되는지 확인합니다."""
|
||||
try:
|
||||
import docker
|
||||
except Exception:
|
||||
pytest.skip("python docker SDK가 필요합니다")
|
||||
|
||||
client = docker.from_env()
|
||||
session = DockerSandboxSession()
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(session.start())
|
||||
backend = session.get_backend()
|
||||
container_id = backend.id
|
||||
|
||||
# 실제로 컨테이너가 존재하는지 확인
|
||||
client.containers.get(container_id)
|
||||
|
||||
asyncio.run(session.stop())
|
||||
|
||||
# stop()이 swallow하므로, 실제 삭제 여부는 inspect로 확인
|
||||
with pytest.raises(Exception):
|
||||
client.containers.get(container_id)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5) Security Verification Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_container_security_options_applied(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""컨테이너 생성 시 네트워크/권한/메모리 제한 옵션이 적용되는지 확인합니다."""
|
||||
try:
|
||||
import docker
|
||||
except Exception:
|
||||
pytest.skip("python docker SDK가 필요합니다")
|
||||
|
||||
client = docker.from_env()
|
||||
container = client.containers.get(docker_backend.id)
|
||||
container.reload()
|
||||
host_cfg = container.attrs.get("HostConfig", {})
|
||||
|
||||
assert host_cfg.get("NetworkMode") == "none"
|
||||
|
||||
cap_drop = host_cfg.get("CapDrop") or []
|
||||
assert "ALL" in cap_drop
|
||||
|
||||
security_opt = host_cfg.get("SecurityOpt") or []
|
||||
assert any("no-new-privileges" in opt for opt in security_opt)
|
||||
|
||||
# Docker가 바이트 단위로 변환합니다(512m ≈ 536,870,912 bytes)
|
||||
memory = host_cfg.get("Memory")
|
||||
assert memory is not None
|
||||
assert memory >= 512 * 1024 * 1024
|
||||
|
||||
pids_limit = host_cfg.get("PidsLimit")
|
||||
assert pids_limit == 128
|
||||
|
||||
|
||||
def test_network_isolation_blocks_outbound(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""network_mode='none' 설정으로 외부 네트워크 연결이 차단되는지 확인합니다."""
|
||||
res = docker_backend.execute(
|
||||
'python3 -c "\n'
|
||||
"import socket\n"
|
||||
"s = socket.socket()\n"
|
||||
"s.settimeout(1.0)\n"
|
||||
"try:\n"
|
||||
" s.connect(('1.1.1.1', 53))\n"
|
||||
" print('UNEXPECTED_CONNECTED')\n"
|
||||
" raise SystemExit(1)\n"
|
||||
"except OSError as e:\n"
|
||||
" print('blocked', type(e).__name__)\n"
|
||||
" raise SystemExit(0)\n"
|
||||
"finally:\n"
|
||||
" s.close()\n"
|
||||
'"'
|
||||
)
|
||||
|
||||
assert res.exit_code == 0
|
||||
assert "blocked" in res.output
|
||||
assert "UNEXPECTED_CONNECTED" not in res.output
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6) LLM Output Formatting Tests (코드 실행 결과가 LLM에 전달되는지 검증)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _format_execute_result_for_llm(result) -> str:
|
||||
"""DeepAgents _execute_tool_generator와 동일한 포맷팅 로직.
|
||||
|
||||
FilesystemMiddleware의 execute tool이 LLM에 반환하는 형식을 재현합니다.
|
||||
"""
|
||||
parts = [result.output]
|
||||
|
||||
if result.exit_code is not None:
|
||||
status = "succeeded" if result.exit_code == 0 else "failed"
|
||||
parts.append(f"\n[Command {status} with exit code {result.exit_code}]")
|
||||
|
||||
if result.truncated:
|
||||
parts.append("\n[Output was truncated due to size limits]")
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def test_execute_result_formatted_for_llm_success(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""성공한 코드 실행 결과가 LLM이 인지할 수 있는 형태로 포맷팅되는지 확인합니다."""
|
||||
result = docker_backend.execute('python3 -c "print(42 * 2)"')
|
||||
|
||||
llm_output = _format_execute_result_for_llm(result)
|
||||
|
||||
assert "84" in llm_output
|
||||
assert "[Command succeeded with exit code 0]" in llm_output
|
||||
assert "truncated" not in llm_output.lower()
|
||||
|
||||
|
||||
def test_execute_result_formatted_for_llm_failure(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""실패한 코드 실행 결과가 LLM이 인지할 수 있는 형태로 포맷팅되는지 확인합니다."""
|
||||
result = docker_backend.execute("python3 -c \"raise ValueError('test error')\"")
|
||||
|
||||
llm_output = _format_execute_result_for_llm(result)
|
||||
|
||||
assert "ValueError" in llm_output
|
||||
assert "test error" in llm_output
|
||||
assert "[Command failed with exit code 1]" in llm_output
|
||||
|
||||
|
||||
def test_execute_result_formatted_for_llm_multiline_output(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""여러 줄 출력이 LLM에 그대로 전달되는지 확인합니다."""
|
||||
result = docker_backend.execute(
|
||||
"python3 -c \"for i in range(5): print(f'line {i}')\""
|
||||
)
|
||||
|
||||
llm_output = _format_execute_result_for_llm(result)
|
||||
|
||||
for i in range(5):
|
||||
assert f"line {i}" in llm_output
|
||||
assert "[Command succeeded with exit code 0]" in llm_output
|
||||
|
||||
|
||||
def test_execute_result_formatted_for_llm_truncation_notice(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""대용량 출력이 잘릴 때 LLM에 truncation 알림이 포함되는지 확인합니다."""
|
||||
result = docker_backend.execute("python3 -c \"print('x' * 110500)\"")
|
||||
|
||||
llm_output = _format_execute_result_for_llm(result)
|
||||
|
||||
assert result.truncated is True
|
||||
assert "[Output was truncated due to size limits]" in llm_output
|
||||
assert "[Command succeeded with exit code 0]" in llm_output
|
||||
|
||||
|
||||
def test_execute_result_contains_stderr_for_llm(
|
||||
docker_backend: DockerSandboxBackend,
|
||||
) -> None:
|
||||
"""stderr 출력이 LLM에 전달되는지 확인합니다."""
|
||||
result = docker_backend.execute(
|
||||
"python3 -c \"import sys; sys.stderr.write('error message\\n')\""
|
||||
)
|
||||
|
||||
llm_output = _format_execute_result_for_llm(result)
|
||||
|
||||
assert "error message" in llm_output
|
||||
0
tests/context_engineering/__init__.py
Normal file
0
tests/context_engineering/__init__.py
Normal file
665
tests/context_engineering/test_caching.py
Normal file
665
tests/context_engineering/test_caching.py
Normal file
@@ -0,0 +1,665 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from context_engineering_research_agent.context_strategies.caching import (
|
||||
CachingConfig,
|
||||
CachingResult,
|
||||
ContextCachingStrategy,
|
||||
OpenRouterSubProvider,
|
||||
ProviderType,
|
||||
detect_openrouter_sub_provider,
|
||||
detect_provider,
|
||||
requires_cache_control_marker,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.caching_telemetry import (
|
||||
CacheTelemetry,
|
||||
PromptCachingTelemetryMiddleware,
|
||||
extract_anthropic_cache_metrics,
|
||||
extract_cache_telemetry,
|
||||
extract_deepseek_cache_metrics,
|
||||
extract_gemini_cache_metrics,
|
||||
extract_openai_cache_metrics,
|
||||
)
|
||||
|
||||
|
||||
class TestCachingConfig:
|
||||
def test_default_values(self):
|
||||
config = CachingConfig()
|
||||
|
||||
assert config.min_cacheable_tokens == 1024
|
||||
assert config.cache_control_type == "ephemeral"
|
||||
assert config.enable_for_system_prompt is True
|
||||
assert config.enable_for_tools is True
|
||||
|
||||
def test_custom_values(self):
|
||||
config = CachingConfig(
|
||||
min_cacheable_tokens=2048,
|
||||
cache_control_type="permanent",
|
||||
enable_for_system_prompt=False,
|
||||
enable_for_tools=False,
|
||||
)
|
||||
|
||||
assert config.min_cacheable_tokens == 2048
|
||||
assert config.cache_control_type == "permanent"
|
||||
assert config.enable_for_system_prompt is False
|
||||
assert config.enable_for_tools is False
|
||||
|
||||
|
||||
class TestCachingResult:
|
||||
def test_not_cached(self):
|
||||
result = CachingResult(was_cached=False)
|
||||
|
||||
assert result.was_cached is False
|
||||
assert result.cached_content_type is None
|
||||
assert result.estimated_tokens_cached == 0
|
||||
|
||||
def test_cached(self):
|
||||
result = CachingResult(
|
||||
was_cached=True,
|
||||
cached_content_type="system_prompt",
|
||||
estimated_tokens_cached=5000,
|
||||
)
|
||||
|
||||
assert result.was_cached is True
|
||||
assert result.cached_content_type == "system_prompt"
|
||||
assert result.estimated_tokens_cached == 5000
|
||||
|
||||
|
||||
class TestContextCachingStrategy:
|
||||
@pytest.fixture
|
||||
def mock_anthropic_model(self) -> MagicMock:
|
||||
mock = MagicMock()
|
||||
mock.__class__.__name__ = "ChatAnthropic"
|
||||
mock.__class__.__module__ = "langchain_anthropic"
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def strategy(self, mock_anthropic_model: MagicMock) -> ContextCachingStrategy:
|
||||
return ContextCachingStrategy(model=mock_anthropic_model)
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_low_threshold(
|
||||
self, mock_anthropic_model: MagicMock
|
||||
) -> ContextCachingStrategy:
|
||||
return ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
model=mock_anthropic_model,
|
||||
)
|
||||
|
||||
def test_estimate_tokens_string(self, strategy: ContextCachingStrategy):
|
||||
content = "a" * 400
|
||||
estimated = strategy._estimate_tokens(content)
|
||||
|
||||
assert estimated == 100
|
||||
|
||||
def test_estimate_tokens_list(self, strategy: ContextCachingStrategy):
|
||||
content = [
|
||||
{"type": "text", "text": "a" * 200},
|
||||
{"type": "text", "text": "b" * 200},
|
||||
]
|
||||
estimated = strategy._estimate_tokens(content)
|
||||
|
||||
assert estimated == 100
|
||||
|
||||
def test_estimate_tokens_dict(self, strategy: ContextCachingStrategy):
|
||||
content = {"type": "text", "text": "a" * 400}
|
||||
estimated = strategy._estimate_tokens(content)
|
||||
|
||||
assert estimated == 100
|
||||
|
||||
def test_should_cache_small_content(self, strategy: ContextCachingStrategy):
|
||||
small_content = "short text"
|
||||
|
||||
assert strategy._should_cache(small_content) is False
|
||||
|
||||
def test_should_cache_large_content(self, strategy: ContextCachingStrategy):
|
||||
large_content = "x" * 5000
|
||||
|
||||
assert strategy._should_cache(large_content) is True
|
||||
|
||||
def test_add_cache_control_string(self, strategy: ContextCachingStrategy):
|
||||
content = "test content"
|
||||
cached = strategy._add_cache_control(content)
|
||||
|
||||
assert cached["type"] == "text"
|
||||
assert cached["text"] == "test content"
|
||||
assert cached["cache_control"]["type"] == "ephemeral"
|
||||
|
||||
def test_add_cache_control_dict(self, strategy: ContextCachingStrategy):
|
||||
content = {"type": "text", "text": "test content"}
|
||||
cached = strategy._add_cache_control(content)
|
||||
|
||||
assert cached["cache_control"]["type"] == "ephemeral"
|
||||
|
||||
def test_add_cache_control_list(self, strategy: ContextCachingStrategy):
|
||||
content = [
|
||||
{"type": "text", "text": "first"},
|
||||
{"type": "text", "text": "second"},
|
||||
]
|
||||
cached = strategy._add_cache_control(content)
|
||||
|
||||
assert "cache_control" not in cached[0]
|
||||
assert cached[1]["cache_control"]["type"] == "ephemeral"
|
||||
|
||||
def test_add_cache_control_empty_list(self, strategy: ContextCachingStrategy):
|
||||
content: list = []
|
||||
cached = strategy._add_cache_control(content)
|
||||
|
||||
assert cached == []
|
||||
|
||||
def test_process_system_message(self, strategy: ContextCachingStrategy):
|
||||
message = SystemMessage(content="You are a helpful assistant")
|
||||
processed = strategy._process_system_message(message)
|
||||
|
||||
assert isinstance(processed, SystemMessage)
|
||||
assert isinstance(processed.content, list)
|
||||
assert processed.content[0]["cache_control"]["type"] == "ephemeral"
|
||||
|
||||
def test_apply_caching_empty_messages(self, strategy: ContextCachingStrategy):
|
||||
messages: list = []
|
||||
cached, result = strategy.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is False
|
||||
assert cached == []
|
||||
|
||||
def test_apply_caching_no_system_message(self, strategy: ContextCachingStrategy):
|
||||
messages = [
|
||||
HumanMessage(content="Hello"),
|
||||
AIMessage(content="Hi there!"),
|
||||
]
|
||||
cached, result = strategy.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is False
|
||||
|
||||
def test_apply_caching_small_system_message(self, strategy: ContextCachingStrategy):
|
||||
messages = [
|
||||
SystemMessage(content="Be helpful"),
|
||||
HumanMessage(content="Hello"),
|
||||
]
|
||||
cached, result = strategy.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is False
|
||||
|
||||
def test_apply_caching_large_system_message(
|
||||
self, strategy_low_threshold: ContextCachingStrategy
|
||||
):
|
||||
large_prompt = "You are a helpful assistant. " * 50
|
||||
messages = [
|
||||
SystemMessage(content=large_prompt),
|
||||
HumanMessage(content="Hello"),
|
||||
]
|
||||
cached, result = strategy_low_threshold.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is True
|
||||
assert result.cached_content_type == "system_prompt"
|
||||
assert result.estimated_tokens_cached > 0
|
||||
|
||||
def test_apply_caching_preserves_message_order(
|
||||
self, strategy_low_threshold: ContextCachingStrategy
|
||||
):
|
||||
large_prompt = "System prompt " * 100
|
||||
messages = [
|
||||
SystemMessage(content=large_prompt),
|
||||
HumanMessage(content="Question 1"),
|
||||
AIMessage(content="Answer 1"),
|
||||
HumanMessage(content="Question 2"),
|
||||
]
|
||||
cached, _ = strategy_low_threshold.apply_caching(messages)
|
||||
|
||||
assert len(cached) == 4
|
||||
assert isinstance(cached[0], SystemMessage)
|
||||
assert isinstance(cached[1], HumanMessage)
|
||||
assert isinstance(cached[2], AIMessage)
|
||||
assert isinstance(cached[3], HumanMessage)
|
||||
|
||||
def test_custom_cache_control_type(self):
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(
|
||||
cache_control_type="permanent",
|
||||
min_cacheable_tokens=10,
|
||||
)
|
||||
)
|
||||
content = "test content"
|
||||
cached = strategy._add_cache_control(content)
|
||||
|
||||
assert cached["cache_control"]["type"] == "permanent"
|
||||
|
||||
def test_disabled_system_prompt_caching(self):
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(
|
||||
enable_for_system_prompt=False,
|
||||
min_cacheable_tokens=10,
|
||||
)
|
||||
)
|
||||
large_prompt = "System prompt " * 100
|
||||
messages = [
|
||||
SystemMessage(content=large_prompt),
|
||||
HumanMessage(content="Hello"),
|
||||
]
|
||||
cached, result = strategy.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is False
|
||||
|
||||
|
||||
class TestProviderDetection:
|
||||
def test_detect_anthropic_from_class_name(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatAnthropic"
|
||||
mock_model.__class__.__module__ = "langchain_anthropic.chat_models"
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.ANTHROPIC
|
||||
|
||||
def test_detect_openai_from_class_name(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatOpenAI"
|
||||
mock_model.__class__.__module__ = "langchain_openai.chat_models"
|
||||
mock_model.openai_api_base = None
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.OPENAI
|
||||
|
||||
def test_detect_gemini_from_class_name(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatGoogleGenerativeAI"
|
||||
mock_model.__class__.__module__ = "langchain_google_genai"
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.GEMINI
|
||||
|
||||
def test_detect_openrouter_from_base_url(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatOpenAI"
|
||||
mock_model.__class__.__module__ = "langchain_openai"
|
||||
mock_model.openai_api_base = "https://openrouter.ai/api/v1"
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.OPENROUTER
|
||||
|
||||
def test_detect_unknown_provider(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "CustomModel"
|
||||
mock_model.__class__.__module__ = "custom_module"
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.UNKNOWN
|
||||
|
||||
def test_detect_none_model(self):
|
||||
assert detect_provider(None) == ProviderType.UNKNOWN
|
||||
|
||||
def test_requires_cache_control_marker_anthropic(self):
|
||||
assert requires_cache_control_marker(ProviderType.ANTHROPIC) is True
|
||||
|
||||
def test_requires_cache_control_marker_openai(self):
|
||||
assert requires_cache_control_marker(ProviderType.OPENAI) is False
|
||||
|
||||
def test_requires_cache_control_marker_gemini(self):
|
||||
assert requires_cache_control_marker(ProviderType.GEMINI) is False
|
||||
|
||||
|
||||
class TestContextCachingStrategyMultiProvider:
|
||||
def test_anthropic_applies_cache_markers(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatAnthropic"
|
||||
mock_model.__class__.__module__ = "langchain_anthropic"
|
||||
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
model=mock_model,
|
||||
)
|
||||
large_prompt = "System prompt " * 100
|
||||
messages = [SystemMessage(content=large_prompt)]
|
||||
|
||||
cached, result = strategy.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is True
|
||||
assert result.cached_content_type == "system_prompt"
|
||||
|
||||
def test_openai_skips_cache_markers(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatOpenAI"
|
||||
mock_model.__class__.__module__ = "langchain_openai"
|
||||
mock_model.openai_api_base = None
|
||||
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
model=mock_model,
|
||||
)
|
||||
large_prompt = "System prompt " * 100
|
||||
messages = [SystemMessage(content=large_prompt)]
|
||||
|
||||
cached, result = strategy.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is False
|
||||
assert result.cached_content_type == "auto_cached_by_openai"
|
||||
|
||||
def test_gemini_skips_cache_markers(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatGoogleGenerativeAI"
|
||||
mock_model.__class__.__module__ = "langchain_google_genai"
|
||||
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
model=mock_model,
|
||||
)
|
||||
large_prompt = "System prompt " * 100
|
||||
messages = [SystemMessage(content=large_prompt)]
|
||||
|
||||
cached, result = strategy.apply_caching(messages)
|
||||
|
||||
assert result.was_cached is False
|
||||
assert result.cached_content_type == "auto_cached_by_gemini"
|
||||
|
||||
def test_set_model_updates_provider(self):
|
||||
strategy = ContextCachingStrategy()
|
||||
|
||||
mock_anthropic = MagicMock()
|
||||
mock_anthropic.__class__.__name__ = "ChatAnthropic"
|
||||
mock_anthropic.__class__.__module__ = "langchain_anthropic"
|
||||
|
||||
strategy.set_model(mock_anthropic)
|
||||
assert strategy.provider == ProviderType.ANTHROPIC
|
||||
|
||||
mock_openai = MagicMock()
|
||||
mock_openai.__class__.__name__ = "ChatOpenAI"
|
||||
mock_openai.__class__.__module__ = "langchain_openai"
|
||||
mock_openai.openai_api_base = None
|
||||
|
||||
strategy.set_model(mock_openai)
|
||||
assert strategy.provider == ProviderType.OPENAI
|
||||
|
||||
|
||||
class TestCacheTelemetry:
|
||||
def test_default_values(self):
|
||||
telemetry = CacheTelemetry(provider=ProviderType.OPENAI)
|
||||
|
||||
assert telemetry.cache_read_tokens == 0
|
||||
assert telemetry.cache_write_tokens == 0
|
||||
assert telemetry.cache_hit_ratio == 0.0
|
||||
|
||||
def test_with_values(self):
|
||||
telemetry = CacheTelemetry(
|
||||
provider=ProviderType.ANTHROPIC,
|
||||
cache_read_tokens=1000,
|
||||
cache_write_tokens=500,
|
||||
total_input_tokens=2000,
|
||||
cache_hit_ratio=0.5,
|
||||
)
|
||||
|
||||
assert telemetry.cache_read_tokens == 1000
|
||||
assert telemetry.cache_write_tokens == 500
|
||||
assert telemetry.cache_hit_ratio == 0.5
|
||||
|
||||
|
||||
class TestCacheMetricsExtraction:
|
||||
def test_extract_anthropic_metrics(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.usage_metadata = {"input_tokens": 1000}
|
||||
mock_response.response_metadata = {
|
||||
"usage": {
|
||||
"input_tokens": 1000,
|
||||
"cache_read_input_tokens": 800,
|
||||
"cache_creation_input_tokens": 200,
|
||||
}
|
||||
}
|
||||
|
||||
telemetry = extract_anthropic_cache_metrics(mock_response)
|
||||
|
||||
assert telemetry.provider == ProviderType.ANTHROPIC
|
||||
assert telemetry.cache_read_tokens == 800
|
||||
assert telemetry.cache_write_tokens == 200
|
||||
assert telemetry.cache_hit_ratio == 0.8
|
||||
|
||||
def test_extract_openai_metrics(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.usage_metadata = {"input_tokens": 1000}
|
||||
mock_response.response_metadata = {
|
||||
"token_usage": {
|
||||
"prompt_tokens": 1000,
|
||||
"prompt_tokens_details": {"cached_tokens": 500},
|
||||
}
|
||||
}
|
||||
|
||||
telemetry = extract_openai_cache_metrics(mock_response)
|
||||
|
||||
assert telemetry.provider == ProviderType.OPENAI
|
||||
assert telemetry.cache_read_tokens == 500
|
||||
assert telemetry.cache_hit_ratio == 0.5
|
||||
|
||||
def test_extract_gemini_metrics(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.usage_metadata = {"input_tokens": 1000}
|
||||
mock_response.response_metadata = {
|
||||
"prompt_token_count": 1000,
|
||||
"cached_content_token_count": 750,
|
||||
}
|
||||
|
||||
telemetry = extract_gemini_cache_metrics(mock_response)
|
||||
|
||||
assert telemetry.provider == ProviderType.GEMINI
|
||||
assert telemetry.cache_read_tokens == 750
|
||||
assert telemetry.cache_hit_ratio == 0.75
|
||||
|
||||
def test_extract_cache_telemetry_unknown_provider(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.usage_metadata = {}
|
||||
mock_response.response_metadata = {}
|
||||
|
||||
telemetry = extract_cache_telemetry(mock_response, ProviderType.UNKNOWN)
|
||||
|
||||
assert telemetry.provider == ProviderType.UNKNOWN
|
||||
assert telemetry.cache_read_tokens == 0
|
||||
|
||||
|
||||
class TestPromptCachingTelemetryMiddleware:
|
||||
def test_initialization(self):
|
||||
middleware = PromptCachingTelemetryMiddleware()
|
||||
|
||||
assert middleware.telemetry_history == []
|
||||
|
||||
def test_get_aggregate_stats_empty(self):
|
||||
middleware = PromptCachingTelemetryMiddleware()
|
||||
stats = middleware.get_aggregate_stats()
|
||||
|
||||
assert stats["total_calls"] == 0
|
||||
assert stats["total_cache_read_tokens"] == 0
|
||||
|
||||
def test_get_aggregate_stats_with_data(self):
|
||||
middleware = PromptCachingTelemetryMiddleware()
|
||||
middleware._telemetry_history = [
|
||||
CacheTelemetry(
|
||||
provider=ProviderType.ANTHROPIC,
|
||||
cache_read_tokens=800,
|
||||
cache_write_tokens=200,
|
||||
total_input_tokens=1000,
|
||||
),
|
||||
CacheTelemetry(
|
||||
provider=ProviderType.ANTHROPIC,
|
||||
cache_read_tokens=900,
|
||||
cache_write_tokens=100,
|
||||
total_input_tokens=1000,
|
||||
),
|
||||
]
|
||||
|
||||
stats = middleware.get_aggregate_stats()
|
||||
|
||||
assert stats["total_calls"] == 2
|
||||
assert stats["total_cache_read_tokens"] == 1700
|
||||
assert stats["total_cache_write_tokens"] == 300
|
||||
assert stats["total_input_tokens"] == 2000
|
||||
assert stats["overall_cache_hit_ratio"] == 0.85
|
||||
|
||||
def test_wrap_model_call_collects_telemetry(self):
|
||||
middleware = PromptCachingTelemetryMiddleware()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.response_metadata = {
|
||||
"model": "claude-3-sonnet",
|
||||
"usage": {"cache_read_input_tokens": 500},
|
||||
}
|
||||
mock_response.usage_metadata = {"input_tokens": 1000}
|
||||
|
||||
def mock_handler(request):
|
||||
return mock_response
|
||||
|
||||
mock_request = MagicMock()
|
||||
result = middleware.wrap_model_call(mock_request, mock_handler)
|
||||
|
||||
assert result == mock_response
|
||||
assert len(middleware.telemetry_history) == 1
|
||||
|
||||
|
||||
class TestOpenRouterSubProvider:
|
||||
def test_detect_anthropic_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("anthropic/claude-3-sonnet")
|
||||
== OpenRouterSubProvider.ANTHROPIC
|
||||
)
|
||||
assert (
|
||||
detect_openrouter_sub_provider("anthropic/claude-3.5-sonnet")
|
||||
== OpenRouterSubProvider.ANTHROPIC
|
||||
)
|
||||
|
||||
def test_detect_openai_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("openai/gpt-4o")
|
||||
== OpenRouterSubProvider.OPENAI
|
||||
)
|
||||
assert (
|
||||
detect_openrouter_sub_provider("openai/o1-preview")
|
||||
== OpenRouterSubProvider.OPENAI
|
||||
)
|
||||
|
||||
def test_detect_gemini_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("google/gemini-2.5-pro")
|
||||
== OpenRouterSubProvider.GEMINI
|
||||
)
|
||||
assert (
|
||||
detect_openrouter_sub_provider("google/gemini-3-flash")
|
||||
== OpenRouterSubProvider.GEMINI
|
||||
)
|
||||
|
||||
def test_detect_deepseek_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("deepseek/deepseek-chat")
|
||||
== OpenRouterSubProvider.DEEPSEEK
|
||||
)
|
||||
|
||||
def test_detect_groq_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("groq/kimi-k2") == OpenRouterSubProvider.GROQ
|
||||
)
|
||||
|
||||
def test_detect_grok_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("xai/grok-2") == OpenRouterSubProvider.GROK
|
||||
)
|
||||
|
||||
def test_detect_llama_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("meta-llama/llama-3.3-70b")
|
||||
== OpenRouterSubProvider.META_LLAMA
|
||||
)
|
||||
|
||||
def test_detect_mistral_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("mistral/mistral-large")
|
||||
== OpenRouterSubProvider.MISTRAL
|
||||
)
|
||||
|
||||
def test_detect_unknown_via_openrouter(self):
|
||||
assert (
|
||||
detect_openrouter_sub_provider("some-provider/some-model")
|
||||
== OpenRouterSubProvider.UNKNOWN
|
||||
)
|
||||
|
||||
|
||||
class TestOpenRouterCachingStrategy:
|
||||
def test_openrouter_anthropic_applies_cache_markers(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatOpenAI"
|
||||
mock_model.__class__.__module__ = "langchain_openai"
|
||||
mock_model.openai_api_base = "https://openrouter.ai/api/v1"
|
||||
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
model=mock_model,
|
||||
openrouter_model_name="anthropic/claude-3-sonnet",
|
||||
)
|
||||
|
||||
assert strategy.provider == ProviderType.OPENROUTER
|
||||
assert strategy.sub_provider == OpenRouterSubProvider.ANTHROPIC
|
||||
assert strategy.should_apply_cache_markers is True
|
||||
|
||||
def test_openrouter_openai_skips_cache_markers(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatOpenAI"
|
||||
mock_model.__class__.__module__ = "langchain_openai"
|
||||
mock_model.openai_api_base = "https://openrouter.ai/api/v1"
|
||||
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
model=mock_model,
|
||||
openrouter_model_name="openai/gpt-4o",
|
||||
)
|
||||
|
||||
assert strategy.provider == ProviderType.OPENROUTER
|
||||
assert strategy.sub_provider == OpenRouterSubProvider.OPENAI
|
||||
assert strategy.should_apply_cache_markers is False
|
||||
|
||||
def test_openrouter_deepseek_skips_cache_markers(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatOpenAI"
|
||||
mock_model.__class__.__module__ = "langchain_openai"
|
||||
mock_model.openai_api_base = "https://openrouter.ai/api/v1"
|
||||
|
||||
strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
model=mock_model,
|
||||
openrouter_model_name="deepseek/deepseek-chat",
|
||||
)
|
||||
|
||||
assert strategy.provider == ProviderType.OPENROUTER
|
||||
assert strategy.sub_provider == OpenRouterSubProvider.DEEPSEEK
|
||||
assert strategy.should_apply_cache_markers is False
|
||||
|
||||
|
||||
class TestGemini3Detection:
|
||||
def test_detect_gemini_3_from_model_name(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatGoogleGenerativeAI"
|
||||
mock_model.__class__.__module__ = "langchain_google_genai"
|
||||
mock_model.model_name = "gemini-3-pro-preview"
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.GEMINI_3
|
||||
|
||||
def test_detect_gemini_3_flash(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatGoogleGenerativeAI"
|
||||
mock_model.__class__.__module__ = "langchain_google_genai"
|
||||
mock_model.model_name = "gemini-3-flash-preview"
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.GEMINI_3
|
||||
|
||||
def test_detect_gemini_25_not_gemini_3(self):
|
||||
mock_model = MagicMock()
|
||||
mock_model.__class__.__name__ = "ChatGoogleGenerativeAI"
|
||||
mock_model.__class__.__module__ = "langchain_google_genai"
|
||||
mock_model.model_name = "gemini-2.5-pro"
|
||||
|
||||
assert detect_provider(mock_model) == ProviderType.GEMINI
|
||||
|
||||
|
||||
class TestDeepSeekMetrics:
|
||||
def test_extract_deepseek_metrics(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.usage_metadata = {"input_tokens": 1000}
|
||||
mock_response.response_metadata = {
|
||||
"cache_hit_tokens": 700,
|
||||
"cache_miss_tokens": 300,
|
||||
}
|
||||
|
||||
telemetry = extract_deepseek_cache_metrics(mock_response)
|
||||
|
||||
assert telemetry.provider == ProviderType.DEEPSEEK
|
||||
assert telemetry.cache_read_tokens == 700
|
||||
assert telemetry.cache_write_tokens == 300
|
||||
assert telemetry.cache_hit_ratio == 0.7
|
||||
156
tests/context_engineering/test_integration.py
Normal file
156
tests/context_engineering/test_integration.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from context_engineering_research_agent import (
|
||||
ContextCachingStrategy,
|
||||
ContextIsolationStrategy,
|
||||
ContextOffloadingStrategy,
|
||||
ContextReductionStrategy,
|
||||
ContextRetrievalStrategy,
|
||||
create_context_aware_agent,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.caching import CachingConfig
|
||||
from context_engineering_research_agent.context_strategies.offloading import (
|
||||
OffloadingConfig,
|
||||
)
|
||||
from context_engineering_research_agent.context_strategies.reduction import (
|
||||
ReductionConfig,
|
||||
)
|
||||
|
||||
SKIP_OPENAI = not os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
|
||||
class TestModuleExports:
|
||||
def test_exports_all_strategies(self):
|
||||
assert ContextOffloadingStrategy is not None
|
||||
assert ContextReductionStrategy is not None
|
||||
assert ContextRetrievalStrategy is not None
|
||||
assert ContextIsolationStrategy is not None
|
||||
assert ContextCachingStrategy is not None
|
||||
|
||||
def test_exports_create_context_aware_agent(self):
|
||||
assert create_context_aware_agent is not None
|
||||
assert callable(create_context_aware_agent)
|
||||
|
||||
|
||||
class TestStrategyInstantiation:
|
||||
def test_offloading_strategy_instantiation(self):
|
||||
strategy = ContextOffloadingStrategy()
|
||||
assert strategy.config.token_limit_before_evict == 20000
|
||||
|
||||
custom_strategy = ContextOffloadingStrategy(
|
||||
config=OffloadingConfig(token_limit_before_evict=10000)
|
||||
)
|
||||
assert custom_strategy.config.token_limit_before_evict == 10000
|
||||
|
||||
def test_reduction_strategy_instantiation(self):
|
||||
strategy = ContextReductionStrategy()
|
||||
assert strategy.config.context_threshold == 0.85
|
||||
|
||||
custom_strategy = ContextReductionStrategy(
|
||||
config=ReductionConfig(context_threshold=0.9)
|
||||
)
|
||||
assert custom_strategy.config.context_threshold == 0.9
|
||||
|
||||
def test_retrieval_strategy_instantiation(self):
|
||||
strategy = ContextRetrievalStrategy()
|
||||
assert len(strategy.tools) == 3
|
||||
|
||||
def test_isolation_strategy_instantiation(self):
|
||||
strategy = ContextIsolationStrategy()
|
||||
assert len(strategy.tools) == 1
|
||||
|
||||
def test_caching_strategy_instantiation(self):
|
||||
strategy = ContextCachingStrategy()
|
||||
assert strategy.config.min_cacheable_tokens == 1024
|
||||
|
||||
custom_strategy = ContextCachingStrategy(
|
||||
config=CachingConfig(min_cacheable_tokens=2048)
|
||||
)
|
||||
assert custom_strategy.config.min_cacheable_tokens == 2048
|
||||
|
||||
|
||||
@pytest.mark.skipif(SKIP_OPENAI, reason="OPENAI_API_KEY not set")
|
||||
class TestCreateContextAwareAgent:
|
||||
def test_create_agent_default_settings(self):
|
||||
agent = create_context_aware_agent(model_name="gpt-4.1")
|
||||
|
||||
assert agent is not None
|
||||
|
||||
def test_create_agent_all_disabled(self):
|
||||
agent = create_context_aware_agent(
|
||||
model_name="gpt-4.1",
|
||||
enable_offloading=False,
|
||||
enable_reduction=False,
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
def test_create_agent_all_enabled(self):
|
||||
agent = create_context_aware_agent(
|
||||
model_name="gpt-4.1",
|
||||
enable_offloading=True,
|
||||
enable_reduction=True,
|
||||
enable_caching=True,
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
def test_create_agent_custom_thresholds(self):
|
||||
agent = create_context_aware_agent(
|
||||
model_name="gpt-4.1",
|
||||
enable_offloading=True,
|
||||
enable_reduction=True,
|
||||
offloading_token_limit=10000,
|
||||
reduction_threshold=0.9,
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
def test_create_agent_offloading_only(self):
|
||||
agent = create_context_aware_agent(
|
||||
model_name="gpt-4.1",
|
||||
enable_offloading=True,
|
||||
enable_reduction=False,
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
def test_create_agent_reduction_only(self):
|
||||
agent = create_context_aware_agent(
|
||||
model_name="gpt-4.1",
|
||||
enable_offloading=False,
|
||||
enable_reduction=True,
|
||||
enable_caching=False,
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
|
||||
class TestStrategyCombinations:
|
||||
def test_offloading_with_reduction(self):
|
||||
offloading = ContextOffloadingStrategy(
|
||||
config=OffloadingConfig(token_limit_before_evict=15000)
|
||||
)
|
||||
reduction = ContextReductionStrategy(
|
||||
config=ReductionConfig(context_threshold=0.8)
|
||||
)
|
||||
|
||||
assert offloading.config.token_limit_before_evict == 15000
|
||||
assert reduction.config.context_threshold == 0.8
|
||||
|
||||
def test_all_strategies_together(self):
|
||||
offloading = ContextOffloadingStrategy()
|
||||
reduction = ContextReductionStrategy()
|
||||
retrieval = ContextRetrievalStrategy()
|
||||
isolation = ContextIsolationStrategy()
|
||||
caching = ContextCachingStrategy()
|
||||
|
||||
strategies = [offloading, reduction, retrieval, isolation, caching]
|
||||
|
||||
assert len(strategies) == 5
|
||||
for strategy in strategies:
|
||||
assert hasattr(strategy, "config")
|
||||
167
tests/context_engineering/test_isolation.py
Normal file
167
tests/context_engineering/test_isolation.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import pytest
|
||||
|
||||
from context_engineering_research_agent.context_strategies.isolation import (
|
||||
ContextIsolationStrategy,
|
||||
IsolationConfig,
|
||||
IsolationResult,
|
||||
)
|
||||
|
||||
|
||||
class TestIsolationConfig:
|
||||
def test_default_values(self):
|
||||
config = IsolationConfig()
|
||||
|
||||
assert config.default_model == "gpt-4.1"
|
||||
assert config.include_general_purpose_agent is True
|
||||
assert config.excluded_state_keys == (
|
||||
"messages",
|
||||
"todos",
|
||||
"structured_response",
|
||||
)
|
||||
|
||||
def test_custom_values(self):
|
||||
config = IsolationConfig(
|
||||
default_model="claude-3-5-sonnet",
|
||||
include_general_purpose_agent=False,
|
||||
excluded_state_keys=("messages", "custom_key"),
|
||||
)
|
||||
|
||||
assert config.default_model == "claude-3-5-sonnet"
|
||||
assert config.include_general_purpose_agent is False
|
||||
assert config.excluded_state_keys == ("messages", "custom_key")
|
||||
|
||||
|
||||
class TestIsolationResult:
|
||||
def test_successful_result(self):
|
||||
result = IsolationResult(
|
||||
subagent_name="researcher",
|
||||
was_successful=True,
|
||||
result_length=500,
|
||||
)
|
||||
|
||||
assert result.subagent_name == "researcher"
|
||||
assert result.was_successful is True
|
||||
assert result.result_length == 500
|
||||
assert result.error is None
|
||||
|
||||
def test_failed_result(self):
|
||||
result = IsolationResult(
|
||||
subagent_name="researcher",
|
||||
was_successful=False,
|
||||
result_length=0,
|
||||
error="SubAgent not found",
|
||||
)
|
||||
|
||||
assert result.was_successful is False
|
||||
assert result.error == "SubAgent not found"
|
||||
|
||||
|
||||
class TestContextIsolationStrategy:
|
||||
@pytest.fixture
|
||||
def strategy(self) -> ContextIsolationStrategy:
|
||||
return ContextIsolationStrategy()
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_with_subagents(self) -> ContextIsolationStrategy:
|
||||
subagents = [
|
||||
{
|
||||
"name": "researcher",
|
||||
"description": "Research agent",
|
||||
"system_prompt": "You are a researcher",
|
||||
"tools": [],
|
||||
},
|
||||
{
|
||||
"name": "coder",
|
||||
"description": "Coding agent",
|
||||
"system_prompt": "You are a coder",
|
||||
"tools": [],
|
||||
},
|
||||
]
|
||||
return ContextIsolationStrategy(subagents=subagents) # type: ignore
|
||||
|
||||
def test_initialization(self, strategy: ContextIsolationStrategy):
|
||||
assert strategy.config is not None
|
||||
assert len(strategy.tools) == 1
|
||||
|
||||
def test_creates_task_tool(self, strategy: ContextIsolationStrategy):
|
||||
assert strategy.tools[0].name == "task"
|
||||
|
||||
def test_task_tool_description(
|
||||
self, strategy_with_subagents: ContextIsolationStrategy
|
||||
):
|
||||
task_tool = strategy_with_subagents.tools[0]
|
||||
|
||||
assert "researcher" in task_tool.description
|
||||
assert "coder" in task_tool.description
|
||||
assert "SubAgent" in task_tool.description
|
||||
|
||||
def test_get_subagent_descriptions(
|
||||
self, strategy_with_subagents: ContextIsolationStrategy
|
||||
):
|
||||
descriptions = strategy_with_subagents._get_subagent_descriptions()
|
||||
|
||||
assert "researcher" in descriptions
|
||||
assert "Research agent" in descriptions
|
||||
assert "coder" in descriptions
|
||||
assert "Coding agent" in descriptions
|
||||
|
||||
def test_prepare_subagent_state(self, strategy: ContextIsolationStrategy):
|
||||
state = {
|
||||
"messages": [{"role": "user", "content": "old message"}],
|
||||
"todos": ["task1", "task2"],
|
||||
"structured_response": {"key": "value"},
|
||||
"files": {"path": "/test"},
|
||||
}
|
||||
|
||||
prepared = strategy._prepare_subagent_state(state, "New task description")
|
||||
|
||||
assert "todos" not in prepared
|
||||
assert "structured_response" not in prepared
|
||||
assert "files" in prepared
|
||||
assert len(prepared["messages"]) == 1
|
||||
assert prepared["messages"][0].content == "New task description"
|
||||
|
||||
def test_prepare_subagent_state_custom_excluded_keys(self):
|
||||
strategy = ContextIsolationStrategy(
|
||||
config=IsolationConfig(excluded_state_keys=("messages", "custom_exclude"))
|
||||
)
|
||||
state = {
|
||||
"messages": [{"role": "user", "content": "old"}],
|
||||
"custom_exclude": "should be excluded",
|
||||
"keep_this": "should be kept",
|
||||
}
|
||||
|
||||
prepared = strategy._prepare_subagent_state(state, "New task")
|
||||
|
||||
assert "custom_exclude" not in prepared
|
||||
assert "keep_this" in prepared
|
||||
|
||||
def test_compile_subagents_empty(self, strategy: ContextIsolationStrategy):
|
||||
agents = strategy._compile_subagents()
|
||||
|
||||
assert agents == {}
|
||||
|
||||
def test_compile_subagents_caches_result(
|
||||
self, strategy_with_subagents: ContextIsolationStrategy
|
||||
):
|
||||
strategy_with_subagents._compiled_agents = {"cached": "value"} # type: ignore
|
||||
|
||||
agents = strategy_with_subagents._compile_subagents()
|
||||
|
||||
assert agents == {"cached": "value"}
|
||||
|
||||
def test_task_without_subagents(self, strategy: ContextIsolationStrategy):
|
||||
task_tool = strategy.tools[0]
|
||||
|
||||
class MockRuntime:
|
||||
state = {}
|
||||
config = {}
|
||||
tool_call_id = "call_123"
|
||||
|
||||
result = task_tool.func( # type: ignore
|
||||
description="Test task",
|
||||
subagent_type="researcher",
|
||||
runtime=MockRuntime(),
|
||||
)
|
||||
|
||||
assert "존재하지 않습니다" in str(result)
|
||||
184
tests/context_engineering/test_offloading.py
Normal file
184
tests/context_engineering/test_offloading.py
Normal file
@@ -0,0 +1,184 @@
|
||||
import pytest
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
from context_engineering_research_agent.context_strategies.offloading import (
|
||||
ContextOffloadingStrategy,
|
||||
OffloadingConfig,
|
||||
OffloadingResult,
|
||||
)
|
||||
|
||||
|
||||
class TestOffloadingConfig:
|
||||
def test_default_values(self):
|
||||
config = OffloadingConfig()
|
||||
|
||||
assert config.token_limit_before_evict == 20000
|
||||
assert config.eviction_path_prefix == "/large_tool_results"
|
||||
assert config.preview_lines == 10
|
||||
assert config.chars_per_token == 4
|
||||
|
||||
def test_custom_values(self):
|
||||
config = OffloadingConfig(
|
||||
token_limit_before_evict=15000,
|
||||
eviction_path_prefix="/custom_path",
|
||||
preview_lines=5,
|
||||
chars_per_token=3,
|
||||
)
|
||||
|
||||
assert config.token_limit_before_evict == 15000
|
||||
assert config.eviction_path_prefix == "/custom_path"
|
||||
assert config.preview_lines == 5
|
||||
assert config.chars_per_token == 3
|
||||
|
||||
|
||||
class TestOffloadingResult:
|
||||
def test_not_offloaded(self):
|
||||
result = OffloadingResult(was_offloaded=False, original_size=100)
|
||||
|
||||
assert result.was_offloaded is False
|
||||
assert result.original_size == 100
|
||||
assert result.file_path is None
|
||||
assert result.preview is None
|
||||
|
||||
def test_offloaded(self):
|
||||
result = OffloadingResult(
|
||||
was_offloaded=True,
|
||||
original_size=100000,
|
||||
file_path="/large_tool_results/call_123",
|
||||
preview="first 10 lines...",
|
||||
)
|
||||
|
||||
assert result.was_offloaded is True
|
||||
assert result.original_size == 100000
|
||||
assert result.file_path == "/large_tool_results/call_123"
|
||||
assert result.preview == "first 10 lines..."
|
||||
|
||||
|
||||
class TestContextOffloadingStrategy:
|
||||
@pytest.fixture
|
||||
def strategy(self) -> ContextOffloadingStrategy:
|
||||
return ContextOffloadingStrategy()
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_low_threshold(self) -> ContextOffloadingStrategy:
|
||||
return ContextOffloadingStrategy(
|
||||
config=OffloadingConfig(token_limit_before_evict=100)
|
||||
)
|
||||
|
||||
def test_estimate_tokens(self, strategy: ContextOffloadingStrategy):
|
||||
content = "a" * 400
|
||||
estimated = strategy._estimate_tokens(content)
|
||||
|
||||
assert estimated == 100
|
||||
|
||||
def test_estimate_tokens_with_custom_chars_per_token(self):
|
||||
strategy = ContextOffloadingStrategy(config=OffloadingConfig(chars_per_token=2))
|
||||
content = "a" * 400
|
||||
estimated = strategy._estimate_tokens(content)
|
||||
|
||||
assert estimated == 200
|
||||
|
||||
def test_should_offload_small_content(self, strategy: ContextOffloadingStrategy):
|
||||
small_content = "short text" * 100
|
||||
|
||||
assert strategy._should_offload(small_content) is False
|
||||
|
||||
def test_should_offload_large_content(self, strategy: ContextOffloadingStrategy):
|
||||
large_content = "x" * 100000
|
||||
|
||||
assert strategy._should_offload(large_content) is True
|
||||
|
||||
def test_should_offload_boundary(self):
|
||||
config = OffloadingConfig(token_limit_before_evict=100, chars_per_token=4)
|
||||
strategy = ContextOffloadingStrategy(config=config)
|
||||
|
||||
exactly_at_limit = "x" * 400
|
||||
just_over_limit = "x" * 404
|
||||
|
||||
assert strategy._should_offload(exactly_at_limit) is False
|
||||
assert strategy._should_offload(just_over_limit) is True
|
||||
|
||||
def test_create_preview_short_content(self, strategy: ContextOffloadingStrategy):
|
||||
content = "line1\nline2\nline3"
|
||||
preview = strategy._create_preview(content)
|
||||
|
||||
assert "line1" in preview
|
||||
assert "line2" in preview
|
||||
assert "line3" in preview
|
||||
|
||||
def test_create_preview_long_content(self, strategy: ContextOffloadingStrategy):
|
||||
lines = [f"line_{i}" for i in range(100)]
|
||||
content = "\n".join(lines)
|
||||
preview = strategy._create_preview(content)
|
||||
|
||||
assert "line_0" in preview
|
||||
assert "line_9" in preview
|
||||
assert "line_10" not in preview
|
||||
|
||||
def test_create_preview_with_custom_lines(self):
|
||||
strategy = ContextOffloadingStrategy(config=OffloadingConfig(preview_lines=3))
|
||||
lines = [f"line_{i}" for i in range(10)]
|
||||
content = "\n".join(lines)
|
||||
preview = strategy._create_preview(content)
|
||||
|
||||
assert "line_0" in preview
|
||||
assert "line_2" in preview
|
||||
assert "line_3" not in preview
|
||||
|
||||
def test_create_preview_truncates_long_lines(
|
||||
self, strategy: ContextOffloadingStrategy
|
||||
):
|
||||
long_line = "x" * 2000
|
||||
preview = strategy._create_preview(long_line)
|
||||
|
||||
assert len(preview.split("\t")[1]) == 1000
|
||||
|
||||
def test_create_offload_message(self, strategy: ContextOffloadingStrategy):
|
||||
message = strategy._create_offload_message(
|
||||
tool_call_id="call_123",
|
||||
file_path="/large_tool_results/call_123",
|
||||
preview="preview content",
|
||||
)
|
||||
|
||||
assert "/large_tool_results/call_123" in message
|
||||
assert "preview content" in message
|
||||
assert "read_file" in message
|
||||
|
||||
def test_sanitize_tool_call_id(self, strategy: ContextOffloadingStrategy):
|
||||
normal_id = "call_abc123"
|
||||
special_id = "call/with:special@chars!"
|
||||
|
||||
assert strategy._sanitize_tool_call_id(normal_id) == "call_abc123"
|
||||
assert strategy._sanitize_tool_call_id(special_id) == "call_with_special_chars_"
|
||||
|
||||
def test_process_tool_result_small_content(
|
||||
self, strategy: ContextOffloadingStrategy
|
||||
):
|
||||
tool_result = ToolMessage(content="small content", tool_call_id="call_123")
|
||||
|
||||
class MockRuntime:
|
||||
state = {}
|
||||
config = {}
|
||||
|
||||
processed, result = strategy.process_tool_result(tool_result, MockRuntime()) # type: ignore
|
||||
|
||||
assert result.was_offloaded is False
|
||||
assert processed.content == "small content"
|
||||
|
||||
def test_process_tool_result_no_backend(
|
||||
self, strategy_low_threshold: ContextOffloadingStrategy
|
||||
):
|
||||
large_content = "x" * 1000
|
||||
tool_result = ToolMessage(content=large_content, tool_call_id="call_123")
|
||||
|
||||
class MockRuntime:
|
||||
state = {}
|
||||
config = {}
|
||||
|
||||
processed, result = strategy_low_threshold.process_tool_result(
|
||||
tool_result,
|
||||
MockRuntime(), # type: ignore
|
||||
)
|
||||
|
||||
assert result.was_offloaded is False
|
||||
assert processed.content == large_content
|
||||
187
tests/context_engineering/test_openrouter_models.py
Normal file
187
tests/context_engineering/test_openrouter_models.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""OpenRouter 15개 모델 통합 테스트.
|
||||
|
||||
실제 OpenRouter API를 통해 다양한 모델의 캐싱 전략과 provider 감지를 검증합니다.
|
||||
OPENROUTER_API_KEY 환경 변수가 필요합니다.
|
||||
|
||||
테스트 대상 모델 (Anthropic/OpenAI/Google 제외):
|
||||
- deepseek/deepseek-v3.2
|
||||
- x-ai/grok-*
|
||||
- xiaomi/mimo-v2-flash
|
||||
- minimax/minimax-m2.1
|
||||
- bytedance-seed/seed-1.6
|
||||
- z-ai/glm-4.7
|
||||
- allenai/olmo-3.1-32b-instruct
|
||||
- mistralai/mistral-small-creative
|
||||
- nvidia/nemotron-3-nano-30b-a3b
|
||||
- qwen/qwen3-max, qwen3-coder-plus, qwen3-coder-flash, qwen3-vl-32b-instruct, qwen3-next-80b-a3b-thinking
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from context_engineering_research_agent.context_strategies.caching import (
|
||||
CachingConfig,
|
||||
ContextCachingStrategy,
|
||||
OpenRouterSubProvider,
|
||||
ProviderType,
|
||||
detect_openrouter_sub_provider,
|
||||
detect_provider,
|
||||
)
|
||||
|
||||
|
||||
def _openrouter_available() -> bool:
|
||||
"""OpenRouter API 키가 설정되어 있는지 확인합니다."""
|
||||
return bool(os.environ.get("OPENROUTER_API_KEY"))
|
||||
|
||||
|
||||
OPENROUTER_AVAILABLE = _openrouter_available()
|
||||
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
pytestmark = [
|
||||
pytest.mark.integration,
|
||||
pytest.mark.skipif(
|
||||
not OPENROUTER_AVAILABLE,
|
||||
reason="OPENROUTER_API_KEY 환경 변수가 설정되지 않았습니다.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
OPENROUTER_MODELS = [
|
||||
("deepseek/deepseek-chat-v3-0324", OpenRouterSubProvider.DEEPSEEK),
|
||||
("x-ai/grok-3-mini-beta", OpenRouterSubProvider.GROK),
|
||||
("qwen/qwen3-235b-a22b", OpenRouterSubProvider.UNKNOWN),
|
||||
("qwen/qwen3-32b", OpenRouterSubProvider.UNKNOWN),
|
||||
("mistralai/mistral-small-3.1-24b-instruct", OpenRouterSubProvider.MISTRAL),
|
||||
("meta-llama/llama-4-maverick", OpenRouterSubProvider.META_LLAMA),
|
||||
("meta-llama/llama-4-scout", OpenRouterSubProvider.META_LLAMA),
|
||||
("nvidia/llama-3.1-nemotron-70b-instruct", OpenRouterSubProvider.UNKNOWN),
|
||||
("microsoft/phi-4", OpenRouterSubProvider.UNKNOWN),
|
||||
("google/gemma-3-27b-it", OpenRouterSubProvider.UNKNOWN),
|
||||
("cohere/command-a", OpenRouterSubProvider.UNKNOWN),
|
||||
("perplexity/sonar-pro", OpenRouterSubProvider.UNKNOWN),
|
||||
("ai21/jamba-1.6-large", OpenRouterSubProvider.UNKNOWN),
|
||||
("inflection/inflection-3-pi", OpenRouterSubProvider.UNKNOWN),
|
||||
("amazon/nova-pro-v1", OpenRouterSubProvider.UNKNOWN),
|
||||
]
|
||||
|
||||
|
||||
def _create_openrouter_model(model_name: str) -> ChatOpenAI:
|
||||
"""OpenRouter 모델 인스턴스를 생성합니다."""
|
||||
return ChatOpenAI(
|
||||
model=model_name,
|
||||
openai_api_key=os.environ.get("OPENROUTER_API_KEY"),
|
||||
openai_api_base=OPENROUTER_BASE_URL,
|
||||
temperature=0.0,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenRouterProviderDetection:
|
||||
"""OpenRouter Provider 감지 테스트."""
|
||||
|
||||
@pytest.mark.parametrize("model_name,expected_sub", OPENROUTER_MODELS)
|
||||
def test_detect_provider_openrouter(
|
||||
self, model_name: str, expected_sub: OpenRouterSubProvider
|
||||
) -> None:
|
||||
"""OpenRouter 모델이 ProviderType.OPENROUTER로 감지되는지 확인합니다."""
|
||||
model = _create_openrouter_model(model_name)
|
||||
provider = detect_provider(model)
|
||||
assert provider == ProviderType.OPENROUTER
|
||||
|
||||
@pytest.mark.parametrize("model_name,expected_sub", OPENROUTER_MODELS)
|
||||
def test_detect_openrouter_sub_provider(
|
||||
self, model_name: str, expected_sub: OpenRouterSubProvider
|
||||
) -> None:
|
||||
"""OpenRouter 모델명에서 sub-provider를 올바르게 감지하는지 확인합니다."""
|
||||
sub_provider = detect_openrouter_sub_provider(model_name)
|
||||
assert sub_provider == expected_sub
|
||||
|
||||
|
||||
class TestOpenRouterCachingStrategy:
|
||||
"""OpenRouter 캐싱 전략 테스트."""
|
||||
|
||||
@pytest.fixture
|
||||
def low_threshold_config(self) -> CachingConfig:
|
||||
return CachingConfig(min_cacheable_tokens=10)
|
||||
|
||||
@pytest.mark.parametrize("model_name,expected_sub", OPENROUTER_MODELS[:5])
|
||||
def test_caching_strategy_initialization(
|
||||
self,
|
||||
model_name: str,
|
||||
expected_sub: OpenRouterSubProvider,
|
||||
) -> None:
|
||||
"""ContextCachingStrategy가 OpenRouter 모델로 올바르게 초기화되는지 확인합니다."""
|
||||
model = _create_openrouter_model(model_name)
|
||||
strategy = ContextCachingStrategy(model=model, openrouter_model_name=model_name)
|
||||
|
||||
assert strategy._provider == ProviderType.OPENROUTER
|
||||
assert strategy._openrouter_sub_provider == expected_sub
|
||||
|
||||
|
||||
class TestOpenRouterModelInvocation:
|
||||
"""OpenRouter 모델 실제 호출 테스트.
|
||||
|
||||
실제 API 비용이 발생하므로 소수의 모델만 테스트합니다.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"deepseek/deepseek-chat-v3-0324",
|
||||
"qwen/qwen3-32b",
|
||||
"mistralai/mistral-small-3.1-24b-instruct",
|
||||
],
|
||||
)
|
||||
def test_simple_invocation(self, model_name: str) -> None:
|
||||
"""모델이 간단한 프롬프트에 응답하는지 확인합니다."""
|
||||
model = _create_openrouter_model(model_name)
|
||||
response = model.invoke("Say 'hello' in one word.")
|
||||
assert response.content
|
||||
assert len(response.content) > 0
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[
|
||||
"deepseek/deepseek-chat-v3-0324",
|
||||
"x-ai/grok-3-mini-beta",
|
||||
],
|
||||
)
|
||||
def test_caching_strategy_apply_does_not_error(self, model_name: str) -> None:
|
||||
"""ContextCachingStrategy.apply()가 에러 없이 동작하는지 확인합니다."""
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
model = _create_openrouter_model(model_name)
|
||||
strategy = ContextCachingStrategy(
|
||||
model=model,
|
||||
openrouter_model_name=model_name,
|
||||
config=CachingConfig(min_cacheable_tokens=10),
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant. " * 100),
|
||||
HumanMessage(content="Hello"),
|
||||
]
|
||||
|
||||
result = strategy.apply(messages)
|
||||
assert result.was_cached is not None
|
||||
|
||||
|
||||
class TestOpenRouterModelNameExtraction:
|
||||
"""OpenRouter 모델명 추출 테스트."""
|
||||
|
||||
def test_model_name_extraction_from_model_attribute(self) -> None:
|
||||
"""model 속성에서 모델명이 올바르게 추출되는지 확인합니다."""
|
||||
model = _create_openrouter_model("deepseek/deepseek-chat-v3-0324")
|
||||
name = getattr(model, "model_name", None) or getattr(model, "model", None)
|
||||
assert name == "deepseek/deepseek-chat-v3-0324"
|
||||
|
||||
def test_openrouter_base_url_detection(self) -> None:
|
||||
"""OpenRouter base URL이 올바르게 감지되는지 확인합니다."""
|
||||
model = _create_openrouter_model("qwen/qwen3-32b")
|
||||
base_url = getattr(model, "openai_api_base", None)
|
||||
assert base_url is not None
|
||||
assert "openrouter" in base_url.lower()
|
||||
255
tests/context_engineering/test_reduction.py
Normal file
255
tests/context_engineering/test_reduction.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import pytest
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from context_engineering_research_agent.context_strategies.reduction import (
|
||||
ContextReductionStrategy,
|
||||
ReductionConfig,
|
||||
ReductionResult,
|
||||
)
|
||||
|
||||
|
||||
class TestReductionConfig:
|
||||
def test_default_values(self):
|
||||
config = ReductionConfig()
|
||||
|
||||
assert config.context_threshold == 0.85
|
||||
assert config.model_context_window == 200000
|
||||
assert config.compaction_age_threshold == 10
|
||||
assert config.min_messages_to_keep == 5
|
||||
assert config.chars_per_token == 4
|
||||
|
||||
def test_custom_values(self):
|
||||
config = ReductionConfig(
|
||||
context_threshold=0.90,
|
||||
model_context_window=100000,
|
||||
compaction_age_threshold=5,
|
||||
min_messages_to_keep=3,
|
||||
)
|
||||
|
||||
assert config.context_threshold == 0.90
|
||||
assert config.model_context_window == 100000
|
||||
assert config.compaction_age_threshold == 5
|
||||
assert config.min_messages_to_keep == 3
|
||||
|
||||
|
||||
class TestReductionResult:
|
||||
def test_not_reduced(self):
|
||||
result = ReductionResult(was_reduced=False)
|
||||
|
||||
assert result.was_reduced is False
|
||||
assert result.technique_used is None
|
||||
assert result.original_message_count == 0
|
||||
assert result.reduced_message_count == 0
|
||||
assert result.estimated_tokens_saved == 0
|
||||
|
||||
def test_reduced_with_compaction(self):
|
||||
result = ReductionResult(
|
||||
was_reduced=True,
|
||||
technique_used="compaction",
|
||||
original_message_count=50,
|
||||
reduced_message_count=30,
|
||||
estimated_tokens_saved=5000,
|
||||
)
|
||||
|
||||
assert result.was_reduced is True
|
||||
assert result.technique_used == "compaction"
|
||||
assert result.original_message_count == 50
|
||||
assert result.reduced_message_count == 30
|
||||
assert result.estimated_tokens_saved == 5000
|
||||
|
||||
|
||||
class TestContextReductionStrategy:
|
||||
@pytest.fixture
|
||||
def strategy(self) -> ContextReductionStrategy:
|
||||
return ContextReductionStrategy()
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_low_threshold(self) -> ContextReductionStrategy:
|
||||
return ContextReductionStrategy(
|
||||
config=ReductionConfig(
|
||||
context_threshold=0.1,
|
||||
model_context_window=1000,
|
||||
)
|
||||
)
|
||||
|
||||
def test_estimate_tokens(self, strategy: ContextReductionStrategy):
|
||||
messages = [
|
||||
HumanMessage(content="a" * 400),
|
||||
AIMessage(content="b" * 400),
|
||||
]
|
||||
estimated = strategy._estimate_tokens(messages)
|
||||
|
||||
assert estimated == 200
|
||||
|
||||
def test_get_context_usage_ratio(self, strategy: ContextReductionStrategy):
|
||||
messages = [
|
||||
HumanMessage(content="x" * 40000),
|
||||
]
|
||||
ratio = strategy._get_context_usage_ratio(messages)
|
||||
|
||||
assert ratio == pytest.approx(0.05, rel=0.01)
|
||||
|
||||
def test_should_reduce_below_threshold(self, strategy: ContextReductionStrategy):
|
||||
messages = [HumanMessage(content="short message")]
|
||||
|
||||
assert strategy._should_reduce(messages) is False
|
||||
|
||||
def test_should_reduce_above_threshold(
|
||||
self, strategy_low_threshold: ContextReductionStrategy
|
||||
):
|
||||
messages = [HumanMessage(content="x" * 1000)]
|
||||
|
||||
assert strategy_low_threshold._should_reduce(messages) is True
|
||||
|
||||
def test_apply_compaction_removes_old_tool_calls(
|
||||
self, strategy: ContextReductionStrategy
|
||||
):
|
||||
messages = []
|
||||
for i in range(20):
|
||||
messages.append(HumanMessage(content=f"question {i}"))
|
||||
ai_msg = AIMessage(
|
||||
content=f"answer {i}",
|
||||
tool_calls=(
|
||||
[{"id": f"call_{i}", "name": "search", "args": {"q": "test"}}]
|
||||
if i < 15
|
||||
else []
|
||||
),
|
||||
)
|
||||
messages.append(ai_msg)
|
||||
if i < 15:
|
||||
messages.append(
|
||||
ToolMessage(content=f"result {i}", tool_call_id=f"call_{i}")
|
||||
)
|
||||
|
||||
compacted, result = strategy.apply_compaction(messages)
|
||||
|
||||
assert result.was_reduced is True
|
||||
assert result.technique_used == "compaction"
|
||||
assert len(compacted) < len(messages)
|
||||
|
||||
def test_apply_compaction_keeps_recent_messages(
|
||||
self, strategy: ContextReductionStrategy
|
||||
):
|
||||
messages = []
|
||||
for i in range(5):
|
||||
messages.append(HumanMessage(content=f"recent question {i}"))
|
||||
messages.append(AIMessage(content=f"recent answer {i}"))
|
||||
|
||||
compacted, result = strategy.apply_compaction(messages)
|
||||
|
||||
assert len(compacted) == len(messages)
|
||||
|
||||
def test_apply_compaction_preserves_text_content(self):
|
||||
strategy = ContextReductionStrategy(
|
||||
config=ReductionConfig(compaction_age_threshold=2)
|
||||
)
|
||||
messages = [
|
||||
HumanMessage(content="old question"),
|
||||
AIMessage(
|
||||
content="old answer with important info",
|
||||
tool_calls=[{"id": "call_old", "name": "search", "args": {}}],
|
||||
),
|
||||
ToolMessage(content="old result", tool_call_id="call_old"),
|
||||
HumanMessage(content="recent question"),
|
||||
AIMessage(content="recent answer"),
|
||||
]
|
||||
|
||||
compacted, _ = strategy.apply_compaction(messages)
|
||||
|
||||
text_contents = [str(m.content) for m in compacted]
|
||||
has_important_info = any("important info" in c for c in text_contents)
|
||||
|
||||
assert has_important_info is True
|
||||
|
||||
def test_apply_compaction_removes_tool_messages(self):
|
||||
strategy = ContextReductionStrategy(
|
||||
config=ReductionConfig(compaction_age_threshold=2)
|
||||
)
|
||||
messages = [
|
||||
HumanMessage(content="old question"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{"id": "call_old", "name": "search", "args": {}}],
|
||||
),
|
||||
ToolMessage(content="old tool result", tool_call_id="call_old"),
|
||||
HumanMessage(content="recent question"),
|
||||
AIMessage(content="recent answer"),
|
||||
]
|
||||
|
||||
compacted, _ = strategy.apply_compaction(messages)
|
||||
|
||||
tool_message_count = sum(1 for m in compacted if isinstance(m, ToolMessage))
|
||||
|
||||
assert tool_message_count == 0
|
||||
|
||||
def test_reduce_context_no_reduction_needed(
|
||||
self, strategy: ContextReductionStrategy
|
||||
):
|
||||
messages = [HumanMessage(content="short")]
|
||||
|
||||
reduced, result = strategy.reduce_context(messages)
|
||||
|
||||
assert result.was_reduced is False
|
||||
assert reduced == messages
|
||||
|
||||
def test_reduce_context_with_very_large_messages(self):
|
||||
strategy = ContextReductionStrategy(
|
||||
config=ReductionConfig(
|
||||
context_threshold=0.01,
|
||||
model_context_window=100,
|
||||
compaction_age_threshold=5,
|
||||
)
|
||||
)
|
||||
messages = []
|
||||
for i in range(20):
|
||||
messages.append(HumanMessage(content=f"question {i} " * 100))
|
||||
messages.append(
|
||||
AIMessage(
|
||||
content=f"answer {i} " * 100,
|
||||
tool_calls=[{"id": f"c{i}", "name": "s", "args": {}}],
|
||||
)
|
||||
)
|
||||
messages.append(ToolMessage(content="result " * 100, tool_call_id=f"c{i}"))
|
||||
|
||||
compacted, result = strategy.apply_compaction(messages)
|
||||
|
||||
assert result.was_reduced is True
|
||||
assert len(compacted) < len(messages)
|
||||
|
||||
def test_create_summary_prompt(self, strategy: ContextReductionStrategy):
|
||||
messages = [
|
||||
HumanMessage(content="What is Python?"),
|
||||
AIMessage(content="Python is a programming language."),
|
||||
]
|
||||
|
||||
prompt = strategy._create_summary_prompt(messages)
|
||||
|
||||
assert "Python" in prompt
|
||||
assert "Human" in prompt
|
||||
assert "AI" in prompt
|
||||
|
||||
def test_apply_summarization_no_model(self, strategy: ContextReductionStrategy):
|
||||
messages = [HumanMessage(content="test")]
|
||||
|
||||
summarized, result = strategy.apply_summarization(messages)
|
||||
|
||||
assert result.was_reduced is False
|
||||
assert summarized == messages
|
||||
|
||||
def test_compaction_preserves_system_messages(self):
|
||||
strategy = ContextReductionStrategy(
|
||||
config=ReductionConfig(compaction_age_threshold=2)
|
||||
)
|
||||
messages = [
|
||||
SystemMessage(content="You are a helpful assistant"),
|
||||
HumanMessage(content="old question"),
|
||||
AIMessage(content="old answer"),
|
||||
HumanMessage(content="recent question"),
|
||||
AIMessage(content="recent answer"),
|
||||
]
|
||||
|
||||
compacted, _ = strategy.apply_compaction(messages)
|
||||
|
||||
system_messages = [m for m in compacted if isinstance(m, SystemMessage)]
|
||||
assert len(system_messages) == 1
|
||||
assert "helpful assistant" in str(system_messages[0].content)
|
||||
163
tests/context_engineering/test_retrieval.py
Normal file
163
tests/context_engineering/test_retrieval.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import pytest
|
||||
|
||||
from context_engineering_research_agent.context_strategies.retrieval import (
|
||||
ContextRetrievalStrategy,
|
||||
RetrievalConfig,
|
||||
RetrievalResult,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrievalConfig:
|
||||
def test_default_values(self):
|
||||
config = RetrievalConfig()
|
||||
|
||||
assert config.default_read_limit == 500
|
||||
assert config.max_grep_results == 100
|
||||
assert config.max_glob_results == 100
|
||||
assert config.truncate_line_length == 2000
|
||||
|
||||
def test_custom_values(self):
|
||||
config = RetrievalConfig(
|
||||
default_read_limit=1000,
|
||||
max_grep_results=50,
|
||||
max_glob_results=200,
|
||||
truncate_line_length=3000,
|
||||
)
|
||||
|
||||
assert config.default_read_limit == 1000
|
||||
assert config.max_grep_results == 50
|
||||
assert config.max_glob_results == 200
|
||||
assert config.truncate_line_length == 3000
|
||||
|
||||
|
||||
class TestRetrievalResult:
|
||||
def test_basic_result(self):
|
||||
result = RetrievalResult(
|
||||
tool_used="grep",
|
||||
query="TODO",
|
||||
result_count=10,
|
||||
)
|
||||
|
||||
assert result.tool_used == "grep"
|
||||
assert result.query == "TODO"
|
||||
assert result.result_count == 10
|
||||
assert result.was_truncated is False
|
||||
|
||||
def test_truncated_result(self):
|
||||
result = RetrievalResult(
|
||||
tool_used="glob",
|
||||
query="**/*.py",
|
||||
result_count=100,
|
||||
was_truncated=True,
|
||||
)
|
||||
|
||||
assert result.was_truncated is True
|
||||
|
||||
|
||||
class TestContextRetrievalStrategy:
|
||||
@pytest.fixture
|
||||
def strategy(self) -> ContextRetrievalStrategy:
|
||||
return ContextRetrievalStrategy()
|
||||
|
||||
@pytest.fixture
|
||||
def strategy_with_custom_config(self) -> ContextRetrievalStrategy:
|
||||
return ContextRetrievalStrategy(
|
||||
config=RetrievalConfig(
|
||||
default_read_limit=100,
|
||||
max_grep_results=10,
|
||||
max_glob_results=10,
|
||||
)
|
||||
)
|
||||
|
||||
def test_initialization(self, strategy: ContextRetrievalStrategy):
|
||||
assert strategy.config is not None
|
||||
assert len(strategy.tools) == 3
|
||||
|
||||
def test_creates_read_file_tool(self, strategy: ContextRetrievalStrategy):
|
||||
tool_names = [t.name for t in strategy.tools]
|
||||
|
||||
assert "read_file" in tool_names
|
||||
|
||||
def test_creates_grep_tool(self, strategy: ContextRetrievalStrategy):
|
||||
tool_names = [t.name for t in strategy.tools]
|
||||
|
||||
assert "grep" in tool_names
|
||||
|
||||
def test_creates_glob_tool(self, strategy: ContextRetrievalStrategy):
|
||||
tool_names = [t.name for t in strategy.tools]
|
||||
|
||||
assert "glob" in tool_names
|
||||
|
||||
def test_read_file_tool_description(self, strategy: ContextRetrievalStrategy):
|
||||
read_file_tool = next(t for t in strategy.tools if t.name == "read_file")
|
||||
|
||||
assert "500" in read_file_tool.description
|
||||
assert "offset" in read_file_tool.description.lower()
|
||||
assert "limit" in read_file_tool.description.lower()
|
||||
|
||||
def test_grep_tool_description(self, strategy: ContextRetrievalStrategy):
|
||||
grep_tool = next(t for t in strategy.tools if t.name == "grep")
|
||||
|
||||
assert "100" in grep_tool.description
|
||||
assert "pattern" in grep_tool.description.lower()
|
||||
|
||||
def test_glob_tool_description(self, strategy: ContextRetrievalStrategy):
|
||||
glob_tool = next(t for t in strategy.tools if t.name == "glob")
|
||||
|
||||
assert "100" in glob_tool.description
|
||||
assert "**/*.py" in glob_tool.description
|
||||
|
||||
def test_custom_config_affects_tool_descriptions(
|
||||
self, strategy_with_custom_config: ContextRetrievalStrategy
|
||||
):
|
||||
read_file_tool = next(
|
||||
t for t in strategy_with_custom_config.tools if t.name == "read_file"
|
||||
)
|
||||
grep_tool = next(
|
||||
t for t in strategy_with_custom_config.tools if t.name == "grep"
|
||||
)
|
||||
|
||||
assert "100" in read_file_tool.description
|
||||
assert "10" in grep_tool.description
|
||||
|
||||
def test_no_backend_read_file(self, strategy: ContextRetrievalStrategy):
|
||||
read_file_tool = next(t for t in strategy.tools if t.name == "read_file")
|
||||
|
||||
class MockRuntime:
|
||||
state = {}
|
||||
config = {}
|
||||
|
||||
result = read_file_tool.func( # type: ignore
|
||||
file_path="/test.txt",
|
||||
runtime=MockRuntime(),
|
||||
)
|
||||
|
||||
assert "백엔드가 설정되지 않았습니다" in result
|
||||
|
||||
def test_no_backend_grep(self, strategy: ContextRetrievalStrategy):
|
||||
grep_tool = next(t for t in strategy.tools if t.name == "grep")
|
||||
|
||||
class MockRuntime:
|
||||
state = {}
|
||||
config = {}
|
||||
|
||||
result = grep_tool.func( # type: ignore
|
||||
pattern="TODO",
|
||||
runtime=MockRuntime(),
|
||||
)
|
||||
|
||||
assert "백엔드가 설정되지 않았습니다" in result
|
||||
|
||||
def test_no_backend_glob(self, strategy: ContextRetrievalStrategy):
|
||||
glob_tool = next(t for t in strategy.tools if t.name == "glob")
|
||||
|
||||
class MockRuntime:
|
||||
state = {}
|
||||
config = {}
|
||||
|
||||
result = glob_tool.func( # type: ignore
|
||||
pattern="**/*.py",
|
||||
runtime=MockRuntime(),
|
||||
)
|
||||
|
||||
assert "백엔드가 설정되지 않았습니다" in result
|
||||
65
uv.lock
generated
65
uv.lock
generated
@@ -340,10 +340,12 @@ dependencies = [
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "docker" },
|
||||
{ name = "ipykernel" },
|
||||
{ name = "ipywidgets" },
|
||||
{ name = "langgraph-cli", extra = ["inmem"] },
|
||||
{ name = "mypy" },
|
||||
{ name = "pytest" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
@@ -361,10 +363,12 @@ requires-dist = [
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "docker", specifier = ">=7.1.0" },
|
||||
{ name = "ipykernel", specifier = ">=7.1.0" },
|
||||
{ name = "ipywidgets", specifier = ">=8.1.8" },
|
||||
{ name = "langgraph-cli", extras = ["inmem"], specifier = ">=0.4.11" },
|
||||
{ name = "mypy", specifier = ">=1.19.1" },
|
||||
{ name = "pytest", specifier = ">=9.0.2" },
|
||||
{ name = "ruff", specifier = ">=0.14.10" },
|
||||
]
|
||||
|
||||
@@ -393,6 +397,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277, upload-time = "2023-12-24T09:54:30.421Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docker"
|
||||
version = "7.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pywin32", marker = "sys_platform == 'win32'" },
|
||||
{ name = "requests" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/91/9b/4a2ea29aeba62471211598dac5d96825bb49348fa07e906ea930394a83ce/docker-7.1.0.tar.gz", hash = "sha256:ad8c70e6e3f8926cb8a92619b832b4ea5299e2831c14284663184e200546fa6c", size = 117834, upload-time = "2024-05-23T11:13:57.216Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/26/57c6fb270950d476074c087527a558ccb6f4436657314bfb6cdf484114c4/docker-7.1.0-py3-none-any.whl", hash = "sha256:c96b93b7f0a746f9e77d325bcfb87422a3d8bd4f03136ae8a85b37f1898d5fc0", size = 147774, upload-time = "2024-05-23T11:13:55.01Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.17.0"
|
||||
@@ -624,6 +642,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipykernel"
|
||||
version = "7.1.0"
|
||||
@@ -1408,6 +1435,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731, upload-time = "2025-12-05T13:52:56.823Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt-toolkit"
|
||||
version = "3.0.52"
|
||||
@@ -1597,6 +1633,22 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "9.0.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "pygments" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@@ -1627,6 +1679,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pywin32"
|
||||
version = "311"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c9/31/097f2e132c4f16d99a22bfb777e0fd88bd8e1c634304e102f313af69ace5/pywin32-311-cp314-cp314-win32.whl", hash = "sha256:b7a2c10b93f8986666d0c803ee19b5990885872a7de910fc460f9b0c2fbf92ee", size = 8840714, upload-time = "2025-07-14T20:13:32.449Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/4b/07c77d8ba0e01349358082713400435347df8426208171ce297da32c313d/pywin32-311-cp314-cp314-win_amd64.whl", hash = "sha256:3aca44c046bd2ed8c90de9cb8427f581c479e594e99b5c0bb19b29c10fd6cb87", size = 9656800, upload-time = "2025-07-14T20:13:34.312Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/d2/21af5c535501a7233e734b8af901574572da66fcc254cb35d0609c9080dd/pywin32-311-cp314-cp314-win_arm64.whl", hash = "sha256:a508e2d9025764a8270f93111a970e1d0fbfc33f4153b388bb649b7eec4f9b42", size = 8932540, upload-time = "2025-07-14T20:13:36.379Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.3"
|
||||
|
||||
Reference in New Issue
Block a user