update litellm completion apis and fix launch of AIOS terminal (#459)

This commit is contained in:
dongyuanjushi
2025-03-05 15:30:01 -05:00
committed by GitHub
parent 965f69e559
commit e82528bfac
9 changed files with 378 additions and 113 deletions

View File

@@ -57,12 +57,6 @@ llms:
# An example command to run the vllm server is:
# vllm serve meta-llama/Llama-3.2-3B-Instruct --port 8091
# - name: "meta-llama/Llama-3.2-3B-Instruct"
# backend: "vllm"
# max_new_tokens: 1024
# temperature: 1.0
# hostname: "http://localhost:8091/v1" # Make sure to run the vllm server
# # SGLang Models
# - name: "meta-llama/Llama-3.2-3B-Instruct"
# backend: "sglang"

View File

@@ -17,15 +17,42 @@ class SimpleContextManager(BaseContextManager):
def start(self):
pass
def save_context(self, model_name,model, messages, tools, temperature, pid, time_limit):
def save_context(self,
model_name,
model,
messages,
tools,
message_return_type,
temperature,
pid,
time_limit
):
if isinstance(model, str):
response = completion(
model=model,
messages=messages,
# tools=tools,
temperature=temperature,
stream=True
)
if tools:
response = completion(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
stream=True
)
elif message_return_type == "json":
response = completion(
model=model,
messages=messages,
temperature=temperature,
# format="json"
)
else:
response = completion(
model=model,
messages=messages,
temperature=temperature
)
start_time = time.time()
completed_response = ""
@@ -42,12 +69,32 @@ class SimpleContextManager(BaseContextManager):
return completed_response, finished
elif isinstance(model, OpenAI):
completed_response = model.chat.completions.create(
if tools:
response = model.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
temperature=temperature,
stream=True
)
elif message_return_type == "json":
response = model.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
# format="json"
stream=True
)
else:
response = model.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
stream=True
)
start_time = time.time()
completed_response = ""

View File

