refactor llm routing and update tests for routers (#491)
* update * update * update launch * add more error messages for adapter * fix smart routing * update config * update * update test * update test * update test * update workflow * update * update * update * update * update * update * update * update
This commit is contained in:
21
.github/workflows/cancel-workflow.yml
vendored
Normal file
21
.github/workflows/cancel-workflow.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Cancel PR Workflows on Merge
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [closed] # runs when the PR is closed (merged or not)
|
||||
|
||||
permissions:
|
||||
actions: write # lets the workflow cancel other runs
|
||||
|
||||
jobs:
|
||||
cancel:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Cancel previous PR runs
|
||||
uses: styfle/cancel-workflow-action@v1
|
||||
with:
|
||||
workflow_id: all # cancel every workflow in this repo
|
||||
access_token: ${{ github.token }} # default token is fine
|
||||
pr_number: ${{ github.event.pull_request.number }}
|
||||
ignore_sha: true # required for pull_request.closed events
|
||||
90
.github/workflows/test-ollama.yml
vendored
Normal file
90
.github/workflows/test-ollama.yml
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
# .github/workflows/test-ollama.yml
|
||||
name: Test Ollama
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
actions: write
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
OLLAMA_MAX_WAIT: 60
|
||||
KERNEL_MAX_WAIT: 60
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
|
||||
- name: Clone Cerebrum
|
||||
uses: sudosubin/git-clone-action@v1.0.1
|
||||
with:
|
||||
repository: agiresearch/Cerebrum
|
||||
path: Cerebrum
|
||||
|
||||
- name: Install Cerebrum (editable)
|
||||
run: python -m pip install -e Cerebrum/
|
||||
|
||||
- name: Install Ollama
|
||||
run: curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: Pull Ollama model
|
||||
run: ollama pull qwen2.5:7b
|
||||
|
||||
- name: Start Ollama
|
||||
run: |
|
||||
ollama serve > ollama-llm.log 2>&1 &
|
||||
for ((i=1;i<=${OLLAMA_MAX_WAIT};i++)); do
|
||||
curl -s http://localhost:11434/api/version && echo "Ollama ready" && break
|
||||
echo "Waiting for Ollama… ($i/${OLLAMA_MAX_WAIT})"; sleep 1
|
||||
done
|
||||
curl -s http://localhost:11434/api/version > /dev/null \
|
||||
|| { echo "❌ Ollama failed to start"; exit 1; }
|
||||
|
||||
- name: Start AIOS kernel
|
||||
run: |
|
||||
bash runtime/launch_kernel.sh > kernel.log 2>&1 &
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
mkdir -p test_results
|
||||
mapfile -t TESTS < <(find . -type f -path "*/llm/ollama/*" -name "*.py")
|
||||
if [ "${#TESTS[@]}" -eq 0 ]; then
|
||||
echo "⚠️ No llm/ollama tests found – skipping."
|
||||
exit 0
|
||||
fi
|
||||
for t in "${TESTS[@]}"; do
|
||||
echo "▶️ Running $t"
|
||||
python "$t" | tee -a test_results/ollama_tests.log
|
||||
echo "----------------------------------------"
|
||||
done
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: logs
|
||||
path: |
|
||||
ollama-llm.log
|
||||
kernel.log
|
||||
test_results/
|
||||
156
.github/workflows/test_ollama.yml
vendored
156
.github/workflows/test_ollama.yml
vendored
@@ -1,156 +0,0 @@
|
||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
||||
|
||||
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
||||
|
||||
name: AIOS Application
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ "main" ]
|
||||
pull_request:
|
||||
branches: [ "main" ]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
||||
|
||||
- name: Git Clone Action
|
||||
# You may pin to the exact commit or the version.
|
||||
# uses: sudosubin/git-clone-action@8a93ce24d47782e30077508cccacf8a05a891bae
|
||||
uses: sudosubin/git-clone-action@v1.0.1
|
||||
with:
|
||||
# Repository owner and name. Ex: sudosubin/git-clone-action
|
||||
repository: agiresearch/Cerebrum
|
||||
path: Cerebrum
|
||||
|
||||
- name: Install cerebrum special edition
|
||||
run: |
|
||||
python -m pip install -e Cerebrum/
|
||||
|
||||
- name: Download and install Ollama
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: Pull Ollama models
|
||||
run: |
|
||||
ollama pull qwen2.5:7b
|
||||
|
||||
- name: Run Ollama serve
|
||||
run: |
|
||||
ollama serve 2>&1 | tee ollama-llm.log &
|
||||
# Wait for ollama to start
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:11434/api/version > /dev/null; then
|
||||
echo "Ollama is running"
|
||||
break
|
||||
fi
|
||||
echo "Waiting for ollama to start... ($i/30)"
|
||||
sleep 1
|
||||
done
|
||||
# Verify ollama is running
|
||||
curl -s http://localhost:11434/api/version || (echo "Failed to start ollama" && exit 1)
|
||||
|
||||
- name: Run AIOS kernel in background
|
||||
run: |
|
||||
bash runtime/launch_kernel.sh &>logs &
|
||||
KERNEL_PID=$!
|
||||
|
||||
# Set maximum wait time (60 seconds)
|
||||
max_wait=60
|
||||
start_time=$SECONDS
|
||||
|
||||
# Dynamically check if the process is running until it succeeds or times out
|
||||
while true; do
|
||||
if ! ps -p $KERNEL_PID > /dev/null; then
|
||||
echo "Kernel process died. Checking logs:"
|
||||
cat logs
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if nc -z localhost 8000; then
|
||||
if curl -s http://localhost:8000/health > /dev/null; then
|
||||
echo "Kernel successfully started and healthy"
|
||||
break
|
||||
fi
|
||||
fi
|
||||
|
||||
# Check if timed out
|
||||
elapsed=$((SECONDS - start_time))
|
||||
if [ $elapsed -ge $max_wait ]; then
|
||||
echo "Timeout after ${max_wait} seconds. Kernel failed to start properly."
|
||||
cat logs
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Waiting for kernel to start... (${elapsed}s elapsed)"
|
||||
sleep 1
|
||||
done
|
||||
|
||||
- name: Run all tests
|
||||
run: |
|
||||
# Create test results directory
|
||||
mkdir -p test_results
|
||||
|
||||
# Function to check if a path contains agent or llm
|
||||
contains_agent_or_llm() {
|
||||
local dir_path=$(dirname "$1")
|
||||
if [[ "$dir_path" == *"agent"* || "$dir_path" == *"llm"* ]]; then
|
||||
return 0 # True in bash
|
||||
else
|
||||
return 1 # False in bash
|
||||
fi
|
||||
}
|
||||
|
||||
# Process test files
|
||||
find tests -type f -name "*.py" | while read -r test_file; do
|
||||
if contains_agent_or_llm "$test_file"; then
|
||||
# For agent or llm directories, only run ollama tests
|
||||
if [[ "$test_file" == *"ollama"* ]]; then
|
||||
echo "Running Ollama test in agent/llm directory: $test_file"
|
||||
python $test_file | tee -a ollama_tests.log
|
||||
echo "----------------------------------------"
|
||||
fi
|
||||
else
|
||||
# For other directories, run all tests
|
||||
echo "Running test: $test_file"
|
||||
python $test_file | tee -a all_tests.log
|
||||
echo "----------------------------------------"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Upload a Build Artifact
|
||||
if: always() # Upload logs even if job fails
|
||||
uses: actions/upload-artifact@v4.4.3
|
||||
with:
|
||||
name: logs
|
||||
path: |
|
||||
logs
|
||||
agent.log
|
||||
|
||||
- name: Collect debug information
|
||||
if: failure()
|
||||
run: |
|
||||
echo "=== Kernel Logs ==="
|
||||
cat logs
|
||||
echo "=== Environment Variables ==="
|
||||
env | grep -i api_key || true
|
||||
echo "=== Process Status ==="
|
||||
ps aux | grep kernel
|
||||
@@ -18,9 +18,12 @@ llms:
|
||||
# - name: "gpt-4o-mini"
|
||||
# backend: "openai"
|
||||
|
||||
# - name: "gpt-4o"
|
||||
# backend: "openai"
|
||||
|
||||
|
||||
# Google Models
|
||||
# - name: "gemini-1.5-flash"
|
||||
# - name: "gemini-2.0-flash"
|
||||
# backend: "google"
|
||||
|
||||
|
||||
@@ -47,26 +50,27 @@ llms:
|
||||
# backend: "vllm"
|
||||
# hostname: "http://localhost:8091"
|
||||
|
||||
router:
|
||||
strategy: "sequential"
|
||||
bootstrap_url: "https://drive.google.com/file/d/1SF7MAvtnsED7KMeMdW3JDIWYNGPwIwL7/view"
|
||||
|
||||
|
||||
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
use_context_manager: false
|
||||
|
||||
memory:
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
|
||||
storage:
|
||||
root_dir: "root"
|
||||
use_vector_db: true
|
||||
|
||||
scheduler:
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
|
||||
agent_factory:
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
max_workers: 64
|
||||
|
||||
server:
|
||||
host: "localhost"
|
||||
port: 8000
|
||||
port: 8000
|
||||
|
||||
@@ -6,7 +6,6 @@ api_keys:
|
||||
gemini: "" # Google Gemini API key
|
||||
groq: "" # Groq API key
|
||||
anthropic: "" # Anthropic API key
|
||||
novita: "" # Novita AI API key
|
||||
huggingface:
|
||||
auth_token: "" # Your HuggingFace auth token for authorized models
|
||||
cache_dir: "" # Your cache directory for saving huggingface models
|
||||
@@ -16,15 +15,15 @@ llms:
|
||||
models:
|
||||
|
||||
# OpenAI Models
|
||||
- name: "gpt-4o"
|
||||
backend: "openai"
|
||||
# - name: "gpt-4o"
|
||||
# backend: "openai"
|
||||
|
||||
- name: "gpt-4o-mini"
|
||||
backend: "openai"
|
||||
# - name: "gpt-4o-mini"
|
||||
# backend: "openai"
|
||||
|
||||
# Google Models
|
||||
- name: "gemini-2.0-flash"
|
||||
backend: "google"
|
||||
# - name: "gemini-2.0-flash"
|
||||
# backend: "google"
|
||||
|
||||
|
||||
# Anthropic Models
|
||||
@@ -32,9 +31,9 @@ llms:
|
||||
# backend: "anthropic"
|
||||
|
||||
# Ollama Models
|
||||
- name: "qwen2.5:72b"
|
||||
- name: "qwen2.5:7b"
|
||||
backend: "ollama"
|
||||
hostname: "http://localhost:8091" # Make sure to run ollama server
|
||||
hostname: "http://localhost:11434" # Make sure to run ollama server
|
||||
|
||||
# HuggingFace Models
|
||||
# - name: "meta-llama/Llama-3.1-8B-Instruct"
|
||||
@@ -43,7 +42,7 @@ llms:
|
||||
# eval_device: "cuda:0" # Device for model evaluation
|
||||
|
||||
# vLLM Models
|
||||
# To use vllm as backend, you need to install vllm and run the vllm server: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
||||
# To use vllm as backend, you need to install vllm and run the vllm server https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
|
||||
# An example command to run the vllm server is:
|
||||
# vllm serve meta-llama/Llama-3.2-3B-Instruct --port 8091
|
||||
# - name: "meta-llama/Llama-3.1-8B-Instruct"
|
||||
@@ -51,34 +50,33 @@ llms:
|
||||
# hostname: "http://localhost:8091"
|
||||
|
||||
# SGLang Models
|
||||
# To use sglang as backend, you need to install sglang and run the sglang server: https://docs.sglang.ai/backend/openai_api_completions.html
|
||||
# python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-72B-Instruct --grammar-backend outlines --tool-call-parser qwen25 --host 127.0.0.1 --port 30001 --tp 4 --disable-custom-all-reduce
|
||||
# python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --tool-call-parser qwen25 --host 127.0.0.1 --port 30001
|
||||
|
||||
- name: "Qwen/Qwen2.5-72B-Instruct"
|
||||
backend: "sglang"
|
||||
hostname: "http://localhost:30001/v1"
|
||||
# - name: "Qwen/Qwen2.5-VL-72B-Instruct"
|
||||
# backend: "sglang"
|
||||
# hostname: "http://localhost:30001/v1"
|
||||
|
||||
# Novita Models
|
||||
- name: "meta-llama/llama-4-scout-17b-16e-instruct"
|
||||
backend: "novita"
|
||||
router:
|
||||
strategy: "smart" # choose from [sequential, smart]
|
||||
bootstrap_url: "https://drive.google.com/file/d/1SF7MAvtnsED7KMeMdW3JDIWYNGPwIwL7/view"
|
||||
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
use_context_manager: false
|
||||
|
||||
memory:
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
|
||||
storage:
|
||||
root_dir: "root"
|
||||
use_vector_db: true
|
||||
|
||||
scheduler:
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
|
||||
agent_factory:
|
||||
log_mode: "console"
|
||||
log_mode: "console" # choose from [console, file]
|
||||
max_workers: 64
|
||||
|
||||
server:
|
||||
host: "localhost"
|
||||
host: "0.0.0.0"
|
||||
port: 8000
|
||||
|
||||
@@ -162,6 +162,15 @@ class ConfigManager:
|
||||
"""
|
||||
return self.config.get('llms', {})
|
||||
|
||||
def get_router_config(self) -> dict:
|
||||
"""
|
||||
Retrieves the router configuration settings.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing router configurations
|
||||
"""
|
||||
return self.config.get("llms", {}).get("router", {})
|
||||
|
||||
def get_storage_config(self) -> dict:
|
||||
"""
|
||||
Retrieves the storage configuration settings.
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,11 +0,0 @@
|
||||
[
|
||||
{"name": "Qwen-2.5-7B-Instruct", "hosted_model": "ollama/qwen2.5:7b", "cost_per_input_token": 0.000000267, "cost_per_output_token": 0.000000267},
|
||||
{"name": "Qwen-2.5-14B-Instruct", "hosted_model": "ollama/qwen2.5:14b", "cost_per_input_token": 0.000000534, "cost_per_output_token": 0.000000534},
|
||||
{"name": "Qwen-2.5-32B-Instruct", "hosted_model": "ollama/qwen2.5:32b", "cost_per_input_token": 0.00000122, "cost_per_output_token": 0.00000122},
|
||||
{"name": "Llama-3.1-8B-Instruct", "hosted_model": "ollama/llama3.1:8b", "cost_per_input_token": 0.000000305, "cost_per_output_token": 0.000000305},
|
||||
{"name": "Deepseek-r1-7b", "hosted_model": "ollama/deepseek-r1:7b", "cost_per_input_token": 0.000000267, "cost_per_output_token": 0.000000267},
|
||||
{"name": "Deepseek-r1-14b", "hosted_model": "ollama/deepseek-r1:14b", "cost_per_input_token": 0.000000534, "cost_per_output_token": 0.000000534},
|
||||
{"name": "gpt-4o-mini", "hosted_model": "gpt-4o-mini", "cost_per_input_token": 0.00000015, "cost_per_output_token": 0.0000006},
|
||||
{"name": "gpt-4o", "hosted_model": "gpt-4o", "cost_per_input_token": 0.0000025, "cost_per_output_token": 0.00001},
|
||||
{"name": "gemini-1.5-flash", "hosted_model": "gemini/gemini-1.5-flash", "cost_per_input_token": 0.000000075, "cost_per_output_token": 0.0000003}
|
||||
]
|
||||
@@ -12,6 +12,8 @@ import json
|
||||
|
||||
from threading import Lock
|
||||
|
||||
import openai
|
||||
|
||||
import os
|
||||
|
||||
from pulp import (
|
||||
@@ -23,6 +25,13 @@ from pulp import (
|
||||
value
|
||||
)
|
||||
|
||||
import litellm
|
||||
|
||||
import tempfile
|
||||
import gdown
|
||||
|
||||
from litellm import token_counter
|
||||
|
||||
"""
|
||||
Load balancing strategies. Each class represents a strategy which returns the
|
||||
next endpoint that the router should use.
|
||||
@@ -37,9 +46,9 @@ used whenever the strategy is called in __call__, and then return the name of
|
||||
the specific LLM endpoint.
|
||||
"""
|
||||
|
||||
class RouterStrategy(Enum):
|
||||
Sequential = 0,
|
||||
Smart = 1
|
||||
class RouterStrategy:
|
||||
Sequential = "sequential"
|
||||
Smart = "smart"
|
||||
|
||||
class SequentialRouting:
|
||||
"""
|
||||
@@ -75,13 +84,13 @@ class SequentialRouting:
|
||||
# def __call__(self):
|
||||
# return self.get_model()
|
||||
|
||||
def get_model_idxs(self, selected_llms: List[str], n_queries: int=1):
|
||||
def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries: List[List[Dict[str, Any]]]):
|
||||
"""
|
||||
Selects model indices from the available LLM configurations using a round-robin strategy.
|
||||
|
||||
Args:
|
||||
selected_llms (List[str]): A list of selected LLM names from which models will be chosen.
|
||||
n_queries (int): The number of queries to distribute among the selected models. Defaults to 1.
|
||||
selected_llm_lists (List[List[str]]): A list of selected LLM names from which models will be chosen.
|
||||
queries (List[List[Dict[str, Any]]]): A list of queries to distribute among the selected models.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of indices corresponding to the selected models in `self.llm_configs`.
|
||||
@@ -90,477 +99,304 @@ class SequentialRouting:
|
||||
# current = self.selected_llms[self.idx]
|
||||
model_idxs = []
|
||||
|
||||
# breakpoint()
|
||||
available_models = [llm.name for llm in self.llm_configs]
|
||||
|
||||
for _ in range(n_queries):
|
||||
current = selected_llms[self.idx]
|
||||
for i, llm_config in enumerate(self.llm_configs):
|
||||
# breakpoint()
|
||||
if llm_config["name"] == current["name"]:
|
||||
model_idxs.append(i)
|
||||
n_queries = len(queries)
|
||||
|
||||
for i in range(n_queries):
|
||||
selected_llm_list = selected_llm_lists[i]
|
||||
|
||||
if not selected_llm_list or len(selected_llm_list) == 0:
|
||||
model_idxs.append(0)
|
||||
continue
|
||||
|
||||
model_idx = -1
|
||||
for selected_llm in selected_llm_list:
|
||||
if selected_llm["name"] in available_models:
|
||||
model_idx = available_models.index(selected_llm["name"])
|
||||
break
|
||||
self.idx = (self.idx + 1) % len(selected_llms)
|
||||
|
||||
# breakpoint()
|
||||
|
||||
model_idxs.append(model_idx)
|
||||
|
||||
return model_idxs
|
||||
|
||||
bucket_size = max_output_length / num_buckets
|
||||
|
||||
total_queries = len(test_data)
|
||||
model_metrics = defaultdict(lambda: {
|
||||
"correct_predictions": 0,
|
||||
"total_predictions": 0,
|
||||
"correct_length_predictions": 0,
|
||||
"total_length_predictions": 0,
|
||||
"correct_bucket_predictions": 0,
|
||||
})
|
||||
|
||||
# Initialize metrics for each model
|
||||
model_metrics = defaultdict(lambda: {
|
||||
"correct_predictions": 0,
|
||||
"total_predictions": 0,
|
||||
"correct_length_predictions": 0,
|
||||
"total_length_predictions": 0,
|
||||
"correct_bucket_predictions": 0
|
||||
})
|
||||
|
||||
for test_item in tqdm(test_data, desc="Evaluating"):
|
||||
similar_results = self.query_similar(
|
||||
test_item["query"],
|
||||
"train",
|
||||
n_results=n_similar
|
||||
)
|
||||
|
||||
model_stats = defaultdict(lambda: {
|
||||
"total_length": 0,
|
||||
"count": 0,
|
||||
"correct_count": 0
|
||||
})
|
||||
|
||||
|
||||
for metadata in similar_results["metadatas"][0]:
|
||||
for model_info in json.loads(metadata["models"]):
|
||||
model_name = model_info["model_name"]
|
||||
stats = model_stats[model_name]
|
||||
stats["total_length"] += model_info["output_token_length"]
|
||||
stats["count"] += 1
|
||||
stats["correct_count"] += int(model_info["correctness"])
|
||||
|
||||
for model_output in test_item["outputs"]:
|
||||
model_name = model_output["model_name"]
|
||||
if model_name in model_stats:
|
||||
stats = model_stats[model_name]
|
||||
|
||||
if stats["count"] > 0:
|
||||
predicted_correctness = stats["correct_count"] / stats["count"] >= 0.5
|
||||
actual_correctness = model_output["correctness"]
|
||||
|
||||
if predicted_correctness == actual_correctness:
|
||||
model_metrics[model_name]["correct_predictions"] += 1
|
||||
|
||||
predicted_length = stats["total_length"] / stats["count"]
|
||||
|
||||
predicted_bucket = min(int(predicted_length / bucket_size), num_buckets - 1)
|
||||
|
||||
actual_length = model_output["output_token_length"]
|
||||
|
||||
# length_error = abs(predicted_length - actual_length)
|
||||
actual_bucket = min(int(actual_length / bucket_size), num_buckets - 1)
|
||||
|
||||
if predicted_bucket == actual_bucket:
|
||||
model_metrics[model_name]["correct_length_predictions"] += 1
|
||||
|
||||
if abs(predicted_bucket - actual_bucket) <= 1:
|
||||
model_metrics[model_name]["correct_bucket_predictions"] += 1
|
||||
|
||||
model_metrics[model_name]["total_predictions"] += 1
|
||||
model_metrics[model_name]["total_length_predictions"] += 1
|
||||
|
||||
results = {}
|
||||
|
||||
# Calculate overall accuracy across all models
|
||||
total_correct_predictions = sum(metrics["correct_predictions"] for metrics in model_metrics.values())
|
||||
total_correct_length_predictions = sum(metrics["correct_length_predictions"] for metrics in model_metrics.values())
|
||||
total_correct_bucket_predictions = sum(metrics["correct_bucket_predictions"] for metrics in model_metrics.values())
|
||||
total_predictions = sum(metrics["total_predictions"] for metrics in model_metrics.values())
|
||||
|
||||
if total_predictions > 0:
|
||||
results["overall"] = {
|
||||
"performance_accuracy": total_correct_predictions / total_predictions,
|
||||
"length_accuracy": total_correct_length_predictions / total_predictions,
|
||||
"bucket_accuracy": total_correct_bucket_predictions / total_predictions
|
||||
}
|
||||
for model_name, metrics in model_metrics.items():
|
||||
total = metrics["total_predictions"]
|
||||
if total > 0:
|
||||
results[model_name] = {
|
||||
"performance_accuracy": metrics["correct_predictions"] / total,
|
||||
"length_accuracy": metrics["correct_length_predictions"] / total,
|
||||
"bucket_accuracy": metrics["correct_bucket_predictions"] / total
|
||||
}
|
||||
results["overall"] = {
|
||||
"performance_accuracy": total_correct_predictions / total_predictions,
|
||||
"length_accuracy": total_correct_length_predictions / total_predictions,
|
||||
"bucket_accuracy": total_correct_bucket_predictions / total_predictions
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
def get_cost_per_token(model_name: str) -> tuple[float, float]:
|
||||
"""Fetch the latest *per‑token* input/output pricing from LiteLLM.
|
||||
|
||||
This pulls the live `model_cost` map which LiteLLM refreshes from
|
||||
`api.litellm.ai`, so you always have current pricing information.
|
||||
If the model is unknown, graceful fall‑back to zero cost.
|
||||
"""
|
||||
cost_map = litellm.model_cost # Live dict {model: {input_cost_per_token, output_cost_per_token, ...}}
|
||||
info = cost_map.get(model_name, {})
|
||||
return info.get("input_cost_per_token", 0.0), info.get("output_cost_per_token", 0.0)
|
||||
|
||||
def get_token_lengths(queries: List[List[Dict[str, Any]]]):
|
||||
"""
|
||||
Get the token lengths of a list of queries.
|
||||
"""
|
||||
return [token_counter(model="gpt-4o-mini", messages=query) for query in queries]
|
||||
|
||||
def messages_to_query(messages: List[Dict[str, str]],
|
||||
strategy: str = "last_user") -> str:
|
||||
"""
|
||||
Convert OpenAI ChatCompletion-style messages into a single query string.
|
||||
strategy:
|
||||
- "last_user": last user message only
|
||||
- "concat_users": concat all user messages
|
||||
- "concat_all": concat role-labelled full history
|
||||
- "summarize": use GPT to summarize into a short query
|
||||
"""
|
||||
if strategy == "last_user":
|
||||
for msg in reversed(messages):
|
||||
if msg["role"] == "user":
|
||||
return msg["content"].strip()
|
||||
return "" # fallback
|
||||
|
||||
if strategy == "concat_users":
|
||||
return "\n\n".join(m["content"].strip()
|
||||
for m in messages if m["role"] == "user")
|
||||
|
||||
if strategy == "concat_all":
|
||||
return "\n\n".join(f'{m["role"].upper()}: {m["content"].strip()}'
|
||||
for m in messages)
|
||||
|
||||
if strategy == "summarize":
|
||||
full_text = "\n\n".join(f'{m["role"]}: {m["content"]}'
|
||||
for m in messages)
|
||||
rsp = openai.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[
|
||||
{"role": "system",
|
||||
"content": ("You are a helpful assistant that rewrites a "
|
||||
"dialogue into a concise search query.")},
|
||||
{"role": "user", "content": full_text}
|
||||
],
|
||||
max_tokens=32,
|
||||
temperature=0.2
|
||||
)
|
||||
return rsp.choices[0].message.content.strip()
|
||||
|
||||
raise ValueError("unknown strategy")
|
||||
|
||||
class SmartRouting:
|
||||
"""
|
||||
The SmartRouting class implements a cost-performance optimized selection strategy for LLM requests.
|
||||
It uses historical performance data to predict which models will perform best for a given query
|
||||
while minimizing cost.
|
||||
"""Cost‑/performance‑aware LLM router.
|
||||
|
||||
This strategy ensures that models are selected based on their predicted performance and cost,
|
||||
optimizing for both quality and efficiency.
|
||||
|
||||
Args:
|
||||
llm_configs (List[Dict[str, Any]]): A list of LLM configurations, where each dictionary contains
|
||||
model information such as name, backend, cost parameters, etc.
|
||||
performance_requirement (float): The minimum performance score required (default: 0.7)
|
||||
n_similar (int): Number of similar queries to retrieve for prediction (default: 16)
|
||||
Two **key upgrades** compared with the original version:
|
||||
1. **Bootstrap local ChromaDB** automatically on first run by pulling a
|
||||
prepared JSONL corpus from Google Drive (uses `gdown`).
|
||||
2. **Live pricing**: no more hard‑coded token prices – we call LiteLLM to
|
||||
fetch up‑to‑date `input_cost_per_token` / `output_cost_per_token` for
|
||||
every model the moment we need them.
|
||||
"""
|
||||
def __init__(self, llm_configs: List[Dict[str, Any]], performance_requirement: float=0.7, n_similar: int=16):
|
||||
self.num_buckets = 10
|
||||
self.max_output_limit = 1024
|
||||
self.n_similar = n_similar
|
||||
self.bucket_size = self.max_output_limit / self.num_buckets
|
||||
self.performance_requirement = performance_requirement
|
||||
|
||||
print(f"Performance requirement: {self.performance_requirement}")
|
||||
|
||||
self.llm_configs = llm_configs
|
||||
self.store = self.QueryStore()
|
||||
self.lock = Lock()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Construction helpers
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
class QueryStore:
|
||||
"""
|
||||
Internal class for storing and retrieving query embeddings and related model performance data.
|
||||
Uses ChromaDB for vector similarity search.
|
||||
"""
|
||||
def __init__(self,
|
||||
model_name: str = "BAAI/bge-small-en-v1.5",
|
||||
persist_directory: str = "llm_router"):
|
||||
|
||||
file_path = os.path.join(os.path.dirname(__file__), persist_directory)
|
||||
self.client = chromadb.PersistentClient(path=file_path)
|
||||
|
||||
self.embedding_function = embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=model_name
|
||||
)
|
||||
"""Simple wrapper around ChromaDB persistent collections."""
|
||||
|
||||
def __init__(self,
|
||||
model_name: str = "all-MiniLM-L6-v2",
|
||||
persist_directory: str = "llm_router",
|
||||
bootstrap_url: str | None = None):
|
||||
self._persist_root = os.path.join(os.path.dirname(__file__), persist_directory)
|
||||
os.makedirs(self._persist_root, exist_ok=True)
|
||||
|
||||
self.client = chromadb.PersistentClient(path=self._persist_root)
|
||||
self.embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
||||
|
||||
# Always create/get collections up‑front so we can inspect counts.
|
||||
# self.train_collection = self._get_or_create_collection("train_queries")
|
||||
# self.val_collection = self._get_or_create_collection("val_queries")
|
||||
# self.test_collection = self._get_or_create_collection("test_queries")
|
||||
self.collection = self._get_or_create_collection("historical_queries")
|
||||
|
||||
# If DB is empty and we have a bootstrap URL – populate it.
|
||||
if bootstrap_url and self.collection.count() == 0:
|
||||
self._bootstrap_from_drive(bootstrap_url)
|
||||
|
||||
# .................................................................
|
||||
# Chroma helpers
|
||||
# .................................................................
|
||||
|
||||
def _get_or_create_collection(self, name: str):
|
||||
try:
|
||||
self.train_collection = self.client.get_collection(
|
||||
name="train_queries",
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
except:
|
||||
self.train_collection = self.client.create_collection(
|
||||
name="train_queries",
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
|
||||
try:
|
||||
self.val_collection = self.client.get_collection(
|
||||
name="val_queries",
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
except:
|
||||
self.val_collection = self.client.create_collection(
|
||||
name="val_queries",
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
|
||||
try:
|
||||
self.test_collection = self.client.get_collection(
|
||||
name="test_queries",
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
except:
|
||||
self.test_collection = self.client.create_collection(
|
||||
name="test_queries",
|
||||
embedding_function=self.embedding_function
|
||||
)
|
||||
|
||||
def add_data(self, data: List[Dict], split: str = "train"):
|
||||
"""
|
||||
Add query data to the appropriate collection.
|
||||
|
||||
Args:
|
||||
data (List[Dict]): List of query data items
|
||||
split (str): Data split ("train", "val", or "test")
|
||||
"""
|
||||
collection = getattr(self, f"{split}_collection")
|
||||
queries = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
correct_count = 0
|
||||
total_count = 0
|
||||
|
||||
for idx, item in enumerate(tqdm(data, desc=f"Adding {split} data")):
|
||||
return self.client.get_collection(name=name, embedding_function=self.embedding_function)
|
||||
except Exception:
|
||||
return self.client.create_collection(name=name, embedding_function=self.embedding_function)
|
||||
|
||||
# .................................................................
|
||||
# Bootstrap logic – download + ingest
|
||||
# .................................................................
|
||||
|
||||
def _bootstrap_from_drive(self, url_or_id: str):
|
||||
print("\n[SmartRouting] Bootstrapping ChromaDB from Google Drive…\n")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
# NB: gdown accepts both share links and raw IDs.
|
||||
local_path = os.path.join(tmp, "bootstrap.json")
|
||||
|
||||
gdown.download(url_or_id, local_path, quiet=False, fuzzy=True)
|
||||
|
||||
# Expect JSONL with {"query": ..., "split": "train"|"val"|"test", ...}
|
||||
# split_map: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
with open(local_path, "r") as f:
|
||||
data = json.load(f)
|
||||
|
||||
self.add_data(data)
|
||||
|
||||
print("[SmartRouting] Bootstrap complete – collections populated.\n")
|
||||
|
||||
# .................................................................
|
||||
# Public data API
|
||||
# .................................................................
|
||||
|
||||
def add_data(self, data: List[Dict[str, Any]]):
|
||||
collection = self.collection
|
||||
queries, metadatas, ids = [], [], []
|
||||
correct_count = total_count = 0
|
||||
|
||||
for idx, item in enumerate(tqdm(data, desc=f"Ingesting historical queries")):
|
||||
query = item["query"]
|
||||
|
||||
metadata = {
|
||||
model_metadatas = item["outputs"]
|
||||
for model_metadata in model_metadatas:
|
||||
model_metadata.pop("prediction")
|
||||
meta = {
|
||||
"input_token_length": item["input_token_length"],
|
||||
"models": []
|
||||
"models": json.dumps(model_metadatas), # store raw list
|
||||
}
|
||||
|
||||
for output in item["outputs"]:
|
||||
model_info = {
|
||||
"model_name": output["model_name"],
|
||||
"correctness": output["correctness"],
|
||||
"output_token_length": output["output_token_length"]
|
||||
}
|
||||
total_count += 1
|
||||
if output["correctness"]:
|
||||
correct_count += 1
|
||||
total_count += 1
|
||||
metadata["models"].append(model_info)
|
||||
|
||||
metadata["models"] = json.dumps(metadata["models"])
|
||||
|
||||
queries.append(query)
|
||||
metadatas.append(metadata)
|
||||
ids.append("id"+str(idx))
|
||||
|
||||
print(f"Correctness: {correct_count} / {total_count}")
|
||||
|
||||
collection.add(
|
||||
documents=queries,
|
||||
metadatas=metadatas,
|
||||
ids = ids
|
||||
)
|
||||
|
||||
def query_similar(self, query: str | List[str], split: str = "train", n_results: int = 16):
|
||||
"""
|
||||
Find similar queries in the database.
|
||||
|
||||
Args:
|
||||
query (str|List[str]): Query or list of queries to find similar items for
|
||||
split (str): Data split to search in
|
||||
n_results (int): Number of similar results to return
|
||||
|
||||
Returns:
|
||||
Dict: Results from ChromaDB query
|
||||
"""
|
||||
collection = getattr(self, f"{split}_collection")
|
||||
|
||||
results = collection.query(
|
||||
query_texts=query if isinstance(query, List) else [query],
|
||||
n_results=n_results
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def predict(self, query: str | List[str], model_configs: List[Dict], n_similar: int = 16):
|
||||
"""
|
||||
Predict performance and output length for models on the given query.
|
||||
|
||||
Args:
|
||||
query (str|List[str]): Query or list of queries
|
||||
model_configs (List[Dict]): List of model configurations
|
||||
n_similar (int): Number of similar queries to use for prediction
|
||||
|
||||
Returns:
|
||||
Tuple[np.array, np.array]: Performance scores and length scores
|
||||
"""
|
||||
# Get similar results from training data
|
||||
similar_results = self.query_similar(query, "train", n_results=n_similar)
|
||||
|
||||
total_performance_scores = []
|
||||
total_length_scores = []
|
||||
# Aggregate stats from similar results
|
||||
for i in range(len(similar_results["metadatas"])):
|
||||
model_stats = defaultdict(lambda: {
|
||||
"total_length": 0,
|
||||
"count": 0,
|
||||
"correct_count": 0
|
||||
})
|
||||
metadatas = similar_results["metadatas"][i]
|
||||
for metadata in metadatas:
|
||||
for model_info in json.loads(metadata["models"]):
|
||||
model_name = model_info["model_name"]
|
||||
stats = model_stats[model_name]
|
||||
stats["total_length"] += model_info["output_token_length"]
|
||||
stats["count"] += 1
|
||||
stats["correct_count"] += int(model_info["correctness"])
|
||||
|
||||
# Calculate performance and length scores for each model
|
||||
performance_scores = []
|
||||
length_scores = []
|
||||
|
||||
for model in model_configs:
|
||||
model_name = model["name"]
|
||||
if model_name in model_stats and model_stats[model_name]["count"] > 0:
|
||||
stats = model_stats[model_name]
|
||||
# Calculate performance score as accuracy
|
||||
perf_score = stats["correct_count"] / stats["count"]
|
||||
# Calculate average length
|
||||
avg_length = stats["total_length"] / stats["count"]
|
||||
|
||||
performance_scores.append(perf_score)
|
||||
length_scores.append(avg_length)
|
||||
metadatas.append(meta)
|
||||
ids.append(f"{idx}")
|
||||
|
||||
collection.add(documents=queries, metadatas=metadatas, ids=ids)
|
||||
print(f"[SmartRouting]: {total_count} historical queries ingested.")
|
||||
|
||||
# ..................................................................
|
||||
def query_similar(self, query: str | List[str], n_results: int = 16):
|
||||
collection = self.collection
|
||||
return collection.query(query_texts=query if isinstance(query, list) else [query], n_results=n_results)
|
||||
|
||||
# ..................................................................
|
||||
def predict(self, query: str | List[str], model_configs: List[Dict[str, Any]], n_similar: int = 16):
|
||||
similar = self.query_similar(query, n_results=n_similar)
|
||||
perf_mat, len_mat = [], []
|
||||
for meta_group in similar["metadatas"]:
|
||||
model_stats: dict[str, dict[str, float]] = defaultdict(lambda: {"total_len": 0, "cnt": 0, "correct": 0})
|
||||
for meta in meta_group:
|
||||
for m in json.loads(meta["models"]):
|
||||
s = model_stats[m["model_name"]]
|
||||
s["total_len"] += m["output_token_length"]
|
||||
s["cnt"] += 1
|
||||
s["correct"] += int(m["correctness"])
|
||||
perf_row, len_row = [], []
|
||||
for cfg in model_configs:
|
||||
stats = model_stats.get(cfg["name"], None)
|
||||
if stats and stats["cnt"]:
|
||||
perf_row.append(stats["correct"] / stats["cnt"])
|
||||
len_row.append(stats["total_len"] / stats["cnt"])
|
||||
else:
|
||||
# If no data for model, use default scores
|
||||
performance_scores.append(0.0)
|
||||
length_scores.append(0.0)
|
||||
|
||||
total_performance_scores.append(performance_scores)
|
||||
total_length_scores.append(length_scores)
|
||||
|
||||
return np.array(total_performance_scores), np.array(total_length_scores)
|
||||
|
||||
def optimize_model_selection_local(self, model_configs, perf_scores, cost_scores):
|
||||
"""
|
||||
Optimize model selection for a single query based on performance and cost.
|
||||
perf_row.append(0.0)
|
||||
len_row.append(0.0)
|
||||
perf_mat.append(perf_row)
|
||||
len_mat.append(len_row)
|
||||
return np.array(perf_mat), np.array(len_mat)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# SmartRouting main methods
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
def __init__(self,
|
||||
llm_configs: List[Dict[str, Any]],
|
||||
bootstrap_url: str ,
|
||||
performance_requirement: float = 0.7,
|
||||
n_similar: int = 16,
|
||||
):
|
||||
self.llm_configs = llm_configs
|
||||
self.available_models = [llm.name for llm in llm_configs]
|
||||
self.bootstrap_url = bootstrap_url
|
||||
self.performance_requirement = performance_requirement
|
||||
self.n_similar = n_similar
|
||||
self.lock = Lock()
|
||||
self.max_output_limit = 1024
|
||||
self.num_buckets = 10
|
||||
self.bucket_size = self.max_output_limit / self.num_buckets
|
||||
|
||||
# Initialise query store – will self‑populate if empty
|
||||
self.store = self.QueryStore(bootstrap_url=bootstrap_url)
|
||||
|
||||
print(f"[SmartRouting] Ready – performance threshold: {self.performance_requirement}\n")
|
||||
|
||||
# .....................................................................
|
||||
# Local (per‑query) optimisation helper
|
||||
# .....................................................................
|
||||
|
||||
def _select_model_single(self, model_cfgs: List[Dict[str, Any]], perf: np.ndarray, cost: np.ndarray) -> int | None:
|
||||
qualified = [i for i, p in enumerate(perf) if p >= self.performance_requirement]
|
||||
if qualified:
|
||||
# Pick cheapest among qualified
|
||||
return min(qualified, key=lambda i: cost[i])
|
||||
# Else, fallback – best performance overall
|
||||
return int(np.argmax(perf)) if len(perf) else 0
|
||||
|
||||
# .....................................................................
|
||||
# Public API – batch selection
|
||||
# .....................................................................
|
||||
|
||||
def get_model_idxs(self, selected_llm_lists: List[List[Dict[str, Any]]], queries: List[str]):
|
||||
if len(selected_llm_lists) != len(queries):
|
||||
raise ValueError("selected_llm_lists must have same length as queries")
|
||||
|
||||
input_lens = get_token_lengths(queries)
|
||||
chosen_indices: list[int] = []
|
||||
|
||||
Args:
|
||||
model_configs (List[Dict]): List of model configurations
|
||||
perf_scores (np.array): Performance scores for each model
|
||||
cost_scores (np.array): Cost scores for each model
|
||||
|
||||
Returns:
|
||||
int: Index of the selected model
|
||||
"""
|
||||
n_models = len(model_configs)
|
||||
|
||||
# Get all available models
|
||||
available_models = list(range(n_models))
|
||||
|
||||
if not available_models:
|
||||
return None
|
||||
|
||||
# Find models that meet performance requirement
|
||||
qualified_models = []
|
||||
for i in available_models:
|
||||
if perf_scores[i] >= self.performance_requirement:
|
||||
qualified_models.append(i)
|
||||
|
||||
if qualified_models:
|
||||
# If there are models meeting performance requirement,
|
||||
# select the one with lowest cost
|
||||
min_cost = float('inf')
|
||||
selected_model = None
|
||||
for i in qualified_models:
|
||||
if cost_scores[i] < min_cost:
|
||||
min_cost = cost_scores[i]
|
||||
selected_model = i
|
||||
return selected_model
|
||||
else:
|
||||
# If no model meets performance requirement,
|
||||
# select available model with highest performance
|
||||
max_perf = float('-inf')
|
||||
selected_model = None
|
||||
for i in available_models:
|
||||
if perf_scores[i] > max_perf:
|
||||
max_perf = perf_scores[i]
|
||||
selected_model = i
|
||||
return selected_model
|
||||
|
||||
def get_model_idxs(self, selected_llms: List[Dict[str, Any]], queries: List[str]=None, input_token_lengths: List[int]=None):
|
||||
"""
|
||||
Selects model indices from the available LLM configurations based on predicted performance and cost.
|
||||
|
||||
Args:
|
||||
selected_llms (List[Dict]): A list of selected LLM configurations from which models will be chosen.
|
||||
n_queries (int): The number of queries to process. Defaults to 1.
|
||||
queries (List[str], optional): List of query strings. If provided, will be used for model selection.
|
||||
input_token_lengths (List[int], optional): List of input token lengths. Required if queries is provided.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of indices corresponding to the selected models in `self.llm_configs`.
|
||||
"""
|
||||
model_idxs = []
|
||||
|
||||
# Ensure we have matching number of queries and token lengths
|
||||
if len(queries) != len(input_token_lengths):
|
||||
raise ValueError("Number of queries must match number of input token lengths")
|
||||
|
||||
# Process each query
|
||||
for i in range(len(queries)):
|
||||
query = queries[i]
|
||||
input_token_length = input_token_lengths[i]
|
||||
|
||||
# Get performance and length predictions
|
||||
perf_scores, length_scores = self.store.predict(query, selected_llms, n_similar=self.n_similar)
|
||||
perf_scores = perf_scores[0] # First query's scores
|
||||
length_scores = length_scores[0] # First query's length predictions
|
||||
|
||||
# Calculate cost scores
|
||||
converted_queries = [messages_to_query(query) for query in queries]
|
||||
|
||||
for q, q_len, candidate_cfgs in zip(converted_queries, input_lens, selected_llm_lists):
|
||||
perf, out_len = self.store.predict(q, candidate_cfgs, n_similar=self.n_similar)
|
||||
perf, out_len = perf[0], out_len[0] # unpack single query
|
||||
|
||||
# Dynamic price lookup via LiteLLM
|
||||
cost_scores = []
|
||||
for j in range(len(selected_llms)):
|
||||
pred_output_length = length_scores[j]
|
||||
input_cost = input_token_length * selected_llms[j].get("cost_per_input_token", 0)
|
||||
output_cost = pred_output_length * selected_llms[j].get("cost_per_output_token", 0)
|
||||
weighted_score = input_cost + output_cost
|
||||
cost_scores.append(weighted_score)
|
||||
|
||||
for cfg, pred_len in zip(candidate_cfgs, out_len):
|
||||
in_cost_pt, out_cost_pt = get_cost_per_token(cfg["name"])
|
||||
cost_scores.append(q_len * in_cost_pt + pred_len * out_cost_pt)
|
||||
cost_scores = np.array(cost_scores)
|
||||
|
||||
sel_local_idx = self._select_model_single(candidate_cfgs, perf, cost_scores)
|
||||
if sel_local_idx is None:
|
||||
chosen_indices.append(0) # safe fallback
|
||||
continue
|
||||
|
||||
# Map back to global llm_configs index
|
||||
sel_name = candidate_cfgs[sel_local_idx]["name"]
|
||||
|
||||
# Select optimal model
|
||||
selected_idx = self.optimize_model_selection_local(
|
||||
selected_llms,
|
||||
perf_scores,
|
||||
cost_scores
|
||||
)
|
||||
|
||||
# Find the index in the original llm_configs
|
||||
for idx, config in enumerate(self.llm_configs):
|
||||
if config["name"] == selected_llms[selected_idx]["name"]:
|
||||
model_idxs.append(idx)
|
||||
break
|
||||
else:
|
||||
# If not found, use the first model as fallback
|
||||
model_idxs.append(0)
|
||||
|
||||
return model_idxs
|
||||
|
||||
def optimize_model_selection_global(self, perf_scores, cost_scores):
|
||||
"""
|
||||
Globally optimize model selection for multiple queries using linear programming.
|
||||
|
||||
Args:
|
||||
perf_scores (np.array): Performance scores matrix [queries × models]
|
||||
cost_scores (np.array): Cost scores matrix [queries × models]
|
||||
|
||||
Returns:
|
||||
np.array: Array of selected model indices for each query
|
||||
"""
|
||||
n_models = len(self.llm_configs)
|
||||
n_queries = len(perf_scores)
|
||||
|
||||
sel_idx = self.available_models.index(sel_name)
|
||||
chosen_indices.append(sel_idx)
|
||||
|
||||
return chosen_indices
|
||||
|
||||
# .....................................................................
|
||||
# Global optimisation (unchanged except for cost lookup)
|
||||
# .....................................................................
|
||||
|
||||
def optimize_model_selection_global(self, perf_scores: np.ndarray, cost_scores: np.ndarray):
|
||||
n_queries, n_models = perf_scores.shape
|
||||
prob = LpProblem("LLM_Scheduling", LpMinimize)
|
||||
|
||||
# Decision variables
|
||||
x = LpVariable.dicts("assign",
|
||||
((i, j) for i in range(n_queries)
|
||||
for j in range(n_models)),
|
||||
cat='Binary')
|
||||
|
||||
# Objective function: minimize total cost
|
||||
prob += lpSum(x[i,j] * cost_scores[i,j]
|
||||
for i in range(n_queries)
|
||||
for j in range(n_models))
|
||||
|
||||
# Quality constraint: ensure overall performance meets requirement
|
||||
prob += lpSum(x[i,j] * perf_scores[i,j]
|
||||
for i in range(n_queries)
|
||||
for j in range(n_models)) >= self.performance_requirement * n_queries
|
||||
|
||||
# Assignment constraints: each query must be assigned to exactly one model
|
||||
x = LpVariable.dicts("assign", ((i, j) for i in range(n_queries) for j in range(n_models)), cat="Binary")
|
||||
prob += lpSum(x[i, j] * cost_scores[i, j] for i in range(n_queries) for j in range(n_models))
|
||||
prob += lpSum(x[i, j] * perf_scores[i, j] for i in range(n_queries) for j in range(n_models)) >= self.performance_requirement * n_queries
|
||||
for i in range(n_queries):
|
||||
prob += lpSum(x[i,j] for j in range(n_models)) == 1
|
||||
|
||||
# Solve
|
||||
prob += lpSum(x[i, j] for j in range(n_models)) == 1
|
||||
prob.solve(PULP_CBC_CMD(msg=False))
|
||||
|
||||
# Extract solution
|
||||
solution = np.zeros((n_queries, n_models))
|
||||
sol = np.zeros((n_queries, n_models))
|
||||
for i in range(n_queries):
|
||||
for j in range(n_models):
|
||||
solution[i,j] = value(x[i,j])
|
||||
|
||||
solution = np.argmax(solution, axis=1)
|
||||
return solution
|
||||
sol[i, j] = value(x[i, j])
|
||||
return np.argmax(sol, axis=1)
|
||||
|
||||
@@ -3,6 +3,8 @@ import re
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
def merge_messages_with_tools(messages: list, tools: list) -> list:
|
||||
"""
|
||||
Integrate tool information into the messages for open-sourced LLMs which don't support tool calling.
|
||||
@@ -338,4 +340,16 @@ def pre_process_tools(tools):
|
||||
if "/" in tool_name:
|
||||
tool_name = "__".join(tool_name.split("/"))
|
||||
tool["function"]["name"] = tool_name
|
||||
return tools
|
||||
return tools
|
||||
|
||||
def check_availability_for_selected_llm_lists(available_llm_names: List[str], selected_llm_lists: List[List[Dict[str, Any]]]):
|
||||
selected_llm_lists_availability = []
|
||||
for selected_llm_list in selected_llm_lists:
|
||||
all_available = True
|
||||
|
||||
for llm in selected_llm_list:
|
||||
if llm["name"] not in available_llm_names:
|
||||
all_available = False
|
||||
break
|
||||
selected_llm_lists_availability.append(all_available)
|
||||
return selected_llm_lists_availability
|
||||
|
||||
@@ -24,7 +24,7 @@ from queue import Empty
|
||||
import traceback
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, Any, Dict
|
||||
from typing import Optional, Any, Dict, List
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
@@ -34,8 +34,9 @@ class FIFOScheduler(BaseScheduler):
|
||||
"""
|
||||
A FIFO (First-In-First-Out) task scheduler implementation.
|
||||
|
||||
This scheduler processes tasks in the order they arrive, with a 1-second timeout
|
||||
for each task. It handles different types of system calls: LLM, Memory, Storage, and Tool.
|
||||
This scheduler processes tasks in the order they arrive.
|
||||
LLM tasks are batched based on a time interval. Other tasks (Memory, Storage, Tool)
|
||||
are processed individually as they arrive.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -48,7 +49,8 @@ class FIFOScheduler(BaseScheduler):
|
||||
get_llm_syscall=llm_queue.get,
|
||||
get_memory_syscall=memory_queue.get,
|
||||
get_storage_syscall=storage_queue.get,
|
||||
get_tool_syscall=tool_queue.get
|
||||
get_tool_syscall=tool_queue.get,
|
||||
batch_interval=0.1 # Process LLM requests every 100ms
|
||||
)
|
||||
scheduler.start()
|
||||
```
|
||||
@@ -61,14 +63,11 @@ class FIFOScheduler(BaseScheduler):
|
||||
storage_manager: StorageManager,
|
||||
tool_manager: ToolManager,
|
||||
log_mode: str,
|
||||
# llm_request_queue: LLMRequestQueue,
|
||||
# memory_request_queue: MemoryRequestQueue,
|
||||
# storage_request_queue: StorageRequestQueue,
|
||||
# tool_request_queue: ToolRequestQueue,
|
||||
get_llm_syscall: LLMRequestQueueGetMessage,
|
||||
get_memory_syscall: MemoryRequestQueueGetMessage,
|
||||
get_storage_syscall: StorageRequestQueueGetMessage,
|
||||
get_tool_syscall: ToolRequestQueueGetMessage,
|
||||
batch_interval: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize the FIFO Scheduler.
|
||||
@@ -83,6 +82,7 @@ class FIFOScheduler(BaseScheduler):
|
||||
get_memory_syscall: Function to get Memory syscalls
|
||||
get_storage_syscall: Function to get Storage syscalls
|
||||
get_tool_syscall: Function to get Tool syscalls
|
||||
batch_interval: Time interval in seconds to batch LLM requests. Defaults to 0.1.
|
||||
"""
|
||||
super().__init__(
|
||||
llm,
|
||||
@@ -95,81 +95,89 @@ class FIFOScheduler(BaseScheduler):
|
||||
get_storage_syscall,
|
||||
get_tool_syscall,
|
||||
)
|
||||
self.batch_interval = batch_interval
|
||||
|
||||
def _execute_syscall(
|
||||
self,
|
||||
syscall: Any,
|
||||
def _execute_batch_syscalls(
|
||||
self,
|
||||
batch: List[Any],
|
||||
executor: Any,
|
||||
syscall_type: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> None:
|
||||
"""
|
||||
Execute a system call with proper status tracking and error handling.
|
||||
|
||||
Execute a batch of system calls with proper status tracking and error handling.
|
||||
|
||||
Args:
|
||||
syscall: The system call to execute
|
||||
executor: Function to execute the syscall
|
||||
syscall_type: Type of the syscall for logging
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Response from the syscall execution
|
||||
|
||||
Example:
|
||||
```python
|
||||
response = scheduler._execute_syscall(
|
||||
llm_syscall,
|
||||
self.llm.execute_llm_syscall,
|
||||
"LLM"
|
||||
)
|
||||
```
|
||||
batch: The list of system calls to execute
|
||||
executor: Function to execute the batch of syscalls
|
||||
syscall_type: Type of the syscalls for logging
|
||||
"""
|
||||
if not batch:
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
for syscall in batch:
|
||||
try:
|
||||
syscall.set_status("executing")
|
||||
|
||||
logger.info(f"{syscall.agent_name} preparing batched {syscall_type} syscall.")
|
||||
syscall.set_start_time(start_time)
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing syscall {getattr(syscall, 'agent_name', 'unknown')} for batch execution: {str(e)}")
|
||||
continue
|
||||
|
||||
if not batch:
|
||||
logger.warning(f"Empty batch after preparation for {syscall_type}, skipping execution.")
|
||||
return
|
||||
|
||||
# self.logger.log(
|
||||
# f"Executing batch of {len(batch)} {syscall_type} syscalls.\n",
|
||||
# "executing_batch"
|
||||
# )
|
||||
logger.info(f"Executing batch of {len(batch)} {syscall_type} syscalls.")
|
||||
|
||||
try:
|
||||
syscall.set_status("executing")
|
||||
self.logger.log(
|
||||
f"{syscall.agent_name} is executing {syscall_type} syscall.\n",
|
||||
"executing"
|
||||
)
|
||||
syscall.set_start_time(time.time())
|
||||
responses = executor(batch)
|
||||
|
||||
response = executor(syscall)
|
||||
syscall.set_response(response)
|
||||
for i, syscall in enumerate(batch):
|
||||
logger.info(f"Completed batched {syscall_type} syscall for {syscall.agent_name}. "
|
||||
f"Thread ID: {syscall.get_pid()}\n")
|
||||
|
||||
syscall.event.set()
|
||||
syscall.set_status("done")
|
||||
syscall.set_end_time(time.time())
|
||||
|
||||
self.logger.log(
|
||||
f"Completed {syscall_type} syscall for {syscall.agent_name}. "
|
||||
f"Thread ID: {syscall.get_pid()}\n",
|
||||
"done"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {syscall_type} syscall: {str(e)}")
|
||||
logger.error(f"Error executing {syscall_type} syscall batch: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def process_llm_requests(self) -> None:
|
||||
"""
|
||||
Process LLM requests from the queue.
|
||||
Process LLM requests from the queue in batches based on batch_interval.
|
||||
|
||||
Example:
|
||||
```python
|
||||
scheduler.process_llm_requests()
|
||||
# Processes LLM requests like:
|
||||
# {
|
||||
# "messages": [{"role": "user", "content": "Hello"}],
|
||||
# "temperature": 0.7
|
||||
# }
|
||||
# Collects LLM requests for 0.1 seconds (default) then processes them:
|
||||
# Batch = [
|
||||
# {"messages": [{"role": "user", "content": "Hello"}]},
|
||||
# {"messages": [{"role": "user", "content": "World"}]}
|
||||
# ]
|
||||
# Calls self.llm.execute_batch_llm_syscall(Batch)
|
||||
```
|
||||
"""
|
||||
while self.active:
|
||||
try:
|
||||
llm_syscall = self.get_llm_syscall()
|
||||
self._execute_syscall(llm_syscall, self.llm.execute_llm_syscall, "LLM")
|
||||
except Empty:
|
||||
pass
|
||||
time.sleep(self.batch_interval)
|
||||
|
||||
batch = []
|
||||
while True:
|
||||
try:
|
||||
llm_syscall = self.get_llm_syscall()
|
||||
# Add logging here: print(f"Retrieved syscall: {llm_syscall.pid}, Queue size now: {self.llm_queue.qsize()}")
|
||||
batch.append(llm_syscall)
|
||||
except Empty:
|
||||
# This is the expected way to finish collecting the batch
|
||||
# print(f"Queue empty, finishing batch collection with {len(batch)} items.")
|
||||
break
|
||||
|
||||
if batch:
|
||||
self._execute_batch_syscalls(batch, self.llm.execute_llm_syscalls, "LLM")
|
||||
|
||||
def process_memory_requests(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -4,11 +4,9 @@
|
||||
|
||||
from .base import BaseScheduler
|
||||
|
||||
|
||||
# allows for memory to be shared safely between threads
|
||||
from queue import Queue, Empty
|
||||
|
||||
|
||||
from ..context.simple_context import SimpleContextManager
|
||||
|
||||
from aios.hooks.types.llm import LLMRequestQueueGetMessage
|
||||
@@ -27,7 +25,7 @@ from threading import Thread
|
||||
from .base import BaseScheduler
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,73 +57,54 @@ class RRScheduler(BaseScheduler):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.time_slice = time_slice
|
||||
self.context_manager = SimpleContextManager()
|
||||
|
||||
def _execute_syscall(
|
||||
self,
|
||||
syscall: Any,
|
||||
|
||||
def _execute_batch_syscalls(
|
||||
self,
|
||||
batch: List[Any],
|
||||
executor: Any,
|
||||
syscall_type: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> None:
|
||||
"""
|
||||
Execute a system call with time slice enforcement.
|
||||
|
||||
Execute a batch of system calls with proper status tracking and error handling.
|
||||
|
||||
Args:
|
||||
syscall: The system call to execute
|
||||
executor: Function to execute the syscall
|
||||
syscall_type: Type of the syscall for logging
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: Response from the syscall execution
|
||||
|
||||
Example:
|
||||
```python
|
||||
response = scheduler._execute_syscall(
|
||||
llm_syscall,
|
||||
self.llm.execute_llm_syscall,
|
||||
"LLM"
|
||||
)
|
||||
```
|
||||
batch: The list of system calls to execute
|
||||
executor: Function to execute the batch of syscalls
|
||||
syscall_type: Type of the syscalls for logging
|
||||
"""
|
||||
try:
|
||||
syscall.set_time_limit(self.time_slice)
|
||||
syscall.set_status("executing")
|
||||
self.logger.log(
|
||||
f"{syscall.agent_name} is executing {syscall_type} syscall.\n",
|
||||
"executing"
|
||||
)
|
||||
syscall.set_start_time(time.time())
|
||||
if not batch:
|
||||
return
|
||||
|
||||
response = executor(syscall)
|
||||
|
||||
# breakpoint()
|
||||
|
||||
syscall.set_response(response)
|
||||
|
||||
if response.finished:
|
||||
syscall.set_status("done")
|
||||
log_status = "done"
|
||||
else:
|
||||
syscall.set_status("suspending")
|
||||
log_status = "suspending"
|
||||
start_time = time.time()
|
||||
for syscall in batch:
|
||||
try:
|
||||
syscall.set_status("executing")
|
||||
syscall.set_time_limit(self.time_slice)
|
||||
|
||||
# breakpoint()
|
||||
logger.info(f"{syscall.agent_name} preparing batched {syscall_type} syscall.")
|
||||
syscall.set_start_time(start_time)
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing syscall {getattr(syscall, 'agent_name', 'unknown')} for batch execution: {str(e)}")
|
||||
continue
|
||||
|
||||
if not batch:
|
||||
message = f"Empty batch after preparation for {syscall_type}, skipping execution."
|
||||
logger.warning(message)
|
||||
return
|
||||
|
||||
logger.info(f"Executing batch of {len(batch)} {syscall_type} syscalls.")
|
||||
|
||||
try:
|
||||
responses = executor(batch)
|
||||
|
||||
for i, syscall in enumerate(batch):
|
||||
logger.info(f"Completed batched {syscall_type} syscall for {syscall.agent_name}. "
|
||||
f"Thread ID: {syscall.get_pid()}\n")
|
||||
|
||||
syscall.set_end_time(time.time())
|
||||
|
||||
syscall.event.set()
|
||||
|
||||
self.logger.log(
|
||||
f"{syscall_type} syscall for {syscall.agent_name} is {log_status}. "
|
||||
f"Thread ID: {syscall.get_pid()}\n",
|
||||
log_status
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing {syscall_type} syscall: {str(e)}")
|
||||
logger.error(f"Error executing {syscall_type} syscall batch: {str(e)}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def process_llm_requests(self) -> None:
|
||||
"""
|
||||
@@ -144,7 +123,7 @@ class RRScheduler(BaseScheduler):
|
||||
while self.active:
|
||||
try:
|
||||
llm_syscall = self.get_llm_syscall()
|
||||
self._execute_syscall(llm_syscall, self.llm.execute_llm_syscall, "LLM")
|
||||
self._execute_batch_syscalls(llm_syscall, self.llm.execute_llm_syscalls, "LLM")
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,11 +13,7 @@ from aios.hooks.stores._global import (
|
||||
global_llm_req_queue_add_message,
|
||||
global_memory_req_queue_add_message,
|
||||
global_storage_req_queue_add_message,
|
||||
global_tool_req_queue_add_message,
|
||||
# global_llm_req_queue,
|
||||
# global_memory_req_queue,
|
||||
# global_storage_req_queue,
|
||||
# global_tool_req_queue,
|
||||
global_tool_req_queue_add_message
|
||||
)
|
||||
|
||||
import threading
|
||||
@@ -113,6 +109,8 @@ class SyscallExecutor:
|
||||
|
||||
if isinstance(syscall, LLMSyscall):
|
||||
global_llm_req_queue_add_message(syscall)
|
||||
print(f"Syscall {syscall.agent_name} added to LLM queue")
|
||||
|
||||
elif isinstance(syscall, StorageSyscall):
|
||||
global_storage_req_queue_add_message(syscall)
|
||||
elif isinstance(syscall, MemorySyscall):
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
# Installation directory
|
||||
INSTALL_DIR="$HOME/.aios-1"
|
||||
REPO_URL="https://github.com/agiresearch/AIOS"
|
||||
TAG="v0.2.0.beta" # Replace with your specific tag
|
||||
BIN_DIR="$HOME/.local/bin" # User-level bin directory
|
||||
|
||||
echo "Installing AIOS..."
|
||||
@@ -26,7 +25,6 @@ git clone "$REPO_URL" "$INSTALL_DIR/src"
|
||||
|
||||
# Checkout the specific tag
|
||||
cd "$INSTALL_DIR/src"
|
||||
git checkout tags/"$TAG" -b "$TAG-branch"
|
||||
|
||||
# Find Python executable path
|
||||
PYTHON_PATH=$(command -v python3)
|
||||
|
||||
@@ -19,4 +19,5 @@ redis>=4.5.1
|
||||
sentence-transformers
|
||||
nltk
|
||||
scikit-learn
|
||||
pulp
|
||||
pulp
|
||||
gdown
|
||||
@@ -32,6 +32,8 @@ from cerebrum.storage.apis import StorageQuery, StorageResponse
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
import asyncio
|
||||
|
||||
import uvicorn
|
||||
|
||||
load_dotenv()
|
||||
@@ -340,6 +342,22 @@ def restart_kernel():
|
||||
print(f"Stack trace: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
@app.get("/status")
|
||||
async def get_server_status():
|
||||
"""Check if the server is running and core components are initialized."""
|
||||
inactive_components = [
|
||||
component for component, instance in active_components.items() if not instance
|
||||
]
|
||||
|
||||
if not inactive_components:
|
||||
return {"status": "ok", "message": "All core components are active."}
|
||||
else:
|
||||
return {
|
||||
"status": "warning",
|
||||
"message": f"Server is running, but some components are inactive: {', '.join(inactive_components)}",
|
||||
"inactive_components": inactive_components
|
||||
}
|
||||
|
||||
@app.post("/core/refresh")
|
||||
async def refresh_configuration():
|
||||
"""Refresh all component configurations"""
|
||||
@@ -601,7 +619,14 @@ async def handle_query(request: QueryRequest):
|
||||
action_type=request.query_data.action_type,
|
||||
message_return_type=request.query_data.message_return_type,
|
||||
)
|
||||
return execute_request(request.agent_name, query)
|
||||
result_dict = await asyncio.to_thread(
|
||||
execute_request, # The method to call
|
||||
request.agent_name, # First arg to execute_request
|
||||
query # Second arg to execute_request
|
||||
)
|
||||
|
||||
return result_dict
|
||||
|
||||
elif request.query_type == "storage":
|
||||
query = StorageQuery(
|
||||
params=request.query_data.params,
|
||||
|
||||
117
tests/modules/llm/ollama/test_single.py
Normal file
117
tests/modules/llm/ollama/test_single.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import unittest
|
||||
from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output
|
||||
from cerebrum.utils.communication import aios_kernel_url
|
||||
from cerebrum.utils.utils import _parse_json_output
|
||||
|
||||
class TestAgent:
|
||||
"""
|
||||
TestAgent class is responsible for interacting with OpenAI's API using ChatCompletion.
|
||||
It maintains a conversation history to simulate real dialogue behavior.
|
||||
"""
|
||||
def __init__(self, agent_name):
|
||||
self.agent_name = agent_name
|
||||
self.messages = []
|
||||
|
||||
def llm_chat_run(self, task_input):
|
||||
"""Sends the input to the OpenAI API and returns the response."""
|
||||
self.messages.append({"role": "user", "content": task_input})
|
||||
|
||||
tool_response = llm_chat(
|
||||
agent_name=self.agent_name,
|
||||
messages=self.messages,
|
||||
base_url=aios_kernel_url,
|
||||
llms=[{
|
||||
"name": "qwen2.5:7b",
|
||||
"backend": "ollama"
|
||||
}]
|
||||
)
|
||||
|
||||
return tool_response["response"].get("response_message", "")
|
||||
|
||||
def llm_chat_with_json_output_run(self, task_input):
|
||||
"""Sends the input to the OpenAI API and returns the response."""
|
||||
self.messages.append({"role": "user", "content": task_input})
|
||||
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "thinking",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thinking": {
|
||||
"type": "string",
|
||||
},
|
||||
"answer": {
|
||||
"type": "string",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_response = llm_chat_with_json_output(
|
||||
agent_name=self.agent_name,
|
||||
messages=self.messages,
|
||||
base_url=aios_kernel_url,
|
||||
llms=[{
|
||||
"name": "qwen2.5:7b",
|
||||
"backend": "ollama"
|
||||
}],
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
return tool_response["response"].get("response_message", "")
|
||||
|
||||
|
||||
class TestLLMAPI(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for OpenAI's API using the TestAgent class.
|
||||
Each test checks if the API returns a non-empty string response.
|
||||
Here has various category test cases, including greeting, math question, science question, history question, technology question.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.agent = TestAgent("test_agent")
|
||||
|
||||
def assert_chat_response(self, response):
|
||||
"""Helper method to validate common response conditions."""
|
||||
self.assertIsInstance(response, str)
|
||||
self.assertGreater(len(response), 0)
|
||||
|
||||
def assert_json_output_response(self, response):
|
||||
"""Helper method to validate common response conditions."""
|
||||
parsed_response = _parse_json_output(response)
|
||||
self.assertIsInstance(parsed_response, dict)
|
||||
self.assertIn("thinking", parsed_response)
|
||||
self.assertIn("answer", parsed_response)
|
||||
self.assertIsInstance(parsed_response["thinking"], str)
|
||||
self.assertIsInstance(parsed_response["answer"], str)
|
||||
|
||||
def test_agent_with_greeting(self):
|
||||
chat_response = self.agent.llm_chat_run("Hello, how are you?")
|
||||
self.assert_chat_response(chat_response)
|
||||
# json_output_response = self.agent.llm_chat_with_json_output_run("Hello, how are you? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
# self.assert_json_output_response(json_output_response)
|
||||
|
||||
def test_agent_with_math_question(self):
|
||||
chat_response = self.agent.llm_chat_run("What is 25 times 4?")
|
||||
self.assert_chat_response(chat_response)
|
||||
# json_output_response = self.agent.llm_chat_with_json_output_run("What is 25 times 4? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
# self.assert_json_output_response(json_output_response)
|
||||
|
||||
|
||||
def test_agent_with_history_question(self):
|
||||
chat_response = self.agent.llm_chat_run("Who was the first president of the United States?")
|
||||
self.assert_chat_response(chat_response)
|
||||
# json_output_response = self.agent.llm_chat_with_json_output_run("Who was the first president of the United States? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
# self.assert_json_output_response(json_output_response)
|
||||
|
||||
def test_agent_with_technology_question(self):
|
||||
chat_response = self.agent.llm_chat_run("What is quantum computing?")
|
||||
self.assert_chat_response(chat_response)
|
||||
# json_output_response = self.agent.llm_chat_with_json_output_run("What is quantum computing? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
# self.assert_json_output_response(json_output_response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
161
tests/modules/llm/openai/test_concurrent.py
Normal file
161
tests/modules/llm/openai/test_concurrent.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# test_client.py
|
||||
import unittest
|
||||
import requests
|
||||
import threading
|
||||
import json
|
||||
import time
|
||||
from typing import List, Dict, Any, Tuple
|
||||
|
||||
from cerebrum.llm.apis import llm_chat
|
||||
|
||||
from cerebrum.utils.communication import aios_kernel_url
|
||||
|
||||
# --- Helper function to send a single request ---
|
||||
def send_request(payload: Dict[str, Any]) -> Tuple[Dict[str, Any] | None, float]:
|
||||
"""Sends a POST request to the query endpoint and returns status code, response JSON, and duration."""
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = llm_chat(
|
||||
agent_name=payload["agent_name"],
|
||||
messages=payload["query_data"]["messages"],
|
||||
llms=payload["query_data"]["llms"] if "llms" in payload["query_data"] else None,
|
||||
)
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
return response["response"], duration
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.time()
|
||||
duration = end_time - start_time
|
||||
# General unexpected errors
|
||||
return {"error": f"An unexpected error occurred: {e}"}, duration
|
||||
|
||||
|
||||
class TestConcurrentLLMQueries(unittest.TestCase):
|
||||
|
||||
def _run_concurrent_requests(self, payloads: List[Dict[str, Any]]):
|
||||
results = [None] * len(payloads)
|
||||
threads = []
|
||||
|
||||
def worker(index, payload):
|
||||
response_data, duration = send_request(payload)
|
||||
results[index] = {"data": response_data, "duration": duration}
|
||||
print(f"Thread {index}: Completed in {duration:.2f}s with response: {json.dumps(response_data)}")
|
||||
|
||||
print(f"\n--- Running test: {self._testMethodName} ---")
|
||||
print(f"Sending {len(payloads)} concurrent requests to {aios_kernel_url}...")
|
||||
for i, payload in enumerate(payloads):
|
||||
thread = threading.Thread(target=worker, args=(i, payload))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for i, thread in enumerate(threads):
|
||||
thread.join()
|
||||
print(f"Thread {i} finished.")
|
||||
|
||||
print("--- All threads completed ---")
|
||||
return results
|
||||
|
||||
def test_both_llms(self):
|
||||
payload1 = {
|
||||
"agent_name": "test_agent_1",
|
||||
"query_type": "llm",
|
||||
"query_data": {
|
||||
"llms": [{"name": "gpt-4o", "backend": "openai"}],
|
||||
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||
"action_type": "chat",
|
||||
"message_return_type": "text",
|
||||
}
|
||||
}
|
||||
payload2 = {
|
||||
"agent_name": "test_agent_2",
|
||||
"query_type": "llm",
|
||||
"query_data": {
|
||||
"llms": [{"name": "gpt-4o-mini", "backend": "openai"}], # Using a different model for variety
|
||||
"messages": [{"role": "user", "content": "What is the capital of the United States?"}],
|
||||
"action_type": "chat",
|
||||
"message_return_type": "text",
|
||||
}
|
||||
}
|
||||
results = self._run_concurrent_requests([payload1, payload2])
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"Result {i} (No LLM): {result}")
|
||||
# Both should succeed using defaults
|
||||
status, response_message, error_message, finished = result["data"]["status_code"], result["data"]["response_message"], result["data"]["error"], result["data"]["finished"]
|
||||
|
||||
self.assertEqual(status, 200, f"Request {i} (No LLM) should succeed, but failed with status {status}")
|
||||
self.assertIsNone(error_message, f"Request {i} (No LLM) returned an unexpected error: {error_message}")
|
||||
self.assertIsInstance(response_message, str, f"Request {i} (No LLM) result is not a string")
|
||||
self.assertTrue(finished, f"Request {i} (No LLM) result is empty") # Check not empty
|
||||
|
||||
def test_one_llm_one_empty(self):
|
||||
payload_llm = {
|
||||
"agent_name": "test_agent_1",
|
||||
"query_type": "llm",
|
||||
"query_data": {
|
||||
"llms": [{"name": "gpt-4o", "backend": "openai"}],
|
||||
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||
"action_type": "chat",
|
||||
"message_return_type": "text",
|
||||
}
|
||||
}
|
||||
payload_no_llm = {
|
||||
"agent_name": "test_agent_2",
|
||||
"query_type": "llm",
|
||||
"query_data": {
|
||||
# 'llms' key is omitted entirely
|
||||
"messages": [{"role": "user", "content": "What is the capital of the United States?"}],
|
||||
"action_type": "chat",
|
||||
"message_return_type": "text",
|
||||
}
|
||||
}
|
||||
results = self._run_concurrent_requests([payload_llm, payload_no_llm])
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"Result {i} (No LLM): {result}")
|
||||
# Both should succeed using defaults
|
||||
status, response_message, error_message, finished = result["data"]["status_code"], result["data"]["response_message"], result["data"]["error"], result["data"]["finished"]
|
||||
|
||||
self.assertEqual(status, 200, f"Request {i} (No LLM) should succeed, but failed with status {status}")
|
||||
self.assertIsNone(error_message, f"Request {i} (No LLM) returned an unexpected error: {error_message}")
|
||||
self.assertIsInstance(response_message, str, f"Request {i} (No LLM) result is not a string")
|
||||
self.assertTrue(finished, f"Request {i} (No LLM) result is empty") # Check not empty
|
||||
|
||||
def test_no_llms(self):
|
||||
"""Case 2: Both payloads have no LLMs defined. Should succeed using defaults."""
|
||||
payload1 = {
|
||||
"agent_name": "test_agent_1",
|
||||
"query_type": "llm",
|
||||
"query_data": {
|
||||
# 'llms' key is omitted
|
||||
"messages": [{"role": "user", "content": "What is the capital of France?"}],
|
||||
"action_type": "chat",
|
||||
"message_return_type": "text",
|
||||
}
|
||||
}
|
||||
payload2 = {
|
||||
"agent_name": "test_agent_2",
|
||||
"query_type": "llm",
|
||||
"query_data": {
|
||||
# 'llms' key is omitted
|
||||
"messages": [{"role": "user", "content": "What is the capital of the United States?"}],
|
||||
"action_type": "chat",
|
||||
"message_return_type": "text",
|
||||
}
|
||||
}
|
||||
results = self._run_concurrent_requests([payload1, payload2])
|
||||
|
||||
for i, result in enumerate(results):
|
||||
print(f"Result {i} (No LLM): {result}")
|
||||
# Both should succeed using defaults
|
||||
status, response_message, error_message, finished = result["data"]["status_code"], result["data"]["response_message"], result["data"]["error"], result["data"]["finished"]
|
||||
|
||||
self.assertEqual(status, 200, f"Request {i} (No LLM) should succeed, but failed with status {status}")
|
||||
self.assertIsNone(error_message, f"Request {i} (No LLM) returned an unexpected error: {error_message}")
|
||||
self.assertIsInstance(response_message, str, f"Request {i} (No LLM) result is not a string")
|
||||
self.assertTrue(finished, f"Request {i} (No LLM) result is empty") # Check not empty
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
117
tests/modules/llm/openai/test_single.py
Normal file
117
tests/modules/llm/openai/test_single.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import unittest
|
||||
from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output
|
||||
from cerebrum.utils.communication import aios_kernel_url
|
||||
from cerebrum.utils.utils import _parse_json_output
|
||||
|
||||
class TestAgent:
|
||||
"""
|
||||
TestAgent class is responsible for interacting with OpenAI's API using ChatCompletion.
|
||||
It maintains a conversation history to simulate real dialogue behavior.
|
||||
"""
|
||||
def __init__(self, agent_name):
|
||||
self.agent_name = agent_name
|
||||
self.messages = []
|
||||
|
||||
def llm_chat_run(self, task_input):
|
||||
"""Sends the input to the OpenAI API and returns the response."""
|
||||
self.messages.append({"role": "user", "content": task_input})
|
||||
|
||||
tool_response = llm_chat(
|
||||
agent_name=self.agent_name,
|
||||
messages=self.messages,
|
||||
base_url=aios_kernel_url,
|
||||
llms=[{
|
||||
"name": "gpt-4o-mini",
|
||||
"backend": "openai"
|
||||
}]
|
||||
)
|
||||
|
||||
return tool_response["response"].get("response_message", "")
|
||||
|
||||
def llm_chat_with_json_output_run(self, task_input):
|
||||
"""Sends the input to the OpenAI API and returns the response."""
|
||||
self.messages.append({"role": "user", "content": task_input})
|
||||
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "thinking",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thinking": {
|
||||
"type": "string",
|
||||
},
|
||||
"answer": {
|
||||
"type": "string",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_response = llm_chat_with_json_output(
|
||||
agent_name=self.agent_name,
|
||||
messages=self.messages,
|
||||
base_url=aios_kernel_url,
|
||||
llms=[{
|
||||
"name": "gpt-4o-mini",
|
||||
"backend": "openai"
|
||||
}],
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
return tool_response["response"].get("response_message", "")
|
||||
|
||||
|
||||
class TestLLMAPI(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for OpenAI's API using the TestAgent class.
|
||||
Each test checks if the API returns a non-empty string response.
|
||||
Here has various category test cases, including greeting, math question, science question, history question, technology question.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.agent = TestAgent("test_agent")
|
||||
|
||||
def assert_chat_response(self, response):
|
||||
"""Helper method to validate common response conditions."""
|
||||
self.assertIsInstance(response, str)
|
||||
self.assertGreater(len(response), 0)
|
||||
|
||||
def assert_json_output_response(self, response):
|
||||
"""Helper method to validate common response conditions."""
|
||||
parsed_response = _parse_json_output(response)
|
||||
self.assertIsInstance(parsed_response, dict)
|
||||
self.assertIn("thinking", parsed_response)
|
||||
self.assertIn("answer", parsed_response)
|
||||
self.assertIsInstance(parsed_response["thinking"], str)
|
||||
self.assertIsInstance(parsed_response["answer"], str)
|
||||
|
||||
def test_agent_with_greeting(self):
|
||||
chat_response = self.agent.llm_chat_run("Hello, how are you?")
|
||||
self.assert_chat_response(chat_response)
|
||||
json_output_response = self.agent.llm_chat_with_json_output_run("Hello, how are you? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
self.assert_json_output_response(json_output_response)
|
||||
|
||||
def test_agent_with_math_question(self):
|
||||
chat_response = self.agent.llm_chat_run("What is 25 times 4?")
|
||||
self.assert_chat_response(chat_response)
|
||||
json_output_response = self.agent.llm_chat_with_json_output_run("What is 25 times 4? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
self.assert_json_output_response(json_output_response)
|
||||
|
||||
|
||||
def test_agent_with_history_question(self):
|
||||
chat_response = self.agent.llm_chat_run("Who was the first president of the United States?")
|
||||
self.assert_chat_response(chat_response)
|
||||
json_output_response = self.agent.llm_chat_with_json_output_run("Who was the first president of the United States? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
self.assert_json_output_response(json_output_response)
|
||||
|
||||
def test_agent_with_technology_question(self):
|
||||
chat_response = self.agent.llm_chat_run("What is quantum computing?")
|
||||
self.assert_chat_response(chat_response)
|
||||
json_output_response = self.agent.llm_chat_with_json_output_run("What is quantum computing? Output in JSON format of {{'thinking': 'your thinking', 'answer': 'your answer'}}.")
|
||||
self.assert_json_output_response(json_output_response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,62 +0,0 @@
|
||||
import unittest
|
||||
from cerebrum.llm.apis import llm_chat
|
||||
from cerebrum.utils.communication import aios_kernel_url
|
||||
|
||||
class TestAgent:
|
||||
"""
|
||||
TestAgent class is responsible for interacting with the LLM API using llm_chat.
|
||||
It maintains a conversation history to simulate real dialogue behavior.
|
||||
"""
|
||||
def __init__(self, agent_name):
|
||||
self.agent_name = agent_name
|
||||
self.messages = []
|
||||
|
||||
def run(self, task_input):
|
||||
"""Sends the input to the LLM API and returns the response."""
|
||||
self.messages.append({"role": "user", "content": task_input})
|
||||
|
||||
tool_response = llm_chat(
|
||||
agent_name=self.agent_name,
|
||||
messages=self.messages,
|
||||
base_url=aios_kernel_url
|
||||
)
|
||||
|
||||
return tool_response["response"].get("response_message", "")
|
||||
|
||||
|
||||
class TestLLMAPI(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for the LLM API using the TestAgent class.
|
||||
Each test checks if the API returns a non-empty string response.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.agent = TestAgent("test_agent")
|
||||
|
||||
def assert_valid_response(self, response):
|
||||
"""Helper method to validate common response conditions."""
|
||||
self.assertIsInstance(response, str)
|
||||
self.assertGreater(len(response), 0)
|
||||
|
||||
def test_agent_with_greeting(self):
|
||||
response = self.agent.run("Hello, how are you?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_math_question(self):
|
||||
response = self.agent.run("What is 25 times 4?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_science_question(self):
|
||||
response = self.agent.run("Explain the theory of relativity.")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_history_question(self):
|
||||
response = self.agent.run("Who was the first president of the United States?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_technology_question(self):
|
||||
response = self.agent.run("What is quantum computing?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,112 +0,0 @@
|
||||
import unittest
|
||||
from cerebrum.llm.apis import llm_chat, llm_chat_with_json_output
|
||||
from cerebrum.utils.communication import aios_kernel_url
|
||||
|
||||
class TestAgent:
|
||||
"""
|
||||
TestAgent class is responsible for interacting with OpenAI's API using ChatCompletion.
|
||||
It maintains a conversation history to simulate real dialogue behavior.
|
||||
"""
|
||||
def __init__(self, agent_name, api_key):
|
||||
self.agent_name = agent_name
|
||||
self.api_key = api_key
|
||||
self.messages = []
|
||||
|
||||
def llm_chat_run(self, task_input):
|
||||
"""Sends the input to the OpenAI API and returns the response."""
|
||||
self.messages.append({"role": "user", "content": task_input})
|
||||
|
||||
tool_response = llm_chat(
|
||||
agent_name=self.agent_name,
|
||||
messages=self.messages,
|
||||
base_url=aios_kernel_url,
|
||||
llms=[{
|
||||
"name": "gpt-4o-mini",
|
||||
"backend": "openai"
|
||||
}]
|
||||
)
|
||||
|
||||
return tool_response["response"].get("response_message", "")
|
||||
|
||||
def llm_chat_with_json_output_run(self, task_input):
|
||||
"""Sends the input to the OpenAI API and returns the response."""
|
||||
self.messages.append({"role": "user", "content": task_input})
|
||||
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "keywords",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keywords": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tool_response = llm_chat_with_json_output(
|
||||
agent_name=self.agent_name,
|
||||
messages=self.messages,
|
||||
base_url=aios_kernel_url,
|
||||
llms=[{
|
||||
"name": "gpt-4o-mini",
|
||||
"backend": "openai"
|
||||
}],
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
return tool_response["response"].get("response_message", "")
|
||||
|
||||
|
||||
class TestLLMAPI(unittest.TestCase):
|
||||
"""
|
||||
Unit tests for OpenAI's API using the TestAgent class.
|
||||
Each test checks if the API returns a non-empty string response.
|
||||
Here has various category test cases, including greeting, math question, science question, history question, technology question.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.api_key = "your-openai-api-key" # Replace with your actual OpenAI API key
|
||||
self.agent = TestAgent("test_agent", self.api_key)
|
||||
|
||||
def assert_valid_response(self, response):
|
||||
"""Helper method to validate common response conditions."""
|
||||
self.assertIsInstance(response, str)
|
||||
self.assertGreater(len(response), 0)
|
||||
|
||||
def test_agent_with_greeting(self):
|
||||
response = self.agent.llm_chat_run("Hello, how are you?")
|
||||
self.assert_valid_response(response)
|
||||
response = self.agent.llm_chat_with_json_output_run("Hello, how are you?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_math_question(self):
|
||||
response = self.agent.llm_chat_run("What is 25 times 4?")
|
||||
self.assert_valid_response(response)
|
||||
response = self.agent.llm_chat_with_json_output_run("What is 25 times 4?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_science_question(self):
|
||||
response = self.agent.llm_chat_run("Explain the theory of relativity.")
|
||||
self.assert_valid_response(response)
|
||||
response = self.agent.llm_chat_with_json_output_run("Explain the theory of relativity.")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_history_question(self):
|
||||
response = self.agent.llm_chat_run("Who was the first president of the United States?")
|
||||
self.assert_valid_response(response)
|
||||
response = self.agent.llm_chat_with_json_output_run("Who was the first president of the United States?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
def test_agent_with_technology_question(self):
|
||||
response = self.agent.llm_chat_run("What is quantum computing?")
|
||||
self.assert_valid_response(response)
|
||||
response = self.agent.llm_chat_with_json_output_run("What is quantum computing?")
|
||||
self.assert_valid_response(response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user