fix(core): Fix race condition in custom component tool invocations (#11994)

* fix custom component as tool

* fix unit tests

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Cristhian Zanforlin Lousa
2026-03-10 16:45:49 -03:00
committed by GitHub
parent 1710053e76
commit 409bd9ee31
3 changed files with 189 additions and 13 deletions

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import re
from copy import deepcopy
from typing import TYPE_CHECKING
import pandas as pd
@@ -83,14 +84,20 @@ def _patch_send_message_decorator(component, func):
def _build_output_function(component: Component, output_method: Callable, event_manager: EventManager | None = None):
method_name = output_method.__name__
def output_function(*args, **kwargs):
# Create an isolated copy to prevent race conditions when this
# tool is invoked concurrently by an agent (GitHub issue #8791)
comp = deepcopy(component)
local_method = getattr(comp, method_name)
try:
if event_manager:
event_manager.on_build_start(data={"id": component.get_id()})
component.set(*args, **kwargs)
result = output_method()
event_manager.on_build_start(data={"id": comp.get_id()})
comp.set(*args, **kwargs)
result = local_method()
if event_manager:
event_manager.on_build_end(data={"id": component.get_id()})
event_manager.on_build_end(data={"id": comp.get_id()})
except Exception as e:
raise ToolException(e) from e
@@ -107,14 +114,20 @@ def _build_output_function(component: Component, output_method: Callable, event_
def _build_output_async_function(
component: Component, output_method: Callable, event_manager: EventManager | None = None
):
method_name = output_method.__name__
async def output_function(*args, **kwargs):
# Create an isolated copy to prevent race conditions when this
# tool is invoked concurrently by an agent (GitHub issue #8791)
comp = deepcopy(component)
local_method = getattr(comp, method_name)
try:
if event_manager:
await asyncio.to_thread(event_manager.on_build_start, data={"id": component.get_id()})
component.set(*args, **kwargs)
result = await output_method()
await asyncio.to_thread(event_manager.on_build_start, data={"id": comp.get_id()})
comp.set(*args, **kwargs)
result = await local_method()
if event_manager:
await asyncio.to_thread(event_manager.on_build_end, data={"id": component.get_id()})
await asyncio.to_thread(event_manager.on_build_end, data={"id": comp.get_id()})
except Exception as e:
raise ToolException(e) from e
if isinstance(result, Message):

View File

@@ -378,16 +378,18 @@ class Component(CustomComponent):
def __deepcopy__(self, memo: dict) -> Component:
if id(self) in memo:
return memo[id(self)]
kwargs = deepcopy(self.__config, memo)
kwargs["inputs"] = deepcopy(self.__inputs, memo)
# Shallow-copy config/inputs: they may contain non-picklable services
# (e.g. _tracing_service holds ServiceManager with threading.RLock).
kwargs = dict(self.__config)
kwargs["inputs"] = dict(self.__inputs)
new_component = type(self)(**kwargs)
new_component._code = self._code
new_component._outputs_map = self._outputs_map
new_component._inputs = self._inputs
new_component._inputs = deepcopy(self._inputs, memo)
new_component._edges = self._edges
new_component._components = self._components
new_component._parameters = self._parameters
new_component._attributes = self._attributes
new_component._parameters = dict(self._parameters)
new_component._attributes = dict(self._attributes)
new_component._output_logs = self._output_logs
new_component._logs = self._logs # type: ignore[attr-defined]
memo[id(self)] = new_component

View File

@@ -0,0 +1,161 @@
"""Test for GitHub issue #8791: Race condition when component-tool is invoked concurrently.
When an Agent invokes the same component-based tool multiple times concurrently,
the same component instance is reused, causing inputs to be overwritten between
concurrent invocations (data corruption).
"""
import threading
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from lfx.base.tools.component_tool import ComponentToolkit
from lfx.custom.custom_component.component import Component
from lfx.inputs.inputs import MessageTextInput
from lfx.io import Output
from lfx.schema.data import Data
class SlowLabelComponent(Component):
"""Test component that simulates a tool with processing delay.
Records input values before and after a delay to detect
if concurrent invocations overwrite each other's inputs.
"""
display_name = "Slow Label Tool"
description = "Adds a label to a product with simulated delay."
name = "SlowLabelComponent"
inputs = [
MessageTextInput(name="product_id", display_name="Product ID", tool_mode=True),
MessageTextInput(name="label", display_name="Label", tool_mode=True),
]
outputs = [
Output(display_name="Result", name="result", method="process"),
]
def process(self) -> Data:
import time
captured_id = self.product_id
captured_label = self.label
time.sleep(0.2)
return Data(
data={
"product_id_before": captured_id,
"label_before": captured_label,
"product_id_after": self.product_id,
"label_after": self.label,
}
)
def test_should_isolate_inputs_when_tool_invoked_concurrently():
"""Bug #8791: concurrent tool invocations must not share mutable state.
GIVEN: A component converted to a StructuredTool via ComponentToolkit
WHEN: The tool is invoked concurrently with different inputs
THEN: Each invocation must see its own inputs (no cross-contamination)
"""
# Arrange
component = SlowLabelComponent()
toolkit = ComponentToolkit(component=component)
tools = toolkit.get_tools()
assert len(tools) == 1
tool = tools[0]
results = []
def invoke_tool(product_id: str, label: str) -> None:
result = tool.invoke({"product_id": product_id, "label": label})
results.append(result)
# Act - invoke the same tool concurrently with different inputs
with ThreadPoolExecutor(max_workers=2) as executor:
future1 = executor.submit(invoke_tool, "PROD-001", "Electronics")
future2 = executor.submit(invoke_tool, "PROD-002", "Clothing")
future1.result()
future2.result()
# Assert - each invocation must retain its own inputs throughout execution
assert len(results) == 2
for result in results:
# Inputs captured before and after the delay must be identical
assert result["product_id_before"] == result["product_id_after"], (
f"product_id changed during execution: '{result['product_id_before']}' -> '{result['product_id_after']}'"
)
assert result["label_before"] == result["label_after"], (
f"label changed during execution: '{result['label_before']}' -> '{result['label_after']}'"
)
# Both products must have been processed (not duplicated)
product_ids = {r["product_id_before"] for r in results}
assert product_ids == {"PROD-001", "PROD-002"}, (
f"Expected both products to be processed independently, got: {product_ids}"
)
def test_deepcopy_with_non_picklable_state():
"""Deepcopy must not fail when the component carries non-picklable objects.
Real components receive services (e.g. _tracing_service) that hold
threading.RLock instances. __deepcopy__ must handle these gracefully.
"""
component = SlowLabelComponent(_tracing_service=_FakeServiceWithLock())
# Must not raise "cannot pickle '_thread.RLock' object"
clone = deepcopy(component)
# The clone must be a distinct object
assert clone is not component
# The non-picklable service should be shared (shallow-copied), not duplicated
assert clone._tracing_service is component._tracing_service # type: ignore[attr-defined]
def test_should_isolate_inputs_when_component_has_non_picklable_state():
"""End-to-end: concurrent tool invocation must work even with non-picklable state.
Combines both bugs: race condition (#8791) + RLock deepcopy failure.
The component has a _tracing_service with RLock AND is invoked concurrently.
"""
# Arrange
component = SlowLabelComponent(_tracing_service=_FakeServiceWithLock())
toolkit = ComponentToolkit(component=component)
tools = toolkit.get_tools()
tool = tools[0]
results = []
def invoke_tool(product_id: str, label: str) -> None:
result = tool.invoke({"product_id": product_id, "label": label})
results.append(result)
# Act - invoke concurrently with a component that has non-picklable state
with ThreadPoolExecutor(max_workers=2) as executor:
future1 = executor.submit(invoke_tool, "P1", "LabelA")
future2 = executor.submit(invoke_tool, "P2", "LabelB")
future1.result()
future2.result()
# Assert - no race condition AND no pickle error
assert len(results) == 2
for result in results:
assert result["product_id_before"] == result["product_id_after"]
assert result["label_before"] == result["label_after"]
product_ids = {r["product_id_before"] for r in results}
assert product_ids == {"P1", "P2"}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _FakeServiceWithLock:
"""Mimics a service that holds a threading.RLock (like ServiceManager)."""
def __init__(self):
self._lock = threading.RLock()