From 110df6b3526f71aa151d4ef069c05166badfb84a Mon Sep 17 00:00:00 2001 From: Mateusz Szewczyk <139469471+MateuszOssGit@users.noreply.github.com> Date: Mon, 9 Mar 2026 17:15:46 +0100 Subject: [PATCH] chore: Added support for `space_id` scope in `WatsonxAIComponent` (#11732) * chore: Added support for `space_id` scope in `WatsonxAIComponent` * [autofix.ci] apply automated fixes * chore: fix tests * chore: Added support for `space_id` scope in `WatsonxAIComponent` v2 * [autofix.ci] apply automated fixes * CR fix * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes (attempt 2/3) * [autofix.ci] apply automated fixes (attempt 3/3) * [autofix.ci] apply automated fixes * fix unit tests * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: himavarshagoutham --- src/lfx/src/lfx/_assets/component_index.json | 37 +- .../src/lfx/_assets/stable_hash_history.json | 2 +- src/lfx/src/lfx/components/ibm/watsonx.py | 28 +- src/lfx/tests/unit/components/ibm/__init__.py | 0 .../tests/unit/components/ibm/test_watsonx.py | 468 ++++++++++++++++++ 5 files changed, 519 insertions(+), 16 deletions(-) create mode 100644 src/lfx/tests/unit/components/ibm/__init__.py create mode 100644 src/lfx/tests/unit/components/ibm/test_watsonx.py diff --git a/src/lfx/src/lfx/_assets/component_index.json b/src/lfx/src/lfx/_assets/component_index.json index c7df0ecfc..c9357cd95 100644 --- a/src/lfx/src/lfx/_assets/component_index.json +++ b/src/lfx/src/lfx/_assets/component_index.json @@ -76390,6 +76390,7 @@ "stream", "base_url", "project_id", + "space_id", "api_key", "model_name", "max_tokens", @@ -76407,7 +76408,7 @@ "icon": "WatsonxAI", "legacy": false, "metadata": { - "code_hash": "689a87ecc73a", + "code_hash": "e7d46a4de547", "dependencies": { "dependencies": [ { @@ -76503,7 +76504,8 @@ "https://eu-gb.ml.cloud.ibm.com", "https://au-syd.ml.cloud.ibm.com", "https://jp-tok.ml.cloud.ibm.com", - "https://ca-tor.ml.cloud.ibm.com" + "https://ca-tor.ml.cloud.ibm.com", + "https://ap-south-1.aws.wxai.ibm.com" ], "options_metadata": [], "override_skip": false, @@ -76535,7 +76537,7 @@ "show": true, "title_case": false, "type": "code", - "value": "import json\nfrom typing import Any\n\nfrom langchain_ibm import ChatWatsonx\nfrom pydantic.v1 import SecretStr\n\nfrom lfx.base.models.model import LCModelComponent\nfrom lfx.base.models.model_utils import get_watsonx_llm_models\nfrom lfx.field_typing import LanguageModel\nfrom lfx.field_typing.range_spec import RangeSpec\nfrom lfx.inputs.inputs import BoolInput, DropdownInput, IntInput, SecretStrInput, SliderInput, StrInput\nfrom lfx.log.logger import logger\nfrom lfx.schema.dotdict import dotdict\n\n\nclass WatsonxAIComponent(LCModelComponent):\n display_name = \"IBM watsonx.ai\"\n description = \"Generate text using IBM watsonx.ai foundation models.\"\n icon = \"WatsonxAI\"\n name = \"IBMwatsonxModel\"\n beta = False\n\n _default_models = [\"ibm/granite-3-2b-instruct\", \"ibm/granite-3-8b-instruct\", \"ibm/granite-13b-instruct-v2\"]\n _urls = [\n \"https://us-south.ml.cloud.ibm.com\",\n \"https://eu-de.ml.cloud.ibm.com\",\n \"https://eu-gb.ml.cloud.ibm.com\",\n \"https://au-syd.ml.cloud.ibm.com\",\n \"https://jp-tok.ml.cloud.ibm.com\",\n \"https://ca-tor.ml.cloud.ibm.com\",\n ]\n inputs = [\n *LCModelComponent.get_base_inputs(),\n DropdownInput(\n name=\"base_url\",\n display_name=\"watsonx API Endpoint\",\n info=\"The base URL of the API.\",\n value=[],\n options=_urls,\n real_time_refresh=True,\n required=True,\n ),\n StrInput(\n name=\"project_id\",\n display_name=\"watsonx Project ID\",\n required=True,\n info=\"The project ID or deployment space ID that is associated with the foundation model.\",\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Watsonx API Key\",\n info=\"The API Key to use for the model.\",\n required=True,\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=[],\n value=None,\n real_time_refresh=True,\n required=True,\n refresh_button=True,\n ),\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n info=\"The maximum number of tokens to generate.\",\n range_spec=RangeSpec(min=1, max=4096),\n value=1000,\n ),\n StrInput(\n name=\"stop_sequence\",\n display_name=\"Stop Sequence\",\n advanced=True,\n info=\"Sequence where generation should stop.\",\n field_type=\"str\",\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n info=\"Controls randomness, higher values increase diversity.\",\n value=0.1,\n range_spec=RangeSpec(min=0, max=2, step=0.01),\n advanced=True,\n ),\n SliderInput(\n name=\"top_p\",\n display_name=\"Top P\",\n info=\"The cumulative probability cutoff for token selection. \"\n \"Lower values mean sampling from a smaller, more top-weighted nucleus.\",\n value=0.9,\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n advanced=True,\n ),\n SliderInput(\n name=\"frequency_penalty\",\n display_name=\"Frequency Penalty\",\n info=\"Penalty for frequency of token usage.\",\n value=0.5,\n range_spec=RangeSpec(min=-2.0, max=2.0, step=0.01),\n advanced=True,\n ),\n SliderInput(\n name=\"presence_penalty\",\n display_name=\"Presence Penalty\",\n info=\"Penalty for token presence in prior text.\",\n value=0.3,\n range_spec=RangeSpec(min=-2.0, max=2.0, step=0.01),\n advanced=True,\n ),\n IntInput(\n name=\"seed\",\n display_name=\"Random Seed\",\n advanced=True,\n info=\"The random seed for the model.\",\n value=8,\n ),\n BoolInput(\n name=\"logprobs\",\n display_name=\"Log Probabilities\",\n advanced=True,\n info=\"Whether to return log probabilities of the output tokens.\",\n value=True,\n ),\n IntInput(\n name=\"top_logprobs\",\n display_name=\"Top Log Probabilities\",\n advanced=True,\n info=\"Number of most likely tokens to return at each position.\",\n value=3,\n range_spec=RangeSpec(min=1, max=20),\n ),\n StrInput(\n name=\"logit_bias\",\n display_name=\"Logit Bias\",\n advanced=True,\n info='JSON string of token IDs to bias or suppress (e.g., {\"1003\": -100, \"1004\": 100}).',\n field_type=\"str\",\n ),\n ]\n\n @staticmethod\n def fetch_models(base_url: str) -> list[str]:\n \"\"\"Fetch available models from the watsonx.ai API.\n\n Uses centralized model fetching from model_utils.\n \"\"\"\n return get_watsonx_llm_models(base_url, default_models=WatsonxAIComponent._default_models)\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n \"\"\"Update model options when URL or API key changes.\"\"\"\n if field_name == \"base_url\" and field_value:\n try:\n models = self.fetch_models(base_url=field_value)\n build_config[\"model_name\"][\"options\"] = models\n if build_config[\"model_name\"][\"value\"]:\n build_config[\"model_name\"][\"value\"] = models[0]\n info_message = f\"Updated model options: {len(models)} models found in {field_value}\"\n logger.info(info_message)\n except Exception: # noqa: BLE001\n logger.exception(\"Error updating model options.\")\n if field_name == \"model_name\" and field_value and field_value in WatsonxAIComponent._urls:\n build_config[\"model_name\"][\"options\"] = self.fetch_models(base_url=field_value)\n build_config[\"model_name\"][\"value\"] = \"\"\n return build_config\n\n def build_model(self) -> LanguageModel:\n # Parse logit_bias from JSON string if provided\n logit_bias = None\n if hasattr(self, \"logit_bias\") and self.logit_bias:\n try:\n logit_bias = json.loads(self.logit_bias)\n except json.JSONDecodeError:\n logger.warning(\"Invalid logit_bias JSON format. Using default instead.\")\n logit_bias = {\"1003\": -100, \"1004\": -100}\n\n chat_params = {\n \"max_tokens\": getattr(self, \"max_tokens\", None),\n \"temperature\": getattr(self, \"temperature\", None),\n \"top_p\": getattr(self, \"top_p\", None),\n \"frequency_penalty\": getattr(self, \"frequency_penalty\", None),\n \"presence_penalty\": getattr(self, \"presence_penalty\", None),\n \"seed\": getattr(self, \"seed\", None),\n \"stop\": [self.stop_sequence] if self.stop_sequence else [],\n \"n\": 1,\n \"logprobs\": getattr(self, \"logprobs\", True),\n \"top_logprobs\": getattr(self, \"top_logprobs\", None),\n \"time_limit\": 600000,\n \"logit_bias\": logit_bias,\n }\n\n # Pass API key as plain string to avoid SecretStr serialization issues\n # when model is configured with with_config() or used in batch operations\n api_key_value = self.api_key\n if isinstance(api_key_value, SecretStr):\n api_key_value = api_key_value.get_secret_value()\n\n return ChatWatsonx(\n apikey=api_key_value,\n url=self.base_url,\n project_id=self.project_id,\n model_id=self.model_name,\n params=chat_params,\n streaming=self.stream,\n )\n" + "value": "import json\nfrom typing import Any\n\nfrom langchain_ibm import ChatWatsonx\nfrom pydantic.v1 import SecretStr\n\nfrom lfx.base.models.model import LCModelComponent\nfrom lfx.base.models.model_utils import get_watsonx_llm_models\nfrom lfx.field_typing import LanguageModel\nfrom lfx.field_typing.range_spec import RangeSpec\nfrom lfx.inputs.inputs import BoolInput, DropdownInput, IntInput, SecretStrInput, SliderInput, StrInput\nfrom lfx.log.logger import logger\nfrom lfx.schema.dotdict import dotdict\n\n\nclass WatsonxAIComponent(LCModelComponent):\n \"\"\"LFX component for IBM watsonx.ai text/chat generation.\"\"\"\n\n display_name = \"IBM watsonx.ai\"\n description = \"Generate text using IBM watsonx.ai foundation models.\"\n icon = \"WatsonxAI\"\n name = \"IBMwatsonxModel\"\n beta = False\n\n _default_models = [\"ibm/granite-3-2b-instruct\", \"ibm/granite-3-8b-instruct\", \"ibm/granite-13b-instruct-v2\"]\n _urls = [\n \"https://us-south.ml.cloud.ibm.com\",\n \"https://eu-de.ml.cloud.ibm.com\",\n \"https://eu-gb.ml.cloud.ibm.com\",\n \"https://au-syd.ml.cloud.ibm.com\",\n \"https://jp-tok.ml.cloud.ibm.com\",\n \"https://ca-tor.ml.cloud.ibm.com\",\n \"https://ap-south-1.aws.wxai.ibm.com\",\n ]\n inputs = [\n *LCModelComponent.get_base_inputs(),\n DropdownInput(\n name=\"base_url\",\n display_name=\"watsonx API Endpoint\",\n info=\"The base URL of the API.\",\n value=[],\n options=_urls,\n real_time_refresh=True,\n required=True,\n ),\n StrInput(\n name=\"project_id\",\n display_name=\"watsonx Project_ID\",\n required=False,\n info=\"The project ID associated with the foundation model.\",\n ),\n StrInput(\n name=\"space_id\",\n display_name=\"watsonx Space_ID\",\n required=False,\n info=\"The deployment space ID associated with the foundation model.\",\n ),\n SecretStrInput(\n name=\"api_key\",\n display_name=\"Watsonx API Key\",\n info=\"The API Key to use for the model.\",\n required=True,\n ),\n DropdownInput(\n name=\"model_name\",\n display_name=\"Model Name\",\n options=[],\n value=None,\n real_time_refresh=True,\n required=True,\n refresh_button=True,\n ),\n IntInput(\n name=\"max_tokens\",\n display_name=\"Max Tokens\",\n advanced=True,\n info=\"The maximum number of tokens to generate.\",\n range_spec=RangeSpec(min=1, max=4096),\n value=1000,\n ),\n StrInput(\n name=\"stop_sequence\",\n display_name=\"Stop Sequence\",\n advanced=True,\n info=\"Sequence where generation should stop.\",\n field_type=\"str\",\n ),\n SliderInput(\n name=\"temperature\",\n display_name=\"Temperature\",\n info=\"Controls randomness, higher values increase diversity.\",\n value=0.1,\n range_spec=RangeSpec(min=0, max=2, step=0.01),\n advanced=True,\n ),\n SliderInput(\n name=\"top_p\",\n display_name=\"Top P\",\n info=\"The cumulative probability cutoff for token selection. \"\n \"Lower values mean sampling from a smaller, more top-weighted nucleus.\",\n value=0.9,\n range_spec=RangeSpec(min=0, max=1, step=0.01),\n advanced=True,\n ),\n SliderInput(\n name=\"frequency_penalty\",\n display_name=\"Frequency Penalty\",\n info=\"Penalty for frequency of token usage.\",\n value=0.5,\n range_spec=RangeSpec(min=-2.0, max=2.0, step=0.01),\n advanced=True,\n ),\n SliderInput(\n name=\"presence_penalty\",\n display_name=\"Presence Penalty\",\n info=\"Penalty for token presence in prior text.\",\n value=0.3,\n range_spec=RangeSpec(min=-2.0, max=2.0, step=0.01),\n advanced=True,\n ),\n IntInput(\n name=\"seed\",\n display_name=\"Random Seed\",\n advanced=True,\n info=\"The random seed for the model.\",\n value=8,\n ),\n BoolInput(\n name=\"logprobs\",\n display_name=\"Log Probabilities\",\n advanced=True,\n info=\"Whether to return log probabilities of the output tokens.\",\n value=True,\n ),\n IntInput(\n name=\"top_logprobs\",\n display_name=\"Top Log Probabilities\",\n advanced=True,\n info=\"Number of most likely tokens to return at each position.\",\n value=3,\n range_spec=RangeSpec(min=1, max=20),\n ),\n StrInput(\n name=\"logit_bias\",\n display_name=\"Logit Bias\",\n advanced=True,\n info='JSON string of token IDs to bias or suppress (e.g., {\"1003\": -100, \"1004\": 100}).',\n field_type=\"str\",\n ),\n ]\n\n @staticmethod\n def fetch_models(base_url: str) -> list[str]:\n \"\"\"Fetch available models from the watsonx.ai API.\n\n Uses centralized model fetching from model_utils.\n \"\"\"\n return get_watsonx_llm_models(base_url, default_models=WatsonxAIComponent._default_models)\n\n def update_build_config(self, build_config: dotdict, field_value: Any, field_name: str | None = None):\n \"\"\"Update model options when URL or API key changes.\"\"\"\n if field_name == \"base_url\" and field_value:\n try:\n models = self.fetch_models(base_url=field_value)\n build_config[\"model_name\"][\"options\"] = models\n if build_config[\"model_name\"].get(\"value\") not in models:\n build_config[\"model_name\"][\"value\"] = models[0] if models else None\n info_message = f\"Updated model options: {len(models)} models found in {field_value}\"\n logger.info(info_message)\n except Exception: # noqa: BLE001\n logger.exception(\"Error updating model options.\")\n\n return build_config\n\n def build_model(self) -> LanguageModel:\n # Parse logit_bias from JSON string if provided\n logit_bias = None\n if hasattr(self, \"logit_bias\") and self.logit_bias:\n try:\n logit_bias = json.loads(self.logit_bias)\n except json.JSONDecodeError:\n logger.warning(\"Invalid logit_bias JSON format. Using default instead.\")\n logit_bias = {\"1003\": -100, \"1004\": -100}\n\n chat_params = {\n \"max_tokens\": getattr(self, \"max_tokens\", None),\n \"temperature\": getattr(self, \"temperature\", None),\n \"top_p\": getattr(self, \"top_p\", None),\n \"frequency_penalty\": getattr(self, \"frequency_penalty\", None),\n \"presence_penalty\": getattr(self, \"presence_penalty\", None),\n \"seed\": getattr(self, \"seed\", None),\n \"stop\": [self.stop_sequence] if self.stop_sequence else [],\n \"n\": 1,\n \"logprobs\": getattr(self, \"logprobs\", True),\n \"top_logprobs\": getattr(self, \"top_logprobs\", None),\n \"time_limit\": 600000,\n \"logit_bias\": logit_bias,\n }\n\n # Pass API key as plain string to avoid SecretStr serialization issues\n # when model is configured with with_config() or used in batch operations\n api_key_value = self.api_key\n if isinstance(api_key_value, SecretStr):\n api_key_value = api_key_value.get_secret_value()\n\n if bool(self.space_id) == bool(self.project_id):\n msg = \"Exactly one of Project_ID or Space_ID must be selected\"\n raise ValueError(msg)\n\n return ChatWatsonx(\n apikey=api_key_value,\n url=self.base_url,\n project_id=self.project_id,\n space_id=self.space_id,\n model_id=self.model_name,\n params=chat_params,\n streaming=self.stream,\n )\n" }, "frequency_penalty": { "_input_type": "SliderInput", @@ -76717,16 +76719,16 @@ "project_id": { "_input_type": "StrInput", "advanced": false, - "display_name": "watsonx Project ID", + "display_name": "watsonx Project_ID", "dynamic": false, - "info": "The project ID or deployment space ID that is associated with the foundation model.", + "info": "The project ID associated with the foundation model.", "list": false, "list_add_label": "Add More", "load_from_db": false, "name": "project_id", "override_skip": false, "placeholder": "", - "required": true, + "required": false, "show": true, "title_case": false, "tool_mode": false, @@ -76755,6 +76757,27 @@ "type": "int", "value": 8 }, + "space_id": { + "_input_type": "StrInput", + "advanced": false, + "display_name": "watsonx Space_ID", + "dynamic": false, + "info": "The deployment space ID associated with the foundation model.", + "list": false, + "list_add_label": "Add More", + "load_from_db": false, + "name": "space_id", + "override_skip": false, + "placeholder": "", + "required": false, + "show": true, + "title_case": false, + "tool_mode": false, + "trace_as_metadata": true, + "track_in_telemetry": false, + "type": "str", + "value": "" + }, "stop_sequence": { "_input_type": "StrInput", "advanced": true, @@ -118266,6 +118289,6 @@ "num_components": 359, "num_modules": 97 }, - "sha256": "0f9bfcab9c8747258f998b115b6975919f27938dff6112f891b2ed208a4fcefc", + "sha256": "d927edd2abfc5eb28e0b7338b0c0342eb5cb687a9a464f995c41bf40517e234d", "version": "0.3.0" } \ No newline at end of file diff --git a/src/lfx/src/lfx/_assets/stable_hash_history.json b/src/lfx/src/lfx/_assets/stable_hash_history.json index 511355135..c877021c9 100644 --- a/src/lfx/src/lfx/_assets/stable_hash_history.json +++ b/src/lfx/src/lfx/_assets/stable_hash_history.json @@ -936,7 +936,7 @@ }, "IBMwatsonxModel": { "versions": { - "0.3.0": "689a87ecc73a" + "0.3.0": "e7d46a4de547" } }, "WatsonxEmbeddingsComponent": { diff --git a/src/lfx/src/lfx/components/ibm/watsonx.py b/src/lfx/src/lfx/components/ibm/watsonx.py index 324cf7b59..1aa223454 100644 --- a/src/lfx/src/lfx/components/ibm/watsonx.py +++ b/src/lfx/src/lfx/components/ibm/watsonx.py @@ -14,6 +14,8 @@ from lfx.schema.dotdict import dotdict class WatsonxAIComponent(LCModelComponent): + """LFX component for IBM watsonx.ai text/chat generation.""" + display_name = "IBM watsonx.ai" description = "Generate text using IBM watsonx.ai foundation models." icon = "WatsonxAI" @@ -28,6 +30,7 @@ class WatsonxAIComponent(LCModelComponent): "https://au-syd.ml.cloud.ibm.com", "https://jp-tok.ml.cloud.ibm.com", "https://ca-tor.ml.cloud.ibm.com", + "https://ap-south-1.aws.wxai.ibm.com", ] inputs = [ *LCModelComponent.get_base_inputs(), @@ -42,9 +45,15 @@ class WatsonxAIComponent(LCModelComponent): ), StrInput( name="project_id", - display_name="watsonx Project ID", - required=True, - info="The project ID or deployment space ID that is associated with the foundation model.", + display_name="watsonx Project_ID", + required=False, + info="The project ID associated with the foundation model.", + ), + StrInput( + name="space_id", + display_name="watsonx Space_ID", + required=False, + info="The deployment space ID associated with the foundation model.", ), SecretStrInput( name="api_key", @@ -154,15 +163,13 @@ class WatsonxAIComponent(LCModelComponent): try: models = self.fetch_models(base_url=field_value) build_config["model_name"]["options"] = models - if build_config["model_name"]["value"]: - build_config["model_name"]["value"] = models[0] + if build_config["model_name"].get("value") not in models: + build_config["model_name"]["value"] = models[0] if models else None info_message = f"Updated model options: {len(models)} models found in {field_value}" logger.info(info_message) except Exception: # noqa: BLE001 logger.exception("Error updating model options.") - if field_name == "model_name" and field_value and field_value in WatsonxAIComponent._urls: - build_config["model_name"]["options"] = self.fetch_models(base_url=field_value) - build_config["model_name"]["value"] = "" + return build_config def build_model(self) -> LanguageModel: @@ -196,10 +203,15 @@ class WatsonxAIComponent(LCModelComponent): if isinstance(api_key_value, SecretStr): api_key_value = api_key_value.get_secret_value() + if bool(self.space_id) == bool(self.project_id): + msg = "Exactly one of Project_ID or Space_ID must be selected" + raise ValueError(msg) + return ChatWatsonx( apikey=api_key_value, url=self.base_url, project_id=self.project_id, + space_id=self.space_id, model_id=self.model_name, params=chat_params, streaming=self.stream, diff --git a/src/lfx/tests/unit/components/ibm/__init__.py b/src/lfx/tests/unit/components/ibm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/lfx/tests/unit/components/ibm/test_watsonx.py b/src/lfx/tests/unit/components/ibm/test_watsonx.py new file mode 100644 index 000000000..9678ca70e --- /dev/null +++ b/src/lfx/tests/unit/components/ibm/test_watsonx.py @@ -0,0 +1,468 @@ +"""Unit tests for IBM watsonx.ai component.""" + +import sys +from unittest.mock import MagicMock, Mock, patch + +import pytest +from lfx.schema.dotdict import dotdict + +# Mock the langchain_ibm module before importing the component +sys.modules["langchain_ibm"] = MagicMock() + + +# Create a mock SecretStr class +class MockSecretStr: + """Mock SecretStr for testing.""" + + def __init__(self, value): + self._value = value + + def get_secret_value(self): + return self._value + + +class TestWatsonxAIComponent: + """Test suite for WatsonxAIComponent.""" + + @pytest.fixture + def wx_component(self): + """Create a WatsonxAIComponent instance for testing.""" + from lfx.components.ibm.watsonx import WatsonxAIComponent + + return WatsonxAIComponent() + + @pytest.fixture + def mock_response(self): + """Create a mock response for API calls.""" + mock_resp = Mock() + mock_resp.json.return_value = { + "resources": [ + {"model_id": "ibm/granite-3-2b-instruct"}, + {"model_id": "ibm/granite-3-8b-instruct"}, + {"model_id": "meta-llama/llama-3-70b-instruct"}, + ] + } + mock_resp.raise_for_status = Mock() + return mock_resp + + def test_component_attributes(self, wx_component): + """Test that component has correct attributes.""" + assert wx_component.display_name == "IBM watsonx.ai" + assert wx_component.description == "Generate text using IBM watsonx.ai foundation models." + assert wx_component.icon == "WatsonxAI" + assert wx_component.name == "IBMwatsonxModel" + assert wx_component.beta is False + + def test_default_models(self): + """Test that default models are defined.""" + from lfx.components.ibm.watsonx import WatsonxAIComponent + + assert len(WatsonxAIComponent._default_models) == 3 + assert "ibm/granite-3-2b-instruct" in WatsonxAIComponent._default_models + assert "ibm/granite-3-8b-instruct" in WatsonxAIComponent._default_models + assert "ibm/granite-13b-instruct-v2" in WatsonxAIComponent._default_models + + def test_urls_defined(self): + """Test that API URLs are defined.""" + from lfx.components.ibm.watsonx import WatsonxAIComponent + + assert len(WatsonxAIComponent._urls) > 0 + assert "https://us-south.ml.cloud.ibm.com" in WatsonxAIComponent._urls + assert "https://eu-de.ml.cloud.ibm.com" in WatsonxAIComponent._urls + + def test_inputs_defined(self, wx_component): + """Test that all required inputs are defined.""" + input_names = [inp.name for inp in wx_component.inputs] + + # Check for required inputs + assert "base_url" in input_names + assert "project_id" in input_names + assert "space_id" in input_names + assert "api_key" in input_names + assert "model_name" in input_names + assert "max_tokens" in input_names + assert "temperature" in input_names + assert "top_p" in input_names + assert "stream" in input_names + + @patch("lfx.base.models.model_utils.requests.get") + def test_fetch_models_success(self, mock_get, mock_response): + """Test successful model fetching from API.""" + from lfx.components.ibm.watsonx import WatsonxAIComponent + + mock_get.return_value = mock_response + + models = WatsonxAIComponent.fetch_models("https://us-south.ml.cloud.ibm.com") + + assert len(models) == 3 + assert "ibm/granite-3-2b-instruct" in models + assert "ibm/granite-3-8b-instruct" in models + assert "meta-llama/llama-3-70b-instruct" in models + + # Verify API call + mock_get.assert_called_once() + call_args = mock_get.call_args + assert "https://us-south.ml.cloud.ibm.com/ml/v1/foundation_model_specs" in call_args[0] + + @patch("lfx.base.models.model_utils.requests.get") + def test_fetch_models_api_error(self, mock_get): + """Test that default models are returned on API error.""" + from lfx.components.ibm.watsonx import WatsonxAIComponent + + mock_get.side_effect = Exception("API Error") + + models = WatsonxAIComponent.fetch_models("https://us-south.ml.cloud.ibm.com") + + # Should return default models on error + assert models == WatsonxAIComponent._default_models + + @patch("lfx.base.models.model_utils.requests.get") + def test_fetch_models_timeout(self, mock_get): + """Test that default models are returned on timeout.""" + from lfx.components.ibm.watsonx import WatsonxAIComponent + + mock_get.side_effect = TimeoutError("Request timeout") + + models = WatsonxAIComponent.fetch_models("https://us-south.ml.cloud.ibm.com") + + assert models == WatsonxAIComponent._default_models + + @patch("lfx.components.ibm.watsonx.WatsonxAIComponent.fetch_models") + def test_update_build_config_base_url(self, mock_fetch, wx_component): + """Test update_build_config when base_url changes.""" + mock_fetch.return_value = ["model1", "model2", "model3"] + + build_config = dotdict({"model_name": {"options": [], "value": None}}) + + result = wx_component.update_build_config( + build_config, field_value="https://us-south.ml.cloud.ibm.com", field_name="base_url" + ) + + assert result["model_name"]["options"] == ["model1", "model2", "model3"] + assert result["model_name"]["value"] == "model1" + mock_fetch.assert_called_once_with(base_url="https://us-south.ml.cloud.ibm.com") + + @patch("lfx.components.ibm.watsonx.WatsonxAIComponent.fetch_models") + def test_update_build_config_base_url_preserves_valid_model(self, mock_fetch, wx_component): + """Test that valid model selection is preserved when updating base_url.""" + mock_fetch.return_value = ["model1", "model2", "model3"] + + build_config = dotdict({"model_name": {"options": ["model1"], "value": "model2"}}) + + result = wx_component.update_build_config( + build_config, field_value="https://us-south.ml.cloud.ibm.com", field_name="base_url" + ) + + # model2 is in the new list, so it should be preserved + assert result["model_name"]["value"] == "model2" + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_with_project_id(self, mock_chatwatsonx, wx_component): + """Test building model with ProjectID container scope.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = "test-project-id" + wx_component.space_id = None + wx_component.model_name = "ibm/granite-3-8b-instruct" + wx_component.stream = False + wx_component.max_tokens = 1000 + wx_component.temperature = 0.7 + wx_component.top_p = 0.9 + wx_component.frequency_penalty = 0.5 + wx_component.presence_penalty = 0.3 + wx_component.seed = 8 + wx_component.stop_sequence = None + wx_component.logprobs = True + wx_component.top_logprobs = 3 + wx_component.logit_bias = None + + wx_component.build_model() + + mock_chatwatsonx.assert_called_once() + call_kwargs = mock_chatwatsonx.call_args[1] + + assert call_kwargs["apikey"] == "test-api-key" # pragma: allowlist secret + assert call_kwargs["url"] == "https://us-south.ml.cloud.ibm.com" + assert call_kwargs["project_id"] == "test-project-id" + assert call_kwargs["space_id"] is None + assert call_kwargs["model_id"] == "ibm/granite-3-8b-instruct" + assert call_kwargs["streaming"] is False + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_with_space_id(self, mock_chatwatsonx, wx_component): + """Test building model with SpaceID container scope.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = None + wx_component.space_id = "test-space-id" + wx_component.model_name = "ibm/granite-3-8b-instruct" + wx_component.stream = True + wx_component.max_tokens = 2000 + wx_component.temperature = 0.5 + wx_component.top_p = 0.95 + wx_component.frequency_penalty = 0.0 + wx_component.presence_penalty = 0.0 + wx_component.seed = 42 + wx_component.stop_sequence = "END" + wx_component.logprobs = False + wx_component.top_logprobs = 5 + wx_component.logit_bias = None + + wx_component.build_model() + + mock_chatwatsonx.assert_called_once() + call_kwargs = mock_chatwatsonx.call_args[1] + + assert call_kwargs["apikey"] == "test-api-key" # pragma: allowlist secret + assert call_kwargs["url"] == "https://us-south.ml.cloud.ibm.com" + assert call_kwargs["project_id"] is None + assert call_kwargs["space_id"] == "test-space-id" + assert call_kwargs["model_id"] == "ibm/granite-3-8b-instruct" + assert call_kwargs["streaming"] is True + assert call_kwargs["params"]["stop"] == ["END"] + + @patch("lfx.components.ibm.watsonx.SecretStr", MockSecretStr) + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_with_secret_str_api_key(self, mock_chatwatsonx, wx_component): + """Test that SecretStr API key is properly converted to string.""" + wx_component.api_key = MockSecretStr("secret-api-key") + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = "test-project-id" + wx_component.space_id = None + wx_component.model_name = "ibm/granite-3-8b-instruct" + wx_component.stream = False + wx_component.max_tokens = 1000 + wx_component.temperature = 0.7 + wx_component.top_p = 0.9 + wx_component.frequency_penalty = 0.5 + wx_component.presence_penalty = 0.3 + wx_component.seed = 8 + wx_component.stop_sequence = None + wx_component.logprobs = True + wx_component.top_logprobs = 3 + wx_component.logit_bias = None + + wx_component.build_model() + + call_kwargs = mock_chatwatsonx.call_args[1] + assert call_kwargs["apikey"] == "secret-api-key" # pragma: allowlist secret + assert isinstance(call_kwargs["apikey"], str) + + @patch("lfx.components.ibm.watsonx.WatsonxAIComponent.fetch_models") + @patch("lfx.components.ibm.watsonx.logger") + def test_update_build_config_base_url_with_exception(self, mock_logger, mock_fetch, wx_component): + """Test update_build_config handles exceptions when fetching models.""" + mock_fetch.side_effect = Exception("Network error") + + build_config = dotdict({"model_name": {"options": ["old_model"], "value": "old_model"}}) + + result = wx_component.update_build_config( + build_config, field_value="https://us-south.ml.cloud.ibm.com", field_name="base_url" + ) + + # Should log the exception but not crash + mock_logger.exception.assert_called_once_with("Error updating model options.") + # Original config should be preserved + assert result["model_name"]["options"] == ["old_model"] + assert result["model_name"]["value"] == "old_model" + + @patch("lfx.components.ibm.watsonx.WatsonxAIComponent.fetch_models") + def test_update_build_config_base_url_empty_models_list(self, mock_fetch, wx_component): + """Test update_build_config when fetch_models returns empty list.""" + mock_fetch.return_value = [] + + build_config = dotdict({"model_name": {"options": ["old_model"], "value": "old_model"}}) + + result = wx_component.update_build_config( + build_config, field_value="https://us-south.ml.cloud.ibm.com", field_name="base_url" + ) + + assert result["model_name"]["options"] == [] + assert result["model_name"]["value"] is None + + @patch("lfx.components.ibm.watsonx.WatsonxAIComponent.fetch_models") + def test_update_build_config_base_url_resets_invalid_model(self, mock_fetch, wx_component): + """Test that invalid model value is reset when base_url changes.""" + mock_fetch.return_value = ["model1", "model2"] + + build_config = dotdict({"model_name": {"options": ["old_model"], "value": "old_model"}}) + + result = wx_component.update_build_config( + build_config, field_value="https://us-south.ml.cloud.ibm.com", field_name="base_url" + ) + + # old_model is not in new list, so should be reset to first model + assert result["model_name"]["value"] == "model1" + + def test_update_build_config_base_url_empty_value(self, wx_component): + """Test update_build_config with empty base_url value.""" + build_config = dotdict({"model_name": {"options": ["model1"], "value": "model1"}}) + + result = wx_component.update_build_config(build_config, field_value="", field_name="base_url") + + # Should not update when field_value is empty + assert result["model_name"]["options"] == ["model1"] + assert result["model_name"]["value"] == "model1" + + def test_update_build_config_base_url_none_value(self, wx_component): + """Test update_build_config with None base_url value.""" + build_config = dotdict({"model_name": {"options": ["model1"], "value": "model1"}}) + + result = wx_component.update_build_config(build_config, field_value=None, field_name="base_url") + + # Should not update when field_value is None + assert result["model_name"]["options"] == ["model1"] + assert result["model_name"]["value"] == "model1" + + def test_update_build_config_unrelated_field(self, wx_component): + """Test update_build_config with unrelated field name.""" + build_config = dotdict( + { + "model_name": {"options": ["model1"], "value": "model1"}, + "space_id": {"advanced": True, "required": False, "value": None}, + "project_id": {"advanced": True, "required": False, "value": None}, + } + ) + + result = wx_component.update_build_config(build_config, field_value="some_value", field_name="unrelated_field") + + # Should return config unchanged + assert result["model_name"]["options"] == ["model1"] + assert result["model_name"]["value"] == "model1" + assert result["space_id"]["advanced"] is True + assert result["project_id"]["advanced"] is True + + def test_update_build_config_none_field_name(self, wx_component): + """Test update_build_config with None field_name.""" + build_config = dotdict( + { + "model_name": {"options": ["model1"], "value": "model1"}, + "space_id": {"advanced": True, "required": False, "value": None}, + "project_id": {"advanced": True, "required": False, "value": None}, + } + ) + + result = wx_component.update_build_config(build_config, field_value="some_value", field_name=None) + + # Should return config unchanged + assert result["model_name"]["options"] == ["model1"] + assert result["model_name"]["value"] == "model1" + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_with_logit_bias_json(self, mock_chatwatsonx, wx_component): + """Test building model with logit_bias as JSON string.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = "test-project-id" + wx_component.space_id = None + wx_component.model_name = "ibm/granite-3-8b-instruct" + wx_component.logprobs = True + wx_component.top_logprobs = 3 + wx_component.logit_bias = '{"1003": -100, "1004": 100}' + + wx_component.build_model() + + call_kwargs = mock_chatwatsonx.call_args[1] + assert call_kwargs["params"]["logit_bias"] == {"1003": -100, "1004": 100} + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + @patch("lfx.components.ibm.watsonx.logger") + def test_build_model_with_invalid_logit_bias_json(self, mock_logger, mock_chatwatsonx, wx_component): + """Test that invalid logit_bias JSON uses default value.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = "test-project-id" + wx_component.space_id = None + wx_component.model_name = "ibm/granite-3-8b-instruct" + wx_component.logprobs = True + wx_component.top_logprobs = 3 + wx_component.logit_bias = "invalid json" + + wx_component.build_model() + + call_kwargs = mock_chatwatsonx.call_args[1] + assert call_kwargs["params"]["logit_bias"] == {"1003": -100, "1004": -100} + mock_logger.warning.assert_called_once() + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_params_structure(self, mock_chatwatsonx, wx_component): + """Test that model params are structured correctly.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = "test-project-id" + wx_component.space_id = None + wx_component.model_name = "ibm/granite-3-8b-instruct" + wx_component.stream = False + wx_component.max_tokens = 1500 + wx_component.temperature = 0.8 + wx_component.top_p = 0.85 + wx_component.frequency_penalty = 0.6 + wx_component.presence_penalty = 0.4 + wx_component.seed = 123 + wx_component.stop_sequence = "STOP" + wx_component.logprobs = True + wx_component.top_logprobs = 10 + wx_component.logit_bias = None + + wx_component.build_model() + + call_kwargs = mock_chatwatsonx.call_args[1] + params = call_kwargs["params"] + + assert params["max_tokens"] == 1500 + assert params["temperature"] == 0.8 + assert params["top_p"] == 0.85 + assert params["frequency_penalty"] == 0.6 + assert params["presence_penalty"] == 0.4 + assert params["seed"] == 123 + assert params["stop"] == ["STOP"] + assert params["n"] == 1 + assert params["logprobs"] is True + assert params["top_logprobs"] == 10 + assert params["time_limit"] == 600000 + assert params["logit_bias"] is None + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_with_both_project_and_space_id_raises_error(self, mock_chatwatsonx, wx_component): + """Test that providing both project_id and space_id raises ValueError.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = "test-project-id" + wx_component.space_id = "test-space-id" + wx_component.model_name = "ibm/granite-3-8b-instruct" + + with pytest.raises(ValueError, match="Exactly one of Project_ID or Space_ID must be selected"): + wx_component.build_model() + + # Ensure ChatWatsonx was not called + mock_chatwatsonx.assert_not_called() + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_with_neither_project_nor_space_id_raises_error(self, mock_chatwatsonx, wx_component): + """Test that providing neither project_id nor space_id raises ValueError.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = None + wx_component.space_id = None + wx_component.model_name = "ibm/granite-3-8b-instruct" + + with pytest.raises(ValueError, match="Exactly one of Project_ID or Space_ID must be selected"): + wx_component.build_model() + + # Ensure ChatWatsonx was not called + mock_chatwatsonx.assert_not_called() + + @patch("lfx.components.ibm.watsonx.ChatWatsonx") + def test_build_model_with_empty_string_project_and_space_id_raises_error(self, mock_chatwatsonx, wx_component): + """Test that providing empty strings for both project_id and space_id raises ValueError.""" + wx_component.api_key = "test-api-key" # pragma: allowlist secret + wx_component.base_url = "https://us-south.ml.cloud.ibm.com" + wx_component.project_id = "" + wx_component.space_id = "" + wx_component.model_name = "ibm/granite-3-8b-instruct" + + with pytest.raises(ValueError, match="Exactly one of Project_ID or Space_ID must be selected"): + wx_component.build_model() + + # Ensure ChatWatsonx was not called + mock_chatwatsonx.assert_not_called()