Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions haystack/components/generators/hugging_face_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def run(
replies = [o["generated_text"] for o in output if "generated_text" in o]

if self.stop_words:
# the output of the pipeline includes the stop word
replies = [reply.replace(stop_word, "").rstrip() for reply in replies for stop_word in self.stop_words]
# The output of the pipeline includes the stop word, so strip each stop word
# from each reply without duplicating replies when multiple stop words are set.
for stop_word in self.stop_words:
replies = [reply.replace(stop_word, "").rstrip() for reply in replies]

return {"replies": replies}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Fixes `HuggingFaceLocalGenerator` so using multiple `stop_words` no longer duplicates replies while stripping stop words from generated text.
17 changes: 17 additions & 0 deletions test/components/generators/test_hugging_face_local_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,23 @@ def test_run_stop_words_removal(self):
results = generator.run(prompt="irrelevant")
assert results == {"replies": ["Hello"]}

def test_run_stop_words_removal_multiple_entries(self):
"""Test that multiple stop words are removed sequentially without duplicating replies."""
generator = HuggingFaceLocalGenerator(
model="Qwen/Qwen3-0.6B", task="text-generation", stop_words=[" STOP", " END"]
)
generator.pipeline = Mock(
return_value=[
{"generated_text": "Paris is the capital. STOP"},
{"generated_text": "France is in Europe. END"},
]
)
generator.stopping_criteria_list = Mock()

results = generator.run(prompt="irrelevant")

assert results == {"replies": ["Paris is the capital.", "France is in Europe."]}

@pytest.mark.integration
def test_stop_words_criteria_using_hf_tokenizer(self):
"""
Expand Down