Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def to_dict(self) -> dict[str, Any]:
self,
labels=self.labels,
model=self.huggingface_pipeline_kwargs["model"],
classification_field=self.classification_field,
multi_label=self.multi_label,
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
token=self.token,
multi_label=self.multi_label,
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/joiners/document_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _concatenate(document_lists: list[list[Document]]) -> list[Document]:
for doc in itertools.chain.from_iterable(document_lists):
docs_per_id[doc.id].append(doc)
for docs in docs_per_id.values():
doc_with_best_score = max(docs, key=lambda doc: doc.score if doc.score else -inf)
doc_with_best_score = max(docs, key=lambda doc: doc.score if doc.score is not None else -inf)
output.append(doc_with_best_score)
return output

Expand Down
10 changes: 6 additions & 4 deletions haystack/utils/requests_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ def request_with_retry(
after=after_log(logger, logging.DEBUG),
)
def run() -> httpx.Response:
timeout = kwargs.pop("timeout", 10)
request_kwargs = dict(kwargs)
timeout = request_kwargs.pop("timeout", 10)
with httpx.Client() as client:
res = client.request(**kwargs, timeout=timeout)
res = client.request(**request_kwargs, timeout=timeout)

if res.status_code in status_codes_to_retry:
# We raise only for the status codes that must trigger a retry
Expand Down Expand Up @@ -177,9 +178,10 @@ async def example_5xx():
after=after_log(logger, logging.DEBUG),
)
async def run() -> httpx.Response:
timeout = kwargs.pop("timeout", 10)
request_kwargs = dict(kwargs)
timeout = request_kwargs.pop("timeout", 10)
async with httpx.AsyncClient() as client:
res = await client.request(**kwargs, timeout=timeout)
res = await client.request(**request_kwargs, timeout=timeout)

if res.status_code in status_codes_to_retry:
# We raise only for the status codes that must trigger a retry
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Fixed `DocumentJoiner` in `concatenate` mode treating documents with a score of `0.0` as unscored when deduplicating by ID.
Duplicate documents with a zero score could lose to documents with negative or missing scores.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
`TransformersZeroShotDocumentClassifier.to_dict()` now serializes `classification_field` and `multi_label`, so pipeline dump/load round-trips preserve the configured classification behavior.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Preserve user-specified timeouts across retries in `request_with_retry` and
`async_request_with_retry`.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def test_to_dict(self):
"init_parameters": {
"model": "cross-encoder/nli-deberta-v3-xsmall",
"labels": ["positive", "negative"],
"classification_field": None,
"multi_label": False,
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"multi_label": True,
"classification_field": "title",
Expand All @@ -47,6 +49,17 @@ def test_to_dict(self):
},
}

def test_to_dict_from_dict_round_trip(self):
component = TransformersZeroShotDocumentClassifier(
model="cross-encoder/nli-deberta-v3-xsmall",
labels=["a", "b"],
classification_field="title",
multi_label=True,
)
restored = TransformersZeroShotDocumentClassifier.from_dict(component.to_dict())
assert restored.classification_field == "title"
assert restored.multi_label is True

def test_from_dict(self, del_hf_env_vars):
data = {
"type": "haystack.components.classifiers.zero_shot_document_classifier.TransformersZeroShotDocumentClassifier", # noqa: E501
Expand Down
11 changes: 11 additions & 0 deletions test/components/joiners/test_document_joiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ def test_run_with_concatenate_join_mode_and_duplicate_documents(self):
output["documents"], key=lambda d: d.id
)

def test_concatenate_keeps_highest_score_for_zero_and_negative_scores(self):
joiner = DocumentJoiner(sort_by_score=False)
documents_1 = [Document(id="dup", content="no score")]
documents_2 = [Document(id="dup", content="zero score", score=0.0)]
documents_3 = [Document(id="dup", content="negative score", score=-0.1)]

output = joiner.run([documents_1, documents_2, documents_3])
assert len(output["documents"]) == 1
assert output["documents"][0].content == "zero score"
assert output["documents"][0].score == 0.0

def test_run_with_merge_join_mode(self):
joiner = DocumentJoiner(join_mode="merge", weights=[1.5, 0.5])
documents_1 = [Document(content="a", score=1.0), Document(content="b", score=2.0)]
Expand Down