feat(acp): support loading sessions in acp (#5942)
This commit is contained in:
@@ -11,7 +11,7 @@ use goose::mcp_utils::ToolResult;
|
||||
use goose::providers::create;
|
||||
use goose::session::session_manager::SessionType;
|
||||
use goose::session::SessionManager;
|
||||
use rmcp::model::{Content, RawContent, ResourceContents};
|
||||
use rmcp::model::{Content, RawContent, ResourceContents, Role};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
@@ -570,7 +570,7 @@ impl acp::Agent for GooseAcpAgent {
|
||||
|
||||
// Advertise Goose's capabilities
|
||||
let agent_capabilities = acp::AgentCapabilities {
|
||||
load_session: false, // TODO: Implement session persistence
|
||||
load_session: true,
|
||||
prompt_capabilities: acp::PromptCapabilities {
|
||||
image: true, // Goose supports image inputs via providers
|
||||
audio: false, // TODO: Add audio support when providers support it
|
||||
@@ -638,19 +638,108 @@ impl acp::Agent for GooseAcpAgent {
|
||||
args: acp::LoadSessionRequest,
|
||||
) -> Result<acp::LoadSessionResponse, acp::Error> {
|
||||
info!("ACP: Received load session request {:?}", args);
|
||||
// For now, will start a new session. We could use goose session storage as an enhancement
|
||||
// we would need to map ACP session IDs to goose session ids (which by default are auto generated)
|
||||
// normal goose session restore in CLI doesn't load conversation visually.
|
||||
//
|
||||
// Example flow:
|
||||
// - Load session file by session_id (might need to map ACP session IDs to Goose session paths)
|
||||
// - For each message in history:
|
||||
// - If user message: send user_message_chunk notification
|
||||
// - If assistant message: send agent_message_chunk notification
|
||||
// - If tool calls/responses: send appropriate notifications
|
||||
|
||||
// For now, we don't support loading previous sessions
|
||||
Err(acp::Error::method_not_found())
|
||||
let session_id = args.session_id.0.to_string();
|
||||
|
||||
let goose_session = SessionManager::get_session(&session_id, true)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to load session {}: {}", session_id, e);
|
||||
acp::Error::invalid_params()
|
||||
})?;
|
||||
|
||||
let conversation = goose_session.conversation.ok_or_else(|| {
|
||||
error!("Session {} has no conversation data", session_id);
|
||||
acp::Error::internal_error()
|
||||
})?;
|
||||
|
||||
SessionManager::update_session(&session_id)
|
||||
.working_dir(args.cwd.clone())
|
||||
.apply()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
error!("Failed to update session working directory: {}", e);
|
||||
acp::Error::internal_error()
|
||||
})?;
|
||||
|
||||
let mut session = GooseAcpSession {
|
||||
messages: conversation.clone(),
|
||||
tool_call_ids: HashMap::new(),
|
||||
tool_requests: HashMap::new(),
|
||||
cancel_token: None,
|
||||
};
|
||||
|
||||
// Replay conversation history to client
|
||||
for message in conversation.messages() {
|
||||
// Only replay user-visible messages
|
||||
if !message.metadata.user_visible {
|
||||
continue;
|
||||
}
|
||||
|
||||
for content_item in &message.content {
|
||||
match content_item {
|
||||
MessageContent::Text(text) => {
|
||||
let update = match message.role {
|
||||
Role::User => acp::SessionUpdate::UserMessageChunk {
|
||||
content: text.text.clone().into(),
|
||||
},
|
||||
Role::Assistant => acp::SessionUpdate::AgentMessageChunk {
|
||||
content: text.text.clone().into(),
|
||||
},
|
||||
};
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.session_update_tx
|
||||
.send((
|
||||
SessionNotification {
|
||||
session_id: args.session_id.clone(),
|
||||
update,
|
||||
meta: None,
|
||||
},
|
||||
tx,
|
||||
))
|
||||
.map_err(|_| acp::Error::internal_error())?;
|
||||
rx.await.map_err(|_| acp::Error::internal_error())?;
|
||||
}
|
||||
MessageContent::ToolRequest(tool_request) => {
|
||||
self.handle_tool_request(tool_request, &args.session_id, &mut session)
|
||||
.await?;
|
||||
}
|
||||
MessageContent::ToolResponse(tool_response) => {
|
||||
self.handle_tool_response(tool_response, &args.session_id, &mut session)
|
||||
.await?;
|
||||
}
|
||||
MessageContent::Thinking(thinking) => {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.session_update_tx
|
||||
.send((
|
||||
SessionNotification {
|
||||
session_id: args.session_id.clone(),
|
||||
update: acp::SessionUpdate::AgentThoughtChunk {
|
||||
content: thinking.thinking.clone().into(),
|
||||
},
|
||||
meta: None,
|
||||
},
|
||||
tx,
|
||||
))
|
||||
.map_err(|_| acp::Error::internal_error())?;
|
||||
rx.await.map_err(|_| acp::Error::internal_error())?;
|
||||
}
|
||||
_ => {
|
||||
// Ignore other content types
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut sessions = self.sessions.lock().await;
|
||||
sessions.insert(session_id.clone(), session);
|
||||
|
||||
info!("Loaded ACP session {}", session_id);
|
||||
|
||||
Ok(acp::LoadSessionResponse {
|
||||
modes: None,
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn prompt(&self, args: acp::PromptRequest) -> Result<acp::PromptResponse, acp::Error> {
|
||||
|
||||
@@ -2,14 +2,23 @@
|
||||
"""
|
||||
Simple ACP client to test the goose ACP agent.
|
||||
Connects to goose acp running on stdio.
|
||||
|
||||
Tests:
|
||||
1. Initialize - Establish connection and verify capabilities
|
||||
2. session/new - Create a new session
|
||||
3. session/prompt - Send a prompt to the session
|
||||
4. session/load - Load an existing session (new feature)
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
class AcpClient:
|
||||
def __init__(self):
|
||||
# Start the goose acp process
|
||||
self.process = subprocess.Popen(
|
||||
['cargo', 'run', '-p', 'goose-cli', '--', 'acp'],
|
||||
stdin=subprocess.PIPE,
|
||||
@@ -19,8 +28,19 @@ class AcpClient:
|
||||
bufsize=0
|
||||
)
|
||||
self.request_id = 0
|
||||
|
||||
def send_request(self, method, params=None):
|
||||
|
||||
def send_request(self, method, params=None, collect_notifications=False):
|
||||
"""Send a request and wait for the response.
|
||||
|
||||
Args:
|
||||
method: The JSON-RPC method name
|
||||
params: Optional parameters for the request
|
||||
collect_notifications: If True, collect notifications until response arrives
|
||||
|
||||
Returns:
|
||||
Tuple of (response, notifications) if collect_notifications is True,
|
||||
otherwise just the response.
|
||||
"""
|
||||
self.request_id += 1
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
@@ -29,22 +49,42 @@ class AcpClient:
|
||||
}
|
||||
if params:
|
||||
request["params"] = params
|
||||
|
||||
# Send the request
|
||||
|
||||
request_str = json.dumps(request)
|
||||
print(f">>> Sending: {request_str}")
|
||||
self.process.stdin.write(request_str + '\n')
|
||||
self.process.stdin.flush()
|
||||
|
||||
# Read response
|
||||
response_line = self.process.stdout.readline()
|
||||
if not response_line:
|
||||
return None
|
||||
|
||||
print(f"<<< Response: {response_line}")
|
||||
return json.loads(response_line)
|
||||
|
||||
|
||||
notifications = []
|
||||
|
||||
# Read responses until we get one with our request ID
|
||||
while True:
|
||||
response_line = self.process.stdout.readline()
|
||||
if not response_line:
|
||||
if collect_notifications:
|
||||
return None, notifications
|
||||
return None
|
||||
|
||||
response = json.loads(response_line)
|
||||
|
||||
# Check if this is a notification (has 'method' but no 'id')
|
||||
if 'method' in response and 'id' not in response:
|
||||
print(f"<<< Notification: {response['method']}: {response.get('params', {}).get('update', {}).get('sessionUpdate', 'unknown')}")
|
||||
if collect_notifications:
|
||||
notifications.append(response)
|
||||
continue
|
||||
|
||||
if response.get('id') == self.request_id:
|
||||
print(f"<<< Response: {response_line.strip()}")
|
||||
if collect_notifications:
|
||||
return response, notifications
|
||||
return response
|
||||
else:
|
||||
# Response for a different request ID, skip
|
||||
print(f"<<< Unexpected response ID: {response}")
|
||||
|
||||
def initialize(self):
|
||||
"""Initialize the ACP connection and verify capabilities."""
|
||||
return self.send_request("initialize", {
|
||||
"protocolVersion": "v1",
|
||||
"clientCapabilities": {},
|
||||
@@ -53,14 +93,33 @@ class AcpClient:
|
||||
"version": "1.0.0"
|
||||
}
|
||||
})
|
||||
|
||||
def new_session(self):
|
||||
return self.send_request("newSession", {
|
||||
"context": {}
|
||||
})
|
||||
|
||||
|
||||
def new_session(self, cwd=None):
|
||||
"""Create a new session (session/new)."""
|
||||
params = {
|
||||
"mcpServers": [],
|
||||
"cwd": cwd or os.getcwd()
|
||||
}
|
||||
return self.send_request("session/new", params)
|
||||
|
||||
def load_session(self, session_id, cwd=None):
|
||||
"""Load an existing session (session/load).
|
||||
|
||||
Returns: (response, notifications) tuple with session history notifications.
|
||||
"""
|
||||
params = {
|
||||
"sessionId": session_id,
|
||||
"mcpServers": [],
|
||||
"cwd": cwd or os.getcwd()
|
||||
}
|
||||
return self.send_request("session/load", params, collect_notifications=True)
|
||||
|
||||
def prompt(self, session_id, text):
|
||||
return self.send_request("prompt", {
|
||||
"""Send a prompt to the session (session/prompt).
|
||||
|
||||
Returns: (response, notifications) tuple with streaming notifications.
|
||||
"""
|
||||
return self.send_request("session/prompt", {
|
||||
"sessionId": session_id,
|
||||
"prompt": [
|
||||
{
|
||||
@@ -68,48 +127,138 @@ class AcpClient:
|
||||
"text": text
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
}, collect_notifications=True)
|
||||
|
||||
def close(self):
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
self.process.wait()
|
||||
|
||||
|
||||
def test_new_session(client):
|
||||
"""Test creating a new session and sending a prompt."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST: New Session Flow")
|
||||
print("="*60)
|
||||
|
||||
print("\n2. Creating new session (session/new)...")
|
||||
session_response = client.new_session()
|
||||
if session_response and 'result' in session_response:
|
||||
session_id = session_response['result']['sessionId']
|
||||
print(f" ✓ Created session: {session_id}")
|
||||
return session_id
|
||||
else:
|
||||
print(f" ✗ Failed to create session: {session_response}")
|
||||
return None
|
||||
|
||||
|
||||
def test_load_session(client, session_id):
|
||||
"""Test loading an existing session."""
|
||||
print("\n" + "="*60)
|
||||
print("TEST: Load Session Flow")
|
||||
print("="*60)
|
||||
|
||||
print(f"\n4. Loading existing session (session/load) with ID: {session_id}")
|
||||
load_response, notifications = client.load_session(session_id)
|
||||
|
||||
# Show notifications received (these are the session history)
|
||||
if notifications:
|
||||
print(f" 📝 Received {len(notifications)} notification(s) (session history replay):")
|
||||
for n in notifications:
|
||||
update = n.get('params', {}).get('update', {})
|
||||
update_type = update.get('sessionUpdate', 'unknown')
|
||||
content = update.get('content', {})
|
||||
if isinstance(content, dict):
|
||||
text = content.get('text', '')[:50]
|
||||
else:
|
||||
text = str(content)[:50]
|
||||
print(f" - {update_type}: {text}...")
|
||||
|
||||
if load_response and 'result' in load_response:
|
||||
print(f" ✓ Session loaded successfully")
|
||||
print(f" Response: {load_response['result']}")
|
||||
return True
|
||||
else:
|
||||
print(f" ✗ Failed to load session: {load_response}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
print("Starting ACP client test...")
|
||||
print("="*60)
|
||||
print("ACP Client Test Suite")
|
||||
print("="*60)
|
||||
print("\nStarting ACP client test...")
|
||||
|
||||
client = AcpClient()
|
||||
|
||||
|
||||
try:
|
||||
# Initialize the agent
|
||||
print("\n1. Initializing agent...")
|
||||
init_response = client.initialize()
|
||||
if init_response and 'result' in init_response:
|
||||
print(f" Initialized successfully: {init_response['result']}")
|
||||
capabilities = init_response['result'].get('agentCapabilities', {})
|
||||
print(f" ✓ Initialized successfully")
|
||||
print(f" - loadSession capability: {capabilities.get('loadSession', False)}")
|
||||
print(f" - promptCapabilities: {capabilities.get('promptCapabilities', {})}")
|
||||
|
||||
if not capabilities.get('loadSession'):
|
||||
print(" ⚠ Warning: loadSession capability is not advertised")
|
||||
else:
|
||||
print(f" Failed to initialize: {init_response}")
|
||||
return
|
||||
|
||||
# Create a new session
|
||||
print("\n2. Creating new session...")
|
||||
session_response = client.new_session()
|
||||
if session_response and 'result' in session_response:
|
||||
session_id = session_response['result']['sessionId']
|
||||
print(f" Created session: {session_id}")
|
||||
print(f" ✗ Failed to initialize: {init_response}")
|
||||
return 1
|
||||
|
||||
session_id = test_new_session(client)
|
||||
if not session_id:
|
||||
return 1
|
||||
|
||||
print("\n3. Sending prompt (session/prompt)...")
|
||||
prompt_response, notifications = client.prompt(session_id, "Hello! Say 'test successful' if you can hear me.")
|
||||
if notifications:
|
||||
print(f" 📝 Received {len(notifications)} streaming notification(s)")
|
||||
if prompt_response and 'result' in prompt_response:
|
||||
print(f" ✓ Got response: {prompt_response['result']}")
|
||||
elif prompt_response and 'error' in prompt_response:
|
||||
print(f" ✗ Error: {prompt_response['error']}")
|
||||
else:
|
||||
print(f" Failed to create session: {session_response}")
|
||||
return
|
||||
|
||||
# Send a prompt
|
||||
print("\n3. Sending prompt...")
|
||||
prompt_response = client.prompt(session_id, "Hello! What is 2 + 2?")
|
||||
if prompt_response:
|
||||
print(f" Got response: {prompt_response}")
|
||||
print(f" ✗ Failed to get prompt response: {prompt_response}")
|
||||
|
||||
# Close the client and start a new one to simulate reconnection
|
||||
print("\n--- Simulating client restart ---")
|
||||
client.close()
|
||||
time.sleep(1)
|
||||
|
||||
client = AcpClient()
|
||||
|
||||
print("\n5. Re-initializing after restart...")
|
||||
init_response = client.initialize()
|
||||
if init_response and 'result' in init_response:
|
||||
print(f" ✓ Re-initialized successfully")
|
||||
else:
|
||||
print(" Failed to get prompt response")
|
||||
|
||||
print(f" ✗ Failed to re-initialize: {init_response}")
|
||||
return 1
|
||||
|
||||
if not test_load_session(client, session_id):
|
||||
return 1
|
||||
|
||||
print("\n6. Sending prompt to loaded session...")
|
||||
prompt_response, notifications = client.prompt(session_id, "What was my previous message?")
|
||||
if notifications:
|
||||
print(f" 📝 Received {len(notifications)} streaming notification(s)")
|
||||
if prompt_response and 'result' in prompt_response:
|
||||
print(f" ✓ Got response: {prompt_response['result']}")
|
||||
elif prompt_response and 'error' in prompt_response:
|
||||
print(f" ✗ Error: {prompt_response['error']}")
|
||||
else:
|
||||
print(f" ✗ Failed to get prompt response: {prompt_response}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("All tests completed!")
|
||||
print("="*60)
|
||||
return 0
|
||||
|
||||
finally:
|
||||
client.close()
|
||||
print("\nTest complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
sys.exit(main())
|
||||
|
||||
Reference in New Issue
Block a user