633 lines
22 KiB
Python
633 lines
22 KiB
Python
from typing_extensions import Literal
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from pydantic import BaseModel, Field, root_validator
|
|
from typing import Optional, Dict, Any, Union
|
|
from dotenv import load_dotenv
|
|
import traceback
|
|
import json
|
|
import logging
|
|
import yaml
|
|
import os
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from aios.hooks.modules.llm import useCore
|
|
from aios.hooks.modules.memory import useMemoryManager
|
|
from aios.hooks.modules.storage import useStorageManager
|
|
from aios.hooks.modules.tool import useToolManager
|
|
from aios.hooks.modules.agent import useFactory
|
|
from aios.hooks.modules.scheduler import fifo_scheduler_nonblock as fifo_scheduler
|
|
from aios.hooks.modules.scheduler import rr_scheduler_nonblock as rr_scheduler
|
|
|
|
from aios.syscall.syscall import useSysCall
|
|
from aios.config.config_manager import config
|
|
|
|
from cerebrum.llm.apis import LLMQuery, LLMResponse
|
|
|
|
from cerebrum.memory.apis import MemoryQuery, MemoryResponse
|
|
|
|
from cerebrum.tool.apis import ToolQuery, ToolResponse
|
|
|
|
from cerebrum.storage.apis import StorageQuery, StorageResponse
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
# from cerebrum.llm.layer import LLMLayer as LLMConfig
|
|
# from cerebrum.memory.layer import MemoryLayer as MemoryConfig
|
|
# from cerebrum.storage.layer import StorageLayer as StorageConfig
|
|
# from cerebrum.tool.layer import ToolLayer as ToolManagerConfig
|
|
|
|
load_dotenv()
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # Modify this in production!
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Store component configurations and instances
|
|
active_components = {
|
|
"llm": None,
|
|
"storage": None,
|
|
"memory": None,
|
|
"tool": None
|
|
}
|
|
|
|
execute_request, SysCallWrapper = useSysCall()
|
|
|
|
# Configure the root logger
|
|
logging.basicConfig(
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler() # Output to console
|
|
]
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class LLMConfig(BaseModel):
|
|
llm_name: str
|
|
max_gpu_memory: dict | None = None
|
|
eval_device: str = "cuda:0"
|
|
max_new_tokens: int = 2048
|
|
log_mode: str = "INFO"
|
|
llm_backend: str = "default"
|
|
api_key: str | None = None
|
|
|
|
|
|
class StorageConfig(BaseModel):
|
|
root_dir: str = "root"
|
|
use_vector_db: bool = False
|
|
vector_db_config: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class MemoryConfig(BaseModel):
|
|
memory_limit: int = 104857600 # 100MB in bytes
|
|
eviction_k: int = 10
|
|
custom_eviction_policy: Optional[str] = None
|
|
|
|
|
|
class ToolManagerConfig(BaseModel):
|
|
allowed_tools: Optional[list[str]] = None
|
|
custom_tools: Optional[Dict[str, Any]] = None
|
|
|
|
|
|
class SchedulerConfig(BaseModel):
|
|
log_mode: str = "INFO"
|
|
max_workers: int = 64
|
|
custom_syscalls: Optional[Dict[str, Any]] = None
|
|
|
|
class AgentSubmit(BaseModel):
|
|
agent_id: str
|
|
agent_config: Dict[str, Any]
|
|
|
|
class QueryRequest(BaseModel):
|
|
agent_name: str
|
|
query_type: Literal["llm", "tool", "storage", "memory"]
|
|
query_data: LLMQuery | ToolQuery | StorageQuery | MemoryQuery
|
|
|
|
@root_validator(pre=True)
|
|
def convert_query_data(cls, values: dict[str, Any]) -> dict[str, Any]:
|
|
if 'query_type' not in values or 'query_data' not in values:
|
|
return values
|
|
|
|
query_type = values['query_type']
|
|
query_data = values['query_data']
|
|
|
|
type_mapping = {
|
|
'llm': LLMQuery,
|
|
'tool': ToolQuery,
|
|
'storage': StorageQuery,
|
|
'memory': MemoryQuery
|
|
}
|
|
|
|
if isinstance(query_data, type_mapping[query_type]):
|
|
return values
|
|
|
|
if isinstance(query_data, dict):
|
|
values['query_data'] = type_mapping[query_type](**query_data)
|
|
|
|
return values
|
|
|
|
def initialize_llm_cores(config: dict) -> Any:
|
|
"""Initialize LLM core with configuration."""
|
|
try:
|
|
llm_configs = config.get("models", [])
|
|
if not llm_configs:
|
|
raise ValueError("No LLM models configured")
|
|
|
|
log_mode = config.get("log_mode", "console")
|
|
use_context_manager = config.get("use_context_manager", False)
|
|
# model = models[0] # Currently using first model as default
|
|
|
|
llms = useCore(
|
|
llm_configs=llm_configs,
|
|
log_mode=log_mode,
|
|
use_context_manager=use_context_manager
|
|
)
|
|
|
|
if llms:
|
|
print("✅ LLM cores initialized")
|
|
return llms
|
|
raise ValueError("LLM core initialization returned None")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ Error initializing LLM core: {str(e)}")
|
|
return None
|
|
|
|
def initialize_storage_manager(storage_config: dict) -> Any:
|
|
"""Initialize storage manager with configuration."""
|
|
try:
|
|
storage_manager = useStorageManager(
|
|
root_dir=storage_config.get("root_dir", "root"),
|
|
use_vector_db=storage_config.get("use_vector_db", True),
|
|
**(storage_config.get("vector_db_config", {}) or {}),
|
|
)
|
|
print("✅ Storage manager initialized")
|
|
return storage_manager
|
|
except Exception as e:
|
|
print(f"❌ Storage setup failed: {str(e)}")
|
|
raise Exception(f"Failed to initialize storage manager: {str(e)}")
|
|
|
|
def initialize_memory_manager(memory_config: dict, storage_manager: Any) -> Any:
|
|
"""Initialize memory manager with configuration."""
|
|
try:
|
|
memory_manager = useMemoryManager(
|
|
memory_limit=memory_config.get("memory_limit", 524288),
|
|
eviction_k=memory_config.get("eviction_k", 3),
|
|
storage_manager=storage_manager,
|
|
)
|
|
print("✅ Memory manager initialized")
|
|
return memory_manager
|
|
except Exception as e:
|
|
print(f"❌ Memory setup failed: {str(e)}")
|
|
raise Exception(f"Failed to initialize memory manager: {str(e)}")
|
|
|
|
def initialize_tool_manager() -> Any:
|
|
"""Initialize tool manager."""
|
|
try:
|
|
print("\n[DEBUG] ===== Setting up Tool Manager =====")
|
|
tool_manager = useToolManager()
|
|
print("✅ Tool manager initialized")
|
|
return tool_manager
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
stack_trace = traceback.format_exc()
|
|
print(f"[ERROR] Tool Manager Setup Failed: {error_msg}")
|
|
print(f"[ERROR] Stack Trace:\n{stack_trace}")
|
|
raise Exception({
|
|
"error": "Failed to initialize tool manager",
|
|
"message": error_msg,
|
|
"traceback": stack_trace
|
|
})
|
|
|
|
def initialize_scheduler(components: dict, scheduler_config: dict) -> Any:
|
|
"""Initialize scheduler with components and configuration."""
|
|
try:
|
|
# Get use_context setting from llms config
|
|
llms_config = config.get_llms_config()
|
|
use_context = llms_config.get("use_context_manager", False)
|
|
|
|
# Check if trying to use FIFO scheduler with context management
|
|
# if use_context and isinstance(scheduler_config.get("scheduler_type"), str) and scheduler_config.get("scheduler_type").lower() == "fifo":
|
|
# raise ValueError("FIFO scheduler cannot be used with context management enabled. Please either disable context management or use Round Robin scheduler.")
|
|
|
|
# Round Robin scheduler
|
|
if use_context:
|
|
scheduler = rr_scheduler(
|
|
llm=components["llms"],
|
|
memory_manager=components["memory"],
|
|
storage_manager=components["storage"],
|
|
tool_manager=components["tool"],
|
|
log_mode=scheduler_config.get("log_mode", "console"),
|
|
get_llm_syscall=None,
|
|
get_memory_syscall=None,
|
|
get_storage_syscall=None,
|
|
get_tool_syscall=None,
|
|
)
|
|
else:
|
|
scheduler = fifo_scheduler(
|
|
llm=components["llms"],
|
|
memory_manager=components["memory"],
|
|
storage_manager=components["storage"],
|
|
tool_manager=components["tool"],
|
|
log_mode=scheduler_config.get("log_mode", "console"),
|
|
get_llm_syscall=None,
|
|
get_memory_syscall=None,
|
|
get_storage_syscall=None,
|
|
get_tool_syscall=None,
|
|
)
|
|
scheduler.start()
|
|
print("✅ Scheduler initialized and started")
|
|
return scheduler
|
|
|
|
except Exception as e:
|
|
print(f"❌ Scheduler setup failed: {str(e)}")
|
|
raise Exception(f"Failed to initialize scheduler: {str(e)}")
|
|
|
|
def initialize_agent_factory(agent_factory_config: dict) -> dict:
|
|
"""Initialize agent factory with configuration."""
|
|
try:
|
|
submit_agent, await_agent_execution = useFactory(
|
|
log_mode=agent_factory_config.get("log_mode", "console"),
|
|
max_workers=agent_factory_config.get("max_workers", 64)
|
|
)
|
|
print("✅ Agent factory initialized")
|
|
return {
|
|
"submit": submit_agent,
|
|
"await": await_agent_execution,
|
|
}
|
|
except Exception as e:
|
|
print(f"❌ Agent factory setup failed: {str(e)}")
|
|
raise Exception(f"Failed to initialize agent factory: {str(e)}")
|
|
|
|
def initialize_components() -> dict:
|
|
"""Initialize all components with proper error handling and dependencies."""
|
|
components = {
|
|
"llms": None,
|
|
"storage": None,
|
|
"memory": None,
|
|
"tool": None,
|
|
"scheduler": None,
|
|
"factory": None
|
|
}
|
|
|
|
try:
|
|
# Load configurations
|
|
llms_config = config.get_llms_config()
|
|
storage_config = config.get_storage_config()
|
|
memory_config = config.get_memory_config()
|
|
scheduler_config = config.get_scheduler_config()
|
|
agent_factory_config = config.get_agent_factory_config()
|
|
|
|
# Initialize components in order of dependency
|
|
components["llms"] = initialize_llm_cores(llms_config)
|
|
components["storage"] = initialize_storage_manager(storage_config)
|
|
|
|
if not components["storage"]:
|
|
raise Exception("Storage manager must be initialized first")
|
|
|
|
components["memory"] = initialize_memory_manager(memory_config, components["storage"])
|
|
components["tool"] = initialize_tool_manager()
|
|
|
|
# Verify required components
|
|
required_components = ["llms", "memory", "storage", "tool"]
|
|
missing_components = [
|
|
comp for comp in required_components if not components[comp]
|
|
]
|
|
|
|
if missing_components:
|
|
raise Exception(f"Missing required components: {', '.join(missing_components)}")
|
|
|
|
# Initialize scheduler and agent factory
|
|
components["scheduler"] = initialize_scheduler(components, scheduler_config)
|
|
components["factory"] = initialize_agent_factory(agent_factory_config)
|
|
|
|
print("✅ All components initialized successfully")
|
|
return components
|
|
|
|
except Exception as e:
|
|
print(f"❌ Component initialization failed: {str(e)}")
|
|
raise
|
|
|
|
# Initialize components when starting up
|
|
active_components = initialize_components()
|
|
|
|
def restart_kernel():
|
|
"""Restart kernel service and reload configuration"""
|
|
try:
|
|
# Clean up existing components
|
|
for component in ["llms", "memory", "storage", "tool"]:
|
|
if active_components[component]:
|
|
if hasattr(active_components[component], "cleanup"):
|
|
active_components[component].cleanup()
|
|
active_components[component] = None
|
|
|
|
# Initialize new components
|
|
if not initialize_components():
|
|
raise Exception("Failed to initialize components")
|
|
|
|
print("✅ All components reinitialized successfully")
|
|
|
|
except Exception as e:
|
|
print(f"❌ Error restarting kernel: {str(e)}")
|
|
print(f"Stack trace: {traceback.format_exc()}")
|
|
raise
|
|
|
|
@app.post("/core/refresh")
|
|
async def refresh_configuration():
|
|
"""Refresh all component configurations"""
|
|
try:
|
|
print("Received refresh request")
|
|
config.refresh()
|
|
print("Configuration reloaded")
|
|
restart_kernel()
|
|
print("Kernel restarted")
|
|
return {
|
|
"status": "success",
|
|
"message": "Configuration refreshed and components reinitialized"
|
|
}
|
|
except Exception as e:
|
|
print(f"Error during refresh: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to refresh configuration: {str(e)}"
|
|
)
|
|
|
|
@app.get("/core/status")
|
|
async def get_status():
|
|
"""Get the status of all core components."""
|
|
return {
|
|
component: "active" if instance else "inactive"
|
|
for component, instance in active_components.items()
|
|
}
|
|
|
|
|
|
@app.get("/core/llms/check")
|
|
async def check_llms():
|
|
"""Check if what LLM cores are initialized."""
|
|
return {
|
|
active_components["llms"]
|
|
}
|
|
|
|
# Add new constant for proc directory
|
|
PROC_DIR = Path("proc")
|
|
|
|
# Create proc directory if it doesn't exist
|
|
PROC_DIR.mkdir(exist_ok=True)
|
|
|
|
def save_agent_process_info(agent_id: str, execution_id: int, config: Dict[str, Any]):
|
|
try:
|
|
process_info = {
|
|
"agent_id": agent_id,
|
|
"execution_id": execution_id,
|
|
"status": "running",
|
|
"start_time": datetime.now().isoformat(),
|
|
"config": config,
|
|
"task": config.get("task", "No task specified")
|
|
}
|
|
|
|
proc_file = PROC_DIR / f"{execution_id}.json"
|
|
with open(proc_file, "w") as f:
|
|
json.dump(process_info, f, indent=2)
|
|
|
|
except Exception as e:
|
|
print(f"Failed to save process info: {str(e)}")
|
|
# Don't raise exception - this is not critical functionality
|
|
|
|
def update_agent_process_status(execution_id: int, status: str, result: Any = None):
|
|
try:
|
|
proc_file = PROC_DIR / f"{execution_id}.json"
|
|
if not proc_file.exists():
|
|
return
|
|
|
|
with open(proc_file) as f:
|
|
process_info = json.load(f)
|
|
|
|
process_info["status"] = status
|
|
if status == "completed":
|
|
process_info["end_time"] = datetime.now().isoformat()
|
|
process_info["result"] = result
|
|
|
|
with open(proc_file, "w") as f:
|
|
json.dump(process_info, f, indent=2)
|
|
|
|
except Exception as e:
|
|
print(f"Failed to update process status: {str(e)}")
|
|
|
|
@app.get("/agents/ps")
|
|
async def list_agent_processes():
|
|
"""List all agent processes and their status"""
|
|
try:
|
|
processes = []
|
|
for proc_file in PROC_DIR.glob("*.json"):
|
|
try:
|
|
with open(proc_file) as f:
|
|
process_info = json.load(f)
|
|
processes.append(process_info)
|
|
except Exception as e:
|
|
print(f"Failed to read process file {proc_file}: {str(e)}")
|
|
continue
|
|
|
|
# Sort by execution ID
|
|
processes.sort(key=lambda x: x["execution_id"])
|
|
|
|
return {
|
|
"status": "success",
|
|
"processes": processes
|
|
}
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to list processes: {str(e)}"
|
|
)
|
|
|
|
@app.post("/agents/submit")
|
|
async def submit_agent(config: AgentSubmit):
|
|
"""Submit an agent for execution using the agent factory."""
|
|
if "factory" not in active_components or not active_components["factory"]:
|
|
raise HTTPException(status_code=400, detail="Agent factory not initialized")
|
|
|
|
try:
|
|
print(f"\n[DEBUG] ===== Agent Submission =====")
|
|
print(f"[DEBUG] Agent ID: {config.agent_id}")
|
|
print(f"[DEBUG] Task: {config.agent_config.get('task', 'No task specified')}")
|
|
|
|
_submit_agent = active_components["factory"]["submit"]
|
|
execution_id = _submit_agent(
|
|
agent_name=config.agent_id, task_input=config.agent_config["task"]
|
|
)
|
|
|
|
save_agent_process_info(
|
|
agent_id=config.agent_id,
|
|
execution_id=execution_id,
|
|
config=config.agent_config
|
|
)
|
|
|
|
return {
|
|
"status": "success",
|
|
"execution_id": execution_id,
|
|
"message": f"Agent {config.agent_id} submitted for execution"
|
|
}
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
stack_trace = traceback.format_exc()
|
|
print(f"[ERROR] Agent submission failed: {error_msg}")
|
|
print(f"[ERROR] Stack Trace:\n{stack_trace}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "Failed to submit agent",
|
|
"message": error_msg,
|
|
"traceback": stack_trace
|
|
}
|
|
)
|
|
|
|
|
|
@app.get("/agents/{execution_id}/status")
|
|
async def get_agent_status(execution_id: int):
|
|
"""Get the status of a submitted agent."""
|
|
if "factory" not in active_components or not active_components["factory"]:
|
|
raise HTTPException(status_code=400, detail="Agent factory not initialized")
|
|
try:
|
|
print(f"\n[DEBUG] ===== Checking Agent Status =====")
|
|
print(f"[DEBUG] Execution ID: {execution_id}")
|
|
|
|
await_execution = active_components["factory"]["await"]
|
|
try:
|
|
result = await_execution(int(execution_id))
|
|
except FileNotFoundError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
if result is None:
|
|
return {
|
|
"status": "running",
|
|
"message": "Execution in progress",
|
|
"execution_id": execution_id
|
|
}
|
|
|
|
update_agent_process_status(execution_id, "completed", result)
|
|
|
|
return {
|
|
"status": "completed",
|
|
"result": result,
|
|
"execution_id": execution_id
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
stack_trace = traceback.format_exc()
|
|
print(f"[ERROR] Failed to get agent status: {error_msg}")
|
|
print(f"[ERROR] Stack Trace:\n{stack_trace}")
|
|
|
|
# Convert unhandled errors to HTTP 500
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail={
|
|
"error": "Failed to get agent status",
|
|
"message": error_msg,
|
|
"traceback": stack_trace
|
|
}
|
|
)
|
|
|
|
@app.post("/core/cleanup")
|
|
async def cleanup_components():
|
|
"""Clean up all active components."""
|
|
try:
|
|
# Clean up in reverse order of dependency
|
|
active_components["scheduler"].stop()
|
|
active_components["scheduler"] = None
|
|
|
|
for component in ["tool", "memory", "storage", "llm"]:
|
|
if active_components[component]:
|
|
if hasattr(active_components[component], "cleanup"):
|
|
active_components[component].cleanup()
|
|
active_components[component] = None
|
|
|
|
for proc_file in PROC_DIR.glob("*.json"):
|
|
try:
|
|
proc_file.unlink()
|
|
except Exception as e:
|
|
print(f"Failed to remove process file {proc_file}: {str(e)}")
|
|
|
|
return {"status": "success", "message": "All components cleaned up"}
|
|
except Exception as e:
|
|
# print(e)
|
|
print(f"Failed to cleanup components: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Failed to cleanup components: {str(e)}"
|
|
)
|
|
|
|
|
|
@app.post("/query")
|
|
async def handle_query(request: QueryRequest):
|
|
# breakpoint()
|
|
try:
|
|
if request.query_type == "llm":
|
|
query = LLMQuery(
|
|
llms=request.query_data.llms,
|
|
messages=request.query_data.messages,
|
|
tools=request.query_data.tools,
|
|
action_type=request.query_data.action_type,
|
|
message_return_type=request.query_data.message_return_type,
|
|
)
|
|
return execute_request(request.agent_name, query)
|
|
elif request.query_type == "storage":
|
|
query = StorageQuery(
|
|
params=request.query_data.params,
|
|
operation_type=request.query_data.operation_type
|
|
)
|
|
return execute_request(request.agent_name, query)
|
|
elif request.query_type == "tool":
|
|
query = ToolQuery(
|
|
params=request.query_data.params,
|
|
operation_type=request.query_data.operation_type
|
|
)
|
|
return execute_request(request.agent_name, query)
|
|
elif request.query_type == "memory":
|
|
query = MemoryQuery(
|
|
params=request.query_data.params,
|
|
operation_type=request.query_data.operation_type
|
|
)
|
|
return execute_request(request.agent_name, query)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.post("/core/config/update")
|
|
async def update_config(request: Request):
|
|
"""Update configuration and API keys"""
|
|
try:
|
|
data = await request.json()
|
|
logger.info(f"Received config update request: {data}")
|
|
|
|
provider = data.get("provider")
|
|
api_key = data.get("api_key")
|
|
|
|
if not all([provider, api_key]):
|
|
raise ValueError("Missing required fields: provider, api_key")
|
|
|
|
# Update configuration
|
|
config.config["api_keys"][provider] = api_key
|
|
config.save_config()
|
|
|
|
# Try to reinitialize LLM component
|
|
try:
|
|
await refresh_configuration()
|
|
return {"status": "success", "message": "Configuration updated and services restarted"}
|
|
except Exception as e:
|
|
# If restart fails, roll back the configuration
|
|
config.refresh() # Reload the original configuration
|
|
raise Exception(f"Failed to restart services with new configuration: {str(e)}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Config update failed: {str(e)}")
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to update configuration: {str(e)}"
|
|
) |