-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathdata_sampler.py
More file actions
369 lines (306 loc) · 13 KB
/
data_sampler.py
File metadata and controls
369 lines (306 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Functionaliry for sampling elements during bundle execution."""
# pytype: skip-file
from __future__ import annotations
import collections
import logging
import threading
import time
import traceback
from dataclasses import dataclass
from threading import Timer
from typing import Any
from typing import Deque
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
from apache_beam.coders.coder_impl import CoderImpl
from apache_beam.coders.coder_impl import WindowedValueCoderImpl
from apache_beam.coders.coders import Coder
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.utils.windowed_value import WindowedValue
_LOGGER = logging.getLogger(__name__)
class SampleTimer:
"""Periodic timer for sampling elements."""
def __init__(self, timeout_secs: float, sampler: OutputSampler) -> None:
self._target_timeout_secs = timeout_secs
self._timeout_secs = min(timeout_secs, 0.5) if timeout_secs > 0 else 0.0
self._timer = Timer(self._timeout_secs, self.sample)
self._sampler = sampler
self._sample_duration_secs = 0.0
def reset(self) -> None:
# For the first 30 seconds, sample every 0.5 seconds. After that, sample at
# the normal rate.
if self._sample_duration_secs >= 30.0:
self._timeout_secs = self._target_timeout_secs
self._sample_duration_secs += self._timeout_secs
self._timer.cancel()
self._timer = Timer(self._timeout_secs, self.sample)
self._timer.start()
def stop(self) -> None:
self._timer.cancel()
def sample(self) -> None:
self._sampler.sample()
self.reset()
@dataclass
class ExceptionMetadata:
# The repr-ified Exception.
msg: str
# The transform where the exception occurred.
transform_id: str
# The instruction when the exception occurred.
instruction_id: str
@dataclass
class ElementSampler:
"""Record class to hold sampled elements.
This class is used as an optimization to quickly sample elements. This is a
shared reference between the Operation and the OutputSampler.
"""
# Is true iff the `el` has been set with a sample.
has_element: bool = False
# The sampled element. Note that `None` is a valid element and cannot be uesd
# as a sentintel to check if there is a sample. Use the `has_element` flag to
# check for this case.
el: Any = None
class OutputSampler:
"""Represents a way to sample an output of a PTransform.
This is configurable to only keep `max_samples` (see constructor) sampled
elements in memory. Samples are taken every `sample_every_sec`.
"""
def __init__(
self,
coder: Coder,
max_samples: int = 10,
sample_every_sec: float = 5) -> None:
self._samples: Deque[Any] = collections.deque(maxlen=max_samples)
self._samples_lock: threading.Lock = threading.Lock()
self._coder_impl: CoderImpl = coder.get_impl()
self._sample_timer = SampleTimer(sample_every_sec, self)
self.element_sampler = ElementSampler()
self.element_sampler.has_element = False
self._exceptions: Deque[Tuple[Any, ExceptionMetadata]] = collections.deque(
maxlen=max_samples)
# For testing, it's easier to disable the Timer and manually sample.
if sample_every_sec > 0:
self._sample_timer.reset()
def stop(self) -> None:
"""Stops sampling."""
self._sample_timer.stop()
def remove_windowed_value(self, el: Union[WindowedValue, Any]) -> Any:
"""Retrieves the value from the WindowedValue.
The Python SDK passes elements as WindowedValues, which may not match the
coder for that particular PCollection.
"""
while isinstance(el, WindowedValue):
el = el.value
return el
def flush(self, clear: bool = True) -> List[beam_fn_api_pb2.SampledElement]:
"""Returns all samples and optionally clears buffer if clear is True."""
with self._samples_lock:
# TODO(rohdesamuel): There can duplicates between the exceptions and
# samples. This happens when the OutputSampler samples during an
# exception. The fix is to create a OutputSampler per process bundle.
# Until then use a set to keep track of the elements.
seen = set(id(el) for el, _ in self._exceptions)
if isinstance(self._coder_impl, WindowedValueCoderImpl):
exceptions = [s for s in self._exceptions]
samples = [s for s in self._samples if id(s) not in seen]
else:
exceptions = [(self.remove_windowed_value(a), b)
for a, b in self._exceptions]
samples = [
self.remove_windowed_value(s) for s in self._samples
if id(s) not in seen
]
# Encode in the nested context b/c this ensures that the SDK can decode
# the bytes with the ToStringFn.
if clear:
self._samples.clear()
self._exceptions.clear()
ret = []
try:
ret = [
beam_fn_api_pb2.SampledElement(
element=self._coder_impl.encode_nested(s),
) for s in samples
]
except Exception as e: # pylint: disable=broad-except
_LOGGER.warning('Could not encode sampled values: %s' % e)
try:
ret.extend(
beam_fn_api_pb2.SampledElement(
element=self._coder_impl.encode_nested(s),
exception=beam_fn_api_pb2.SampledElement.Exception(
instruction_id=exn.instruction_id,
transform_id=exn.transform_id,
error=exn.msg)) for s, exn in exceptions)
except Exception as e: # pylint: disable=broad-except
_LOGGER.warning('Could not encode sampled exception values: %s' % e)
return ret
def sample(self) -> None:
"""Samples the given element to an internal buffer."""
with self._samples_lock:
if self.element_sampler.has_element:
self._samples.append(self.element_sampler.el)
self.element_sampler.has_element = False
def sample_exception(
self, el: Any, exc_info: Any, transform_id: str,
instruction_id: str) -> None:
"""Adds the given exception to the samples."""
with self._samples_lock:
err_string = ''.join(traceback.format_exception(*exc_info))
self._exceptions.append(
(el, ExceptionMetadata(err_string, transform_id, instruction_id)))
class DataSampler:
"""A class for querying any samples generated during execution.
This class is meant to be a singleton with regard to a particular
`sdk_worker.SdkHarness`. When creating the operators, individual
`OutputSampler`s are created from `DataSampler.initialize_samplers`. This
allows for multi-threaded sampling of a PCollection across the SdkHarness.
Samples generated during execution can then be sampled with the `samples`
method. This filters samples from the given pcollection ids.
"""
def __init__(
self,
max_samples: int = 10,
sample_every_sec: float = 30,
sample_only_exceptions: bool = False,
clock=None) -> None:
# Key is PCollection id. Is guarded by the _samplers_lock.
self._samplers: Dict[str, OutputSampler] = {}
# Bundles are processed in parallel, so new samplers may be added when the
# runner queries for samples.
self._samplers_lock: threading.Lock = threading.Lock()
self._max_samples = max_samples
self._sample_every_sec = 0.0 if sample_only_exceptions else sample_every_sec
self._samplers_by_output: Dict[str, List[OutputSampler]] = {}
self._clock = clock
_ENABLE_DATA_SAMPLING = 'enable_data_sampling'
_ENABLE_ALWAYS_ON_EXCEPTION_SAMPLING = 'enable_always_on_exception_sampling'
_DISABLE_ALWAYS_ON_EXCEPTION_SAMPLING = 'disable_always_on_exception_sampling'
@staticmethod
def create(sdk_pipeline_options: PipelineOptions, **kwargs):
experiments = sdk_pipeline_options.view_as(DebugOptions).experiments or []
# When true, enables only the sampling of exceptions.
always_on_exception_sampling = (
DataSampler._ENABLE_ALWAYS_ON_EXCEPTION_SAMPLING in experiments and
DataSampler._DISABLE_ALWAYS_ON_EXCEPTION_SAMPLING not in experiments)
# When true, enables the sampling of all PCollections and exceptions.
enable_data_sampling = DataSampler._ENABLE_DATA_SAMPLING in experiments
if enable_data_sampling or always_on_exception_sampling:
sample_only_exceptions = (
always_on_exception_sampling and not enable_data_sampling)
return DataSampler(
sample_only_exceptions=sample_only_exceptions, **kwargs)
else:
return None
def stop(self) -> None:
"""Stops all sampling, does not clear samplers in case there are outstanding
samples.
"""
with self._samplers_lock:
for sampler in self._samplers.values():
sampler.stop()
def sampler_for_output(self, transform_id: str,
output_index: int) -> Optional[OutputSampler]:
"""Returns the OutputSampler for the given output."""
try:
with self._samplers_lock:
outputs = self._samplers_by_output[transform_id]
return outputs[output_index]
except KeyError:
_LOGGER.warning(
f'Out-of-bounds access for transform "{transform_id}" ' +
'and output "{output_index}" OutputSampler. This may ' +
'indicate that the transform was improperly ' +
'initialized with the DataSampler.')
return None
def initialize_samplers(
self,
transform_id: str,
descriptor: beam_fn_api_pb2.ProcessBundleDescriptor,
coder_factory) -> List[OutputSampler]:
"""Creates the OutputSamplers for the given PTransform.
This initializes the samplers only once per PCollection Id. Note that an
OutputSampler is created per PCollection and an ElementSampler is created
per OutputSampler. This means that multiple ProcessBundles can and will
share the same ElementSampler for a given PCollection.
"""
transform_proto = descriptor.transforms[transform_id]
with self._samplers_lock:
if transform_id in self._samplers_by_output:
return self._samplers_by_output[transform_id]
# Initialize the samplers.
for pcoll_id in transform_proto.outputs.values():
# Only initialize new PCollections.
if pcoll_id in self._samplers:
continue
# Create the sampler with the corresponding coder.
coder_id = descriptor.pcollections[pcoll_id].coder_id
coder = coder_factory(coder_id)
sampler = OutputSampler(
coder, self._max_samples, self._sample_every_sec)
self._samplers[pcoll_id] = sampler
# Next update the lookup table for ElementSamplers for a given PTransform.
# Operations look up the ElementSampler for an output based on the index
# of the tag in the PTransform's outputs. The following code intializes
# the array with ElementSamplers in the correct indices.
outputs = transform_proto.outputs
samplers = [self._samplers[pcoll_id] for pcoll_id in outputs.values()]
self._samplers_by_output[transform_id] = samplers
return samplers
def samples(
self,
pcollection_ids: Optional[Iterable[str]] = None
) -> beam_fn_api_pb2.SampleDataResponse:
"""Returns samples filtered PCollection ids.
All samples from the given PCollections are returned. Empty lists are
wildcards.
"""
ret = beam_fn_api_pb2.SampleDataResponse()
with self._samplers_lock:
samplers = self._samplers.copy()
for pcoll_id in samplers:
if pcollection_ids and pcoll_id not in pcollection_ids:
continue
samples = samplers[pcoll_id].flush()
if samples:
ret.element_samples[pcoll_id].elements.extend(samples)
return ret
def wait_for_samples(
self, pcollection_ids: List[str]) -> beam_fn_api_pb2.SampleDataResponse:
"""Waits for samples to exist for the given PCollections (only testing)."""
now = time.time()
end = now + 30
samples = beam_fn_api_pb2.SampleDataResponse()
while now < end:
time.sleep(0.1)
now = time.time()
samples.MergeFrom(self.samples(pcollection_ids))
if not samples:
continue
has_all = all(
pcoll_id in samples.element_samples for pcoll_id in pcollection_ids)
if has_all:
break
return samples