@@ -5,7 +5,7 @@ from aios.utils.id_generator import generator_tool_call_id
from cerebrum.llm.apis import LLMQuery, LLMResponse
from litellm import completion
import json
from .utils import tool_calling_input_format, parse_json_format, parse_tool_calls, pre_process_tools
from .utils import tool_calling_input_format, parse_json_format, parse_tool_calls, slash_to_double_underscore, double_underscore_to_slash, decode_litellm_tool_calls
from typing import Dict, Optional, Any, List, Union
import time
import re
@@ -13,6 +13,7 @@ import os
from aios.config.config_manager import config
from dataclasses import dataclass
import logging
from typing import Any
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -20,8 +21,6 @@ logger = logging.getLogger(__name__)
from openai import OpenAI
from openai import OpenAI
@dataclass
class LLMConfig:
"""
@@ -188,27 +187,8 @@ class LLMAdapter:
ValueError: If required API keys are missing
"""
match config.backend:
case "hflocal":
if "HUGGING_FACE_API_KEY" not in os.environ:
raise ValueError("HUGGING_FACE_API_KEY not found in config or environment variables")
self.llms.append(HfLocalBackend(
config.name,
max_gpu_memory=config.max_gpu_memory,
hostname=config.hostname
))
# case "vllm":
# self.llms.append(VLLMLocalBackend(
# config.name,
# max_gpu_memory=config.max_gpu_memory,
# hostname=config.hostname
# ))
# case "ollama":
# self.llms.append(OllamaBackend(
# config.name,
# hostname=config.hostname
# ))
# due to the compatibility issue of the vllm and sglang in the litellm, we use the openai backend instead
case "vllm":
self.llms.append(OpenAI(
base_url=config.hostname,
@@ -222,11 +202,10 @@ class LLMAdapter:
))
case _:
# Change the backend name of google to gemini to fit the litellm
if config.backend == "google":
config.backend = "gemini"
# if config.backend == "vllm":
# config.backend = "hosted_vllm"
prefix = f"{config.backend}/"
if not config.name.startswith(prefix):
@@ -323,7 +302,7 @@ class LLMAdapter:
try:
messages = llm_syscall.query.messages
tools = llm_syscall.query.tools
ret_type = llm_syscall.query.message_return_type
message_return_type = llm_syscall.query.message_return_type
selected_llms = llm_syscall.query.llms if llm_syscall.query.llms else self.llm_configs
llm_syscall.set_status("executing")
@@ -341,9 +320,10 @@ class LLMAdapter:
# breakpoint()
# if tools:
# tools = pre_process_tools(tools)
if tools:
tools = slash_to_double_underscore(tools)
# deprecated as the tools are already supported in Litellm completion
messages = self._prepare_messages(
llm_syscall=llm_syscall,
model=model,
@@ -361,7 +341,8 @@ class LLMAdapter:
tools=tools,
temperature=temperature,
llm_syscall=llm_syscall,
api_base=api_base
api_base=api_base,
message_return_type=message_return_type
)
except Exception as e:
@@ -371,7 +352,7 @@ class LLMAdapter:
completed_response=completed_response,
finished=finished,
tools=tools,
ret_type=ret_type
message_return_type=message_return_type
)
except Exception as e:
@@ -425,9 +406,9 @@ class LLMAdapter:
}]
# if not isinstance(model, str):
if tools:
tools = pre_process_tools(tools)
messages = tool_calling_input_format(messages, tools)
# if tools:
# tools = pre_process_tools(tools)
# messages = tool_calling_input_format(messages, tools)
return messages
@@ -439,7 +420,8 @@ class LLMAdapter:
tools: Optional[List],
temperature: float,
llm_syscall,
api_base: Optional[str] = None
api_base: Optional[str] = None,
message_return_type: Optional[str] = "text"
) -> Any:
"""
Get response from the model.
@@ -473,65 +455,110 @@ class LLMAdapter:
}
```
"""
if isinstance(model, str):
if isinstance(model, str): # if the model is a string, it will use litellm completion
if self.use_context_manager:
pid = llm_syscall.get_pid()
time_limit = llm_syscall.get_time_limit()
completed_response, finished = self.context_manager.save_context(
model_name=model_name,
model=model,
messages=messages,
# tools=tools,
tools=tools,
temperature=temperature,
pid=pid,
time_limit=time_limit
)
time_limit=time_limit,
message_return_type=message_return_type
)
return completed_response, finished
else:
# breakpoint()
completed_response = completion(
model=model,
messages=messages,
# tools=tools,
temperature=temperature,
api_base=api_base
)
# if tools:
# breakpoint()
if tools:
completed_response = completion(
model=model,
messages=messages,
tools=tools,
temperature=temperature,
api_base=api_base
)
# breakpoint()
completed_response = decode_litellm_tool_calls(completed_response)
return completed_response, True
elif message_return_type == "json":
completed_response = completion(
model=model,
messages=messages,
temperature=temperature,
api_base=api_base,
format="json"
)
return completed_response.choices[0].message.content, True
else:
completed_response = completion(
model=model,
messages=messages,
temperature=temperature,
api_base=api_base
)
return completed_response.choices[0].message.content, True
return completed_response.choices[0].message.content, True
elif isinstance(model, OpenAI):
elif isinstance(model, OpenAI): # if the model is an OpenAI model, it will use OpenAI completion, as the litellm has compatibility issue with the vllm and sglang, here we use OpenAI server to launch vllm and sglang endpoints
if self.use_context_manager:
completed_response, finished = self.context_manager.save_context(
model_name=model_name,
model=model,
messages=messages,
# tools=tools,
tools=tools,
temperature=temperature,
pid=pid,
time_limit=time_limit
time_limit=time_limit,
message_return_type=message_return_type
)
return completed_response, finished
else:
# breakpoint()
# response = model.chat.completions.create(model=model_name, messages=messages, tools=tools, temperature=temperature)
completed_response = model.chat.completions.create(
model=model_name,
messages=messages,
# tools=tools,
temperature=temperature
)
if tools:
completed_response = model.chat.completions.create(
model=model_name,
messages=messages,
tools=tools,
temperature=temperature
)
completed_response = decode_litellm_tool_calls(completed_response)
return completed_response, True
elif message_return_type == "json":
completed_response = model.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature,
format="json"
)
return completed_response.choices[0].message.content, True
else:
completed_response = model.chat.completions.create(
model=model_name,
messages=messages,
temperature=temperature
)
return completed_response.choices[0].message.content, True
# breakpoint()
return completed_response.choices[0].message.content, True
def _process_response(
self,
completed_response: str,
completed_response: str | List, # either a response message of a string or a list of tool calls
finished: bool,
tools: Optional[List] = None,
ret_type: Optional[str] = None
message_return_type: Optional[str] = None
) -> LLMResponse:
"""
Process the model's response into the appropriate format.
@@ -539,7 +566,7 @@ class LLMAdapter:
Args:
response: Raw response from the model
tools: Optional list of tools
ret_type: Expected return type ("json" or None)
message_return_type: Expected return type ("json" or None)
Returns:
Formatted LLMResponse
@@ -549,7 +576,7 @@ class LLMAdapter:
# Input
response = '{"result": "4", "operation": "2 + 2"}'
tools = None
ret_type = "json"
message_return_type = "json"
# Output LLMResponse
{
@@ -581,14 +608,15 @@ class LLMAdapter:
# breakpoint()
if tools:
if tool_calls := parse_tool_calls(completed_response):
return LLMResponse(
response_message=None,
tool_calls=tool_calls,
finished=finished
)
# if tool_calls := parse_tool_calls(completed_response):
tool_calls = double_underscore_to_slash(completed_response)
return LLMResponse(
response_message=None,
tool_calls=tool_calls,
finished=finished
)
if ret_type == "json":
completed_response = parse_json_format(completed_response)
# if message_return_type == "json":
# completed_response = parse_json_format(completed_response)
return LLMResponse(response_message=completed_response, finished=finished)

View File

@@ -68,11 +68,27 @@ def parse_json_format(message: str) -> str:
def generator_tool_call_id():
return str(uuid.uuid4())
def decode_litellm_tool_calls(response):
tool_calls = response.choices[0].message.tool_calls
decoded_tool_calls = []
for tool_call in tool_calls:
decoded_tool_calls.append(
{
"name": tool_call.function.name,
"parameters": tool_call.function.arguments,
"id": tool_call.id
}
)
return decoded_tool_calls
def parse_tool_calls(message):
# add tool call id and type for models don't support tool call
# if isinstance(message, dict):
# message = [message]
tool_calls = json.loads(parse_json_format(message))
# tool_calls = json.loads(parse_json_format(message))
tool_calls = json.loads(message)
# breakpoint()
# tool_calls = json.loads(message)
if isinstance(tool_calls, dict):
@@ -82,11 +98,24 @@ def parse_tool_calls(message):
tool_call["id"] = generator_tool_call_id()
# if "function" in tool_call:
# else:
tool_call["name"] = tool_call["name"].replace("__", "/")
tool_calls = double_underscore_to_slash(tool_calls)
# tool_call["type"] = "function"
return tool_calls
def slash_to_double_underscore(tools):
for tool in tools:
tool_name = tool["function"]["name"]
if "/" in tool_name:
tool_name = "__".join(tool_name.split("/"))
tool["function"]["name"] = tool_name
return tools
def double_underscore_to_slash(tool_calls):
for tool_call in tool_calls:
tool_call["name"] = tool_call["name"].replace("__", "/")
tool_call["parameters"] = json.loads(tool_call["parameters"])
return tool_calls
def pre_process_tools(tools):
for tool in tools:
tool_name = tool["function"]["name"]

View File

@@ -97,6 +97,8 @@ class RRScheduler(BaseScheduler):
response = executor(syscall)
breakpoint()
syscall.set_response(response)
if response.finished:
@@ -105,6 +107,8 @@ class RRScheduler(BaseScheduler):
else:
syscall.set_status("suspending")
log_status = "suspending"
breakpoint()
syscall.set_end_time(time.time())

View File

@@ -279,6 +279,7 @@ class SyscallExecutor:
elif query.action_type == "tool_use":
llm_response = self.execute_llm_syscall(agent_name, query)["response"]
# breakpoint()
tool_response = self.execute_tool_syscall(agent_name, llm_response.tool_calls)
# breakpoint()
return tool_response

View File

@@ -19,6 +19,7 @@ class ToolManager:
def address_request(self, syscall) -> None:
tool_calls = syscall.tool_calls
# breakpoint()
try:
for tool_call in tool_calls:
tool_org_and_name, tool_params = (

View File

@@ -218,29 +218,30 @@ def initialize_scheduler(components: dict, scheduler_config: dict) -> Any:
# raise ValueError("FIFO scheduler cannot be used with context management enabled. Please either disable context management or use Round Robin scheduler.")
# Round Robin scheduler
scheduler = fifo_scheduler(
llm=components["llms"],
memory_manager=components["memory"],
storage_manager=components["storage"],
tool_manager=components["tool"],
log_mode=scheduler_config.get("log_mode", "console"),
get_llm_syscall=None,
get_memory_syscall=None,
get_storage_syscall=None,
get_tool_syscall=None,
)
# scheduler = rr_scheduler(
# llm=components["llms"],
# memory_manager=components["memory"],
# storage_manager=components["storage"],
# tool_manager=components["tool"],
# log_mode=scheduler_config.get("log_mode", "console"),
# get_llm_syscall=None,
# get_memory_syscall=None,
# get_storage_syscall=None,
# get_tool_syscall=None,
# )
if use_context:
scheduler = rr_scheduler(
llm=components["llms"],
memory_manager=components["memory"],
storage_manager=components["storage"],
tool_manager=components["tool"],
log_mode=scheduler_config.get("log_mode", "console"),
get_llm_syscall=None,
get_memory_syscall=None,
get_storage_syscall=None,
get_tool_syscall=None,
)
else:
scheduler = fifo_scheduler(
llm=components["llms"],
memory_manager=components["memory"],
storage_manager=components["storage"],
tool_manager=components["tool"],
log_mode=scheduler_config.get("log_mode", "console"),
get_llm_syscall=None,
get_memory_syscall=None,
get_storage_syscall=None,
get_tool_syscall=None,
)
scheduler.start()
print("✅ Scheduler initialized and started")
return scheduler

160
runtime/run_terminal.py Normal file
View File

@@ -0,0 +1,160 @@
from prompt_toolkit import PromptSession
from prompt_toolkit.styles import Style
from rich.console import Console
from rich.table import Table
from rich.syntax import Syntax
from rich.panel import Panel
from rich.text import Text
import os
import shutil
from datetime import datetime
import sys
import requests
import json
from typing_extensions import Literal
from pydantic import BaseModel
from typing import Optional, Dict, Any, List
from enum import Enum
from cerebrum.config.config_manager import config
from list_agents import get_offline_agents, get_online_agents
from cerebrum.storage.apis import StorageQuery, StorageResponse, mount, retrieve_file, create_file, create_dir, write_file, rollback_file, share_file
from cerebrum.llm.apis import LLMQuery, LLMResponse, llm_call_tool, llm_chat, llm_operate_file
class AIOSTerminal:
def __init__(self):
self.console = Console()
self.style = Style.from_dict({
'prompt': '#00ff00 bold',
'path': '#0000ff bold',
'arrow': '#ff0000',
})
self.session = PromptSession(style=self.style)
self.current_dir = os.getcwd()
# self.storage_client = StorageClient()
def get_prompt(self, extra_str = None):
username = os.getenv('USER', 'user')
path = os.path.basename(self.current_dir)
if extra_str:
return [
('class:prompt', f'🚀 {username}'),
('class:arrow', ''),
('class:path', f'{path}'),
('class:arrow', ''),
('class:prompt', extra_str)
]
else:
return [
('class:prompt', f'🚀 {username}'),
('class:arrow', ''),
('class:path', f'{path}'),
('class:arrow', '')
]
def display_help(self):
help_table = Table(show_header=True, header_style="bold magenta")
help_table.add_column("Command", style="cyan")
help_table.add_column("Description", style="green")
# Add command descriptions
help_table.add_row("help", "Show this help message")
help_table.add_row("exit", "Exit the terminal")
# help_table.add_row("list agents --offline", "List all available offline agents")
help_table.add_row("list agents --online", "List all available agents on the agenthub")
help_table.add_row("<natural language>", "Execute semantic file operations using natural language")
self.console.print(Panel(help_table, title="Available Commands", border_style="blue"))
def handle_list_agents(self, args: str):
"""Handle the 'list agents' command with different parameters.
Args:
args: The arguments passed to the list agents command (--offline or --online)
"""
if '--offline' in args:
agents = get_offline_agents()
self.console.print("\nAgents that have been installed:")
for agent in agents:
self.console.print(f"- {agent}")
elif '--online' in args:
agents = get_online_agents()
self.console.print("\nAvailable agents on the agenthub:")
for agent in agents:
self.console.print(f"- {agent}")
else:
self.console.print("[red]Invalid parameter. Use --offline or --online[/red]")
self.console.print("Example: list agents --offline")
def run(self):
welcome_msg = Text("Welcome to AIOS Terminal! Type 'help' for available commands.", style="bold cyan")
self.console.print(Panel(welcome_msg, border_style="green"))
root_dir = self.current_dir + "/root"
while True:
mount_choice = self.session.prompt(self.get_prompt(extra_str=f"Do you want to mount AIOS Semantic File System to a specific directory you want? By default, it will be mounted at {root_dir}. [y/n] "))
if mount_choice == 'y':
root_dir = self.session.prompt(self.get_prompt(extra_str=f"Enter the absolute path of the directory to mount: "))
break
elif mount_choice == 'n':
# self.console.print("[red]Storage not mounted. Some features may be unavailable.[/red]")
break
else:
self.console.print("[red]Invalid input. Please enter 'y' or 'n'.[/red]")
# breakpoint()
mount(agent_name="terminal", root_dir=root_dir)
mounted_message = Text(f"The semantic file system is mounted at {root_dir}", style="bold cyan")
self.console.print(mounted_message)
while True:
try:
command = self.session.prompt(self.get_prompt())
if command == 'exit':
self.console.print("[yellow]Goodbye! 👋[/yellow]")
break
if command == 'help':
self.display_help()
continue
if command.startswith('list agents'):
args = command[len('list agents'):].strip()
self.handle_list_agents(args)
continue
command_response = llm_operate_file(
agent_name="terminal", messages=[{"role": "user", "content": command}], tools=[], base_url=config.get('kernel', 'base_url')
)
command_output = Text(command_response, style="bold green")
self.console.print(command_output)
# response = self._post_semantic_command(command)
# if cmd in self.commands:
# self.commands[cmd](*args)
# else:
# self.console.print(f"[red]Unknown command: {cmd}[/red]")
except KeyboardInterrupt:
continue
except EOFError:
break
except Exception as e:
self.console.print(f"[red]Error: {str(e)}[/red]")
if __name__ == "__main__":
terminal = AIOSTerminal()
terminal.run()