fix: Proper refresh of Groq models (#12158)

* fix: Proper refresh of Groq models

* Update groq_model_discovery.py

* Update src/lfx/src/lfx/base/models/groq_model_discovery.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update src/lfx/src/lfx/base/models/groq_model_discovery.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Add more unit tests

* Update src/lfx/src/lfx/base/models/groq_model_discovery.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* More thorough unit tests

* Update test_groq_model_discovery.py

* Update groq_model_discovery.py

* Update src/backend/tests/unit/groq/test_groq_model_discovery.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update src/backend/tests/unit/groq/test_groq_model_discovery.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update test_groq_model_discovery.py

* Update groq_model_discovery.py

* Update groq.py

* [autofix.ci] apply automated fixes

* PR review comments

* Redundant errors

* [autofix.ci] apply automated fixes

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eric Hare
2026-03-13 11:15:52 -07:00
committed by GitHub
parent e6d6d2e4b5
commit aea07965a0
11 changed files with 861 additions and 563 deletions

View File

@@ -261,6 +261,34 @@ def mock_groq_client_rate_limit():
return _create_mock_client
@pytest.fixture
def mock_groq_client_chat_not_supported():
"""Mock Groq client that raises 'does not support chat completions' error."""
def _create_mock_client(*_args, **_kwargs):
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = ValueError(
"Error: model 'some-model' does not support chat completions"
)
return mock_client
return _create_mock_client
@pytest.fixture
def mock_groq_client_chat_terms_required():
"""Mock Groq client that raises a terms_required error."""
def _create_mock_client(*_args, **_kwargs):
mock_client = MagicMock()
mock_client.chat.completions.create.side_effect = ValueError(
"Error: model_terms_required - please accept the terms"
)
return mock_client
return _create_mock_client
@pytest.fixture
def sample_models_metadata():
"""Sample model metadata dictionary for testing."""

View File

@@ -0,0 +1,157 @@
"""Tests for Groq _test_chat_completion method."""
import sys
from unittest.mock import MagicMock, Mock, patch
import pytest
from lfx.base.models.groq_model_discovery import GroqModelDiscovery
class TestChatCompletionDetection:
"""Test _test_chat_completion method."""
@patch("groq.Groq")
def test_chat_completion_success(self, mock_groq, mock_api_key, mock_groq_client_tool_calling_success):
"""Test successful chat completion returns True."""
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_chat_completion("llama-3.1-8b-instant")
assert result is True
@patch("groq.Groq")
def test_chat_completion_not_supported(self, mock_groq, mock_api_key, mock_groq_client_chat_not_supported):
"""Test model that does not support chat completions returns False."""
mock_groq.return_value = mock_groq_client_chat_not_supported()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_chat_completion("speech-model")
assert result is False
@patch("groq.Groq")
def test_chat_completion_terms_required_returns_none(
self, mock_groq, mock_api_key, mock_groq_client_chat_terms_required
):
"""Test that access/entitlement errors cause _test_chat_completion to return None."""
mock_groq.return_value = mock_groq_client_chat_terms_required()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_chat_completion("gated-model")
assert result is None
@patch("groq.Groq")
def test_chat_completion_transient_error_returns_none(self, mock_groq, mock_api_key, mock_groq_client_rate_limit):
"""Test that transient errors (e.g. rate limits) return None (indeterminate)."""
mock_groq.return_value = mock_groq_client_rate_limit()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_chat_completion("llama-3.1-8b-instant")
assert result is None
def test_chat_completion_import_error_raises(self, mock_api_key):
"""Test that ImportError propagates when the groq package is not installed."""
# Simulate groq not being installed by hiding it from sys.modules
with patch.dict(sys.modules, {"groq": None}):
discovery = GroqModelDiscovery(api_key=mock_api_key)
with pytest.raises(ImportError):
discovery._test_chat_completion("test-model")
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_chat_failure_marks_model_not_supported(
self,
mock_groq,
mock_get,
mock_api_key,
temp_cache_dir,
):
"""Test that a model failing the chat test is marked not_supported in get_models."""
mock_response = Mock()
mock_response.json.return_value = {
"data": [
{"id": "llama-3.1-8b-instant", "object": "model"},
{"id": "speech-model-v1", "object": "model"},
]
}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# First Groq() call: chat test for llama (succeeds)
# Second Groq() call: tool test for llama (succeeds)
# Third Groq() call: chat test for speech-model (fails with "does not support chat completions")
call_count = [0]
def create_mock_client(*_args, **_kwargs):
mock_client = MagicMock()
if call_count[0] <= 1:
# chat + tool test for llama: succeed
mock_client.chat.completions.create.return_value = MagicMock()
else:
# chat test for speech-model: fails
mock_client.chat.completions.create.side_effect = ValueError(
"Error: model 'speech-model-v1' does not support chat completions"
)
call_count[0] += 1
return mock_client
mock_groq.side_effect = create_mock_client
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# llama should be a normal LLM model with tool_calling
assert "tool_calling" in models["llama-3.1-8b-instant"]
assert models["llama-3.1-8b-instant"].get("not_supported") is None
# speech-model should be marked not_supported
assert models["speech-model-v1"]["not_supported"] is True
assert "tool_calling" not in models["speech-model-v1"]
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_transient_chat_error_does_not_exclude_model(
self,
mock_groq,
mock_get,
mock_api_key,
temp_cache_dir,
):
"""Test that transient chat errors (rate limits) don't incorrectly exclude models."""
mock_response = Mock()
mock_response.json.return_value = {
"data": [
{"id": "rate-limited-model", "object": "model"},
]
}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# First Groq() call: chat test hits rate limit (transient error)
# Second Groq() call: tool test succeeds
call_count = [0]
def create_mock_client(*_args, **_kwargs):
mock_client = MagicMock()
if call_count[0] == 0:
mock_client.chat.completions.create.side_effect = RuntimeError("Rate limit exceeded")
else:
mock_client.chat.completions.create.return_value = MagicMock()
call_count[0] += 1
return mock_client
mock_groq.side_effect = create_mock_client
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# Model should NOT be excluded — it should be treated as a normal LLM
assert "tool_calling" in models["rate-limited-model"]
assert models["rate-limited-model"].get("not_supported") is None

View File

@@ -0,0 +1,39 @@
"""Tests for the get_groq_models() convenience function."""
from unittest.mock import patch
from lfx.base.models.groq_model_discovery import GroqModelDiscovery, get_groq_models
class TestGetGroqModelsConvenienceFunction:
"""Test the convenience function get_groq_models()."""
@patch.object(GroqModelDiscovery, "get_models")
def test_get_groq_models_with_api_key(self, mock_get_models, mock_api_key):
"""Test get_groq_models() function with API key."""
mock_get_models.return_value = {"llama-3.1-8b-instant": {}}
models = get_groq_models(api_key=mock_api_key)
assert "llama-3.1-8b-instant" in models
mock_get_models.assert_called_once_with(force_refresh=False)
@patch.object(GroqModelDiscovery, "get_models")
def test_get_groq_models_without_api_key(self, mock_get_models):
"""Test get_groq_models() function without API key."""
mock_get_models.return_value = {"llama-3.1-8b-instant": {}}
models = get_groq_models()
assert "llama-3.1-8b-instant" in models
mock_get_models.assert_called_once_with(force_refresh=False)
@patch.object(GroqModelDiscovery, "get_models")
def test_get_groq_models_force_refresh(self, mock_get_models, mock_api_key):
"""Test get_groq_models() with force_refresh."""
mock_get_models.return_value = {"llama-3.1-8b-instant": {}}
models = get_groq_models(api_key=mock_api_key, force_refresh=True)
assert "llama-3.1-8b-instant" in models
mock_get_models.assert_called_once_with(force_refresh=True)

View File

@@ -0,0 +1,152 @@
"""Tests for edge cases in Groq model discovery."""
from unittest.mock import MagicMock, Mock, patch
from lfx.base.models.groq_model_discovery import GroqModelDiscovery
class TestGroqModelDiscoveryEdgeCases:
"""Test edge cases in model discovery."""
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_empty_model_list_from_api(self, mock_get, mock_api_key, temp_cache_dir):
"""Test handling of empty model list from API."""
# Mock empty response
mock_response = Mock()
mock_response.json.return_value = {"data": []}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# Should return empty dict (or potentially fallback)
assert isinstance(models, dict)
def test_cache_file_not_exists(self, mock_api_key, temp_cache_dir):
"""Test loading cache when file doesn't exist."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "nonexistent.json"
loaded = discovery._load_cache()
assert loaded is None
def test_cache_directory_created_on_save(self, mock_api_key, temp_cache_dir, sample_models_metadata):
"""Test that cache directory is created if it doesn't exist."""
cache_file = temp_cache_dir / "new_dir" / ".cache" / "test_cache.json"
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = cache_file
# Directory shouldn't exist yet
assert not cache_file.parent.exists()
# Save cache
discovery._save_cache(sample_models_metadata)
# Directory should be created
assert cache_file.parent.exists()
assert cache_file.exists()
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_preview_model_detection(
self,
mock_groq,
mock_get,
mock_api_key,
mock_groq_client_tool_calling_success,
temp_cache_dir,
):
"""Test detection of preview models."""
# Mock API with preview models
mock_response = Mock()
mock_response.json.return_value = {
"data": [
{"id": "llama-3.2-1b-preview", "object": "model"},
{"id": "meta-llama/llama-3.2-90b-preview", "object": "model"},
]
}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# Models with "preview" in name should be marked as preview
assert models["llama-3.2-1b-preview"]["preview"] is True
# Models with "/" should be marked as preview
assert models["meta-llama/llama-3.2-90b-preview"]["preview"] is True
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_mixed_tool_calling_support(
self,
mock_groq,
mock_get,
mock_api_key,
temp_cache_dir,
):
"""Test models with mixed tool calling support."""
# Mock API
mock_response = Mock()
mock_response.json.return_value = {
"data": [
{"id": "llama-3.1-8b-instant", "object": "model"},
{"id": "gemma-7b-it", "object": "model"},
]
}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling - each model goes through chat test then tool test
# Call order: chat(llama), tool(llama), chat(gemma), tool(gemma)
call_count = [0]
def create_mock_client(*_args, **_kwargs):
mock_client = MagicMock()
if call_count[0] <= 2:
# Calls 0-2: chat test for llama (success), tool test for llama (success),
# chat test for gemma (success)
mock_client.chat.completions.create.return_value = MagicMock()
else:
# Call 3: tool test for gemma (fails)
mock_client.chat.completions.create.side_effect = ValueError("tool calling not supported")
call_count[0] += 1
return mock_client
mock_groq.side_effect = create_mock_client
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# First model should support tools
assert models["llama-3.1-8b-instant"]["tool_calling"] is True
# Second model should not support tools
assert models["gemma-7b-it"]["tool_calling"] is False
def test_fallback_models_structure(self, mock_api_key):
"""Test that fallback models have the correct structure."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
fallback = discovery._get_fallback_models()
assert isinstance(fallback, dict)
assert len(fallback) == 2
for metadata in fallback.values():
assert "name" in metadata
assert "provider" in metadata
assert "tool_calling" in metadata
assert "preview" in metadata
assert metadata["tool_calling"] is True # Fallback models should support tools

View File

@@ -0,0 +1,151 @@
"""Tests for error handling in Groq model discovery."""
import json
import sys
from unittest.mock import Mock, patch
import pytest
from lfx.base.models.groq_model_discovery import GroqModelDiscovery
class TestGroqModelDiscoveryErrors:
"""Test error handling in model discovery."""
def test_no_api_key_returns_fallback(self):
"""Test that missing API key returns fallback models."""
discovery = GroqModelDiscovery(api_key=None)
models = discovery.get_models(force_refresh=True)
# Should return minimal fallback list
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
assert len(models) == 2
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_api_connection_error_returns_fallback(self, mock_get, mock_api_key, mock_requests_get_failure):
"""Test that API connection errors return fallback models."""
mock_get.side_effect = mock_requests_get_failure
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_api_timeout_returns_fallback(self, mock_get, mock_api_key, mock_requests_get_timeout):
"""Test that API timeouts return fallback models."""
mock_get.side_effect = mock_requests_get_timeout
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_api_unauthorized_returns_fallback(self, mock_get, mock_api_key, mock_requests_get_unauthorized):
"""Test that unauthorized API requests return fallback models."""
mock_get.side_effect = mock_requests_get_unauthorized
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_invalid_api_response_returns_fallback(self, mock_get, mock_api_key):
"""Test that invalid API response structure returns fallback models."""
# Mock response with missing 'data' field
mock_response = Mock()
mock_response.json.return_value = {"error": "invalid"}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
def test_corrupted_cache_returns_none(self, mock_api_key, mock_corrupted_cache_file):
"""Test that corrupted cache file returns None."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_corrupted_cache_file
loaded = discovery._load_cache()
assert loaded is None
def test_cache_missing_fields_returns_none(self, mock_api_key, temp_cache_dir):
"""Test that cache with missing required fields returns None."""
cache_file = temp_cache_dir / ".cache" / "invalid_cache.json"
cache_file.parent.mkdir(parents=True, exist_ok=True)
# Cache missing 'cached_at' field
cache_data = {"models": {"llama-3.1-8b-instant": {}}}
with cache_file.open("w") as f:
json.dump(cache_data, f)
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = cache_file
loaded = discovery._load_cache()
assert loaded is None
def test_cache_save_failure_logs_warning(self, mock_api_key, temp_cache_dir, sample_models_metadata):
"""Test that cache save failures are logged but don't crash."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
# Set cache file to a path that can't be written (directory instead of file)
discovery.CACHE_FILE = temp_cache_dir
# This should not raise an exception
discovery._save_cache(sample_models_metadata)
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_import_error_during_chat_test_returns_fallback(self, mock_get, mock_api_key, temp_cache_dir):
"""Test that get_models returns fallback models when groq is not installed.
Both _test_chat_completion and _test_tool_calling re-raise ImportError when
the groq package is absent. get_models catches it and falls back to hardcoded
model metadata instead of crashing.
"""
mock_response = Mock()
mock_response.json.return_value = {"data": [{"id": "llama-3.1-8b-instant", "object": "model"}]}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
with patch.dict(sys.modules, {"groq": None}):
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# Should return the hardcoded fallback list, not an empty dict
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
assert len(models) == 2 # exactly the two fallback models
@patch("groq.Groq")
def test_tool_calling_import_error_raises(self, mock_groq, mock_api_key):
"""Test that ImportError during tool calling test is re-raised."""
mock_groq.side_effect = ImportError("groq module not found")
discovery = GroqModelDiscovery(api_key=mock_api_key)
with pytest.raises(ImportError):
discovery._test_tool_calling("test-model")
@patch("groq.Groq")
def test_tool_calling_rate_limit_returns_none(self, mock_groq, mock_api_key, mock_groq_client_rate_limit):
"""Test that rate limit errors return None (indeterminate)."""
mock_groq.return_value = mock_groq_client_rate_limit()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_tool_calling("test-model")
assert result is None

View File

@@ -0,0 +1,234 @@
"""Tests for successful Groq model discovery operations."""
from unittest.mock import Mock, patch
from lfx.base.models.groq_model_discovery import GroqModelDiscovery
class TestGroqModelDiscoverySuccess:
"""Test successful model discovery operations."""
def test_init_with_api_key(self, mock_api_key):
"""Test initialization with API key."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
assert discovery.api_key == mock_api_key
assert discovery.base_url == "https://api.groq.com"
def test_init_without_api_key(self):
"""Test initialization without API key."""
discovery = GroqModelDiscovery()
assert discovery.api_key is None
assert discovery.base_url == "https://api.groq.com"
def test_init_with_custom_base_url(self, mock_api_key):
"""Test initialization with custom base URL."""
custom_url = "https://custom.groq.com"
discovery = GroqModelDiscovery(api_key=mock_api_key, base_url=custom_url)
assert discovery.base_url == custom_url
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_fetch_available_models_success(
self, mock_groq, mock_get, mock_api_key, mock_groq_models_response, mock_groq_client_tool_calling_success
):
"""Test successfully fetching models from API."""
# Mock API response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_groq_models_response
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling tests
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery._fetch_available_models()
assert isinstance(models, list)
assert len(models) == 8
assert "llama-3.1-8b-instant" in models
assert "whisper-large-v3" in models
mock_get.assert_called_once()
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_get_models_categorizes_llm_and_non_llm(
self,
mock_groq,
mock_get,
mock_api_key,
mock_groq_models_response,
mock_groq_client_tool_calling_success,
temp_cache_dir,
):
"""Test that models are correctly categorized as LLM vs non-LLM."""
# Mock API response
mock_response = Mock()
mock_response.json.return_value = mock_groq_models_response
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling tests to always succeed
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# LLM models should be in the result
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
assert "mixtral-8x7b-32768" in models
assert "gemma-7b-it" in models
# Non-LLM models should be marked as not_supported
assert models["whisper-large-v3"]["not_supported"] is True
assert models["distil-whisper-large-v3-en"]["not_supported"] is True
assert models["meta-llama/llama-guard-4-12b"]["not_supported"] is True
assert models["meta-llama/llama-prompt-guard-2-86m"]["not_supported"] is True
# LLM models should have tool_calling field
assert "tool_calling" in models["llama-3.1-8b-instant"]
assert "tool_calling" in models["mixtral-8x7b-32768"]
@patch("groq.Groq")
def test_tool_calling_detection_success(self, mock_groq, mock_api_key, mock_groq_client_tool_calling_success):
"""Test successful tool calling detection."""
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_tool_calling("llama-3.1-8b-instant")
assert result is True
@patch("groq.Groq")
def test_tool_calling_detection_not_supported(self, mock_groq, mock_api_key, mock_groq_client_tool_calling_failure):
"""Test tool calling detection when model doesn't support tools."""
mock_groq.return_value = mock_groq_client_tool_calling_failure()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_tool_calling("gemma-7b-it")
assert result is False
def test_cache_save_and_load(self, mock_api_key, sample_models_metadata, temp_cache_dir):
"""Test saving and loading cache."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
# Save cache
discovery._save_cache(sample_models_metadata)
# Verify file was created
assert discovery.CACHE_FILE.exists()
# Load cache
loaded = discovery._load_cache()
assert loaded is not None
assert len(loaded) == len(sample_models_metadata)
assert "llama-3.1-8b-instant" in loaded
assert loaded["llama-3.1-8b-instant"]["tool_calling"] is True
def test_cache_respects_expiration(self, mock_api_key, mock_expired_cache_file):
"""Test that expired cache returns None."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_expired_cache_file
loaded = discovery._load_cache()
assert loaded is None
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_get_models_uses_cache_when_available(self, mock_groq, mock_get, mock_api_key, mock_cache_file):
"""Test that get_models uses cache when available and not expired."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_cache_file
models = discovery.get_models(force_refresh=False)
# Should use cache, not call API
mock_get.assert_not_called()
mock_groq.assert_not_called()
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_force_refresh_bypasses_cache(
self,
mock_groq,
mock_get,
mock_api_key,
mock_groq_models_response,
mock_groq_client_tool_calling_success,
mock_cache_file,
):
"""Test that force_refresh bypasses cache and fetches fresh data."""
# Mock API response
mock_response = Mock()
mock_response.json.return_value = mock_groq_models_response
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_cache_file
models = discovery.get_models(force_refresh=True)
# Should call API despite cache
mock_get.assert_called()
assert len(models) > 0
def test_provider_name_extraction(self, mock_api_key):
"""Test provider name extraction from model IDs."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
# Models with slash notation
assert discovery._get_provider_name("meta-llama/llama-3.1-8b") == "Meta"
assert discovery._get_provider_name("openai/gpt-oss-safeguard-20b") == "OpenAI"
assert discovery._get_provider_name("qwen/qwen3-32b") == "Alibaba Cloud"
assert discovery._get_provider_name("moonshotai/moonshot-v1") == "Moonshot AI"
assert discovery._get_provider_name("groq/groq-model") == "Groq"
# Models with prefixes
assert discovery._get_provider_name("llama-3.1-8b-instant") == "Meta"
assert discovery._get_provider_name("llama3-70b-8192") == "Meta"
assert discovery._get_provider_name("qwen-2.5-32b") == "Alibaba Cloud"
assert discovery._get_provider_name("allam-1-13b") == "SDAIA"
# Unknown providers default to Groq
assert discovery._get_provider_name("unknown-model") == "Groq"
def test_skip_patterns(self, mock_api_key):
"""Test that SKIP_PATTERNS correctly identify non-LLM models."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
skip_models = [
"whisper-large-v3",
"whisper-large-v3-turbo",
"distil-whisper-large-v3-en",
"playai-tts",
"playai-tts-arabic",
"meta-llama/llama-guard-4-12b",
"meta-llama/llama-prompt-guard-2-86m",
"openai/gpt-oss-safeguard-20b",
"mistral-saba-24b", # safeguard model
]
for model in skip_models:
should_skip = any(pattern in model.lower() for pattern in discovery.SKIP_PATTERNS)
assert should_skip, f"Model {model} should be skipped but wasn't"
# LLM models should not be skipped
llm_models = ["llama-3.1-8b-instant", "mixtral-8x7b-32768", "gemma-7b-it"]
for model in llm_models:
should_skip = any(pattern in model.lower() for pattern in discovery.SKIP_PATTERNS)
assert not should_skip, f"Model {model} should not be skipped"

View File

@@ -1,541 +0,0 @@
"""Comprehensive tests for Groq model discovery system.
Tests cover:
- Success paths: API fetching, caching, tool calling detection
- Error paths: API failures, network errors, invalid responses
- Edge cases: expired cache, corrupted cache, missing API key
"""
import json
from unittest.mock import MagicMock, Mock, patch
from lfx.base.models.groq_model_discovery import GroqModelDiscovery, get_groq_models
class TestGroqModelDiscoverySuccess:
"""Test successful model discovery operations."""
def test_init_with_api_key(self, mock_api_key):
"""Test initialization with API key."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
assert discovery.api_key == mock_api_key
assert discovery.base_url == "https://api.groq.com"
def test_init_without_api_key(self):
"""Test initialization without API key."""
discovery = GroqModelDiscovery()
assert discovery.api_key is None
assert discovery.base_url == "https://api.groq.com"
def test_init_with_custom_base_url(self, mock_api_key):
"""Test initialization with custom base URL."""
custom_url = "https://custom.groq.com"
discovery = GroqModelDiscovery(api_key=mock_api_key, base_url=custom_url)
assert discovery.base_url == custom_url
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_fetch_available_models_success(
self, mock_groq, mock_get, mock_api_key, mock_groq_models_response, mock_groq_client_tool_calling_success
):
"""Test successfully fetching models from API."""
# Mock API response
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_groq_models_response
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling tests
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery._fetch_available_models()
assert isinstance(models, list)
assert len(models) == 8
assert "llama-3.1-8b-instant" in models
assert "whisper-large-v3" in models
mock_get.assert_called_once()
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_get_models_categorizes_llm_and_non_llm(
self,
mock_groq,
mock_get,
mock_api_key,
mock_groq_models_response,
mock_groq_client_tool_calling_success,
temp_cache_dir,
):
"""Test that models are correctly categorized as LLM vs non-LLM."""
# Mock API response
mock_response = Mock()
mock_response.json.return_value = mock_groq_models_response
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling tests to always succeed
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# LLM models should be in the result
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
assert "mixtral-8x7b-32768" in models
assert "gemma-7b-it" in models
# Non-LLM models should be marked as not_supported
assert models["whisper-large-v3"]["not_supported"] is True
assert models["distil-whisper-large-v3-en"]["not_supported"] is True
assert models["meta-llama/llama-guard-4-12b"]["not_supported"] is True
assert models["meta-llama/llama-prompt-guard-2-86m"]["not_supported"] is True
# LLM models should have tool_calling field
assert "tool_calling" in models["llama-3.1-8b-instant"]
assert "tool_calling" in models["mixtral-8x7b-32768"]
@patch("groq.Groq")
def test_tool_calling_detection_success(self, mock_groq, mock_api_key, mock_groq_client_tool_calling_success):
"""Test successful tool calling detection."""
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_tool_calling("llama-3.1-8b-instant")
assert result is True
@patch("groq.Groq")
def test_tool_calling_detection_not_supported(self, mock_groq, mock_api_key, mock_groq_client_tool_calling_failure):
"""Test tool calling detection when model doesn't support tools."""
mock_groq.return_value = mock_groq_client_tool_calling_failure()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_tool_calling("gemma-7b-it")
assert result is False
def test_cache_save_and_load(self, mock_api_key, sample_models_metadata, temp_cache_dir):
"""Test saving and loading cache."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
# Save cache
discovery._save_cache(sample_models_metadata)
# Verify file was created
assert discovery.CACHE_FILE.exists()
# Load cache
loaded = discovery._load_cache()
assert loaded is not None
assert len(loaded) == len(sample_models_metadata)
assert "llama-3.1-8b-instant" in loaded
assert loaded["llama-3.1-8b-instant"]["tool_calling"] is True
def test_cache_respects_expiration(self, mock_api_key, mock_expired_cache_file):
"""Test that expired cache returns None."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_expired_cache_file
loaded = discovery._load_cache()
assert loaded is None
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_get_models_uses_cache_when_available(self, mock_groq, mock_get, mock_api_key, mock_cache_file):
"""Test that get_models uses cache when available and not expired."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_cache_file
models = discovery.get_models(force_refresh=False)
# Should use cache, not call API
mock_get.assert_not_called()
mock_groq.assert_not_called()
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_force_refresh_bypasses_cache(
self,
mock_groq,
mock_get,
mock_api_key,
mock_groq_models_response,
mock_groq_client_tool_calling_success,
mock_cache_file,
):
"""Test that force_refresh bypasses cache and fetches fresh data."""
# Mock API response
mock_response = Mock()
mock_response.json.return_value = mock_groq_models_response
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_cache_file
models = discovery.get_models(force_refresh=True)
# Should call API despite cache
mock_get.assert_called()
assert len(models) > 0
def test_provider_name_extraction(self, mock_api_key):
"""Test provider name extraction from model IDs."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
# Models with slash notation
assert discovery._get_provider_name("meta-llama/llama-3.1-8b") == "Meta"
assert discovery._get_provider_name("openai/gpt-oss-safeguard-20b") == "OpenAI"
assert discovery._get_provider_name("qwen/qwen3-32b") == "Alibaba Cloud"
assert discovery._get_provider_name("moonshotai/moonshot-v1") == "Moonshot AI"
assert discovery._get_provider_name("groq/groq-model") == "Groq"
# Models with prefixes
assert discovery._get_provider_name("llama-3.1-8b-instant") == "Meta"
assert discovery._get_provider_name("llama3-70b-8192") == "Meta"
assert discovery._get_provider_name("qwen-2.5-32b") == "Alibaba Cloud"
assert discovery._get_provider_name("allam-1-13b") == "SDAIA"
# Unknown providers default to Groq
assert discovery._get_provider_name("unknown-model") == "Groq"
def test_skip_patterns(self, mock_api_key):
"""Test that SKIP_PATTERNS correctly identify non-LLM models."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
skip_models = [
"whisper-large-v3",
"whisper-large-v3-turbo",
"distil-whisper-large-v3-en",
"playai-tts",
"playai-tts-arabic",
"meta-llama/llama-guard-4-12b",
"meta-llama/llama-prompt-guard-2-86m",
"openai/gpt-oss-safeguard-20b",
"mistral-saba-24b", # safeguard model
]
for model in skip_models:
should_skip = any(pattern in model.lower() for pattern in discovery.SKIP_PATTERNS)
assert should_skip, f"Model {model} should be skipped but wasn't"
# LLM models should not be skipped
llm_models = ["llama-3.1-8b-instant", "mixtral-8x7b-32768", "gemma-7b-it"]
for model in llm_models:
should_skip = any(pattern in model.lower() for pattern in discovery.SKIP_PATTERNS)
assert not should_skip, f"Model {model} should not be skipped"
class TestGroqModelDiscoveryErrors:
"""Test error handling in model discovery."""
def test_no_api_key_returns_fallback(self):
"""Test that missing API key returns fallback models."""
discovery = GroqModelDiscovery(api_key=None)
models = discovery.get_models(force_refresh=True)
# Should return minimal fallback list
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
assert len(models) == 2
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_api_connection_error_returns_fallback(self, mock_get, mock_api_key, mock_requests_get_failure):
"""Test that API connection errors return fallback models."""
mock_get.side_effect = mock_requests_get_failure
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_api_timeout_returns_fallback(self, mock_get, mock_api_key, mock_requests_get_timeout):
"""Test that API timeouts return fallback models."""
mock_get.side_effect = mock_requests_get_timeout
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_api_unauthorized_returns_fallback(self, mock_get, mock_api_key, mock_requests_get_unauthorized):
"""Test that unauthorized API requests return fallback models."""
mock_get.side_effect = mock_requests_get_unauthorized
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
assert "llama-3.3-70b-versatile" in models
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_invalid_api_response_returns_fallback(self, mock_get, mock_api_key):
"""Test that invalid API response structure returns fallback models."""
# Mock response with missing 'data' field
mock_response = Mock()
mock_response.json.return_value = {"error": "invalid"}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
discovery = GroqModelDiscovery(api_key=mock_api_key)
models = discovery.get_models(force_refresh=True)
# Should return fallback models
assert "llama-3.1-8b-instant" in models
def test_corrupted_cache_returns_none(self, mock_api_key, mock_corrupted_cache_file):
"""Test that corrupted cache file returns None."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = mock_corrupted_cache_file
loaded = discovery._load_cache()
assert loaded is None
def test_cache_missing_fields_returns_none(self, mock_api_key, temp_cache_dir):
"""Test that cache with missing required fields returns None."""
cache_file = temp_cache_dir / ".cache" / "invalid_cache.json"
cache_file.parent.mkdir(parents=True, exist_ok=True)
# Cache missing 'cached_at' field
cache_data = {"models": {"llama-3.1-8b-instant": {}}}
with cache_file.open("w") as f:
json.dump(cache_data, f)
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = cache_file
loaded = discovery._load_cache()
assert loaded is None
def test_cache_save_failure_logs_warning(self, mock_api_key, temp_cache_dir, sample_models_metadata):
"""Test that cache save failures are logged but don't crash."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
# Set cache file to a path that can't be written (directory instead of file)
discovery.CACHE_FILE = temp_cache_dir
# This should not raise an exception
discovery._save_cache(sample_models_metadata)
@patch("groq.Groq")
def test_tool_calling_import_error_returns_false(self, mock_groq, mock_api_key):
"""Test that ImportError during tool calling test returns False."""
mock_groq.side_effect = ImportError("groq module not found")
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_tool_calling("test-model")
assert result is False
@patch("groq.Groq")
def test_tool_calling_rate_limit_returns_false(self, mock_groq, mock_api_key, mock_groq_client_rate_limit):
"""Test that rate limit errors return False conservatively."""
mock_groq.return_value = mock_groq_client_rate_limit()
discovery = GroqModelDiscovery(api_key=mock_api_key)
result = discovery._test_tool_calling("test-model")
assert result is False
class TestGroqModelDiscoveryEdgeCases:
"""Test edge cases in model discovery."""
@patch("lfx.base.models.groq_model_discovery.requests.get")
def test_empty_model_list_from_api(self, mock_get, mock_api_key, temp_cache_dir):
"""Test handling of empty model list from API."""
# Mock empty response
mock_response = Mock()
mock_response.json.return_value = {"data": []}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# Should return empty dict (or potentially fallback)
assert isinstance(models, dict)
def test_cache_file_not_exists(self, mock_api_key, temp_cache_dir):
"""Test loading cache when file doesn't exist."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "nonexistent.json"
loaded = discovery._load_cache()
assert loaded is None
def test_cache_directory_created_on_save(self, mock_api_key, temp_cache_dir, sample_models_metadata):
"""Test that cache directory is created if it doesn't exist."""
cache_file = temp_cache_dir / "new_dir" / ".cache" / "test_cache.json"
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = cache_file
# Directory shouldn't exist yet
assert not cache_file.parent.exists()
# Save cache
discovery._save_cache(sample_models_metadata)
# Directory should be created
assert cache_file.parent.exists()
assert cache_file.exists()
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_preview_model_detection(
self,
mock_groq,
mock_get,
mock_api_key,
mock_groq_client_tool_calling_success,
temp_cache_dir,
):
"""Test detection of preview models."""
# Mock API with preview models
mock_response = Mock()
mock_response.json.return_value = {
"data": [
{"id": "llama-3.2-1b-preview", "object": "model"},
{"id": "meta-llama/llama-3.2-90b-preview", "object": "model"},
]
}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
mock_groq.return_value = mock_groq_client_tool_calling_success()
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# Models with "preview" in name should be marked as preview
assert models["llama-3.2-1b-preview"]["preview"] is True
# Models with "/" should be marked as preview
assert models["meta-llama/llama-3.2-90b-preview"]["preview"] is True
@patch("lfx.base.models.groq_model_discovery.requests.get")
@patch("groq.Groq")
def test_mixed_tool_calling_support(
self,
mock_groq,
mock_get,
mock_api_key,
temp_cache_dir,
):
"""Test models with mixed tool calling support."""
# Mock API
mock_response = Mock()
mock_response.json.return_value = {
"data": [
{"id": "llama-3.1-8b-instant", "object": "model"},
{"id": "gemma-7b-it", "object": "model"},
]
}
mock_response.raise_for_status = Mock()
mock_get.return_value = mock_response
# Mock tool calling - first succeeds, second fails
call_count = [0]
def create_mock_client(*_args, **_kwargs):
mock_client = MagicMock()
if call_count[0] == 0:
# First call succeeds
mock_client.chat.completions.create.return_value = MagicMock()
else:
# Second call fails with tool error
mock_client.chat.completions.create.side_effect = ValueError("tool calling not supported")
call_count[0] += 1
return mock_client
mock_groq.side_effect = create_mock_client
discovery = GroqModelDiscovery(api_key=mock_api_key)
discovery.CACHE_FILE = temp_cache_dir / ".cache" / "test_cache.json"
models = discovery.get_models(force_refresh=True)
# First model should support tools
assert models["llama-3.1-8b-instant"]["tool_calling"] is True
# Second model should not support tools
assert models["gemma-7b-it"]["tool_calling"] is False
def test_fallback_models_structure(self, mock_api_key):
"""Test that fallback models have the correct structure."""
discovery = GroqModelDiscovery(api_key=mock_api_key)
fallback = discovery._get_fallback_models()
assert isinstance(fallback, dict)
assert len(fallback) == 2
for metadata in fallback.values():
assert "name" in metadata
assert "provider" in metadata
assert "tool_calling" in metadata
assert "preview" in metadata
assert metadata["tool_calling"] is True # Fallback models should support tools
class TestGetGroqModelsConvenienceFunction:
"""Test the convenience function get_groq_models()."""
@patch.object(GroqModelDiscovery, "get_models")
def test_get_groq_models_with_api_key(self, mock_get_models, mock_api_key):
"""Test get_groq_models() function with API key."""
mock_get_models.return_value = {"llama-3.1-8b-instant": {}}
models = get_groq_models(api_key=mock_api_key)
assert "llama-3.1-8b-instant" in models
mock_get_models.assert_called_once_with(force_refresh=False)
@patch.object(GroqModelDiscovery, "get_models")
def test_get_groq_models_without_api_key(self, mock_get_models):
"""Test get_groq_models() function without API key."""
mock_get_models.return_value = {"llama-3.1-8b-instant": {}}
models = get_groq_models()
assert "llama-3.1-8b-instant" in models
mock_get_models.assert_called_once_with(force_refresh=False)
@patch.object(GroqModelDiscovery, "get_models")
def test_get_groq_models_force_refresh(self, mock_get_models, mock_api_key):
"""Test get_groq_models() with force_refresh."""
mock_get_models.return_value = {"llama-3.1-8b-instant": {}}
models = get_groq_models(api_key=mock_api_key, force_refresh=True)
assert "llama-3.1-8b-instant" in models
mock_get_models.assert_called_once_with(force_refresh=True)

File diff suppressed because one or more lines are too long

View File

@@ -911,7 +911,7 @@
},
"GroqModel": {
"versions": {
"0.3.0": "2aee7de6ee88"
"0.3.0": "8a55d3fa8173"
}
},
"HomeAssistantControl": {

View File

@@ -22,8 +22,11 @@ class GroqModelDiscovery:
CACHE_FILE = Path(__file__).parent / ".cache" / "groq_models_cache.json"
CACHE_DURATION = timedelta(hours=24) # Refresh cache every 24 hours
# Models to skip from LLM list (audio, TTS, guards)
SKIP_PATTERNS = ["whisper", "tts", "guard", "safeguard", "prompt-guard", "saba"]
# Models to skip from LLM list (audio, TTS, guards, speech)
SKIP_PATTERNS = ["whisper", "tts", "guard", "safeguard", "prompt-guard", "saba", "orpheus", "playai"]
# Phrases that indicate an access/entitlement error rather than a capability error
ACCESS_ERROR_PHRASES = ["terms acceptance", "terms_required", "model_terms_required", "not available"]
def __init__(self, api_key: str | None = None, base_url: str = "https://api.groq.com"):
"""Initialize discovery with optional API key for testing.
@@ -83,10 +86,23 @@ class GroqModelDiscovery:
else:
llm_models.append(model_id)
# Step 3: Test LLM models for tool calling
logger.info(f"Testing {len(llm_models)} LLM models for tool calling support...")
# Step 3: Test LLM models for chat completion and tool calling
logger.info(f"Testing {len(llm_models)} LLM models for capabilities...")
for model_id in llm_models:
supports_chat = self._test_chat_completion(model_id)
if supports_chat is False:
# Model doesn't support chat completions at all (e.g. speech models)
non_llm_models.append(model_id)
logger.debug(f"{model_id}: does not support chat completions, skipping")
continue
if supports_chat is None:
# Transient/access error - assume chat is supported (benefit of the doubt)
logger.info(f"{model_id}: chat test indeterminate, assuming chat supported")
supports_tools = self._test_tool_calling(model_id)
if supports_tools is None:
# Transient/access error on tool test - skip to avoid caching a false negative
logger.info(f"{model_id}: tool test indeterminate, skipping (will retry next refresh)")
continue
models_metadata[model_id] = {
"name": model_id,
"provider": self._get_provider_name(model_id),
@@ -108,12 +124,16 @@ class GroqModelDiscovery:
# Save to cache
self._save_cache(models_metadata)
except (requests.RequestException, KeyError, ValueError, ImportError) as e:
logger.exception(f"Error discovering models: {e}")
except (requests.RequestException, KeyError, ValueError, ImportError):
logger.exception("Error discovering models")
return self._get_fallback_models()
else:
return models_metadata
def _is_access_error(self, error_msg: str) -> bool:
"""Return True if the lowercased error message indicates an access/entitlement issue."""
return any(phrase in error_msg for phrase in self.ACCESS_ERROR_PHRASES)
def _fetch_available_models(self) -> list[str]:
"""Fetch list of available models from Groq API."""
url = f"{self.base_url}/openai/v1/models"
@@ -126,19 +146,66 @@ class GroqModelDiscovery:
# Use direct access to raise KeyError if 'data' is missing
return [model["id"] for model in model_list["data"]]
def _test_tool_calling(self, model_id: str) -> bool:
def _test_chat_completion(self, model_id: str) -> bool | None:
"""Test if a model supports basic chat completions.
This filters out non-chat models (e.g. TTS, speech, embedding models)
that appear in the API model list but cannot handle chat requests.
Args:
model_id: The model ID to test
Returns:
True if model supports chat completions, False if it does not,
None if the result is indeterminate (transient/access errors).
"""
try:
import groq
client = groq.Groq(api_key=self.api_key, base_url=self.base_url)
messages = [{"role": "user", "content": "test"}]
client.chat.completions.create(model=model_id, messages=messages, max_tokens=1)
except ImportError:
logger.warning("groq package not installed, cannot test chat completion")
# Propagate the ImportError so callers can fall back to hardcoded model metadata
raise
except Exception as e: # noqa: BLE001
# The groq SDK does not expose a stable public exception hierarchy: errors can arrive as
# groq.APIStatusError, groq.BadRequestError, plain ValueError, or even undocumented
# runtime exceptions depending on the SDK version and the model being probed. We
# therefore catch Exception broadly and discriminate solely on the error message text,
# which is the only reliable signal available across SDK versions.
error_msg = str(e).lower()
# Genuine capability error: model does not support chat completions
if "does not support chat completions" in error_msg:
logger.debug(f"{model_id}: does not support chat completions")
return False
# Access/entitlement errors: model likely supports chat but is not accessible for this key
if self._is_access_error(error_msg):
logger.info(f"{model_id}: chat completion not accessible for this API key ({e})")
# Do not mark the model as non-chat; assume chat is supported but not usable with this key
return None
# Other errors (rate limits, transient failures) - indeterminate
logger.warning(f"Error testing chat for {model_id}: {e}")
return None
else:
return True
def _test_tool_calling(self, model_id: str) -> bool | None:
"""Test if a model supports tool calling.
Args:
model_id: The model ID to test
Returns:
True if model supports tool calling, False otherwise
True if model supports tool calling, False if it does not,
None if the result is indeterminate (transient/access errors).
"""
try:
import groq
client = groq.Groq(api_key=self.api_key)
client = groq.Groq(api_key=self.api_key, base_url=self.base_url)
# Simple tool definition
tools = [
@@ -163,14 +230,24 @@ class GroqModelDiscovery:
model=model_id, messages=messages, tools=tools, tool_choice="auto", max_tokens=10
)
except (ImportError, AttributeError, TypeError, ValueError, RuntimeError, KeyError) as e:
except ImportError:
logger.warning("groq package not installed, cannot test tool calling")
raise
except Exception as e: # noqa: BLE001
# Same rationale as _test_chat_completion: the groq SDK's exception types are not
# stable across versions, so broad catching with message-based discrimination is the
# only portable approach. See _test_chat_completion for a full explanation.
error_msg = str(e).lower()
# If error mentions tool calling, model doesn't support it
# Genuine capability error: model does not support tools
if "tool" in error_msg:
return False
# Other errors might be rate limits, etc - be conservative
logger.warning(f"Error testing {model_id}: {e}")
return False
# Access/entitlement errors: model may support tools but is not accessible for this key
if self._is_access_error(error_msg):
logger.info(f"{model_id}: tool calling not testable for this API key ({e})")
return None
# Any other API error (rate limits, transient failures, etc) - indeterminate
logger.warning(f"Error testing tool calling for {model_id}: {e}")
return None
else:
return True

View File

@@ -101,8 +101,8 @@ class GroqModel(LCModelComponent):
logger.info(f"Loaded {len(model_ids)} Groq models with tool calling support")
else:
logger.info(f"Loaded {len(model_ids)} Groq models")
except (ValueError, KeyError, TypeError, ImportError) as e:
logger.exception(f"Error getting model names: {e}")
except (ValueError, KeyError, TypeError, ImportError):
logger.exception("Error getting model names")
# Fallback to hardcoded list from groq_constants.py
return GROQ_MODELS
else:
@@ -114,9 +114,10 @@ class GroqModel(LCModelComponent):
if len(self.api_key) != 0:
try:
ids = self.get_models(tool_model_enabled=self.tool_model_enabled)
except (ValueError, KeyError, TypeError, ImportError) as e:
logger.exception(f"Error getting model names: {e}")
except (ValueError, KeyError, TypeError, ImportError):
logger.exception("Error getting model names")
ids = GROQ_MODELS
ids = ids or GROQ_MODELS
build_config.setdefault("model_name", {})
build_config["model_name"]["options"] = ids
build_config["model_name"].setdefault("value", ids[0])