Skip to content
Open
62 changes: 62 additions & 0 deletions sdks/python/apache_beam/yaml/tests/runinference_huggingface.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

pipelines:
- pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- text: "I love Apache Beam!"
- text: "I hate this error."
- type: RunInference
config:
model_handler:
type: "HuggingFacePipeline"
config:
task: "text-classification"
inference_fn:
callable: |
def real_inference(batch, pipeline, inference_args):
predictions = pipeline(batch, **inference_args)

# If it's a single dictionary (batch size of 1), wrap it in a list
if isinstance(predictions, dict):
predictions = [predictions]

return {
'label': [p['label'] for p in predictions],
'score': [p['score'] for p in predictions]
}
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
preprocess:
callable: 'lambda x: x.text'
- type: MapToFields
config:
language: python
fields:
text: text
sentiment:
callable: 'lambda x: x.inference.inference["label"]'
Comment thread
derrickaw marked this conversation as resolved.
Comment thread
derrickaw marked this conversation as resolved.
- type: AssertEqual
config:
elements:
- text: "I love Apache Beam!"
sentiment: "POSITIVE"
- text: "I hate this error."
sentiment: "NEGATIVE"

options:
yaml_experimental_features: ['ML']
49 changes: 49 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,55 @@ def inference_output_type(self):
('model_id', Optional[str])])


@ModelHandlerProvider.register_handler_type('HuggingFacePipeline')
class HuggingFacePipelineProvider(ModelHandlerProvider):
def __init__(
self,
task: Optional[str] = None,
model: Optional[str] = None,
preprocess: Optional[dict[str, str]] = None,
postprocess: Optional[dict[str, str]] = None,
device: Optional[Any] = None,
inference_fn: Optional[dict[str, str]] = None,
load_pipeline_args: Optional[dict[str, Any]] = None,
Comment thread
derrickaw marked this conversation as resolved.
**kwargs):
try:
from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler
except ImportError:
raise ValueError(
'Unable to import HuggingFacePipelineModelHandler. Please '
'install transformers dependencies.')

kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')}

inference_fn_obj = self.parse_processing_transform(
inference_fn, 'inference_fn') if inference_fn else None

handler_kwargs = {}
if inference_fn_obj:
handler_kwargs['inference_fn'] = inference_fn_obj

_handler = HuggingFacePipelineModelHandler(
task=task,
model=model,
device=device,
load_pipeline_args=load_pipeline_args,
Comment thread
derrickaw marked this conversation as resolved.
**handler_kwargs,
**kwargs)

super().__init__(_handler, preprocess, postprocess)

@staticmethod
def validate(config):
if not config.get('task') and not config.get('model'):
raise ValueError(
"HuggingFacePipeline requires either 'task' or "
"'model' to be specified.")

def inference_output_type(self):
return Any


@beam.ptransform.ptransform_fn
def run_inference(
pcoll,
Expand Down
14 changes: 12 additions & 2 deletions sdks/python/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,20 @@ tasks.register("generateYamlDocs") {
outputs.file "${buildDir}/yaml-examples.html"
}

tasks.register("installYamlIntegrationTestDeps") {
dependsOn installGcpTest
doLast {
exec {
executable 'sh'
args '-c', ". ${envdir}/bin/activate && pip install --pre --retries 10 ${buildDir}/apache-beam.tar.gz[ml_test,yaml,transformers]"
}
}
}

tasks.register("yamlIntegrationTests") {
description "Runs precommit integration tests for yaml pipelines."

dependsOn installGcpTest
dependsOn installYamlIntegrationTestDeps
// Need to build all expansion services referenced in apache_beam/yaml/*.*
// grep -oh 'sdk.*Jar' sdks/python/apache_beam/yaml/*.yaml | sort | uniq
dependsOn ":sdks:java:extensions:schemaio-expansion-service:shadowJar"
Expand All @@ -146,7 +156,7 @@ tasks.register("yamlIntegrationTests") {
tasks.register("postCommitYamlIntegrationTests") {
description "Runs postcommit integration tests for yaml pipelines - parameterized by yamlTestSet."

dependsOn installGcpTest
dependsOn installYamlIntegrationTestDeps
// Need to build all expansion services referenced in apache_beam/yaml/*.*
// grep -oh 'sdk.*Jar' sdks/python/apache_beam/yaml/*.yaml | sort | uniq
dependsOn ":sdks:java:extensions:schemaio-expansion-service:shadowJar"
Expand Down
Loading