- Context_Engineering.md: 에이전트 컨텍스트 엔지니어링 개념 정리 문서 추가 - Context_Engineering_Research.ipynb: 연구 노트북 업데이트 - deepagents_sourcecode/: docstring과 주석을 한국어로 번역
218 lines
7.4 KiB
Python
218 lines
7.4 KiB
Python
"""LangGraph 체크포인트 저장 기능을 사용한 스레드/세션 관리입니다.
|
|
|
|
Thread management using LangGraph's built-in checkpoint persistence.
|
|
"""
|
|
|
|
import uuid
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
from rich.table import Table
|
|
|
|
from deepagents_cli.config import COLORS, console
|
|
|
|
# Patch aiosqlite.Connection to add is_alive() method required by langgraph-checkpoint>=2.1.0
|
|
# See: https://github.com/langchain-ai/langgraph/issues/6583
|
|
if not hasattr(aiosqlite.Connection, "is_alive"):
|
|
|
|
def _is_alive(self: aiosqlite.Connection) -> bool:
|
|
"""Check if the connection is still alive."""
|
|
return self._connection is not None
|
|
|
|
aiosqlite.Connection.is_alive = _is_alive
|
|
|
|
|
|
def _format_timestamp(iso_timestamp: str | None) -> str:
|
|
"""Format ISO timestamp for display (e.g., 'Dec 30, 6:10pm')."""
|
|
if not iso_timestamp:
|
|
return ""
|
|
try:
|
|
dt = datetime.fromisoformat(iso_timestamp).astimezone()
|
|
return dt.strftime("%b %d, %-I:%M%p").lower().replace("am", "am").replace("pm", "pm")
|
|
except (ValueError, TypeError):
|
|
return ""
|
|
|
|
|
|
def get_db_path() -> Path:
|
|
"""Get path to global database."""
|
|
db_dir = Path.home() / ".deepagents"
|
|
db_dir.mkdir(parents=True, exist_ok=True)
|
|
return db_dir / "sessions.db"
|
|
|
|
|
|
def generate_thread_id() -> str:
|
|
"""Generate a new 8-char hex thread ID."""
|
|
return uuid.uuid4().hex[:8]
|
|
|
|
|
|
async def _table_exists(conn: aiosqlite.Connection, table: str) -> bool:
|
|
"""Check if a table exists in the database."""
|
|
query = "SELECT 1 FROM sqlite_master WHERE type='table' AND name=?"
|
|
async with conn.execute(query, (table,)) as cursor:
|
|
return await cursor.fetchone() is not None
|
|
|
|
|
|
async def list_threads(
|
|
agent_name: str | None = None,
|
|
limit: int = 20,
|
|
) -> list[dict]:
|
|
"""List threads from checkpoints table."""
|
|
db_path = str(get_db_path())
|
|
async with aiosqlite.connect(db_path, timeout=30.0) as conn:
|
|
# Return empty if table doesn't exist yet (fresh install)
|
|
if not await _table_exists(conn, "checkpoints"):
|
|
return []
|
|
|
|
if agent_name:
|
|
query = """
|
|
SELECT thread_id,
|
|
json_extract(metadata, '$.agent_name') as agent_name,
|
|
MAX(json_extract(metadata, '$.updated_at')) as updated_at
|
|
FROM checkpoints
|
|
WHERE json_extract(metadata, '$.agent_name') = ?
|
|
GROUP BY thread_id
|
|
ORDER BY updated_at DESC
|
|
LIMIT ?
|
|
"""
|
|
params: tuple = (agent_name, limit)
|
|
else:
|
|
query = """
|
|
SELECT thread_id,
|
|
json_extract(metadata, '$.agent_name') as agent_name,
|
|
MAX(json_extract(metadata, '$.updated_at')) as updated_at
|
|
FROM checkpoints
|
|
GROUP BY thread_id
|
|
ORDER BY updated_at DESC
|
|
LIMIT ?
|
|
"""
|
|
params = (limit,)
|
|
|
|
async with conn.execute(query, params) as cursor:
|
|
rows = await cursor.fetchall()
|
|
return [{"thread_id": r[0], "agent_name": r[1], "updated_at": r[2]} for r in rows]
|
|
|
|
|
|
async def get_most_recent(agent_name: str | None = None) -> str | None:
|
|
"""Get most recent thread_id, optionally filtered by agent."""
|
|
db_path = str(get_db_path())
|
|
async with aiosqlite.connect(db_path, timeout=30.0) as conn:
|
|
if not await _table_exists(conn, "checkpoints"):
|
|
return None
|
|
|
|
if agent_name:
|
|
query = """
|
|
SELECT thread_id FROM checkpoints
|
|
WHERE json_extract(metadata, '$.agent_name') = ?
|
|
ORDER BY checkpoint_id DESC
|
|
LIMIT 1
|
|
"""
|
|
params: tuple = (agent_name,)
|
|
else:
|
|
query = "SELECT thread_id FROM checkpoints ORDER BY checkpoint_id DESC LIMIT 1"
|
|
params = ()
|
|
|
|
async with conn.execute(query, params) as cursor:
|
|
row = await cursor.fetchone()
|
|
return row[0] if row else None
|
|
|
|
|
|
async def get_thread_agent(thread_id: str) -> str | None:
|
|
"""Get agent_name for a thread."""
|
|
db_path = str(get_db_path())
|
|
async with aiosqlite.connect(db_path, timeout=30.0) as conn:
|
|
if not await _table_exists(conn, "checkpoints"):
|
|
return None
|
|
|
|
query = """
|
|
SELECT json_extract(metadata, '$.agent_name')
|
|
FROM checkpoints
|
|
WHERE thread_id = ?
|
|
LIMIT 1
|
|
"""
|
|
async with conn.execute(query, (thread_id,)) as cursor:
|
|
row = await cursor.fetchone()
|
|
return row[0] if row else None
|
|
|
|
|
|
async def thread_exists(thread_id: str) -> bool:
|
|
"""Check if a thread exists in checkpoints."""
|
|
db_path = str(get_db_path())
|
|
async with aiosqlite.connect(db_path, timeout=30.0) as conn:
|
|
if not await _table_exists(conn, "checkpoints"):
|
|
return False
|
|
|
|
query = "SELECT 1 FROM checkpoints WHERE thread_id = ? LIMIT 1"
|
|
async with conn.execute(query, (thread_id,)) as cursor:
|
|
row = await cursor.fetchone()
|
|
return row is not None
|
|
|
|
|
|
async def delete_thread(thread_id: str) -> bool:
|
|
"""Delete thread checkpoints. Returns True if deleted."""
|
|
db_path = str(get_db_path())
|
|
async with aiosqlite.connect(db_path, timeout=30.0) as conn:
|
|
if not await _table_exists(conn, "checkpoints"):
|
|
return False
|
|
|
|
cursor = await conn.execute("DELETE FROM checkpoints WHERE thread_id = ?", (thread_id,))
|
|
deleted = cursor.rowcount > 0
|
|
if await _table_exists(conn, "writes"):
|
|
await conn.execute("DELETE FROM writes WHERE thread_id = ?", (thread_id,))
|
|
await conn.commit()
|
|
return deleted
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_checkpointer() -> AsyncIterator[AsyncSqliteSaver]:
|
|
"""Get AsyncSqliteSaver for the global database."""
|
|
async with AsyncSqliteSaver.from_conn_string(str(get_db_path())) as checkpointer:
|
|
yield checkpointer
|
|
|
|
|
|
async def list_threads_command(
|
|
agent_name: str | None = None,
|
|
limit: int = 20,
|
|
) -> None:
|
|
"""CLI handler for: deepagents threads list."""
|
|
threads = await list_threads(agent_name, limit=limit)
|
|
|
|
if not threads:
|
|
if agent_name:
|
|
console.print(f"[yellow]No threads found for agent '{agent_name}'.[/yellow]")
|
|
else:
|
|
console.print("[yellow]No threads found.[/yellow]")
|
|
console.print("[dim]Start a conversation with: deepagents[/dim]")
|
|
return
|
|
|
|
title = f"Threads for '{agent_name}'" if agent_name else "All Threads"
|
|
|
|
table = Table(title=title, show_header=True, header_style=f"bold {COLORS['primary']}")
|
|
table.add_column("Thread ID", style="bold")
|
|
table.add_column("Agent")
|
|
table.add_column("Last Used", style="dim")
|
|
|
|
for t in threads:
|
|
table.add_row(
|
|
t["thread_id"],
|
|
t["agent_name"] or "unknown",
|
|
_format_timestamp(t.get("updated_at")),
|
|
)
|
|
|
|
console.print()
|
|
console.print(table)
|
|
console.print()
|
|
|
|
|
|
async def delete_thread_command(thread_id: str) -> None:
|
|
"""CLI handler for: deepagents threads delete."""
|
|
deleted = await delete_thread(thread_id)
|
|
|
|
if deleted:
|
|
console.print(f"[green]Thread '{thread_id}' deleted.[/green]")
|
|
else:
|
|
console.print(f"[red]Thread '{thread_id}' not found.[/red]")
|