819 lines
28 KiB
Python
Executable File
819 lines
28 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Provider Error Proxy - Simulates provider errors for testing Goose error handling.
|
|
|
|
This proxy intercepts HTTP traffic to AI providers and can inject errors interactively.
|
|
It supports the major providers: OpenAI, Anthropic, Google, OpenRouter, Tetrate, and Databricks.
|
|
|
|
Usage:
|
|
uv run python proxy.py [--port PORT]
|
|
|
|
Interactive commands:
|
|
n - No error (pass through) - permanent mode
|
|
c - Context length exceeded error (1 error by default)
|
|
c 4 - Context length exceeded error (4 errors in a row)
|
|
c 0.3 or c 30% - Context length exceeded error (30% of requests)
|
|
c * - Context length exceeded error (100% of requests)
|
|
r - Rate limit error
|
|
u - Unknown server error (500)
|
|
q - Quit
|
|
|
|
To use with Goose, set the provider host environment variables:
|
|
export OPENAI_HOST=http://localhost:8888
|
|
export ANTHROPIC_HOST=http://localhost:8888
|
|
# etc.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import random
|
|
import threading
|
|
from argparse import ArgumentParser
|
|
from enum import Enum
|
|
from typing import Optional
|
|
|
|
from aiohttp import web, ClientSession, ClientTimeout
|
|
from aiohttp.web import Request, Response, StreamResponse
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Provider endpoint mappings
|
|
PROVIDER_HOSTS = {
|
|
'openai': 'https://api.openai.com',
|
|
'anthropic': 'https://api.anthropic.com',
|
|
'google': 'https://generativelanguage.googleapis.com',
|
|
'openrouter': 'https://openrouter.ai',
|
|
'tetrate': 'https://api.tetrate.io',
|
|
'databricks': 'https://api.databricks.com',
|
|
}
|
|
|
|
# Paths that should always be forwarded without error injection
|
|
# These are typically authentication, configuration, or metadata endpoints
|
|
ALWAYS_FORWARD_PATHS = [
|
|
'/oidc/', # OIDC authentication endpoints
|
|
'/.well-known/', # Well-known endpoints for discovery
|
|
'/oauth', # OAuth endpoints
|
|
'/api/2.0/', # Databricks management API
|
|
]
|
|
|
|
|
|
class ErrorMode(Enum):
|
|
"""Error injection modes."""
|
|
NO_ERROR = 1
|
|
CONTEXT_LENGTH = 2
|
|
RATE_LIMIT = 3
|
|
SERVER_ERROR = 4
|
|
|
|
|
|
# Error responses for each provider and error type
|
|
ERROR_CONFIGS = {
|
|
'openai': {
|
|
ErrorMode.CONTEXT_LENGTH: {
|
|
'status': 400,
|
|
'body': {
|
|
'error': {
|
|
'message': "This model's maximum context length is 128000 tokens. However, your messages resulted in 150000 tokens. Please reduce the length of the messages.",
|
|
'type': 'invalid_request_error',
|
|
'code': 'context_length_exceeded'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.RATE_LIMIT: {
|
|
'status': 429,
|
|
'body': {
|
|
'error': {
|
|
'message': 'Rate limit exceeded. Please try again later.',
|
|
'type': 'rate_limit_error',
|
|
'code': 'rate_limit_exceeded'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.SERVER_ERROR: {
|
|
'status': 500,
|
|
'body': {
|
|
'error': {
|
|
'message': 'The server had an error while processing your request. Sorry about that!',
|
|
'type': 'server_error',
|
|
'code': 'internal_server_error'
|
|
}
|
|
}
|
|
}
|
|
},
|
|
'anthropic': {
|
|
ErrorMode.CONTEXT_LENGTH: {
|
|
'status': 400,
|
|
'body': {
|
|
'type': 'error',
|
|
'error': {
|
|
'type': 'invalid_request_error',
|
|
'message': 'prompt is too long: 150000 tokens > 100000 maximum'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.RATE_LIMIT: {
|
|
'status': 429,
|
|
'body': {
|
|
'type': 'error',
|
|
'error': {
|
|
'type': 'rate_limit_error',
|
|
'message': 'Rate limit exceeded. Please try again later.'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.SERVER_ERROR: {
|
|
'status': 529,
|
|
'body': {
|
|
'type': 'error',
|
|
'error': {
|
|
'type': 'overloaded_error',
|
|
'message': 'The API is temporarily overloaded. Please try again shortly.'
|
|
}
|
|
}
|
|
}
|
|
},
|
|
'google': {
|
|
ErrorMode.CONTEXT_LENGTH: {
|
|
'status': 400,
|
|
'body': {
|
|
'error': {
|
|
'code': 400,
|
|
'message': 'Request payload size exceeds the limit: 20000000 bytes.',
|
|
'status': 'INVALID_ARGUMENT'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.RATE_LIMIT: {
|
|
'status': 429,
|
|
'body': {
|
|
'error': {
|
|
'code': 429,
|
|
'message': 'Resource has been exhausted (e.g. check quota).',
|
|
'status': 'RESOURCE_EXHAUSTED'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.SERVER_ERROR: {
|
|
'status': 503,
|
|
'body': {
|
|
'error': {
|
|
'code': 503,
|
|
'message': 'Service temporarily unavailable',
|
|
'status': 'UNAVAILABLE'
|
|
}
|
|
}
|
|
}
|
|
},
|
|
'openrouter': {
|
|
ErrorMode.CONTEXT_LENGTH: {
|
|
'status': 400,
|
|
'body': {
|
|
'error': {
|
|
'message': 'This model maximum context length is 128000 tokens, however you requested 150000 tokens',
|
|
'code': 400
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.RATE_LIMIT: {
|
|
'status': 429,
|
|
'body': {
|
|
'error': {
|
|
'message': 'Rate limit exceeded',
|
|
'code': 429
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.SERVER_ERROR: {
|
|
'status': 500,
|
|
'body': {
|
|
'error': {
|
|
'message': 'Internal server error',
|
|
'code': 500
|
|
}
|
|
}
|
|
}
|
|
},
|
|
'tetrate': {
|
|
ErrorMode.CONTEXT_LENGTH: {
|
|
'status': 400,
|
|
'body': {
|
|
'error': {
|
|
'message': 'Request exceeds maximum context length',
|
|
'code': 'context_length_exceeded'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.RATE_LIMIT: {
|
|
'status': 429,
|
|
'body': {
|
|
'error': {
|
|
'message': 'Rate limit exceeded',
|
|
'code': 'rate_limit_exceeded'
|
|
}
|
|
}
|
|
},
|
|
ErrorMode.SERVER_ERROR: {
|
|
'status': 503,
|
|
'body': {
|
|
'error': {
|
|
'message': 'Service unavailable',
|
|
'code': 'service_unavailable'
|
|
}
|
|
}
|
|
}
|
|
},
|
|
'databricks': {
|
|
ErrorMode.CONTEXT_LENGTH: {
|
|
'status': 400,
|
|
'body': {
|
|
'error_code': 'INVALID_PARAMETER_VALUE',
|
|
'message': 'The total number of tokens in the request exceeds the maximum allowed'
|
|
}
|
|
},
|
|
ErrorMode.RATE_LIMIT: {
|
|
'status': 429,
|
|
'body': {
|
|
'error_code': 'RATE_LIMIT_EXCEEDED',
|
|
'message': 'Rate limit exceeded'
|
|
}
|
|
},
|
|
ErrorMode.SERVER_ERROR: {
|
|
'status': 500,
|
|
'body': {
|
|
'error_code': 'INTERNAL_ERROR',
|
|
'message': 'Internal server error'
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
class ErrorProxy:
|
|
"""HTTP proxy that can inject errors into provider responses."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the error proxy."""
|
|
self.error_mode = ErrorMode.NO_ERROR
|
|
self.error_count = 0 # Remaining errors to inject (0 = unlimited/percentage mode)
|
|
self.error_percentage = 0.0 # Percentage of requests to error (0.0 = count mode)
|
|
self.request_count = 0
|
|
self.session: Optional[ClientSession] = None
|
|
self.lock = threading.Lock()
|
|
|
|
def set_error_mode(self, mode: ErrorMode, count: int = 1, percentage: float = 0.0):
|
|
"""
|
|
Set the error injection mode.
|
|
|
|
Args:
|
|
mode: The error mode to use
|
|
count: Number of errors to inject (default 1, 0 for unlimited)
|
|
percentage: Percentage of requests to error (0.0-1.0, 0.0 for count mode)
|
|
"""
|
|
with self.lock:
|
|
self.error_mode = mode
|
|
self.error_count = count
|
|
self.error_percentage = percentage
|
|
|
|
def should_inject_error(self) -> bool:
|
|
"""
|
|
Determine if we should inject an error for this request.
|
|
|
|
Returns:
|
|
True if an error should be injected, False otherwise
|
|
"""
|
|
with self.lock:
|
|
if self.error_mode == ErrorMode.NO_ERROR:
|
|
return False
|
|
|
|
# Percentage mode
|
|
if self.error_percentage > 0.0:
|
|
return random.random() < self.error_percentage
|
|
|
|
# Count mode
|
|
if self.error_count > 0:
|
|
self.error_count -= 1
|
|
# If this was the last error, switch back to NO_ERROR
|
|
if self.error_count == 0:
|
|
self.error_mode = ErrorMode.NO_ERROR
|
|
return True
|
|
elif self.error_count == 0 and self.error_percentage == 0.0:
|
|
# Count reached zero, switch back to NO_ERROR
|
|
self.error_mode = ErrorMode.NO_ERROR
|
|
return False
|
|
|
|
return False
|
|
|
|
def get_error_mode(self) -> ErrorMode:
|
|
"""Get the current error injection mode."""
|
|
with self.lock:
|
|
return self.error_mode
|
|
|
|
def get_error_config(self) -> tuple[ErrorMode, int, float]:
|
|
"""Get the current error configuration."""
|
|
with self.lock:
|
|
return (self.error_mode, self.error_count, self.error_percentage)
|
|
|
|
async def start_session(self):
|
|
"""Start the aiohttp client session."""
|
|
timeout = ClientTimeout(total=600) # Match provider timeout
|
|
self.session = ClientSession(timeout=timeout)
|
|
|
|
async def close_session(self):
|
|
"""Close the aiohttp client session."""
|
|
if self.session:
|
|
await self.session.close()
|
|
|
|
def detect_provider(self, request: Request) -> str:
|
|
"""
|
|
Detect which provider this request is for based on headers and path.
|
|
|
|
Args:
|
|
request: The incoming HTTP request
|
|
|
|
Returns:
|
|
Provider name
|
|
"""
|
|
path = request.path.lower()
|
|
|
|
# Check for databricks-specific paths first (before header checks)
|
|
if '/serving-endpoints/' in path or '/api/2.0/' in path or '/oidc/' in path:
|
|
return 'databricks'
|
|
|
|
# Check for provider-specific headers
|
|
if 'x-api-key' in request.headers:
|
|
return 'anthropic'
|
|
if 'x-goog-api-key' in request.headers:
|
|
return 'google'
|
|
if 'authorization' in request.headers:
|
|
auth = request.headers['authorization'].lower()
|
|
if 'bearer' in auth:
|
|
# Most providers use bearer tokens, check path for hints
|
|
if 'anthropic' in path or 'messages' in path:
|
|
return 'anthropic'
|
|
if 'google' in path or 'generativelanguage' in path:
|
|
return 'google'
|
|
if 'openrouter' in path:
|
|
return 'openrouter'
|
|
if 'tetrate' in path:
|
|
return 'tetrate'
|
|
if 'databricks' in path:
|
|
return 'databricks'
|
|
# Default to openai for bearer tokens
|
|
return 'openai'
|
|
|
|
# Default to openai if we can't determine
|
|
return 'openai'
|
|
|
|
def should_always_forward(self, request: Request) -> bool:
|
|
"""
|
|
Check if this request should always be forwarded (never injected with errors).
|
|
|
|
Args:
|
|
request: The incoming HTTP request
|
|
|
|
Returns:
|
|
True if request should always be forwarded
|
|
"""
|
|
path = request.path
|
|
for forward_path in ALWAYS_FORWARD_PATHS:
|
|
if forward_path in path:
|
|
return True
|
|
return False
|
|
|
|
def get_target_url(self, request: Request, provider: str) -> str:
|
|
"""
|
|
Construct the target URL for the provider.
|
|
|
|
Args:
|
|
request: The incoming HTTP request
|
|
provider: The detected provider name
|
|
|
|
Returns:
|
|
Full target URL
|
|
"""
|
|
# Check for provider-specific real host in environment
|
|
real_host_env = f"{provider.upper()}_REAL_HOST"
|
|
base_host = os.environ.get(real_host_env)
|
|
|
|
# If no provider-specific real host and this is an always-forward path,
|
|
# check if ANY *_REAL_HOST is set (for auth endpoints where provider detection might fail)
|
|
if base_host is None and self.should_always_forward(request):
|
|
for provider_name in PROVIDER_HOSTS.keys():
|
|
env_var = f"{provider_name.upper()}_REAL_HOST"
|
|
if env_var in os.environ:
|
|
base_host = os.environ[env_var]
|
|
logger.info(f"Using {env_var} for always-forward path")
|
|
break
|
|
|
|
# Fall back to default provider host
|
|
if base_host is None:
|
|
base_host = PROVIDER_HOSTS.get(provider, PROVIDER_HOSTS['openai'])
|
|
|
|
path = request.path
|
|
query = request.query_string
|
|
|
|
url = f"{base_host}{path}"
|
|
if query:
|
|
url = f"{url}?{query}"
|
|
|
|
return url
|
|
|
|
def _format_status_line(self) -> str:
|
|
"""Format a one-line status indicator."""
|
|
mode, count, percentage = self.get_error_config()
|
|
mode_symbols = {
|
|
ErrorMode.NO_ERROR: "✅",
|
|
ErrorMode.CONTEXT_LENGTH: "📏",
|
|
ErrorMode.RATE_LIMIT: "⏱️",
|
|
ErrorMode.SERVER_ERROR: "💥"
|
|
}
|
|
|
|
symbol = mode_symbols.get(mode, "❓")
|
|
mode_name = mode.name.replace('_', ' ').title()
|
|
|
|
if mode == ErrorMode.NO_ERROR:
|
|
return f"{symbol} {mode_name}"
|
|
elif percentage > 0.0:
|
|
return f"{symbol} {mode_name} ({percentage*100:.0f}%)"
|
|
elif count > 0:
|
|
return f"{symbol} {mode_name} ({count} remaining)"
|
|
else:
|
|
return f"{symbol} {mode_name}"
|
|
|
|
async def handle_request(self, request: Request) -> Response:
|
|
"""
|
|
Handle an incoming HTTP request.
|
|
|
|
Args:
|
|
request: The incoming HTTP request
|
|
|
|
Returns:
|
|
HTTP response (either proxied or error)
|
|
"""
|
|
self.request_count += 1
|
|
provider = self.detect_provider(request)
|
|
|
|
logger.info(f"📨 Request #{self.request_count}: {request.method} {request.path} -> {provider}")
|
|
|
|
# Check if this request should always be forwarded
|
|
if self.should_always_forward(request):
|
|
logger.info(f"🔄 Always forwarding: {request.path}")
|
|
else:
|
|
# Capture the error mode BEFORE checking if we should inject (since that modifies state)
|
|
mode_before_check = self.get_error_mode()
|
|
|
|
# Check if we should inject an error
|
|
should_error = self.should_inject_error()
|
|
if should_error:
|
|
# Use the mode captured before the check, since should_inject_error may have changed it
|
|
error_config = ERROR_CONFIGS.get(provider, ERROR_CONFIGS['openai']).get(
|
|
mode_before_check, ERROR_CONFIGS['openai'][ErrorMode.SERVER_ERROR]
|
|
)
|
|
logger.warning(f"💥 Injecting {mode_before_check.name} error (status {error_config['status']}) for {provider}")
|
|
# Show status after the injection to reflect the updated state
|
|
logger.info(f"Status: {self._format_status_line()}")
|
|
return web.json_response(
|
|
error_config['body'],
|
|
status=error_config['status']
|
|
)
|
|
|
|
# Forward the request to the actual provider
|
|
target_url = self.get_target_url(request, provider)
|
|
|
|
try:
|
|
# Read request body
|
|
body = await request.read()
|
|
|
|
# Copy headers, excluding hop-by-hop headers
|
|
headers = {k: v for k, v in request.headers.items()
|
|
if k.lower() not in ('host', 'connection', 'keep-alive',
|
|
'proxy-authenticate', 'proxy-authorization',
|
|
'te', 'trailers', 'transfer-encoding', 'upgrade')}
|
|
|
|
# Make the proxied request
|
|
async with self.session.request(
|
|
method=request.method,
|
|
url=target_url,
|
|
headers=headers,
|
|
data=body,
|
|
allow_redirects=False
|
|
) as resp:
|
|
# Copy response headers
|
|
# For non-streaming responses, we need to exclude content-encoding and content-length
|
|
# because aiohttp.Response.read() automatically decompresses the body
|
|
response_headers = {k: v for k, v in resp.headers.items()
|
|
if k.lower() not in ('connection', 'keep-alive',
|
|
'transfer-encoding', 'content-encoding',
|
|
'content-length')}
|
|
|
|
# Check if this is a streaming response (SSE)
|
|
content_type = resp.headers.get('content-type', '').lower()
|
|
is_streaming = 'text/event-stream' in content_type
|
|
|
|
if is_streaming:
|
|
# Stream the response (Server-Sent Events)
|
|
logger.info(f"🌊 Streaming response: {resp.status}")
|
|
response = StreamResponse(
|
|
status=resp.status,
|
|
headers=response_headers
|
|
)
|
|
await response.prepare(request)
|
|
|
|
# Stream chunks from provider to client
|
|
try:
|
|
async for chunk in resp.content.iter_any():
|
|
await response.write(chunk)
|
|
await response.write_eof()
|
|
except Exception as stream_error:
|
|
logger.warning(f"Stream write error (client may have disconnected): {stream_error}")
|
|
logger.info(f"Status: {self._format_status_line()}")
|
|
return response
|
|
else:
|
|
# Non-streaming response - read entire body
|
|
response_body = await resp.read()
|
|
logger.info(f"✅ Proxied response: {resp.status}")
|
|
logger.info(f"Status: {self._format_status_line()}")
|
|
|
|
return Response(
|
|
body=response_body,
|
|
status=resp.status,
|
|
headers=response_headers
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error proxying request: {e}", exc_info=True)
|
|
return web.json_response(
|
|
{'error': {'message': f'Proxy error: {str(e)}'}},
|
|
status=500
|
|
)
|
|
|
|
|
|
def parse_command(command: str) -> tuple[Optional[ErrorMode], int, float, Optional[str]]:
|
|
"""
|
|
Parse a command string and return the error mode, count, and percentage.
|
|
|
|
Args:
|
|
command: Command string (e.g., "c", "c 3", "r 30%", "u *")
|
|
|
|
Returns:
|
|
Tuple of (mode, count, percentage, error_message)
|
|
If error_message is not None, parsing failed
|
|
"""
|
|
# Parse command - remove all whitespace and parse
|
|
command_no_space = command.strip().replace(" ", "")
|
|
if not command_no_space:
|
|
return (None, 0, 0.0, "Empty command")
|
|
|
|
# Get the first character (error type letter)
|
|
error_letter = command_no_space[0].lower()
|
|
|
|
# Map letter to ErrorMode
|
|
mode_map = {
|
|
'n': ErrorMode.NO_ERROR,
|
|
'c': ErrorMode.CONTEXT_LENGTH,
|
|
'r': ErrorMode.RATE_LIMIT,
|
|
'u': ErrorMode.SERVER_ERROR
|
|
}
|
|
|
|
if error_letter not in mode_map:
|
|
return (None, 0, 0.0, f"Invalid command: '{error_letter}'. Use n, c, r, or u")
|
|
|
|
mode = mode_map[error_letter]
|
|
|
|
# Parse the rest as count or percentage
|
|
count = 1
|
|
percentage = 0.0
|
|
|
|
if len(command_no_space) > 1:
|
|
value_str = command_no_space[1:]
|
|
|
|
try:
|
|
# Check for * (100%)
|
|
if value_str == '*':
|
|
percentage = 1.0
|
|
count = 0 # Percentage mode
|
|
# Check for percentage with % sign (e.g., "30%")
|
|
elif value_str.endswith('%'):
|
|
percentage = float(value_str[:-1]) / 100.0
|
|
if percentage < 0.0 or percentage > 1.0:
|
|
return (None, 0, 0.0, f"Invalid percentage: {percentage*100:.0f}%. Must be between 0% and 100%")
|
|
count = 0 # Percentage mode
|
|
# Check if it's a decimal (percentage as 0.0-1.0)
|
|
elif '.' in value_str:
|
|
percentage = float(value_str)
|
|
if percentage < 0.0 or percentage > 1.0:
|
|
return (None, 0, 0.0, f"Invalid percentage: {percentage}. Must be between 0.0 and 1.0")
|
|
count = 0 # Percentage mode
|
|
else:
|
|
# It's an integer count
|
|
count = int(value_str)
|
|
if count < 0:
|
|
return (None, 0, 0.0, f"Invalid count: {count}. Must be >= 0")
|
|
except ValueError:
|
|
return (None, 0, 0.0, f"Invalid value: '{value_str}'. Must be an integer, decimal, percentage (30%), or * (100%)")
|
|
|
|
return (mode, count, percentage, None)
|
|
|
|
|
|
def print_status(proxy: ErrorProxy):
|
|
"""Print the current proxy status."""
|
|
mode, count, percentage = proxy.get_error_config()
|
|
mode_names = {
|
|
ErrorMode.NO_ERROR: "✅ No error (pass through)",
|
|
ErrorMode.CONTEXT_LENGTH: "📏 Context length exceeded",
|
|
ErrorMode.RATE_LIMIT: "⏱️ Rate limit exceeded",
|
|
ErrorMode.SERVER_ERROR: "💥 Server error (500)"
|
|
}
|
|
|
|
print("\n" + "=" * 60)
|
|
mode_str = mode_names.get(mode, 'Unknown')
|
|
if mode != ErrorMode.NO_ERROR:
|
|
if percentage > 0.0:
|
|
mode_str += f" ({percentage*100:.0f}% of requests)"
|
|
elif count > 0:
|
|
mode_str += f" ({count} remaining)"
|
|
print(f"Current mode: {mode_str}")
|
|
print(f"Requests handled: {proxy.request_count}")
|
|
print("=" * 60)
|
|
print("\nCommands:")
|
|
print(" n - No error (pass through) - permanent")
|
|
print(" c - Context length exceeded (1 time)")
|
|
print(" c 4 - Context length exceeded (4 times)")
|
|
print(" c 0.3 - Context length exceeded (30% of requests)")
|
|
print(" c 30% - Context length exceeded (30% of requests)")
|
|
print(" c * - Context length exceeded (100% of requests)")
|
|
print(" r - Rate limit error (1 time)")
|
|
print(" u - Unknown server error (1 time)")
|
|
print(" q - Quit")
|
|
print()
|
|
|
|
|
|
def stdin_reader(proxy: ErrorProxy, loop):
|
|
"""Read commands from stdin in a separate thread."""
|
|
print_status(proxy)
|
|
|
|
while True:
|
|
try:
|
|
command = input("Enter command: ").strip()
|
|
|
|
if command.lower() == 'q':
|
|
print("\n🛑 Shutting down proxy...")
|
|
# Schedule the shutdown in the event loop
|
|
asyncio.run_coroutine_threadsafe(shutdown_server(loop), loop)
|
|
break
|
|
|
|
# Parse the command using the shared parser
|
|
mode, count, percentage, error_msg = parse_command(command)
|
|
|
|
if error_msg:
|
|
print(f"❌ {error_msg}")
|
|
continue
|
|
|
|
# Set the error mode
|
|
proxy.set_error_mode(mode, count, percentage)
|
|
print_status(proxy)
|
|
|
|
except EOFError:
|
|
# Handle Ctrl+D
|
|
print("\n🛑 Shutting down proxy...")
|
|
asyncio.run_coroutine_threadsafe(shutdown_server(loop), loop)
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error reading stdin: {e}")
|
|
|
|
|
|
async def shutdown_server(loop):
|
|
"""Shutdown the server gracefully."""
|
|
# Stop the event loop
|
|
loop.stop()
|
|
|
|
|
|
async def create_app(proxy: ErrorProxy) -> web.Application:
|
|
"""
|
|
Create the aiohttp application.
|
|
|
|
Args:
|
|
proxy: The ErrorProxy instance
|
|
|
|
Returns:
|
|
Configured aiohttp application
|
|
"""
|
|
app = web.Application()
|
|
|
|
# Setup and teardown
|
|
async def on_startup(app):
|
|
await proxy.start_session()
|
|
logger.info("🚀 Proxy session started")
|
|
|
|
async def on_cleanup(app):
|
|
await proxy.close_session()
|
|
logger.info("🛑 Proxy session closed")
|
|
|
|
app.on_startup.append(on_startup)
|
|
app.on_cleanup.append(on_cleanup)
|
|
|
|
# Route all requests through the proxy
|
|
app.router.add_route('*', '/{path:.*}', proxy.handle_request)
|
|
|
|
return app
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
parser = ArgumentParser(description='Provider Error Proxy for Goose testing')
|
|
parser.add_argument(
|
|
'--port',
|
|
type=int,
|
|
default=8888,
|
|
help='Port to listen on (default: 8888)'
|
|
)
|
|
parser.add_argument(
|
|
'--mode',
|
|
type=str,
|
|
help='Error mode command (e.g., "c 3", "r 30%%", "u *", "n")'
|
|
)
|
|
parser.add_argument(
|
|
'--no-stdin',
|
|
action='store_true',
|
|
help='Disable stdin reader (for background/automated mode)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
print("=" * 60)
|
|
print("🔧 Provider Error Proxy")
|
|
print("=" * 60)
|
|
print(f"Port: {args.port}")
|
|
print()
|
|
print("To use with Goose, set these environment variables:")
|
|
print(f" export OPENAI_HOST=http://localhost:{args.port}")
|
|
print(f" export ANTHROPIC_HOST=http://localhost:{args.port}")
|
|
print(f" export GOOGLE_HOST=http://localhost:{args.port}")
|
|
print(f" export OPENROUTER_HOST=http://localhost:{args.port}")
|
|
print(f" export TETRATE_HOST=http://localhost:{args.port}")
|
|
print(f" export DATABRICKS_HOST=http://localhost:{args.port}")
|
|
print("=" * 60)
|
|
|
|
# Create proxy instance
|
|
proxy = ErrorProxy()
|
|
|
|
# Set initial error mode from command-line arguments
|
|
if args.mode:
|
|
mode, count, percentage, error_msg = parse_command(args.mode)
|
|
|
|
if error_msg:
|
|
print(f"❌ Error parsing --mode argument: {error_msg}")
|
|
print(f" Example usage: --mode \"c 3\" or --mode \"r 30%\"")
|
|
return
|
|
|
|
proxy.set_error_mode(mode, count, percentage)
|
|
print()
|
|
print(f"Initial mode set from command-line arguments:")
|
|
print(f" Mode: {mode.name}")
|
|
if percentage > 0.0:
|
|
print(f" Percentage: {percentage*100:.0f}%")
|
|
elif count > 0:
|
|
print(f" Count: {count}")
|
|
print()
|
|
|
|
# Create event loop
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
# Start stdin reader thread only if not disabled
|
|
if not args.no_stdin:
|
|
stdin_thread = threading.Thread(target=stdin_reader, args=(proxy, loop), daemon=True)
|
|
stdin_thread.start()
|
|
else:
|
|
print("Running in no-stdin mode (background/automated)")
|
|
print("Use SIGINT (Ctrl+C) or SIGTERM to stop the proxy")
|
|
print()
|
|
|
|
# Create and run the app
|
|
app = loop.run_until_complete(create_app(proxy))
|
|
|
|
# Run the web server
|
|
runner = web.AppRunner(app)
|
|
loop.run_until_complete(runner.setup())
|
|
site = web.TCPSite(runner, 'localhost', args.port)
|
|
loop.run_until_complete(site.start())
|
|
|
|
logger.info(f"Proxy running on http://localhost:{args.port}")
|
|
|
|
try:
|
|
loop.run_forever()
|
|
except KeyboardInterrupt:
|
|
print("\n🛑 Shutting down proxy...")
|
|
finally:
|
|
loop.run_until_complete(runner.cleanup())
|
|
loop.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|