From 1ce23679d920704ebd16448be608036836572319 Mon Sep 17 00:00:00 2001 From: dongyuanjushi Date: Sun, 27 Apr 2025 17:40:11 -0400 Subject: [PATCH] refactor llm routing and update tests for routers (#491) * update * update * update launch * add more error messages for adapter * fix smart routing * update config * update * update test * update test * update test * update workflow * update * update * update * update * update * update * update * update --- .github/workflows/cancel-workflow.yml | 21 + .github/workflows/test-ollama.yml | 90 ++ .github/workflows/test_ollama.yml | 156 --- aios/config/config.yaml | 20 +- aios/config/config.yaml.example | 44 +- aios/config/config_manager.py | 9 + aios/llm_core/adapter.py | 1079 +++++++++++++------ aios/llm_core/llm_cost_map.json | 11 - aios/llm_core/routing.py | 740 +++++-------- aios/llm_core/utils.py | 16 +- aios/scheduler/fifo_scheduler.py | 132 +-- aios/scheduler/rr_scheduler.py | 99 +- aios/syscall/syscall.py | 8 +- install/install.sh | 2 - requirements.txt | 3 +- runtime/launch.py | 27 +- tests/modules/llm/ollama/test_single.py | 117 ++ tests/modules/llm/openai/test_concurrent.py | 161 +++ tests/modules/llm/openai/test_single.py | 117 ++ tests/modules/llm/test_ollama.py | 62 -- tests/modules/llm/test_openai.py | 112 -- 21 files changed, 1716 insertions(+), 1310 deletions(-) create mode 100644 .github/workflows/cancel-workflow.yml create mode 100644 .github/workflows/test-ollama.yml delete mode 100644 .github/workflows/test_ollama.yml delete mode 100644 aios/llm_core/llm_cost_map.json create mode 100644 tests/modules/llm/ollama/test_single.py create mode 100644 tests/modules/llm/openai/test_concurrent.py create mode 100644 tests/modules/llm/openai/test_single.py delete mode 100644 tests/modules/llm/test_ollama.py delete mode 100644 tests/modules/llm/test_openai.py diff --git a/.github/workflows/cancel-workflow.yml b/.github/workflows/cancel-workflow.yml new file mode 100644 index 0000000..ae56e25 --- /dev/null +++ b/.github/workflows/cancel-workflow.yml @@ -0,0 +1,21 @@ +name: Cancel PR Workflows on Merge + +on: + pull_request_target: + types: [closed] # runs when the PR is closed (merged or not) + +permissions: + actions: write # lets the workflow cancel other runs + +jobs: + cancel: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + steps: + - name: Cancel previous PR runs + uses: styfle/cancel-workflow-action@v1 + with: + workflow_id: all # cancel every workflow in this repo + access_token: ${{ github.token }} # default token is fine + pr_number: ${{ github.event.pull_request.number }} + ignore_sha: true # required for pull_request.closed events diff --git a/.github/workflows/test-ollama.yml b/.github/workflows/test-ollama.yml new file mode 100644 index 0000000..d794903 --- /dev/null +++ b/.github/workflows/test-ollama.yml @@ -0,0 +1,90 @@ +# .github/workflows/test-ollama.yml +name: Test Ollama + +on: + pull_request: + branches: [ "main" ] + push: + branches: [ "main" ] + +permissions: + contents: read + actions: write + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + build: + runs-on: ubuntu-latest + + env: + OLLAMA_MAX_WAIT: 60 + KERNEL_MAX_WAIT: 60 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + + - name: Clone Cerebrum + uses: sudosubin/git-clone-action@v1.0.1 + with: + repository: agiresearch/Cerebrum + path: Cerebrum + + - name: Install Cerebrum (editable) + run: python -m pip install -e Cerebrum/ + + - name: Install Ollama + run: curl -fsSL https://ollama.com/install.sh | sh + + - name: Pull Ollama model + run: ollama pull qwen2.5:7b + + - name: Start Ollama + run: | + ollama serve > ollama-llm.log 2>&1 & + for ((i=1;i<=${OLLAMA_MAX_WAIT};i++)); do + curl -s http://localhost:11434/api/version && echo "Ollama ready" && break + echo "Waiting for Ollama… ($i/${OLLAMA_MAX_WAIT})"; sleep 1 + done + curl -s http://localhost:11434/api/version > /dev/null \ + || { echo "❌ Ollama failed to start"; exit 1; } + + - name: Start AIOS kernel + run: | + bash runtime/launch_kernel.sh > kernel.log 2>&1 & + + - name: Run tests + run: | + mkdir -p test_results + mapfile -t TESTS < <(find . -type f -path "*/llm/ollama/*" -name "*.py") + if [ "${#TESTS[@]}" -eq 0 ]; then + echo "⚠️ No llm/ollama tests found – skipping." + exit 0 + fi + for t in "${TESTS[@]}"; do + echo "▶️ Running $t" + python "$t" | tee -a test_results/ollama_tests.log + echo "----------------------------------------" + done + + - name: Upload logs + if: always() + uses: actions/upload-artifact@v4 + with: + name: logs + path: | + ollama-llm.log + kernel.log + test_results/ diff --git a/.github/workflows/test_ollama.yml b/.github/workflows/test_ollama.yml deleted file mode 100644 index 4b451fd..0000000 --- a/.github/workflows/test_ollama.yml +++ /dev/null @@ -1,156 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - -name: AIOS Application - -on: - push: - branches: [ "main" ] - pull_request: - branches: [ "main" ] - -permissions: - contents: read - -jobs: - build: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: "3.10" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - - name: Git Clone Action - # You may pin to the exact commit or the version. - # uses: sudosubin/git-clone-action@8a93ce24d47782e30077508cccacf8a05a891bae - uses: sudosubin/git-clone-action@v1.0.1 - with: - # Repository owner and name. Ex: sudosubin/git-clone-action - repository: agiresearch/Cerebrum - path: Cerebrum - - - name: Install cerebrum special edition - run: | - python -m pip install -e Cerebrum/ - - - name: Download and install Ollama - run: | - curl -fsSL https://ollama.com/install.sh | sh - - - name: Pull Ollama models - run: | - ollama pull qwen2.5:7b - - - name: Run Ollama serve - run: | - ollama serve 2>&1 | tee ollama-llm.log & - # Wait for ollama to start - for i in {1..30}; do - if curl -s http://localhost:11434/api/version > /dev/null; then - echo "Ollama is running" - break - fi - echo "Waiting for ollama to start... ($i/30)" - sleep 1 - done - # Verify ollama is running - curl -s http://localhost:11434/api/version || (echo "Failed to start ollama" && exit 1) - - - name: Run AIOS kernel in background - run: | - bash runtime/launch_kernel.sh &>logs & - KERNEL_PID=$! - - # Set maximum wait time (60 seconds) - max_wait=60 - start_time=$SECONDS - - # Dynamically check if the process is running until it succeeds or times out - while true; do - if ! ps -p $KERNEL_PID > /dev/null; then - echo "Kernel process died. Checking logs:" - cat logs - exit 1 - fi - - if nc -z localhost 8000; then - if curl -s http://localhost:8000/health > /dev/null; then - echo "Kernel successfully started and healthy" - break - fi - fi - - # Check if timed out - elapsed=$((SECONDS - start_time)) - if [ $elapsed -ge $max_wait ]; then - echo "Timeout after ${max_wait} seconds. Kernel failed to start properly." - cat logs - exit 1 - fi - - echo "Waiting for kernel to start... (${elapsed}s elapsed)" - sleep 1 - done - - - name: Run all tests - run: | - # Create test results directory - mkdir -p test_results - - # Function to check if a path contains agent or llm - contains_agent_or_llm() { - local dir_path=$(dirname "$1") - if [[ "$dir_path" == *"agent"* || "$dir_path" == *"llm"* ]]; then - return 0 # True in bash - else - return 1 # False in bash - fi - } - - # Process test files - find tests -type f -name "*.py" | while read -r test_file; do - if contains_agent_or_llm "$test_file"; then - # For agent or llm directories, only run ollama tests - if [[ "$test_file" == *"ollama"* ]]; then - echo "Running Ollama test in agent/llm directory: $test_file" - python $test_file | tee -a ollama_tests.log - echo "----------------------------------------" - fi - else - # For other directories, run all tests - echo "Running test: $test_file" - python $test_file | tee -a all_tests.log - echo "----------------------------------------" - fi - done - - - name: Upload a Build Artifact - if: always() # Upload logs even if job fails - uses: actions/upload-artifact@v4.4.3 - with: - name: logs - path: | - logs - agent.log - - - name: Collect debug information - if: failure() - run: | - echo "=== Kernel Logs ===" - cat logs - echo "=== Environment Variables ===" - env | grep -i api_key || true - echo "=== Process Status ===" - ps aux | grep kernel \ No newline at end of file diff --git a/aios/config/config.yaml b/aios/config/config.yaml index 612f77c..5e980bc 100644 --- a/aios/config/config.yaml +++ b/aios/config/config.yaml @@ -18,9 +18,12 @@ llms: # - name: "gpt-4o-mini" # backend: "openai" + # - name: "gpt-4o" + # backend: "openai" + # Google Models - # - name: "gemini-1.5-flash" + # - name: "gemini-2.0-flash" # backend: "google" @@ -47,26 +50,27 @@ llms: # backend: "vllm" # hostname: "http://localhost:8091" + router: + strategy: "sequential" + bootstrap_url: "https://drive.google.com/file/d/1SF7MAvtnsED7KMeMdW3JDIWYNGPwIwL7/view" - - - log_mode: "console" + log_mode: "console" # choose from [console, file] use_context_manager: false memory: - log_mode: "console" + log_mode: "console" # choose from [console, file] storage: root_dir: "root" use_vector_db: true scheduler: - log_mode: "console" + log_mode: "console" # choose from [console, file] agent_factory: - log_mode: "console" + log_mode: "console" # choose from [console, file] max_workers: 64 server: host: "localhost" - port: 8000 \ No newline at end of file + port: 8000 diff --git a/aios/config/config.yaml.example b/aios/config/config.yaml.example index 40396ec..d2835ee 100644 --- a/aios/config/config.yaml.example +++ b/aios/config/config.yaml.example @@ -6,7 +6,6 @@ api_keys: gemini: "" # Google Gemini API key groq: "" # Groq API key anthropic: "" # Anthropic API key - novita: "" # Novita AI API key huggingface: auth_token: "" # Your HuggingFace auth token for authorized models cache_dir: "" # Your cache directory for saving huggingface models @@ -16,15 +15,15 @@ llms: models: # OpenAI Models - - name: "gpt-4o" - backend: "openai" + # - name: "gpt-4o" + # backend: "openai" - - name: "gpt-4o-mini" - backend: "openai" + # - name: "gpt-4o-mini" + # backend: "openai" # Google Models - - name: "gemini-2.0-flash" - backend: "google" + # - name: "gemini-2.0-flash" + # backend: "google" # Anthropic Models @@ -32,9 +31,9 @@ llms: # backend: "anthropic" # Ollama Models - - name: "qwen2.5:72b" + - name: "qwen2.5:7b" backend: "ollama" - hostname: "http://localhost:8091" # Make sure to run ollama server + hostname: "http://localhost:11434" # Make sure to run ollama server # HuggingFace Models # - name: "meta-llama/Llama-3.1-8B-Instruct" @@ -43,7 +42,7 @@ llms: # eval_device: "cuda:0" # Device for model evaluation # vLLM Models - # To use vllm as backend, you need to install vllm and run the vllm server: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html + # To use vllm as backend, you need to install vllm and run the vllm server https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html # An example command to run the vllm server is: # vllm serve meta-llama/Llama-3.2-3B-Instruct --port 8091 # - name: "meta-llama/Llama-3.1-8B-Instruct" @@ -51,34 +50,33 @@ llms: # hostname: "http://localhost:8091" # SGLang Models - # To use sglang as backend, you need to install sglang and run the sglang server: https://docs.sglang.ai/backend/openai_api_completions.html - # python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-72B-Instruct --grammar-backend outlines --tool-call-parser qwen25 --host 127.0.0.1 --port 30001 --tp 4 --disable-custom-all-reduce + # python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 127.0.0.1 --port 30001 - - name: "Qwen/Qwen2.5-72B-Instruct" - backend: "sglang" - hostname: "http://localhost:30001/v1" + # - name: "Qwen/Qwen2.5-VL-72B-Instruct" + # backend: "sglang" + # hostname: "http://localhost:30001/v1" - # Novita Models - - name: "meta-llama/llama-4-scout-17b-16e-instruct" - backend: "novita" + router: + strategy: "smart" # choose from [sequential, smart] + bootstrap_url: "https://drive.google.com/file/d/1SF7MAvtnsED7KMeMdW3JDIWYNGPwIwL7/view" - log_mode: "console" + log_mode: "console" # choose from [console, file] use_context_manager: false memory: - log_mode: "console" + log_mode: "console" # choose from [console, file] storage: root_dir: "root" use_vector_db: true scheduler: - log_mode: "console" + log_mode: "console" # choose from [console, file] agent_factory: - log_mode: "console" + log_mode: "console" # choose from [console, file] max_workers: 64 server: - host: "localhost" + host: "0.0.0.0" port: 8000 diff --git a/aios/config/config_manager.py b/aios/config/config_manager.py index 0c3d276..ddb2b03 100644 --- a/aios/config/config_manager.py +++ b/aios/config/config_manager.py @@ -162,6 +162,15 @@ class ConfigManager: """ return self.config.get('llms', {}) + def get_router_config(self) -> dict: + """ + Retrieves the router configuration settings. + + Returns: + dict: Dictionary containing router configurations + """ + return self.config.get("llms", {}).get("router", {}) + def get_storage_config(self) -> dict: """ Retrieves the storage configuration settings. diff --git a/aios/llm_core/adapter.py b/aios/llm_core/adapter.py index eb93b25..293099f 100644 --- a/aios/llm_core/adapter.py +++ b/aios/llm_core/adapter.py @@ -14,12 +14,18 @@ from aios.config.config_manager import config from dataclasses import dataclass import logging from typing import Any +import concurrent.futures +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +import traceback +import litellm +from .utils import check_availability_for_selected_llm_lists # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) -from openai import OpenAI +from openai import OpenAI, APIError, RateLimitError, AuthenticationError, BadRequestError, APITimeoutError, APIConnectionError @dataclass class LLMConfig: @@ -32,6 +38,7 @@ class LLMConfig: max_gpu_memory (Optional[str]): Maximum GPU memory allocation eval_device (Optional[str]): Device for model evaluation hostname (Optional[str]): Hostname for the LLM service + api_key (Optional[str]): API key for the LLM Example: ```python @@ -46,6 +53,7 @@ class LLMConfig: max_gpu_memory: Optional[str] = None eval_device: Optional[str] = None hostname: Optional[str] = None + api_key: Optional[str] = None class LLMAdapter: """ @@ -106,10 +114,22 @@ class LLMAdapter: self._setup_api_keys() self._initialize_llms() + routing_strategy = config.get_router_config().get("strategy", RouterStrategy.Sequential) + + # breakpoint() + if routing_strategy == RouterStrategy.Sequential: - self.strategy = SequentialRouting(self.llm_configs) + self.router = SequentialRouting( + llm_configs=self.llm_configs + ) elif routing_strategy == RouterStrategy.Smart: - self.strategy = SmartRouting(self.llm_configs) + self.router = SmartRouting( + llm_configs=self.llm_configs, + bootstrap_url=config.get_router_config().get("bootstrap_url", None) + ) + + else: + raise ValueError(f"Invalid routing strategy: {routing_strategy}") def _setup_api_keys(self) -> None: """ @@ -148,214 +168,457 @@ class LLMAdapter: def _initialize_llms(self) -> None: """Initialize LLM backends based on configurations.""" - for config in self.llm_configs: - llm_name = config.get("name") - llm_backend = config.get("backend") - max_gpu_memory = config.get("max_gpu_memory") - eval_device = config.get("eval_device") - hostname = config.get("hostname") - - if not llm_backend: - continue - + initialized_configs = [] + for config_dict in self.llm_configs: try: - self._initialize_single_llm( - LLMConfig( - name=llm_name, - backend=llm_backend, - max_gpu_memory=max_gpu_memory, - eval_device=eval_device, - hostname=hostname - ) + # Validate config dict structure if necessary before creating LLMConfig + llm_config = LLMConfig( + name=config_dict.get("name"), + backend=config_dict.get("backend"), + max_gpu_memory=config_dict.get("max_gpu_memory"), + eval_device=config_dict.get("eval_device"), + hostname=config_dict.get("hostname"), + api_key=config_dict.get("api_key") ) + if not llm_config.name or not llm_config.backend: + logger.warning(f"Skipping incomplete LLM config: {config_dict}") + continue + + initialized_model = self._initialize_single_llm(llm_config) + if initialized_model: + self.llms.append(initialized_model) + initialized_configs.append(llm_config) + logger.info(f"Successfully initialized LLM: {llm_config.name} ({llm_config.backend})") + else: + # _initialize_single_llm logs the error, just skip adding + pass + except Exception as e: - logger.error(f"Failed to initialize LLM {llm_name}: {e}") + logger.error(f"Failed to process LLM configuration {config_dict.get('name', 'UNKNOWN')}: {e}", exc_info=True) - for llm_config in self.llm_configs: - logger.info(f"Initialized LLM: {llm_config}") + # Update self.llm_configs to only contain successfully initialized ones + self.llm_configs = initialized_configs + + self.available_llm_names = [llm_config.name for llm_config in self.llm_configs] + + if not self.llms: + logger.error("No LLMs were successfully initialized. LLMAdapter may not function.") + else: + logger.info(f"Total successfully initialized LLMs: {len(self.llms)}") - def _initialize_single_llm(self, config: LLMConfig) -> None: + def _initialize_single_llm(self, config: LLMConfig) -> Optional[Union[str, HfLocalBackend, OpenAI]]: """ - Initialize a single LLM based on its configuration. + Initialize a single LLM based on its configuration. Logs errors and returns None on failure. Args: config: Configuration for the LLM - Example: - ```python - config = LLMConfig( - name="mistral-7b", - backend="hflocal", - max_gpu_memory="12GiB" - ) - adapter._initialize_single_llm(config) - ``` - - Raises: - ValueError: If required API keys are missing + Returns: + Initialized model instance or identifier, or None if initialization fails. """ - match config.backend: - case "huggingface": - self.llms.append(HfLocalBackend( - model_name=config.name, - max_gpu_memory=config.max_gpu_memory, - eval_device=config.eval_device - )) + try: + match config.backend: + case "huggingface" | "hflocal": # Allow alias + # HF specific API key check if needed, though HF_TOKEN env var is typical + if not os.getenv("HF_TOKEN"): + logger.warning(f"HF_TOKEN environment variable not set. May impact private model access for {config.name}") + # Add try-except around HfLocalBackend initialization if it can raise specific errors + return HfLocalBackend( + model_name=config.name, + max_gpu_memory=config.max_gpu_memory, + eval_device=config.eval_device + ) - # due to the compatibility issue of the vllm and sglang in the litellm, we use the openai backend instead - case "vllm": - self.llms.append(OpenAI( - base_url=config.hostname, - api_key="sk-1234" - )) - - case "sglang": - self.llms.append(OpenAI( - base_url=config.hostname, - api_key="sk-1234" - )) - - case _: - - # Change the backend name of google to gemini to fit the litellm - if config.backend == "google": - config.backend = "gemini" - - prefix = f"{config.backend}/" - if not config.name.startswith(prefix): - self.llms.append(prefix + config.name) + case "vllm" | "sglang": + # These use OpenAI compatible endpoints + if not config.hostname: + logger.error(f"Hostname (base_url) required for {config.backend} backend ({config.name}) but not provided.") + return None + # OpenAI client init can fail if URL is malformed, though less common. + return OpenAI( + base_url=config.hostname, + api_key=config.api_key or "sk-placeholder" # Use provided key or a placeholder + ) + + case _: + # Handle LiteLLM supported backends + backend_name = config.backend + if backend_name == "google": + backend_name = "gemini" # LiteLLM uses 'gemini' - def _handle_completion_error(self, error: Exception) -> LLMResponse: + # Check for necessary API keys via environment variables for common LiteLLM backends + key_var_map = { + "openai": "OPENAI_API_KEY", "gemini": "GEMINI_API_KEY", + "anthropic": "ANTHROPIC_API_KEY", "groq": "GROQ_API_KEY", + # Add others as needed + } + if backend_name in key_var_map and not os.getenv(key_var_map[backend_name]): + # Allow explicit key from config as fallback + if config.api_key: + os.environ[key_var_map[backend_name]] = config.api_key + logger.info(f"Using API key from config for {backend_name}") + else: + logger.warning(f"Required API key environment variable '{key_var_map[backend_name]}' not set for {config.name} ({backend_name}).") + # Depending on strictness, you might return None here or let LiteLLM handle it. + # Let LiteLLM handle it for now, it will raise an error later if needed. + + # Construct the model string for LiteLLM + prefix = f"{backend_name}/" + if config.name.startswith(prefix): + return config.name # Already prefixed + else: + return prefix + config.name + + except Exception as e: + logger.error(f"Error initializing LLM {config.name} ({config.backend}): {e}", exc_info=True) + return None + + def _handle_completion_error(self, error: Exception, model_name: Optional[str] = "Unknown") -> LLMResponse: """ - Handle errors that occur during LLM completion. + Handle errors that occur during LLM completion, mapping them to LLMResponse. Args: error: The exception that occurred + model_name: Name of the model that caused the error Returns: - LLMResponse with appropriate error message - - Example: - ```python - # Input error - error = ValueError("Invalid API key provided: sk-...") - - # Output LLMResponse - { - "response_message": "Error: Invalid or missing API key for the selected model.", - "error": "Invalid API key provided: sk-****..", - "finished": True, - "status_code": 402 - } - ``` + LLMResponse with appropriate error message and status code """ error_msg = str(error) + status_code = 500 # Default internal server error + user_message = f"LLM Error with model '{model_name}': An unexpected error occurred." - # Mask API key in error message - if "API key provided:" in error_msg: - key_start = error_msg.find("API key provided:") + len("API key provided: ") - key_end = error_msg.find(".", key_start) - if key_end == -1: - key_end = error_msg.find(" ", key_start) - if key_end != -1: - api_key = error_msg[key_start:key_end] - masked_key = f"{api_key[:2]}****{api_key[-2:]}" if len(api_key) > 4 else "****" - error_msg = error_msg[:key_start] + masked_key + error_msg[key_end:] + # --- Specific OpenAI/LiteLLM Error Handling --- + # Note: LiteLLM often wraps underlying provider errors. + if isinstance(error, AuthenticationError) or "invalid api key" in error_msg.lower() or "authentication" in error_msg.lower(): + status_code = 401 # Unauthorized + user_message = f"Authentication Error with model '{model_name}': Invalid or missing API key." + # Mask API key if present in the raw error message + try: + # Simple regex to find potential keys (sk-..., gsk_..., etc.) + masked_error_msg = re.sub(r"([a-zA-Z0-9_]{2})[a-zA-Z0-9_.-]+([a-zA-Z0-9_]{4})", r"\1****\2", error_msg) + error_msg = masked_error_msg + except Exception: + pass # Ignore masking errors + elif isinstance(error, RateLimitError) or "rate limit" in error_msg.lower(): + status_code = 429 # Too Many Requests + user_message = f"Rate Limit Exceeded for model '{model_name}'. Please try again later." + elif isinstance(error, BadRequestError) or "bad request" in error_msg.lower() or "invalid parameter" in error_msg.lower() : + status_code = 400 # Bad Request + user_message = f"Invalid Request for model '{model_name}'. Check your input parameters (e.g., messages, tools format)." + elif isinstance(error, APITimeoutError) or "timeout" in error_msg.lower(): + status_code = 408 # Request Timeout + user_message = f"Request Timed Out for model '{model_name}'. The operation took too long." + elif isinstance(error, APIConnectionError) or "connection error" in error_msg.lower(): + status_code = 503 # Service Unavailable + user_message = f"Connection Error with model '{model_name}'. Could not reach the LLM service." + elif isinstance(error, APIError): # General API error from provider + status_code = 502 # Bad Gateway (if error seems provider-side) + user_message = f"API Error from model '{model_name}'. The provider reported an issue." + elif isinstance(error, litellm.exceptions.NotFound): + status_code = 404 # Not Found + user_message = f"Model Not Found: The specified model '{model_name}' could not be found or accessed." + # --- Add more specific exception checks as needed --- + logger.error(f"LLM completion error for model '{model_name}': {error_msg}", exc_info=True) # Log full traceback - if "Invalid API key" in error_msg or "API key not found" in error_msg: - return LLMResponse( - response_message="Error: Invalid or missing API key for the selected model.", - error=error_msg, - finished=True, - status_code=402 - ) - return LLMResponse( - response_message=f"LLM Error: {error_msg}", - error=error_msg, + response_message=None, + error=user_message, finished=True, - status_code=500 + status_code=status_code ) - - def execute_llm_syscall( + + def execute_llm_syscalls( self, - llm_syscall, - temperature: float = 0.0 - ) -> LLMResponse: + llm_syscalls: List[LLMQuery], + ) -> None: """ - Address request sent from the agent. + Execute a batch of LLM syscalls using the configured routing strategy and parallel execution. Args: - llm_syscall: LLMSyscall object containing the request - temperature: Parameter to control output randomness + llm_syscalls: List of LLMQuery objects Returns: - LLMResponse containing the model's response or error information - - Example: - ```python - # Input - syscall = LLMSyscall( - query=LLMQuery( - messages=[{"role": "user", "content": "Hello!"}], - tools=[{"name": "calculator", "description": "..."}], - message_return_type="json" - ) - ) - - # Output LLMResponse - { - "response_message": "Hello! How can I help you today?", - "finished": True, - "tool_calls": None, - "error": None, - "status_code": 200 - } - ``` + List of LLMResponse objects corresponding to each input syscall. + If an error occurs *before* dispatching a syscall, its corresponding response will be an error LLMResponse. """ - try: - messages = llm_syscall.query.messages - tools = llm_syscall.query.tools - message_return_type = llm_syscall.query.message_return_type - selected_llms = llm_syscall.query.llms if llm_syscall.query.llms else self.llm_configs - response_format = llm_syscall.query.response_format - temperature = llm_syscall.query.temperature - max_tokens = llm_syscall.query.max_new_tokens + num_syscalls = len(llm_syscalls) + if num_syscalls == 0: + return [] + + start_exec_time = time.time() + logger.info(f"Starting batch execution for {num_syscalls} LLM syscalls...") + + if not self.llms: + logger.error("Cannot execute syscalls: No LLMs were successfully initialized.") + error_response = LLMResponse( + response_message=None, + error="System Error: No LLM backends available.", + finished=True, + status_code=503 # Service Unavailable + ) + for llm_syscall in llm_syscalls: + llm_syscall.set_status("done") + llm_syscall.set_response(error_response) + llm_syscall.set_end_time(time.time()) + llm_syscall.event.set() + return + + + selected_llm_lists = [syscall.query.llms for syscall in llm_syscalls] + + for i, selected_llm_list in enumerate(selected_llm_lists): + if not selected_llm_list: + selected_llm_lists[i] = [{"name": llm_config.name, "backend": llm_config.backend} for llm_config in self.llm_configs] + + selected_llm_lists_availability = check_availability_for_selected_llm_lists(self.available_llm_names, selected_llm_lists) + + executable_llm_syscalls = [] + available_selected_llm_lists = [] + + # breakpoint() + + for i, selected_llm_list_availability in enumerate(selected_llm_lists_availability): + if not selected_llm_list_availability: + logger.error(f"Selected LLMs are not available for syscall at index {i}") + llm_syscall = llm_syscalls[i] + llm_syscall.set_status("done") + llm_syscall.set_response(LLMResponse(response_message=None, error="Selected LLMs are not all available. Please check the available LLMs.", finished=True, status_code=500)) + llm_syscall.set_end_time(time.time()) + llm_syscall.event.set() + else: + executable_llm_syscalls.append(llm_syscalls[i]) + available_selected_llm_lists.append(selected_llm_lists[i]) + + queries = [syscall.query.messages for syscall in executable_llm_syscalls] + + model_idxs = self.router.get_model_idxs(available_selected_llm_lists, queries) + + grouped_tasks = defaultdict(list) + + for i, llm_syscall in enumerate(executable_llm_syscalls): + model_idx = model_idxs[i] + + if model_idx not in grouped_tasks: + grouped_tasks[model_idx] = [] + grouped_tasks[model_idx].append(llm_syscall) + + # --- 3. Parallel Execution using ThreadPoolExecutor --- + if not grouped_tasks: + logger.warning("No tasks were grouped for execution.") + # Fill remaining Nones with a generic routing error if needed + error_response = LLMResponse(response_message=None, error="System Error: LLM routing failed.", finished=True, status_code=500) + for llm_syscall in executable_llm_syscalls: + llm_syscall.set_status("done") + llm_syscall.set_response(error_response) + llm_syscall.set_end_time(time.time()) + llm_syscall.event.set() + return + + # Determine max workers for the outer executor (managing groups) + max_group_workers = len(grouped_tasks) + logger.info(f"Processing {len(grouped_tasks)} model groups...") + + # Dictionary to store results keyed by original index + results_dict = {} + + try: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_group_workers, thread_name_prefix="LLMGroupWorker") as group_executor: + future_to_group = { + group_executor.submit(self._process_batch_for_model, model_idx, tasks): model_idx + for model_idx, tasks in grouped_tasks.items() + } + + for future in concurrent.futures.as_completed(future_to_group): + model_idx_completed = future_to_group[future] + model_config = self.llm_configs[model_idx_completed] # Get config using index + model_name_completed = model_config.name + + try: + # Worker returns a list of (original_index, result_tuple) + # where result_tuple is (syscall_object, LLMResponse) + group_results = future.result() # Raises exceptions from the worker function + + for original_idx, (syscall_obj, response) in group_results: + results_dict[original_idx] = response # Store the LLMResponse + # Update syscall object status based on response + if response.finished: + syscall_obj.set_status("done") + else: + # This case implies interruption or streaming, adapt if needed + syscall_obj.set_status("suspend") # Or another status? + syscall_obj.set_response(response) + syscall_obj.set_end_time(time.time()) + syscall_obj.event.set() # Notify anyone waiting on this specific syscall + + logger.info(f"Group processing completed for model '{model_name_completed}' (Index {model_idx_completed}). Processed {len(group_results)} tasks.") + + except Exception as exc: + # Handle errors occurring within the _process_batch_for_model worker + logger.error(f"Worker thread for model group '{model_name_completed}' (Index {model_idx_completed}) failed: {exc}", exc_info=True) + # Assign an error response to all tasks originally assigned to this failed group worker + error_response = LLMResponse( + response_message=None, + error=f"Worker exception: {exc}", + finished=True, + status_code=500 + ) + tasks_in_failed_group = grouped_tasks[model_idx_completed] + for failed_original_idx, failed_syscall in tasks_in_failed_group: + if failed_original_idx not in results_dict: # Avoid overwriting if processed before crash + results_dict[failed_original_idx] = error_response + failed_syscall.set_status("error") + failed_syscall.set_response(error_response) + failed_syscall.set_end_time(time.time()) + failed_syscall.event.set() + + + except Exception as outer_exc: + # Handle errors during ThreadPoolExecutor setup or management + logger.error(f"Critical error during batch execution setup: {outer_exc}", exc_info=True) + error_response = LLMResponse( + response_message=None, + error=str(outer_exc), + finished=True, + status_code=500 + ) + # Assign error to all tasks that haven't received a result yet + for llm_syscall in executable_llm_syscalls: + llm_syscall.set_status("error") + llm_syscall.set_response(error_response) + llm_syscall.set_end_time(time.time()) + llm_syscall.event.set() + + end_exec_time = time.time() + logger.info(f"Batch execution finished for {num_syscalls} syscalls in {end_exec_time - start_exec_time:.2f} seconds.") + + return + + + def _process_batch_for_model(self, model_idx, tasks_with_indices): + """ + Process a batch of tasks assigned to a specific model index using another ThreadPoolExecutor. + + Args: + model_idx: Index of the target model configuration. + tasks_with_indices: List of tuples: [(original_index, llm_syscall), ...] + + Returns: + List of tuples: [(original_index, (llm_syscall, LLMResponse)), ...] + Raises exceptions if the inner execution fails critically. + """ + model_config = self.llm_configs[model_idx] + logger.info(f"Starting processing for {len(tasks_with_indices)} tasks on model '{model_config.name}' (Index {model_idx}).") + + batch_results = [] + max_workers = len(tasks_with_indices) + + try: + with ThreadPoolExecutor(max_workers=max_workers, thread_name_prefix=f"LLMWorker_M{model_idx}") as executor: + # Map future to original index for result correlation + future_to_original_index = { + executor.submit(self.execute_llm_syscall, model_idx, llm_syscall): original_idx + for original_idx, llm_syscall in enumerate(tasks_with_indices) + } + + for future in concurrent.futures.as_completed(future_to_original_index): + original_idx = future_to_original_index[future] + syscall_obj = tasks_with_indices[original_idx] + + try: + # execute_llm_syscall now returns (syscall_obj, LLMResponse) + result_tuple = future.result() + batch_results.append((original_idx, result_tuple)) + except Exception as exc: + # Handle errors from individual execute_llm_syscall calls + logger.error(f"Error executing syscall for original index {original_idx} on model '{model_config.name}': {exc}", exc_info=False) # Log less detail maybe + # Create an error response for this specific task + error_response = self._handle_completion_error(exc, model_config.name) + batch_results.append((original_idx, (syscall_obj, error_response))) # Return error response associated with the syscall + + except Exception as batch_exc: + # Handle errors during the inner ThreadPoolExecutor setup/management + logger.error(f"Critical error processing batch for model '{model_config.name}': {batch_exc}", exc_info=True) + # Re-raise the exception to be caught by the outer executor's error handling + raise batch_exc + + logger.info(f"Finished processing batch for model '{model_config.name}'.") + return batch_results + + + def execute_llm_syscall( + self, + model_idx, + llm_syscall, + temperature: float = 0.0 + ) -> tuple[LLMQuery, LLMResponse]: # Return tuple of (syscall, response) + """ + Execute a single LLM syscall request. Handles setup, execution, and response processing. + + Args: + model_idx: Index of the LLM configuration to use. + llm_syscall: LLMQuery object containing the request. + + Returns: + A tuple containing the original LLMQuery object and the resulting LLMResponse. + """ + model_config = self.llm_configs[model_idx] + model_name = model_config.name + model_identifier = self.llms[model_idx] # This is the actual model object or string ID + api_base = model_config.hostname # Use hostname from the validated config + + try: + # --- Parameter Extraction and Validation --- + try: + messages = llm_syscall.query.messages + tools = llm_syscall.query.tools + message_return_type = llm_syscall.query.message_return_type + response_format = llm_syscall.query.response_format + temperature = llm_syscall.query.temperature if llm_syscall.query.temperature is not None else 1.0 # Default temp if not set + max_tokens = llm_syscall.query.max_new_tokens if llm_syscall.query.max_new_tokens is not None else 1000 # Default max tokens + + # Basic validation + if not messages or not isinstance(messages, list): + raise ValueError("Syscall query must contain a non-empty list of messages.") + # Add more validation as needed (e.g., role/content structure) + + except AttributeError as e: + logger.error(f"Syscall object missing expected attributes: {e}", exc_info=True) + return (llm_syscall, LLMResponse( + response_message=None, + error=f"Missing attribute: {e}", finished=True, status_code=400 + )) + except ValueError as e: + logger.error(f"Syscall validation failed: {e}", exc_info=True) + return (llm_syscall, LLMResponse( + response_message=None, + error=str(e), finished=True, status_code=400 + )) + + llm_syscall.set_status("executing") llm_syscall.set_start_time(time.time()) - - # breakpoint() - - model_idxs = self.strategy.get_model_idxs(selected_llms) - model_idx = model_idxs[0] - - model = self.llms[model_idx] - - model_name = self.llm_configs[model_idx].get("name") - - api_base = self.llm_configs[model_idx].get("hostname", None) - - # breakpoint() - + + # --- Tool Preparation --- + prepared_tools = None if tools: - tools = slash_to_double_underscore(tools) - - # deprecated as the tools are already supported in Litellm completion - # messages = self._prepare_messages( - # model=model, - # messages=messages, - # tools=tools, - # return_type=message_return_type, - # response_format=response_format - # ) + try: + prepared_tools = slash_to_double_underscore(tools) + except Exception as e: + logger.error(f"Error processing tools for syscall: {e}", exc_info=True) + return (llm_syscall, LLMResponse( + response_message=None, + error=f"Tool processing error: {e}", finished=True, status_code=400 + )) + # --- Model Response Generation --- try: completed_response, finished = self._get_model_response( model_name=model_name, - model=model, - messages=messages, - tools=tools, + model=model_identifier, + messages=messages, + tools=prepared_tools, # Use prepared tools llm_syscall=llm_syscall, api_base=api_base, message_return_type=message_return_type, @@ -363,25 +626,48 @@ class LLMAdapter: temperature=temperature, max_tokens=max_tokens ) - except Exception as e: - return self._handle_completion_error(e) + # Handle errors specifically from _get_model_response (API errors, timeouts etc.) + # The exception 'e' here could be a custom exception raised by _get_model_response + # or a standard one. _handle_completion_error should classify it. + logger.warning(f"Model response generation failed for {model_name}: {e}") # Warning level as it's handled + return (llm_syscall, self._handle_completion_error(e, model_name)) - return self._process_response( - completed_response=completed_response, - model=model, - finished=finished, - tools=tools, - message_return_type=message_return_type - ) + # --- Response Processing --- + try: + processed_response = self._process_response( + completed_response=completed_response, + finished=finished, + tools=tools, # Pass original tools for context if needed by processing logic + model=model_identifier, # Pass model identifier if needed + message_return_type=message_return_type + ) + return (llm_syscall, processed_response) + + except Exception as e: + logger.error(f"Failed to process LLM response for {model_name}: {e}", exc_info=True) + return (llm_syscall, LLMResponse( + response_message=None, + error=f"Response processing error: {e}", + finished=True, # Mark as finished even if processing failed + status_code=500 + )) except Exception as e: - return LLMResponse( - response_message=f"System Error: {str(e)}", - error=str(e), + # Catch-all for unexpected errors within the execute_llm_syscall function itself + logger.error(f"Unexpected critical error during syscall execution for {model_name}: {e}", exc_info=True) + # Ensure syscall status is updated if possible before returning + try: + llm_syscall.set_status("error") + except Exception: + pass # Ignore errors during error handling status update + + return (llm_syscall, LLMResponse( + response_message=None, + error=f"Unhandled exception: {str(e)}", finished=True, status_code=500 - ) + )) def _get_model_response( self, @@ -395,195 +681,280 @@ class LLMAdapter: response_format: Optional[Dict[str, Dict]] = None, temperature: float = 1.0, max_tokens: int = 1000 - ) -> Any: + ) -> tuple[Union[str, List, Dict], bool]: # Return type depends on success/tool use """ - Get response from the model. + Get response from the specific model backend. Handles API calls and context management. + Raises exceptions on failure (e.g., API errors, timeouts). Args: - model_name: Name of the model to use - model: The LLM model instance or identifier - messages: Prepared messages - tools: Optional list of tools - temperature: Temperature parameter - llm_syscall: The syscall object - api_base: Optional API base URL - message_return_type: Expected return type ("json" or "text") - response_format: Optional response format specification + model_name: Name of the model (for logging/errors). + model: The LLM model instance or LiteLLM identifier string. + messages: Prepared messages. + tools: Optional list of tools (with double underscores). + llm_syscall: The syscall object (for context manager). + api_base: Optional API base URL. + message_return_type: Expected return type ("json" or "text"). + response_format: Optional response format specification. + temperature: Temperature parameter. + max_tokens: Max tokens parameter. Returns: - Tuple of (model_response, finished_flag) + Tuple of (model_response, finished_flag). Model response can be str, list (tool calls), or dict (json). + + Raises: + Various exceptions from API calls (APIError, Timeout, etc.) or context manager. """ - # Handle context management if enabled - if self.use_context_manager: - pid = llm_syscall.get_pid() - time_limit = llm_syscall.get_time_limit() - completed_response, finished = self.context_manager.generate_response_with_interruption( - model_name=model_name, - model=model, - messages=messages, - tools=tools, - pid=pid, - time_limit=time_limit, - message_return_type=message_return_type, - response_format=response_format, - temperature=temperature, - max_tokens=max_tokens - ) - - # breakpoint() - - return completed_response, finished - - # Process request without context management - completion_kwargs = { - "messages": messages, - "temperature": temperature, - "max_tokens": max_tokens - } - - # Add tools if provided - if tools: - completion_kwargs["tools"] = tools - completion_kwargs["tool_choice"] = "required" - - # Add JSON formatting if requested - if message_return_type == "json": - # completion_kwargs["format"] = "json" - if isinstance(model, str): - completion_kwargs["format"] = "json" # do not use format for sglang and vllm - - if response_format: - completion_kwargs["response_format"] = response_format - - # Add API base if provided - if api_base: - if isinstance(model, str): - completion_kwargs["api_base"] = api_base - - # Handle different model types - # breakpoint() - if isinstance(model, str): - # Use litellm completion for string model identifiers - # breakpoint() - completed_response = completion(model=model, **completion_kwargs) - - # breakpoint() - - if tools: - completed_response = decode_litellm_tool_calls(completed_response) - - return completed_response, True + start_time = time.time() + logger.debug(f"[{model_name}] Getting model response. Tools: {'Yes' if tools else 'No'}, Type: {message_return_type}, Temp: {temperature}, MaxTokens: {max_tokens}") + + try: + # --- Context Management Handling --- + if self.use_context_manager and self.context_manager: + pid = llm_syscall.get_pid() + time_limit = llm_syscall.get_time_limit() + # Assuming generate_response_with_interruption handles its own internal errors + # or propagates them up. Add try-except here if it can fail before calling the model. + logger.debug(f"[{model_name}] Using context manager (PID: {pid}, Limit: {time_limit}s)") + completed_response, finished = self.context_manager.generate_response_with_interruption( + model_name=model_name, + model=model, # Pass the actual model object/ID + messages=messages, + tools=tools, # Pass processed tools + pid=pid, + time_limit=time_limit, + message_return_type=message_return_type, + response_format=response_format, + temperature=temperature, + max_tokens=max_tokens, + api_base=api_base # Pass api_base to context manager if needed + ) + # The context manager should return the raw response (str, dict, or tool call list) + # It might raise exceptions if interrupted or if the underlying call fails. + logger.debug(f"[{model_name}] Context manager returned. Finished: {finished}") + return completed_response, finished + + # --- Direct Model Call Handling (No Context Manager) --- else: - return completed_response.choices[0].message.content, True - - elif isinstance(model, OpenAI): - # Use OpenAI client for OpenAI model instances - # (Used for vllm and sglang endpoints due to litellm compatibility issues) - # breakpoint() - completed_response = model.chat.completions.create( - model=model_name, - **completion_kwargs - ) - - if tools: - breakpoint() - completed_response = decode_litellm_tool_calls(completed_response) - return completed_response, True - else: - return completed_response.choices[0].message.content, True - - elif isinstance(model, HfLocalBackend): - # Use Hugging Face local backend for model instances - # (Used for local model instances) - # breakpoint() - if tools: - new_messages = merge_messages_with_tools(messages, tools) - # breakpoint() - completion_kwargs["messages"] = new_messages - elif message_return_type == "json": - new_messages = merge_messages_with_response_format(messages, response_format) - completion_kwargs["messages"] = new_messages - # breakpoint() - completed_response = model.generate(**completion_kwargs) - return completed_response, True - - # For other model types (should be handled by their respective classes) - else: - raise ValueError(f"Unsupported model type: {type(model)}") + logger.debug(f"[{model_name}] Calling model directly.") + completion_kwargs = { + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens + } + + # Add tools if provided (use processed tool names) + if tools: + completion_kwargs["tools"] = tools + completion_kwargs["tool_choice"] = "auto" # Or "required" if always needed? Let model decide. + + # Add JSON formatting if requested + # Note: Some models handle "format" kwarg, others "response_format". LiteLLM standardizes. + if message_return_type == "json": + # Standard way via response_format + completion_kwargs["response_format"] = {"type": "json_object"} + if response_format: # Allow more specific schema if provided + # Be careful: merging might be complex depending on provider support + logger.warning(f"[{model_name}] Overriding standard JSON format with provided response_format schema. Compatibility depends on model.") + completion_kwargs["response_format"] = response_format + + # Add API base if provided (primarily for LiteLLM string models) + if api_base and isinstance(model, str): + completion_kwargs["api_base"] = api_base + logger.debug(f"[{model_name}] Using api_base: {api_base}") + + + # --- Execute Call Based on Model Type --- + if isinstance(model, str): + # Use LiteLLM completion + logger.debug(f"[{model_name}] Calling litellm.completion for model: {model}") + # LiteLLM raises specific exceptions on failure + response = litellm.completion(model=model, **completion_kwargs) + logger.debug(f"[{model_name}] LiteLLM response received.") + + # Extract content or tool calls from LiteLLM response + message = response.choices[0].message + if message.tool_calls: + # Decode directly here or let _process_response handle? + # Let's return the raw tool calls for _process_response + # logger.debug(f"[{model_name}] LiteLLM returned tool calls: {message.tool_calls}") + # decoded_calls = decode_litellm_tool_calls(response) # Assuming this returns the desired list format + return message.tool_calls, True # Return raw tool calls + else: + # logger.debug(f"[{model_name}] LiteLLM returned content: {message.content[:100]}...") + return message.content, True + + elif isinstance(model, OpenAI): + # Use OpenAI client (for vLLM, SGLang, or direct OpenAI) + logger.debug(f"[{model_name}] Calling OpenAI client for model: {model_name}") + # OpenAI client raises specific exceptions + response = model.chat.completions.create( + model=model_name, # Pass the specific model name if needed by the endpoint + **completion_kwargs + ) + logger.debug(f"[{model_name}] OpenAI client response received.") + + message = response.choices[0].message + if message.tool_calls: + # logger.debug(f"[{model_name}] OpenAI client returned tool calls: {message.tool_calls}") + return message.tool_calls, True # Return raw tool calls + else: + # logger.debug(f"[{model_name}] OpenAI client returned content: {message.content[:100]}...") + return message.content, True + + elif isinstance(model, HfLocalBackend): + # Use Hugging Face local backend + logger.debug(f"[{model_name}] Calling HfLocalBackend.generate") + # Prepare messages specifically for HF backend if needed + if tools: + # HfLocalBackend expects tools merged into messages + final_messages = merge_messages_with_tools(messages, tools) + completion_kwargs["messages"] = final_messages + elif message_return_type == "json": + # HfLocalBackend expects JSON format instruction merged + final_messages = merge_messages_with_response_format(messages, response_format or {"type": "json_object"}) + completion_kwargs["messages"] = final_messages + # Remove tool/format kwargs as they are merged into messages for HF + completion_kwargs.pop("tools", None) + completion_kwargs.pop("tool_choice", None) + completion_kwargs.pop("response_format", None) + + # HfLocalBackend generate might raise its own errors + generated_text = model.generate(**completion_kwargs) + # logger.debug(f"[{model_name}] HfLocalBackend generated: {generated_text[:100]}...") + # HfLocalBackend returns a single string. Tool/JSON decoding happens in _process_response. + return generated_text, True + + + except (APIError, APITimeoutError, APIConnectionError, RateLimitError, AuthenticationError, BadRequestError) as api_err: + # Catch specific API/LiteLLM errors and re-raise them for _handle_completion_error + logger.warning(f"[{model_name}] API call failed: {type(api_err).__name__} - {api_err}") + raise api_err # Propagate the specific error + except Exception as e: + # Catch any other unexpected error during the call + logger.error(f"[{model_name}] Unexpected error during model response retrieval: {e}", exc_info=True) + # Wrap in a generic Exception or re-raise? Re-raise for now. + raise e # Propagate for handling upstream + finally: + end_time = time.time() + logger.debug(f"[{model_name}] _get_model_response took {end_time - start_time:.2f} seconds.") + def _process_response( - self, - completed_response: str | List, # either a response message of a string or a list of tool calls + self, + completed_response: Union[str, List, Dict], # Raw response: str, list of tool calls (OpenAI/LiteLLM style), or dict (JSON) finished: bool, - tools: Optional[List] = None, - model: Union[str, OpenAI, HfLocalBackend] = None, - message_return_type: Optional[str] = None + tools: Optional[List] = None, # Original tools (with slashes) might be needed for context + model: Union[str, OpenAI, HfLocalBackend] = None, # Model identifier/object + message_return_type: Optional[str] = None # "json" or "text" ) -> LLMResponse: """ - Process the model's response into the appropriate format. + Process the raw model's response into a structured LLMResponse. + Handles tool call decoding and JSON parsing. Args: - response: Raw response from the model - tools: Optional list of tools - message_return_type: Expected return type ("json" or None) + completed_response: Raw response from _get_model_response. + finished: Flag indicating if the generation finished. + tools: Original list of tools provided in the request (with slashes). + model: The model identifier or object used. + message_return_type: Expected return type ("json" or "text"). Returns: - Formatted LLMResponse - - Example: - ```python - # Input - response = '{"result": "4", "operation": "2 + 2"}' - tools = None - message_return_type = "json" - - # Output LLMResponse - { - "response_message": {"result": "4", "operation": "2 + 2"}, - "finished": True, - "tool_calls": None, - "error": None, - "status_code": 200 - } - - # Input with tool calls - response = 'I will use the calculator tool...' - tools = [{"name": "calculator", ...}] - - # Output LLMResponse with tool calls - { - "response_message": None, - "finished": True, - "tool_calls": [{ - "id": "call_abc123", - "name": "calculator", - "arguments": {"operation": "add", "numbers": [2, 2]} - }], - "error": None, - "status_code": 200 - } - ``` + Formatted LLMResponse. """ - # breakpoint() + logger.debug(f"Processing response. Finished: {finished}, Type: {type(completed_response)}, Expected: {message_return_type}") - if tools: - if isinstance(model, HfLocalBackend): - if finished: - tool_calls = decode_hf_tool_calls(completed_response) - tool_calls = double_underscore_to_slash(tool_calls) - return LLMResponse( - response_message=None, - tool_calls=tool_calls, - finished=finished - ) - else: - return LLMResponse(response_message=completed_response, finished=finished) - else: - tool_calls = double_underscore_to_slash(completed_response) + try: + # --- Tool Call Handling --- + # Check if tools were expected *and* if the response looks like tool calls + if tools: + if isinstance(completed_response, list) and all(hasattr(item, 'function') for item in completed_response): + # Likely OpenAI/LiteLLM style tool calls list + logger.debug("Processing list of tool calls (OpenAI/LiteLLM style).") + try: + # Need to convert OpenAI/LiteLLM ToolCall objects to our dict format + decoded_calls = decode_litellm_tool_calls({"choices": [{"message": {"tool_calls": completed_response}}]}) # Wrap for decoder + final_tool_calls = double_underscore_to_slash(decoded_calls) + logger.debug(f"Decoded tool calls: {final_tool_calls}") + return LLMResponse( + response_message=None, + tool_calls=final_tool_calls, + finished=finished, + status_code=200 + ) + except Exception as e: + logger.error(f"Error decoding LiteLLM/OpenAI tool calls: {e}", exc_info=True) + return LLMResponse( + response_message=None, + error=f"Tool call decoding error: {e}", + finished=True, # Treat as finished with error + status_code=500 + ) + elif isinstance(model, HfLocalBackend) and isinstance(completed_response, str): + # HF models return text that needs parsing for tool calls + logger.debug("Attempting to decode tool calls from HfLocalBackend string response.") + try: + # decode_hf_tool_calls expects the raw string + tool_calls = decode_hf_tool_calls(completed_response) + if tool_calls: # Check if decoding was successful + tool_calls = double_underscore_to_slash(tool_calls) + logger.debug(f"Decoded HF tool calls: {tool_calls}") + return LLMResponse( + response_message=None, # No text message if tools were called + tool_calls=tool_calls, + finished=finished, # Usually True for HF unless streaming implemented differently + status_code=200 + ) + else: + # Model tried to call tools but failed, or just generated text? + logger.warning("HfLocalBackend response received when tools expected, but decode_hf_tool_calls returned empty. Treating as text.") + # Fall through to text/JSON processing + pass + except Exception as e: + logger.error(f"Error decoding tool calls from HfLocalBackend response: {e}", exc_info=True) + return LLMResponse( + response_message=None, + error=f"Tool call decoding error: {e}", + finished=True, + status_code=500 + ) + # else: Handle cases where tools were expected but response isn't recognized tool format? + # For now, fall through to text/JSON processing. + + # --- Plain Text Response Handling --- + if isinstance(completed_response, str): + logger.debug("Processing as plain text response.") + return LLMResponse( + response_message=completed_response, + finished=finished, + status_code=200 + ) + + # --- Fallback for Unexpected Types --- + logger.warning(f"Unexpected response type received in _process_response: {type(completed_response)}. Content: {str(completed_response)[:200]}...") + # Attempt to convert to string as a last resort + try: + fallback_message = str(completed_response) + return LLMResponse( + response_message=fallback_message, + finished=finished, # Assume finished? + status_code=200, # Or maybe an error code? Let's assume success but log warning. + error="Warning: Unexpected response type processed." + ) + except Exception as str_e: + logger.error(f"Could not convert unexpected response type {type(completed_response)} to string: {str_e}") return LLMResponse( response_message=None, - tool_calls=tool_calls, - finished=finished + error=f"Cannot handle type {type(completed_response)}", + finished=True, + status_code=500 ) - else: - return LLMResponse(response_message=completed_response, finished=finished) + + except Exception as e: + # Catch-all for errors during the processing itself + logger.error(f"Critical error during response processing: {e}", exc_info=True) + return LLMResponse( + response_message=None, + error=f"Unhandled processing exception: {e}", + finished=True, # Mark as finished even if processing failed + status_code=500 + ) diff --git a/aios/llm_core/llm_cost_map.json b/aios/llm_core/llm_cost_map.json deleted file mode 100644 index e98b047..0000000 --- a/aios/llm_core/llm_cost_map.json +++ /dev/null @@ -1,11 +0,0 @@ -[ - {"name": "Qwen-2.5-7B-Instruct", "hosted_model": "ollama/qwen2.5:7b", "cost_per_input_token": 0.000000267, "cost_per_output_token": 0.000000267}, - {"name": "Qwen-2.5-14B-Instruct", "hosted_model": "ollama/qwen2.5:14b", "cost_per_input_token": 0.000000534, "cost_per_output_token": 0.000000534}, - {"name": "Qwen-2.5-32B-Instruct", "hosted_model": "ollama/qwen2.5:32b", "cost_per_input_token": 0.00000122, "cost_per_output_token": 0.00000122}, - {"name": "Llama-3.1-8B-Instruct", "hosted_model": "ollama/llama3.1:8b", "cost_per_input_token": 0.000000305, "cost_per_output_token": 0.000000305}, - {"name": "Deepseek-r1-7b", "hosted_model": "ollama/deepseek-r1:7b", "cost_per_input_token": 0.000000267, "cost_per_output_token": 0.000000267}, - {"name": "Deepseek-r1-14b", "hosted_model": "ollama/deepseek-r1:14b", "cost_per_input_token": 0.000000534, "cost_per_output_token": 0.000000534}, - {"name": "gpt-4o-mini", "hosted_model": "gpt-4o-mini", "cost_per_input_token": 0.00000015, "cost_per_output_token": 0.0000006}, - {"name": "gpt-4o", "hosted_model": "gpt-4o", "cost_per_input_token": 0.0000025, "cost_per_output_token": 0.00001}, - {"name": "gemini-1.5-flash", "hosted_model": "gemini/gemini-1.5-flash", "cost_per_input_token": 0.000000075, "cost_per_output_token": 0.0000003} - ] \ No newline at end of file diff --git a/aios/llm_core/routing.py b/aios/llm_core/routing.py index d6296cf..2eec6c8 100644 --- a/aios/llm_core/routing.py +++ b/aios/llm_core/routing.py @@ -12,6 +12,8 @@ import json from threading import Lock +import openai + import os from pulp import ( @@ -23,6 +25,13 @@ from pulp import ( value ) +import litellm + +import tempfile +import gdown + +from litellm import token_counter + """ Load balancing strategies. Each class represents a strategy which returns the next endpoint that the router should use. @@ -37,9 +46,9 @@ used whenever the strategy is called in __call__, and then return the name of the specific LLM endpoint. """ -class RouterStrategy(Enum): - Sequential = 0, - Smart = 1 +class RouterStrategy: + Sequential = "sequential" + Smart = "smart" class SequentialRouting: """ @@ -75,13 +84,13 @@ class SequentialRouting: # def __call__(self): # return self.get_model() - def get_model_idxs(self, selected_llms: List[str], n_queries: int=1): + def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries: List[List[Dict[str, Any]]]): """ Selects model indices from the available LLM configurations using a round-robin strategy. Args: - selected_llms (List[str]): A list of selected LLM names from which models will be chosen. - n_queries (int): The number of queries to distribute among the selected models. Defaults to 1. + selected_llm_lists (List[List[str]]): A list of selected LLM names from which models will be chosen. + queries (List[List[Dict[str, Any]]]): A list of queries to distribute among the selected models. Returns: List[int]: A list of indices corresponding to the selected models in `self.llm_configs`. @@ -90,477 +99,304 @@ class SequentialRouting: # current = self.selected_llms[self.idx] model_idxs = [] - # breakpoint() + available_models = [llm.name for llm in self.llm_configs] - for _ in range(n_queries): - current = selected_llms[self.idx] - for i, llm_config in enumerate(self.llm_configs): - # breakpoint() - if llm_config["name"] == current["name"]: - model_idxs.append(i) + n_queries = len(queries) + + for i in range(n_queries): + selected_llm_list = selected_llm_lists[i] + + if not selected_llm_list or len(selected_llm_list) == 0: + model_idxs.append(0) + continue + + model_idx = -1 + for selected_llm in selected_llm_list: + if selected_llm["name"] in available_models: + model_idx = available_models.index(selected_llm["name"]) break - self.idx = (self.idx + 1) % len(selected_llms) - # breakpoint() - + model_idxs.append(model_idx) + return model_idxs - - bucket_size = max_output_length / num_buckets - - total_queries = len(test_data) - model_metrics = defaultdict(lambda: { - "correct_predictions": 0, - "total_predictions": 0, - "correct_length_predictions": 0, - "total_length_predictions": 0, - "correct_bucket_predictions": 0, - }) - - # Initialize metrics for each model - model_metrics = defaultdict(lambda: { - "correct_predictions": 0, - "total_predictions": 0, - "correct_length_predictions": 0, - "total_length_predictions": 0, - "correct_bucket_predictions": 0 - }) - - for test_item in tqdm(test_data, desc="Evaluating"): - similar_results = self.query_similar( - test_item["query"], - "train", - n_results=n_similar - ) - - model_stats = defaultdict(lambda: { - "total_length": 0, - "count": 0, - "correct_count": 0 - }) - - - for metadata in similar_results["metadatas"][0]: - for model_info in json.loads(metadata["models"]): - model_name = model_info["model_name"] - stats = model_stats[model_name] - stats["total_length"] += model_info["output_token_length"] - stats["count"] += 1 - stats["correct_count"] += int(model_info["correctness"]) - - for model_output in test_item["outputs"]: - model_name = model_output["model_name"] - if model_name in model_stats: - stats = model_stats[model_name] - - if stats["count"] > 0: - predicted_correctness = stats["correct_count"] / stats["count"] >= 0.5 - actual_correctness = model_output["correctness"] - - if predicted_correctness == actual_correctness: - model_metrics[model_name]["correct_predictions"] += 1 - - predicted_length = stats["total_length"] / stats["count"] - - predicted_bucket = min(int(predicted_length / bucket_size), num_buckets - 1) - - actual_length = model_output["output_token_length"] - - # length_error = abs(predicted_length - actual_length) - actual_bucket = min(int(actual_length / bucket_size), num_buckets - 1) - - if predicted_bucket == actual_bucket: - model_metrics[model_name]["correct_length_predictions"] += 1 - - if abs(predicted_bucket - actual_bucket) <= 1: - model_metrics[model_name]["correct_bucket_predictions"] += 1 - - model_metrics[model_name]["total_predictions"] += 1 - model_metrics[model_name]["total_length_predictions"] += 1 - - results = {} - - # Calculate overall accuracy across all models - total_correct_predictions = sum(metrics["correct_predictions"] for metrics in model_metrics.values()) - total_correct_length_predictions = sum(metrics["correct_length_predictions"] for metrics in model_metrics.values()) - total_correct_bucket_predictions = sum(metrics["correct_bucket_predictions"] for metrics in model_metrics.values()) - total_predictions = sum(metrics["total_predictions"] for metrics in model_metrics.values()) - - if total_predictions > 0: - results["overall"] = { - "performance_accuracy": total_correct_predictions / total_predictions, - "length_accuracy": total_correct_length_predictions / total_predictions, - "bucket_accuracy": total_correct_bucket_predictions / total_predictions - } - for model_name, metrics in model_metrics.items(): - total = metrics["total_predictions"] - if total > 0: - results[model_name] = { - "performance_accuracy": metrics["correct_predictions"] / total, - "length_accuracy": metrics["correct_length_predictions"] / total, - "bucket_accuracy": metrics["correct_bucket_predictions"] / total - } - results["overall"] = { - "performance_accuracy": total_correct_predictions / total_predictions, - "length_accuracy": total_correct_length_predictions / total_predictions, - "bucket_accuracy": total_correct_bucket_predictions / total_predictions - } - - return results + +def get_cost_per_token(model_name: str) -> tuple[float, float]: + """Fetch the latest *per‑token* input/output pricing from LiteLLM. + + This pulls the live `model_cost` map which LiteLLM refreshes from + `api.litellm.ai`, so you always have current pricing information. + If the model is unknown, graceful fall‑back to zero cost. + """ + cost_map = litellm.model_cost # Live dict {model: {input_cost_per_token, output_cost_per_token, ...}} + info = cost_map.get(model_name, {}) + return info.get("input_cost_per_token", 0.0), info.get("output_cost_per_token", 0.0) + +def get_token_lengths(queries: List[List[Dict[str, Any]]]): + """ + Get the token lengths of a list of queries. + """ + return [token_counter(model="gpt-4o-mini", messages=query) for query in queries] + +def messages_to_query(messages: List[Dict[str, str]], + strategy: str = "last_user") -> str: + """ + Convert OpenAI ChatCompletion-style messages into a single query string. + strategy: + - "last_user": last user message only + - "concat_users": concat all user messages + - "concat_all": concat role-labelled full history + - "summarize": use GPT to summarize into a short query + """ + if strategy == "last_user": + for msg in reversed(messages): + if msg["role"] == "user": + return msg["content"].strip() + return "" # fallback + + if strategy == "concat_users": + return "\n\n".join(m["content"].strip() + for m in messages if m["role"] == "user") + + if strategy == "concat_all": + return "\n\n".join(f'{m["role"].upper()}: {m["content"].strip()}' + for m in messages) + + if strategy == "summarize": + full_text = "\n\n".join(f'{m["role"]}: {m["content"]}' + for m in messages) + rsp = openai.chat.completions.create( + model="gpt-4o-mini", + messages=[ + {"role": "system", + "content": ("You are a helpful assistant that rewrites a " + "dialogue into a concise search query.")}, + {"role": "user", "content": full_text} + ], + max_tokens=32, + temperature=0.2 + ) + return rsp.choices[0].message.content.strip() + + raise ValueError("unknown strategy") class SmartRouting: - """ - The SmartRouting class implements a cost-performance optimized selection strategy for LLM requests. - It uses historical performance data to predict which models will perform best for a given query - while minimizing cost. + """Cost‑/performance‑aware LLM router. - This strategy ensures that models are selected based on their predicted performance and cost, - optimizing for both quality and efficiency. - - Args: - llm_configs (List[Dict[str, Any]]): A list of LLM configurations, where each dictionary contains - model information such as name, backend, cost parameters, etc. - performance_requirement (float): The minimum performance score required (default: 0.7) - n_similar (int): Number of similar queries to retrieve for prediction (default: 16) + Two **key upgrades** compared with the original version: + 1. **Bootstrap local ChromaDB** automatically on first run by pulling a + prepared JSONL corpus from Google Drive (uses `gdown`). + 2. **Live pricing**: no more hard‑coded token prices – we call LiteLLM to + fetch up‑to‑date `input_cost_per_token` / `output_cost_per_token` for + every model the moment we need them. """ - def __init__(self, llm_configs: List[Dict[str, Any]], performance_requirement: float=0.7, n_similar: int=16): - self.num_buckets = 10 - self.max_output_limit = 1024 - self.n_similar = n_similar - self.bucket_size = self.max_output_limit / self.num_buckets - self.performance_requirement = performance_requirement - - print(f"Performance requirement: {self.performance_requirement}") - - self.llm_configs = llm_configs - self.store = self.QueryStore() - self.lock = Lock() - + + # --------------------------------------------------------------------- + # Construction helpers + # --------------------------------------------------------------------- + class QueryStore: - """ - Internal class for storing and retrieving query embeddings and related model performance data. - Uses ChromaDB for vector similarity search. - """ - def __init__(self, - model_name: str = "BAAI/bge-small-en-v1.5", - persist_directory: str = "llm_router"): - - file_path = os.path.join(os.path.dirname(__file__), persist_directory) - self.client = chromadb.PersistentClient(path=file_path) - - self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction( - model_name=model_name - ) + """Simple wrapper around ChromaDB persistent collections.""" + + def __init__(self, + model_name: str = "all-MiniLM-L6-v2", + persist_directory: str = "llm_router", + bootstrap_url: str | None = None): + self._persist_root = os.path.join(os.path.dirname(__file__), persist_directory) + os.makedirs(self._persist_root, exist_ok=True) + + self.client = chromadb.PersistentClient(path=self._persist_root) + self.embedding_function = embedding_functions.DefaultEmbeddingFunction() + + # Always create/get collections up‑front so we can inspect counts. + # self.train_collection = self._get_or_create_collection("train_queries") + # self.val_collection = self._get_or_create_collection("val_queries") + # self.test_collection = self._get_or_create_collection("test_queries") + self.collection = self._get_or_create_collection("historical_queries") + # If DB is empty and we have a bootstrap URL – populate it. + if bootstrap_url and self.collection.count() == 0: + self._bootstrap_from_drive(bootstrap_url) + + # ................................................................. + # Chroma helpers + # ................................................................. + + def _get_or_create_collection(self, name: str): try: - self.train_collection = self.client.get_collection( - name="train_queries", - embedding_function=self.embedding_function - ) - except: - self.train_collection = self.client.create_collection( - name="train_queries", - embedding_function=self.embedding_function - ) - - try: - self.val_collection = self.client.get_collection( - name="val_queries", - embedding_function=self.embedding_function - ) - except: - self.val_collection = self.client.create_collection( - name="val_queries", - embedding_function=self.embedding_function - ) - - try: - self.test_collection = self.client.get_collection( - name="test_queries", - embedding_function=self.embedding_function - ) - except: - self.test_collection = self.client.create_collection( - name="test_queries", - embedding_function=self.embedding_function - ) - - def add_data(self, data: List[Dict], split: str = "train"): - """ - Add query data to the appropriate collection. - - Args: - data (List[Dict]): List of query data items - split (str): Data split ("train", "val", or "test") - """ - collection = getattr(self, f"{split}_collection") - queries = [] - metadatas = [] - ids = [] - - correct_count = 0 - total_count = 0 - - for idx, item in enumerate(tqdm(data, desc=f"Adding {split} data")): + return self.client.get_collection(name=name, embedding_function=self.embedding_function) + except Exception: + return self.client.create_collection(name=name, embedding_function=self.embedding_function) + + # ................................................................. + # Bootstrap logic – download + ingest + # ................................................................. + + def _bootstrap_from_drive(self, url_or_id: str): + print("\n[SmartRouting] Bootstrapping ChromaDB from Google Drive…\n") + + with tempfile.TemporaryDirectory() as tmp: + # NB: gdown accepts both share links and raw IDs. + local_path = os.path.join(tmp, "bootstrap.json") + + gdown.download(url_or_id, local_path, quiet=False, fuzzy=True) + + # Expect JSONL with {"query": ..., "split": "train"|"val"|"test", ...} + # split_map: dict[str, list[dict[str, Any]]] = defaultdict(list) + with open(local_path, "r") as f: + data = json.load(f) + + self.add_data(data) + + print("[SmartRouting] Bootstrap complete – collections populated.\n") + + # ................................................................. + # Public data API + # ................................................................. + + def add_data(self, data: List[Dict[str, Any]]): + collection = self.collection + queries, metadatas, ids = [], [], [] + correct_count = total_count = 0 + + for idx, item in enumerate(tqdm(data, desc=f"Ingesting historical queries")): query = item["query"] - - metadata = { + model_metadatas = item["outputs"] + for model_metadata in model_metadatas: + model_metadata.pop("prediction") + meta = { "input_token_length": item["input_token_length"], - "models": [] + "models": json.dumps(model_metadatas), # store raw list } - for output in item["outputs"]: - model_info = { - "model_name": output["model_name"], - "correctness": output["correctness"], - "output_token_length": output["output_token_length"] - } + total_count += 1 if output["correctness"]: correct_count += 1 - total_count += 1 - metadata["models"].append(model_info) - - metadata["models"] = json.dumps(metadata["models"]) - queries.append(query) - metadatas.append(metadata) - ids.append("id"+str(idx)) - - print(f"Correctness: {correct_count} / {total_count}") - - collection.add( - documents=queries, - metadatas=metadatas, - ids = ids - ) - - def query_similar(self, query: str | List[str], split: str = "train", n_results: int = 16): - """ - Find similar queries in the database. - - Args: - query (str|List[str]): Query or list of queries to find similar items for - split (str): Data split to search in - n_results (int): Number of similar results to return - - Returns: - Dict: Results from ChromaDB query - """ - collection = getattr(self, f"{split}_collection") - - results = collection.query( - query_texts=query if isinstance(query, List) else [query], - n_results=n_results - ) - - return results - - def predict(self, query: str | List[str], model_configs: List[Dict], n_similar: int = 16): - """ - Predict performance and output length for models on the given query. - - Args: - query (str|List[str]): Query or list of queries - model_configs (List[Dict]): List of model configurations - n_similar (int): Number of similar queries to use for prediction - - Returns: - Tuple[np.array, np.array]: Performance scores and length scores - """ - # Get similar results from training data - similar_results = self.query_similar(query, "train", n_results=n_similar) - - total_performance_scores = [] - total_length_scores = [] - # Aggregate stats from similar results - for i in range(len(similar_results["metadatas"])): - model_stats = defaultdict(lambda: { - "total_length": 0, - "count": 0, - "correct_count": 0 - }) - metadatas = similar_results["metadatas"][i] - for metadata in metadatas: - for model_info in json.loads(metadata["models"]): - model_name = model_info["model_name"] - stats = model_stats[model_name] - stats["total_length"] += model_info["output_token_length"] - stats["count"] += 1 - stats["correct_count"] += int(model_info["correctness"]) - - # Calculate performance and length scores for each model - performance_scores = [] - length_scores = [] - - for model in model_configs: - model_name = model["name"] - if model_name in model_stats and model_stats[model_name]["count"] > 0: - stats = model_stats[model_name] - # Calculate performance score as accuracy - perf_score = stats["correct_count"] / stats["count"] - # Calculate average length - avg_length = stats["total_length"] / stats["count"] - - performance_scores.append(perf_score) - length_scores.append(avg_length) + metadatas.append(meta) + ids.append(f"{idx}") + + collection.add(documents=queries, metadatas=metadatas, ids=ids) + print(f"[SmartRouting]: {total_count} historical queries ingested.") + + # .................................................................. + def query_similar(self, query: str | List[str], n_results: int = 16): + collection = self.collection + return collection.query(query_texts=query if isinstance(query, list) else [query], n_results=n_results) + + # .................................................................. + def predict(self, query: str | List[str], model_configs: List[Dict[str, Any]], n_similar: int = 16): + similar = self.query_similar(query, n_results=n_similar) + perf_mat, len_mat = [], [] + for meta_group in similar["metadatas"]: + model_stats: dict[str, dict[str, float]] = defaultdict(lambda: {"total_len": 0, "cnt": 0, "correct": 0}) + for meta in meta_group: + for m in json.loads(meta["models"]): + s = model_stats[m["model_name"]] + s["total_len"] += m["output_token_length"] + s["cnt"] += 1 + s["correct"] += int(m["correctness"]) + perf_row, len_row = [], [] + for cfg in model_configs: + stats = model_stats.get(cfg["name"], None) + if stats and stats["cnt"]: + perf_row.append(stats["correct"] / stats["cnt"]) + len_row.append(stats["total_len"] / stats["cnt"]) else: - # If no data for model, use default scores - performance_scores.append(0.0) - length_scores.append(0.0) - - total_performance_scores.append(performance_scores) - total_length_scores.append(length_scores) - - return np.array(total_performance_scores), np.array(total_length_scores) - - def optimize_model_selection_local(self, model_configs, perf_scores, cost_scores): - """ - Optimize model selection for a single query based on performance and cost. + perf_row.append(0.0) + len_row.append(0.0) + perf_mat.append(perf_row) + len_mat.append(len_row) + return np.array(perf_mat), np.array(len_mat) + + # --------------------------------------------------------------------- + # SmartRouting main methods + # --------------------------------------------------------------------- + + def __init__(self, + llm_configs: List[Dict[str, Any]], + bootstrap_url: str , + performance_requirement: float = 0.7, + n_similar: int = 16, + ): + self.llm_configs = llm_configs + self.available_models = [llm.name for llm in llm_configs] + self.bootstrap_url = bootstrap_url + self.performance_requirement = performance_requirement + self.n_similar = n_similar + self.lock = Lock() + self.max_output_limit = 1024 + self.num_buckets = 10 + self.bucket_size = self.max_output_limit / self.num_buckets + + # Initialise query store – will self‑populate if empty + self.store = self.QueryStore(bootstrap_url=bootstrap_url) + + print(f"[SmartRouting] Ready – performance threshold: {self.performance_requirement}\n") + + # ..................................................................... + # Local (per‑query) optimisation helper + # ..................................................................... + + def _select_model_single(self, model_cfgs: List[Dict[str, Any]], perf: np.ndarray, cost: np.ndarray) -> int | None: + qualified = [i for i, p in enumerate(perf) if p >= self.performance_requirement] + if qualified: + # Pick cheapest among qualified + return min(qualified, key=lambda i: cost[i]) + # Else, fallback – best performance overall + return int(np.argmax(perf)) if len(perf) else 0 + + # ..................................................................... + # Public API – batch selection + # ..................................................................... + + def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries: List[str]): + if len(selected_llm_lists) != len(queries): + raise ValueError("selected_llm_lists must have same length as queries") + + input_lens = get_token_lengths(queries) + chosen_indices: list[int] = [] - Args: - model_configs (List[Dict]): List of model configurations - perf_scores (np.array): Performance scores for each model - cost_scores (np.array): Cost scores for each model - - Returns: - int: Index of the selected model - """ - n_models = len(model_configs) - - # Get all available models - available_models = list(range(n_models)) - - if not available_models: - return None - - # Find models that meet performance requirement - qualified_models = [] - for i in available_models: - if perf_scores[i] >= self.performance_requirement: - qualified_models.append(i) - - if qualified_models: - # If there are models meeting performance requirement, - # select the one with lowest cost - min_cost = float('inf') - selected_model = None - for i in qualified_models: - if cost_scores[i] < min_cost: - min_cost = cost_scores[i] - selected_model = i - return selected_model - else: - # If no model meets performance requirement, - # select available model with highest performance - max_perf = float('-inf') - selected_model = None - for i in available_models: - if perf_scores[i] > max_perf: - max_perf = perf_scores[i] - selected_model = i - return selected_model - - def get_model_idxs(self, selected_llms: List[Dict[str, Any]], queries: List[str]=None, input_token_lengths: List[int]=None): - """ - Selects model indices from the available LLM configurations based on predicted performance and cost. - - Args: - selected_llms (List[Dict]): A list of selected LLM configurations from which models will be chosen. - n_queries (int): The number of queries to process. Defaults to 1. - queries (List[str], optional): List of query strings. If provided, will be used for model selection. - input_token_lengths (List[int], optional): List of input token lengths. Required if queries is provided. - - Returns: - List[int]: A list of indices corresponding to the selected models in `self.llm_configs`. - """ - model_idxs = [] - - # Ensure we have matching number of queries and token lengths - if len(queries) != len(input_token_lengths): - raise ValueError("Number of queries must match number of input token lengths") - - # Process each query - for i in range(len(queries)): - query = queries[i] - input_token_length = input_token_lengths[i] - - # Get performance and length predictions - perf_scores, length_scores = self.store.predict(query, selected_llms, n_similar=self.n_similar) - perf_scores = perf_scores[0] # First query's scores - length_scores = length_scores[0] # First query's length predictions - - # Calculate cost scores + converted_queries = [messages_to_query(query) for query in queries] + + for q, q_len, candidate_cfgs in zip(converted_queries, input_lens, selected_llm_lists): + perf, out_len = self.store.predict(q, candidate_cfgs, n_similar=self.n_similar) + perf, out_len = perf[0], out_len[0] # unpack single query + + # Dynamic price lookup via LiteLLM cost_scores = [] - for j in range(len(selected_llms)): - pred_output_length = length_scores[j] - input_cost = input_token_length * selected_llms[j].get("cost_per_input_token", 0) - output_cost = pred_output_length * selected_llms[j].get("cost_per_output_token", 0) - weighted_score = input_cost + output_cost - cost_scores.append(weighted_score) - + for cfg, pred_len in zip(candidate_cfgs, out_len): + in_cost_pt, out_cost_pt = get_cost_per_token(cfg["name"]) + cost_scores.append(q_len * in_cost_pt + pred_len * out_cost_pt) cost_scores = np.array(cost_scores) + + sel_local_idx = self._select_model_single(candidate_cfgs, perf, cost_scores) + if sel_local_idx is None: + chosen_indices.append(0) # safe fallback + continue + + # Map back to global llm_configs index + sel_name = candidate_cfgs[sel_local_idx]["name"] - # Select optimal model - selected_idx = self.optimize_model_selection_local( - selected_llms, - perf_scores, - cost_scores - ) - - # Find the index in the original llm_configs - for idx, config in enumerate(self.llm_configs): - if config["name"] == selected_llms[selected_idx]["name"]: - model_idxs.append(idx) - break - else: - # If not found, use the first model as fallback - model_idxs.append(0) - - return model_idxs - - def optimize_model_selection_global(self, perf_scores, cost_scores): - """ - Globally optimize model selection for multiple queries using linear programming. - - Args: - perf_scores (np.array): Performance scores matrix [queries × models] - cost_scores (np.array): Cost scores matrix [queries × models] - - Returns: - np.array: Array of selected model indices for each query - """ - n_models = len(self.llm_configs) - n_queries = len(perf_scores) - + sel_idx = self.available_models.index(sel_name) + chosen_indices.append(sel_idx) + + return chosen_indices + + # ..................................................................... + # Global optimisation (unchanged except for cost lookup) + # ..................................................................... + + def optimize_model_selection_global(self, perf_scores: np.ndarray, cost_scores: np.ndarray): + n_queries, n_models = perf_scores.shape prob = LpProblem("LLM_Scheduling", LpMinimize) - - # Decision variables - x = LpVariable.dicts("assign", - ((i, j) for i in range(n_queries) - for j in range(n_models)), - cat='Binary') - - # Objective function: minimize total cost - prob += lpSum(x[i,j] * cost_scores[i,j] - for i in range(n_queries) - for j in range(n_models)) - - # Quality constraint: ensure overall performance meets requirement - prob += lpSum(x[i,j] * perf_scores[i,j] - for i in range(n_queries) - for j in range(n_models)) >= self.performance_requirement * n_queries - - # Assignment constraints: each query must be assigned to exactly one model + x = LpVariable.dicts("assign", ((i, j) for i in range(n_queries) for j in range(n_models)), cat="Binary") + prob += lpSum(x[i, j] * cost_scores[i, j] for i in range(n_queries) for j in range(n_models)) + prob += lpSum(x[i, j] * perf_scores[i, j] for i in range(n_queries) for j in range(n_models)) >= self.performance_requirement * n_queries for i in range(n_queries): - prob += lpSum(x[i,j] for j in range(n_models)) == 1 - - # Solve + prob += lpSum(x[i, j] for j in range(n_models)) == 1 prob.solve(PULP_CBC_CMD(msg=False)) - - # Extract solution - solution = np.zeros((n_queries, n_models)) + sol = np.zeros((n_queries, n_models)) for i in range(n_queries): for j in range(n_models): - solution[i,j] = value(x[i,j]) - - solution = np.argmax(solution, axis=1) - return solution + sol[i, j] = value(x[i, j]) + return np.argmax(sol, axis=1) diff --git a/aios/llm_core/utils.py b/aios/llm_core/utils.py index cc34207..183a4b0 100644 --- a/aios/llm_core/utils.py +++ b/aios/llm_core/utils.py @@ -3,6 +3,8 @@ import re import uuid from copy import deepcopy +from typing import List, Dict, Any + def merge_messages_with_tools(messages: list, tools: list) -> list: """ Integrate tool information into the messages for open-sourced LLMs which don't support tool calling. @@ -338,4 +340,16 @@ def pre_process_tools(tools): if "/" in tool_name: tool_name = "__".join(tool_name.split("/")) tool["function"]["name"] = tool_name - return tools \ No newline at end of file + return tools + +def check_availability_for_selected_llm_lists(available_llm_names: List[str], selected_llm_lists: List[List[Dict[str, Any]]]): + selected_llm_lists_availability = [] + for selected_llm_list in selected_llm_lists: + all_available = True + + for llm in selected_llm_list: + if llm["name"] not in available_llm_names: + all_available = False + break + selected_llm_lists_availability.append(all_available) + return selected_llm_lists_availability diff --git a/aios/scheduler/fifo_scheduler.py b/aios/scheduler/fifo_scheduler.py index 91a0bed..647aa2c 100644 --- a/aios/scheduler/fifo_scheduler.py +++ b/aios/scheduler/fifo_scheduler.py @@ -24,7 +24,7 @@ from queue import Empty import traceback import time import logging -from typing import Optional, Any, Dict +from typing import Optional, Any, Dict, List # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') @@ -34,8 +34,9 @@ class FIFOScheduler(BaseScheduler): """ A FIFO (First-In-First-Out) task scheduler implementation. - This scheduler processes tasks in the order they arrive, with a 1-second timeout - for each task. It handles different types of system calls: LLM, Memory, Storage, and Tool. + This scheduler processes tasks in the order they arrive. + LLM tasks are batched based on a time interval. Other tasks (Memory, Storage, Tool) + are processed individually as they arrive. Example: ```python @@ -48,7 +49,8 @@ class FIFOScheduler(BaseScheduler): get_llm_syscall=llm_queue.get, get_memory_syscall=memory_queue.get, get_storage_syscall=storage_queue.get, - get_tool_syscall=tool_queue.get + get_tool_syscall=tool_queue.get, + batch_interval=0.1 # Process LLM requests every 100ms ) scheduler.start() ``` @@ -61,14 +63,11 @@ class FIFOScheduler(BaseScheduler): storage_manager: StorageManager, tool_manager: ToolManager, log_mode: str, - # llm_request_queue: LLMRequestQueue, - # memory_request_queue: MemoryRequestQueue, - # storage_request_queue: StorageRequestQueue, - # tool_request_queue: ToolRequestQueue, get_llm_syscall: LLMRequestQueueGetMessage, get_memory_syscall: MemoryRequestQueueGetMessage, get_storage_syscall: StorageRequestQueueGetMessage, get_tool_syscall: ToolRequestQueueGetMessage, + batch_interval: float = 1.0, ): """ Initialize the FIFO Scheduler. @@ -83,6 +82,7 @@ class FIFOScheduler(BaseScheduler): get_memory_syscall: Function to get Memory syscalls get_storage_syscall: Function to get Storage syscalls get_tool_syscall: Function to get Tool syscalls + batch_interval: Time interval in seconds to batch LLM requests. Defaults to 0.1. """ super().__init__( llm, @@ -95,81 +95,89 @@ class FIFOScheduler(BaseScheduler): get_storage_syscall, get_tool_syscall, ) + self.batch_interval = batch_interval - def _execute_syscall( - self, - syscall: Any, + def _execute_batch_syscalls( + self, + batch: List[Any], executor: Any, syscall_type: str - ) -> Optional[Dict[str, Any]]: + ) -> None: """ - Execute a system call with proper status tracking and error handling. - + Execute a batch of system calls with proper status tracking and error handling. + Args: - syscall: The system call to execute - executor: Function to execute the syscall - syscall_type: Type of the syscall for logging - - Returns: - Optional[Dict[str, Any]]: Response from the syscall execution - - Example: - ```python - response = scheduler._execute_syscall( - llm_syscall, - self.llm.execute_llm_syscall, - "LLM" - ) - ``` + batch: The list of system calls to execute + executor: Function to execute the batch of syscalls + syscall_type: Type of the syscalls for logging """ + if not batch: + return + + start_time = time.time() + for syscall in batch: + try: + syscall.set_status("executing") + + logger.info(f"{syscall.agent_name} preparing batched {syscall_type} syscall.") + syscall.set_start_time(start_time) + except Exception as e: + logger.error(f"Error preparing syscall {getattr(syscall, 'agent_name', 'unknown')} for batch execution: {str(e)}") + continue + + if not batch: + logger.warning(f"Empty batch after preparation for {syscall_type}, skipping execution.") + return + + # self.logger.log( + # f"Executing batch of {len(batch)} {syscall_type} syscalls.\n", + # "executing_batch" + # ) + logger.info(f"Executing batch of {len(batch)} {syscall_type} syscalls.") + try: - syscall.set_status("executing") - self.logger.log( - f"{syscall.agent_name} is executing {syscall_type} syscall.\n", - "executing" - ) - syscall.set_start_time(time.time()) + responses = executor(batch) - response = executor(syscall) - syscall.set_response(response) + for i, syscall in enumerate(batch): + logger.info(f"Completed batched {syscall_type} syscall for {syscall.agent_name}. " + f"Thread ID: {syscall.get_pid()}\n") - syscall.event.set() - syscall.set_status("done") - syscall.set_end_time(time.time()) - - self.logger.log( - f"Completed {syscall_type} syscall for {syscall.agent_name}. " - f"Thread ID: {syscall.get_pid()}\n", - "done" - ) - - return response except Exception as e: - logger.error(f"Error executing {syscall_type} syscall: {str(e)}") + logger.error(f"Error executing {syscall_type} syscall batch: {str(e)}") traceback.print_exc() - return None + def process_llm_requests(self) -> None: """ - Process LLM requests from the queue. + Process LLM requests from the queue in batches based on batch_interval. Example: ```python - scheduler.process_llm_requests() - # Processes LLM requests like: - # { - # "messages": [{"role": "user", "content": "Hello"}], - # "temperature": 0.7 - # } + # Collects LLM requests for 0.1 seconds (default) then processes them: + # Batch = [ + # {"messages": [{"role": "user", "content": "Hello"}]}, + # {"messages": [{"role": "user", "content": "World"}]} + # ] + # Calls self.llm.execute_batch_llm_syscall(Batch) ``` """ while self.active: - try: - llm_syscall = self.get_llm_syscall() - self._execute_syscall(llm_syscall, self.llm.execute_llm_syscall, "LLM") - except Empty: - pass + time.sleep(self.batch_interval) + + batch = [] + while True: + try: + llm_syscall = self.get_llm_syscall() + # Add logging here: print(f"Retrieved syscall: {llm_syscall.pid}, Queue size now: {self.llm_queue.qsize()}") + batch.append(llm_syscall) + except Empty: + # This is the expected way to finish collecting the batch + # print(f"Queue empty, finishing batch collection with {len(batch)} items.") + break + + if batch: + self._execute_batch_syscalls(batch, self.llm.execute_llm_syscalls, "LLM") def process_memory_requests(self) -> None: """ diff --git a/aios/scheduler/rr_scheduler.py b/aios/scheduler/rr_scheduler.py index 683414a..24217ab 100644 --- a/aios/scheduler/rr_scheduler.py +++ b/aios/scheduler/rr_scheduler.py @@ -4,11 +4,9 @@ from .base import BaseScheduler - # allows for memory to be shared safely between threads from queue import Queue, Empty - from ..context.simple_context import SimpleContextManager from aios.hooks.types.llm import LLMRequestQueueGetMessage @@ -27,7 +25,7 @@ from threading import Thread from .base import BaseScheduler import logging -from typing import Optional, Dict, Any +from typing import Optional, Dict, Any, List logger = logging.getLogger(__name__) @@ -59,73 +57,54 @@ class RRScheduler(BaseScheduler): super().__init__(*args, **kwargs) self.time_slice = time_slice self.context_manager = SimpleContextManager() - - def _execute_syscall( - self, - syscall: Any, + + def _execute_batch_syscalls( + self, + batch: List[Any], executor: Any, syscall_type: str - ) -> Optional[Dict[str, Any]]: + ) -> None: """ - Execute a system call with time slice enforcement. - + Execute a batch of system calls with proper status tracking and error handling. + Args: - syscall: The system call to execute - executor: Function to execute the syscall - syscall_type: Type of the syscall for logging - - Returns: - Optional[Dict[str, Any]]: Response from the syscall execution - - Example: - ```python - response = scheduler._execute_syscall( - llm_syscall, - self.llm.execute_llm_syscall, - "LLM" - ) - ``` + batch: The list of system calls to execute + executor: Function to execute the batch of syscalls + syscall_type: Type of the syscalls for logging """ - try: - syscall.set_time_limit(self.time_slice) - syscall.set_status("executing") - self.logger.log( - f"{syscall.agent_name} is executing {syscall_type} syscall.\n", - "executing" - ) - syscall.set_start_time(time.time()) + if not batch: + return - response = executor(syscall) - - # breakpoint() - - syscall.set_response(response) - - if response.finished: - syscall.set_status("done") - log_status = "done" - else: - syscall.set_status("suspending") - log_status = "suspending" + start_time = time.time() + for syscall in batch: + try: + syscall.set_status("executing") + syscall.set_time_limit(self.time_slice) - # breakpoint() + logger.info(f"{syscall.agent_name} preparing batched {syscall_type} syscall.") + syscall.set_start_time(start_time) + except Exception as e: + logger.error(f"Error preparing syscall {getattr(syscall, 'agent_name', 'unknown')} for batch execution: {str(e)}") + continue + + if not batch: + message = f"Empty batch after preparation for {syscall_type}, skipping execution." + logger.warning(message) + return + + logger.info(f"Executing batch of {len(batch)} {syscall_type} syscalls.") + + try: + responses = executor(batch) + + for i, syscall in enumerate(batch): + logger.info(f"Completed batched {syscall_type} syscall for {syscall.agent_name}. " + f"Thread ID: {syscall.get_pid()}\n") - syscall.set_end_time(time.time()) - - syscall.event.set() - - self.logger.log( - f"{syscall_type} syscall for {syscall.agent_name} is {log_status}. " - f"Thread ID: {syscall.get_pid()}\n", - log_status - ) - - return response except Exception as e: - logger.error(f"Error executing {syscall_type} syscall: {str(e)}") + logger.error(f"Error executing {syscall_type} syscall batch: {str(e)}") traceback.print_exc() - return None def process_llm_requests(self) -> None: """ @@ -144,7 +123,7 @@ class RRScheduler(BaseScheduler): while self.active: try: llm_syscall = self.get_llm_syscall() - self._execute_syscall(llm_syscall, self.llm.execute_llm_syscall, "LLM") + self._execute_batch_syscalls(llm_syscall, self.llm.execute_llm_syscalls, "LLM") except Empty: pass diff --git a/aios/syscall/syscall.py b/aios/syscall/syscall.py index c03ffee..25e0a3d 100755 --- a/aios/syscall/syscall.py +++ b/aios/syscall/syscall.py @@ -13,11 +13,7 @@ from aios.hooks.stores._global import ( global_llm_req_queue_add_message, global_memory_req_queue_add_message, global_storage_req_queue_add_message, - global_tool_req_queue_add_message, - # global_llm_req_queue, - # global_memory_req_queue, - # global_storage_req_queue, - # global_tool_req_queue, + global_tool_req_queue_add_message ) import threading @@ -113,6 +109,8 @@ class SyscallExecutor: if isinstance(syscall, LLMSyscall): global_llm_req_queue_add_message(syscall) + print(f"Syscall {syscall.agent_name} added to LLM queue") + elif isinstance(syscall, StorageSyscall): global_storage_req_queue_add_message(syscall) elif isinstance(syscall, MemorySyscall): diff --git a/install/install.sh b/install/install.sh index f7967ac..51df4db 100644 --- a/install/install.sh +++ b/install/install.sh @@ -3,7 +3,6 @@ # Installation directory INSTALL_DIR="$HOME/.aios-1" REPO_URL="https://github.com/agiresearch/AIOS" -TAG="v0.2.0.beta" # Replace with your specific tag BIN_DIR="$HOME/.local/bin" # User-level bin directory echo "Installing AIOS..." @@ -26,7 +25,6 @@ git clone "$REPO_URL" "$INSTALL_DIR/src" # Checkout the specific tag cd "$INSTALL_DIR/src" -git checkout tags/"$TAG" -b "$TAG-branch" # Find Python executable path PYTHON_PATH=$(command -v python3) diff --git a/requirements.txt b/requirements.txt index 8db69ff..bf3dc00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,5 @@ redis>=4.5.1 sentence-transformers nltk scikit-learn -pulp \ No newline at end of file +pulp +gdown \ No newline at end of file diff --git a/runtime/launch.py b/runtime/launch.py index 428ede7..39883cf 100644 --- a/runtime/launch.py +++ b/runtime/launch.py @@ -32,6 +32,8 @@ from cerebrum.storage.apis import StorageQuery, StorageResponse from fastapi.middleware.cors import CORSMiddleware +import asyncio + import uvicorn load_dotenv() @@ -340,6 +342,22 @@ def restart_kernel(): print(f"Stack trace: {traceback.format_exc()}") raise +@app.get("/status") +async def get_server_status(): + """Check if the server is running and core components are initialized.""" + inactive_components = [ + component for component, instance in active_components.items() if not instance + ] + + if not inactive_components: + return {"status": "ok", "message": "All core components are active."} + else: + return { + "status": "warning", + "message": f"Server is running, but some components are inactive: {', '.join(inactive_components)}", + "inactive_components": inactive_components + } + @app.post("/core/refresh") async def refresh_configuration(): """Refresh all component configurations""" @@ -601,7 +619,14 @@ async def handle_query(request: QueryRequest): action_type=request.query_data.action_type, message_return_type=request.query_data.message_return_type, ) - return execute_request(request.agent_name, query) + result_dict = await asyncio.to_thread( + execute_request, # The method to call + request.agent_name, # First arg to execute_request + query # Second arg to execute_request + ) + + return result_dict + elif request.query_type == "storage": query = StorageQuery( params=request.query_data.params, diff --git a/tests/modules/llm/ollama/test_single.py b/tests/modules/llm/ollama/test_single.py new file mode 100644 index 0000000..5381950 --- /dev/null +++ b/tests/modules/llm/ollama/test_single.py @@ -0,0 +1,117 @@ +import unittest +from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output +from cerebrum.utils.communication import aios_kernel_url +from cerebrum.utils.utils import _parse_json_output + +class TestAgent: + """ + TestAgent class is responsible for interacting with OpenAI's API using ChatCompletion. + It maintains a conversation history to simulate real dialogue behavior. + """ + def __init__(self, agent_name): + self.agent_name = agent_name + self.messages = [] + + def llm_chat_run(self, task_input): + """Sends the input to the OpenAI API and returns the response.""" + self.messages.append({"role": "user", "content": task_input}) + + tool_response = llm_chat( + agent_name=self.agent_name, + messages=self.messages, + base_url=aios_kernel_url, + llms=[{ + "name": "qwen2.5:7b", + "backend": "ollama" + }] + ) + + return tool_response["response"].get("response_message", "") + + def llm_chat_with_json_output_run(self, task_input): + """Sends the input to the OpenAI API and returns the response.""" + self.messages.append({"role": "user", "content": task_input}) + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "thinking", + "schema": { + "type": "object", + "properties": { + "thinking": { + "type": "string", + }, + "answer": { + "type": "string", + } + } + } + } + } + + tool_response = llm_chat_with_json_output( + agent_name=self.agent_name, + messages=self.messages, + base_url=aios_kernel_url, + llms=[{ + "name": "qwen2.5:7b", + "backend": "ollama" + }], + response_format=response_format + ) + + return tool_response["response"].get("response_message", "") + + +class TestLLMAPI(unittest.TestCase): + """ + Unit tests for OpenAI's API using the TestAgent class. + Each test checks if the API returns a non-empty string response. + Here has various category test cases, including greeting, math question, science question, history question, technology question. + """ + def setUp(self): + self.agent = TestAgent("test_agent") + + def assert_chat_response(self, response): + """Helper method to validate common response conditions.""" + self.assertIsInstance(response, str) + self.assertGreater(len(response), 0) + + def assert_json_output_response(self, response): + """Helper method to validate common response conditions.""" + parsed_response = _parse_json_output(response) + self.assertIsInstance(parsed_response, dict) + self.assertIn("thinking", parsed_response) + self.assertIn("answer", parsed_response) + self.assertIsInstance(parsed_response["thinking"], str) + self.assertIsInstance(parsed_response["answer"], str) + + def test_agent_with_greeting(self): + chat_response = self.agent.llm_chat_run("Hello, how are you?") + self.assert_chat_response(chat_response) + # json_output_response = self.agent.llm_chat_with_json_output_run("Hello, how are you? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + # self.assert_json_output_response(json_output_response) + + def test_agent_with_math_question(self): + chat_response = self.agent.llm_chat_run("What is 25 times 4?") + self.assert_chat_response(chat_response) + # json_output_response = self.agent.llm_chat_with_json_output_run("What is 25 times 4? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + # self.assert_json_output_response(json_output_response) + + + def test_agent_with_history_question(self): + chat_response = self.agent.llm_chat_run("Who was the first president of the United States?") + self.assert_chat_response(chat_response) + # json_output_response = self.agent.llm_chat_with_json_output_run("Who was the first president of the United States? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + # self.assert_json_output_response(json_output_response) + + def test_agent_with_technology_question(self): + chat_response = self.agent.llm_chat_run("What is quantum computing?") + self.assert_chat_response(chat_response) + # json_output_response = self.agent.llm_chat_with_json_output_run("What is quantum computing? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + # self.assert_json_output_response(json_output_response) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/modules/llm/openai/test_concurrent.py b/tests/modules/llm/openai/test_concurrent.py new file mode 100644 index 0000000..435c311 --- /dev/null +++ b/tests/modules/llm/openai/test_concurrent.py @@ -0,0 +1,161 @@ +# test_client.py +import unittest +import requests +import threading +import json +import time +from typing import List, Dict, Any, Tuple + +from cerebrum.llm.apis import llm_chat + +from cerebrum.utils.communication import aios_kernel_url + +# --- Helper function to send a single request --- +def send_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, float]: + """Sends a POST request to the query endpoint and returns status code, response JSON, and duration.""" + start_time = time.time() + try: + response = llm_chat( + agent_name=payload["agent_name"], + messages=payload["query_data"]["messages"], + llms=payload["query_data"]["llms"] if "llms" in payload["query_data"] else None, + ) + end_time = time.time() + duration = end_time - start_time + return response["response"], duration + + except Exception as e: + end_time = time.time() + duration = end_time - start_time + # General unexpected errors + return {"error": f"An unexpected error occurred: {e}"}, duration + + +class TestConcurrentLLMQueries(unittest.TestCase): + + def _run_concurrent_requests(self, payloads: List[Dict[str, Any]]): + results = [None] * len(payloads) + threads = [] + + def worker(index, payload): + response_data, duration = send_request(payload) + results[index] = {"data": response_data, "duration": duration} + print(f"Thread {index}: Completed in {duration:.2f}s with response: {json.dumps(response_data)}") + + print(f"\n--- Running test: {self._testMethodName} ---") + print(f"Sending {len(payloads)} concurrent requests to {aios_kernel_url}...") + for i, payload in enumerate(payloads): + thread = threading.Thread(target=worker, args=(i, payload)) + threads.append(thread) + thread.start() + + for i, thread in enumerate(threads): + thread.join() + print(f"Thread {i} finished.") + + print("--- All threads completed ---") + return results + + def test_both_llms(self): + payload1 = { + "agent_name": "test_agent_1", + "query_type": "llm", + "query_data": { + "llms": [{"name": "gpt-4o", "backend": "openai"}], + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "action_type": "chat", + "message_return_type": "text", + } + } + payload2 = { + "agent_name": "test_agent_2", + "query_type": "llm", + "query_data": { + "llms": [{"name": "gpt-4o-mini", "backend": "openai"}], # Using a different model for variety + "messages": [{"role": "user", "content": "What is the capital of the United States?"}], + "action_type": "chat", + "message_return_type": "text", + } + } + results = self._run_concurrent_requests([payload1, payload2]) + + for i, result in enumerate(results): + print(f"Result {i} (No LLM): {result}") + # Both should succeed using defaults + status, response_message, error_message, finished = result["data"]["status_code"], result["data"]["response_message"], result["data"]["error"], result["data"]["finished"] + + self.assertEqual(status, 200, f"Request {i} (No LLM) should succeed, but failed with status {status}") + self.assertIsNone(error_message, f"Request {i} (No LLM) returned an unexpected error: {error_message}") + self.assertIsInstance(response_message, str, f"Request {i} (No LLM) result is not a string") + self.assertTrue(finished, f"Request {i} (No LLM) result is empty") # Check not empty + + def test_one_llm_one_empty(self): + payload_llm = { + "agent_name": "test_agent_1", + "query_type": "llm", + "query_data": { + "llms": [{"name": "gpt-4o", "backend": "openai"}], + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "action_type": "chat", + "message_return_type": "text", + } + } + payload_no_llm = { + "agent_name": "test_agent_2", + "query_type": "llm", + "query_data": { + # 'llms' key is omitted entirely + "messages": [{"role": "user", "content": "What is the capital of the United States?"}], + "action_type": "chat", + "message_return_type": "text", + } + } + results = self._run_concurrent_requests([payload_llm, payload_no_llm]) + + for i, result in enumerate(results): + print(f"Result {i} (No LLM): {result}") + # Both should succeed using defaults + status, response_message, error_message, finished = result["data"]["status_code"], result["data"]["response_message"], result["data"]["error"], result["data"]["finished"] + + self.assertEqual(status, 200, f"Request {i} (No LLM) should succeed, but failed with status {status}") + self.assertIsNone(error_message, f"Request {i} (No LLM) returned an unexpected error: {error_message}") + self.assertIsInstance(response_message, str, f"Request {i} (No LLM) result is not a string") + self.assertTrue(finished, f"Request {i} (No LLM) result is empty") # Check not empty + + def test_no_llms(self): + """Case 2: Both payloads have no LLMs defined. Should succeed using defaults.""" + payload1 = { + "agent_name": "test_agent_1", + "query_type": "llm", + "query_data": { + # 'llms' key is omitted + "messages": [{"role": "user", "content": "What is the capital of France?"}], + "action_type": "chat", + "message_return_type": "text", + } + } + payload2 = { + "agent_name": "test_agent_2", + "query_type": "llm", + "query_data": { + # 'llms' key is omitted + "messages": [{"role": "user", "content": "What is the capital of the United States?"}], + "action_type": "chat", + "message_return_type": "text", + } + } + results = self._run_concurrent_requests([payload1, payload2]) + + for i, result in enumerate(results): + print(f"Result {i} (No LLM): {result}") + # Both should succeed using defaults + status, response_message, error_message, finished = result["data"]["status_code"], result["data"]["response_message"], result["data"]["error"], result["data"]["finished"] + + self.assertEqual(status, 200, f"Request {i} (No LLM) should succeed, but failed with status {status}") + self.assertIsNone(error_message, f"Request {i} (No LLM) returned an unexpected error: {error_message}") + self.assertIsInstance(response_message, str, f"Request {i} (No LLM) result is not a string") + self.assertTrue(finished, f"Request {i} (No LLM) result is empty") # Check not empty + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/modules/llm/openai/test_single.py b/tests/modules/llm/openai/test_single.py new file mode 100644 index 0000000..34a1bd6 --- /dev/null +++ b/tests/modules/llm/openai/test_single.py @@ -0,0 +1,117 @@ +import unittest +from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output +from cerebrum.utils.communication import aios_kernel_url +from cerebrum.utils.utils import _parse_json_output + +class TestAgent: + """ + TestAgent class is responsible for interacting with OpenAI's API using ChatCompletion. + It maintains a conversation history to simulate real dialogue behavior. + """ + def __init__(self, agent_name): + self.agent_name = agent_name + self.messages = [] + + def llm_chat_run(self, task_input): + """Sends the input to the OpenAI API and returns the response.""" + self.messages.append({"role": "user", "content": task_input}) + + tool_response = llm_chat( + agent_name=self.agent_name, + messages=self.messages, + base_url=aios_kernel_url, + llms=[{ + "name": "gpt-4o-mini", + "backend": "openai" + }] + ) + + return tool_response["response"].get("response_message", "") + + def llm_chat_with_json_output_run(self, task_input): + """Sends the input to the OpenAI API and returns the response.""" + self.messages.append({"role": "user", "content": task_input}) + + response_format = { + "type": "json_schema", + "json_schema": { + "name": "thinking", + "schema": { + "type": "object", + "properties": { + "thinking": { + "type": "string", + }, + "answer": { + "type": "string", + } + } + } + } + } + + tool_response = llm_chat_with_json_output( + agent_name=self.agent_name, + messages=self.messages, + base_url=aios_kernel_url, + llms=[{ + "name": "gpt-4o-mini", + "backend": "openai" + }], + response_format=response_format + ) + + return tool_response["response"].get("response_message", "") + + +class TestLLMAPI(unittest.TestCase): + """ + Unit tests for OpenAI's API using the TestAgent class. + Each test checks if the API returns a non-empty string response. + Here has various category test cases, including greeting, math question, science question, history question, technology question. + """ + def setUp(self): + self.agent = TestAgent("test_agent") + + def assert_chat_response(self, response): + """Helper method to validate common response conditions.""" + self.assertIsInstance(response, str) + self.assertGreater(len(response), 0) + + def assert_json_output_response(self, response): + """Helper method to validate common response conditions.""" + parsed_response = _parse_json_output(response) + self.assertIsInstance(parsed_response, dict) + self.assertIn("thinking", parsed_response) + self.assertIn("answer", parsed_response) + self.assertIsInstance(parsed_response["thinking"], str) + self.assertIsInstance(parsed_response["answer"], str) + + def test_agent_with_greeting(self): + chat_response = self.agent.llm_chat_run("Hello, how are you?") + self.assert_chat_response(chat_response) + json_output_response = self.agent.llm_chat_with_json_output_run("Hello, how are you? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + self.assert_json_output_response(json_output_response) + + def test_agent_with_math_question(self): + chat_response = self.agent.llm_chat_run("What is 25 times 4?") + self.assert_chat_response(chat_response) + json_output_response = self.agent.llm_chat_with_json_output_run("What is 25 times 4? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + self.assert_json_output_response(json_output_response) + + + def test_agent_with_history_question(self): + chat_response = self.agent.llm_chat_run("Who was the first president of the United States?") + self.assert_chat_response(chat_response) + json_output_response = self.agent.llm_chat_with_json_output_run("Who was the first president of the United States? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + self.assert_json_output_response(json_output_response) + + def test_agent_with_technology_question(self): + chat_response = self.agent.llm_chat_run("What is quantum computing?") + self.assert_chat_response(chat_response) + json_output_response = self.agent.llm_chat_with_json_output_run("What is quantum computing? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.") + self.assert_json_output_response(json_output_response) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/modules/llm/test_ollama.py b/tests/modules/llm/test_ollama.py deleted file mode 100644 index 9073e90..0000000 --- a/tests/modules/llm/test_ollama.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest -from cerebrum.llm.apis import llm_chat -from cerebrum.utils.communication import aios_kernel_url - -class TestAgent: - """ - TestAgent class is responsible for interacting with the LLM API using llm_chat. - It maintains a conversation history to simulate real dialogue behavior. - """ - def __init__(self, agent_name): - self.agent_name = agent_name - self.messages = [] - - def run(self, task_input): - """Sends the input to the LLM API and returns the response.""" - self.messages.append({"role": "user", "content": task_input}) - - tool_response = llm_chat( - agent_name=self.agent_name, - messages=self.messages, - base_url=aios_kernel_url - ) - - return tool_response["response"].get("response_message", "") - - -class TestLLMAPI(unittest.TestCase): - """ - Unit tests for the LLM API using the TestAgent class. - Each test checks if the API returns a non-empty string response. - """ - def setUp(self): - self.agent = TestAgent("test_agent") - - def assert_valid_response(self, response): - """Helper method to validate common response conditions.""" - self.assertIsInstance(response, str) - self.assertGreater(len(response), 0) - - def test_agent_with_greeting(self): - response = self.agent.run("Hello, how are you?") - self.assert_valid_response(response) - - def test_agent_with_math_question(self): - response = self.agent.run("What is 25 times 4?") - self.assert_valid_response(response) - - def test_agent_with_science_question(self): - response = self.agent.run("Explain the theory of relativity.") - self.assert_valid_response(response) - - def test_agent_with_history_question(self): - response = self.agent.run("Who was the first president of the United States?") - self.assert_valid_response(response) - - def test_agent_with_technology_question(self): - response = self.agent.run("What is quantum computing?") - self.assert_valid_response(response) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/modules/llm/test_openai.py b/tests/modules/llm/test_openai.py deleted file mode 100644 index 5f0b737..0000000 --- a/tests/modules/llm/test_openai.py +++ /dev/null @@ -1,112 +0,0 @@ -import unittest -from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output -from cerebrum.utils.communication import aios_kernel_url - -class TestAgent: - """ - TestAgent class is responsible for interacting with OpenAI's API using ChatCompletion. - It maintains a conversation history to simulate real dialogue behavior. - """ - def __init__(self, agent_name, api_key): - self.agent_name = agent_name - self.api_key = api_key - self.messages = [] - - def llm_chat_run(self, task_input): - """Sends the input to the OpenAI API and returns the response.""" - self.messages.append({"role": "user", "content": task_input}) - - tool_response = llm_chat( - agent_name=self.agent_name, - messages=self.messages, - base_url=aios_kernel_url, - llms=[{ - "name": "gpt-4o-mini", - "backend": "openai" - }] - ) - - return tool_response["response"].get("response_message", "") - - def llm_chat_with_json_output_run(self, task_input): - """Sends the input to the OpenAI API and returns the response.""" - self.messages.append({"role": "user", "content": task_input}) - - response_format = { - "type": "json_schema", - "json_schema": { - "name": "keywords", - "schema": { - "type": "object", - "properties": { - "keywords": { - "type": "array", - "items": {"type": "string"} - } - } - } - } - } - - tool_response = llm_chat_with_json_output( - agent_name=self.agent_name, - messages=self.messages, - base_url=aios_kernel_url, - llms=[{ - "name": "gpt-4o-mini", - "backend": "openai" - }], - response_format=response_format - ) - - return tool_response["response"].get("response_message", "") - - -class TestLLMAPI(unittest.TestCase): - """ - Unit tests for OpenAI's API using the TestAgent class. - Each test checks if the API returns a non-empty string response. - Here has various category test cases, including greeting, math question, science question, history question, technology question. - """ - def setUp(self): - self.api_key = "your-openai-api-key" # Replace with your actual OpenAI API key - self.agent = TestAgent("test_agent", self.api_key) - - def assert_valid_response(self, response): - """Helper method to validate common response conditions.""" - self.assertIsInstance(response, str) - self.assertGreater(len(response), 0) - - def test_agent_with_greeting(self): - response = self.agent.llm_chat_run("Hello, how are you?") - self.assert_valid_response(response) - response = self.agent.llm_chat_with_json_output_run("Hello, how are you?") - self.assert_valid_response(response) - - def test_agent_with_math_question(self): - response = self.agent.llm_chat_run("What is 25 times 4?") - self.assert_valid_response(response) - response = self.agent.llm_chat_with_json_output_run("What is 25 times 4?") - self.assert_valid_response(response) - - def test_agent_with_science_question(self): - response = self.agent.llm_chat_run("Explain the theory of relativity.") - self.assert_valid_response(response) - response = self.agent.llm_chat_with_json_output_run("Explain the theory of relativity.") - self.assert_valid_response(response) - - def test_agent_with_history_question(self): - response = self.agent.llm_chat_run("Who was the first president of the United States?") - self.assert_valid_response(response) - response = self.agent.llm_chat_with_json_output_run("Who was the first president of the United States?") - self.assert_valid_response(response) - - def test_agent_with_technology_question(self): - response = self.agent.llm_chat_run("What is quantum computing?") - self.assert_valid_response(response) - response = self.agent.llm_chat_with_json_output_run("What is quantum computing?") - self.assert_valid_response(response) - - -if __name__ == "__main__": - unittest.main()