Skip to content

Commit dc49ef7

Browse files
authored
Add FuzzedDataProvider utility for structured fuzzing
1 parent 7344b52 commit dc49ef7

1 file changed

Lines changed: 244 additions & 0 deletions

File tree

fuzzeddataprovider.py

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
"""Pure-Python FuzzedDataProvider matching the atheris API.
2+
3+
This is a drop-in replacement for atheris.FuzzedDataProvider that requires
4+
no native compilation. It matches atheris's consumption semantics:
5+
- ConsumeBytes/ConsumeInt/ConsumeFloat/ConsumeUnicode consume from the FRONT
6+
- ConsumeIntInRange/ConsumeBool/PickValueInList consume from the BACK
7+
8+
Reference: https://github.com/google/atheris
9+
"""
10+
11+
import struct
12+
13+
14+
class FuzzedDataProvider:
15+
def __init__(self, data):
16+
if not isinstance(data, (bytes, bytearray)):
17+
raise TypeError("data must be bytes or bytearray")
18+
self._data = bytes(data)
19+
self._front = 0
20+
self._back = len(self._data)
21+
22+
def remaining_bytes(self):
23+
return max(0, self._back - self._front)
24+
25+
def buffer(self):
26+
return self._data[self._front : self._back]
27+
28+
# -- Front-consuming methods (ConsumeBytes, ConsumeInt, etc.) --
29+
30+
def _consume_front(self, n):
31+
n = min(n, self.remaining_bytes())
32+
result = self._data[self._front : self._front + n]
33+
self._front += n
34+
return result
35+
36+
def ConsumeBytes(self, count):
37+
count = max(0, int(count))
38+
return self._consume_front(count)
39+
40+
def ConsumeInt(self, byte_count):
41+
byte_count = max(0, int(byte_count))
42+
raw = self._consume_front(byte_count)
43+
if not raw:
44+
return 0
45+
val = int.from_bytes(raw, "little")
46+
bits = len(raw) * 8
47+
if val >= (1 << (bits - 1)):
48+
val -= 1 << bits
49+
return val
50+
51+
def ConsumeUInt(self, byte_count):
52+
byte_count = max(0, int(byte_count))
53+
raw = self._consume_front(byte_count)
54+
if not raw:
55+
return 0
56+
return int.from_bytes(raw, "little")
57+
58+
def ConsumeFloat(self):
59+
raw = self._consume_front(8)
60+
if len(raw) < 8:
61+
raw = raw + b"\x00" * (8 - len(raw))
62+
return struct.unpack("<d", raw)[0]
63+
64+
def ConsumeRegularFloat(self):
65+
val = self.ConsumeFloat()
66+
if val != val or val == float("inf") or val == float("-inf"):
67+
return 0.0
68+
return val
69+
70+
def ConsumeUnicode(self, count):
71+
count = max(0, int(count))
72+
if count == 0 or self.remaining_bytes() == 0:
73+
return ""
74+
# First byte selects encoding mode (matching atheris behavior)
75+
mode_byte = self._consume_front(1)
76+
mode = mode_byte[0] if mode_byte else 0
77+
if mode == 1:
78+
# ASCII mode: one byte per character, masked to 0-127
79+
raw = self._consume_front(count)
80+
return "".join(chr(b & 0x7F) for b in raw)
81+
elif mode == 2:
82+
# UTF-16 mode: two bytes per character
83+
raw = self._consume_front(count * 2)
84+
chars = []
85+
for i in range(0, len(raw) - 1, 2):
86+
cp = int.from_bytes(raw[i : i + 2], "little")
87+
chars.append(chr(cp))
88+
return "".join(chars[:count])
89+
else:
90+
# UTF-32 mode: four bytes per character, clamped to valid range
91+
raw = self._consume_front(count * 4)
92+
chars = []
93+
for i in range(0, len(raw) - 3, 4):
94+
cp = int.from_bytes(raw[i : i + 4], "little") & 0x10FFFF
95+
try:
96+
chars.append(chr(cp))
97+
except (ValueError, OverflowError):
98+
chars.append(" ")
99+
return "".join(chars[:count])
100+
101+
def ConsumeUnicodeNoSurrogates(self, count):
102+
count = max(0, int(count))
103+
if count == 0 or self.remaining_bytes() == 0:
104+
return ""
105+
mode_byte = self._consume_front(1)
106+
mode = mode_byte[0] if mode_byte else 0
107+
if mode == 1:
108+
raw = self._consume_front(count)
109+
return "".join(chr(b & 0x7F) for b in raw)
110+
elif mode == 2:
111+
raw = self._consume_front(count * 2)
112+
chars = []
113+
for i in range(0, len(raw) - 1, 2):
114+
cp = int.from_bytes(raw[i : i + 2], "little")
115+
if 0xD800 <= cp <= 0xDFFF:
116+
cp -= 0xD800
117+
chars.append(chr(cp))
118+
return "".join(chars[:count])
119+
else:
120+
raw = self._consume_front(count * 4)
121+
chars = []
122+
for i in range(0, len(raw) - 3, 4):
123+
cp = int.from_bytes(raw[i : i + 4], "little") & 0x10FFFF
124+
if 0xD800 <= cp <= 0xDFFF:
125+
cp -= 0xD800
126+
try:
127+
chars.append(chr(cp))
128+
except (ValueError, OverflowError):
129+
chars.append(" ")
130+
return "".join(chars[:count])
131+
132+
def ConsumeString(self, count):
133+
return self.ConsumeUnicode(count)
134+
135+
# -- Back-consuming methods (ConsumeIntInRange, ConsumeBool, etc.) --
136+
137+
def _consume_back(self, n):
138+
n = min(n, self.remaining_bytes())
139+
result = self._data[self._back - n : self._back]
140+
self._back -= n
141+
return result
142+
143+
def ConsumeIntInRange(self, lo, hi):
144+
lo, hi = int(lo), int(hi)
145+
if lo > hi:
146+
lo, hi = hi, lo
147+
if lo == hi:
148+
return lo
149+
rng = hi - lo
150+
# Match LLVM: consume ceil(bits_needed/8) bytes from back
151+
# LLVM loops while offset < sizeof(T)*8 && (range >> offset) > 0
152+
nbytes = (rng.bit_length() + 7) // 8
153+
raw = self._consume_back(nbytes)
154+
if not raw:
155+
return lo
156+
# LLVM reads bytes from back as big-endian accumulation:
157+
# result = (result << 8) | next_byte_from_back
158+
# which equals int.from_bytes(reversed_bytes, 'big')
159+
# But since _consume_back returns bytes in memory order and
160+
# int.from_bytes(raw, 'little') produces the same value, we use that.
161+
val = int.from_bytes(raw, "little")
162+
return lo + (val % (rng + 1))
163+
164+
# Alias for LLVM naming compatibility
165+
ConsumeIntegralInRange = ConsumeIntInRange
166+
167+
def ConsumeBool(self):
168+
# Matches LLVM: 1 & ConsumeIntegral<uint8_t>()
169+
# ConsumeIntegral<uint8_t>() = ConsumeIntegralInRange(0, 255)
170+
return (self.ConsumeIntInRange(0, 255) & 1) == 1
171+
172+
def ConsumeProbability(self):
173+
# Matches LLVM: ConsumeIntegral<uint64_t>() / UINT64_MAX
174+
# ConsumeIntegral<uint64_t>() = ConsumeIntegralInRange(0, 2^64-1)
175+
# When range == UINT64_MAX, no modulo is applied (special case)
176+
raw = self._consume_back(8)
177+
if not raw:
178+
return 0.0
179+
val = int.from_bytes(raw, "little")
180+
return val / float((1 << 64) - 1)
181+
182+
def ConsumeFloatInRange(self, lo, hi):
183+
lo, hi = float(lo), float(hi)
184+
if lo > hi:
185+
lo, hi = hi, lo
186+
p = self.ConsumeProbability()
187+
return lo + (hi - lo) * p
188+
189+
def PickValueInList(self, lst):
190+
if not lst:
191+
raise ValueError("list must not be empty")
192+
idx = self.ConsumeIntInRange(0, len(lst) - 1)
193+
return lst[idx]
194+
195+
# -- List methods --
196+
197+
def ConsumeIntList(self, count, byte_count):
198+
count = max(0, int(count))
199+
return [self.ConsumeInt(byte_count) for _ in range(count)]
200+
201+
def ConsumeIntListInRange(self, count, lo, hi):
202+
count = max(0, int(count))
203+
return [self.ConsumeIntInRange(lo, hi) for _ in range(count)]
204+
205+
def ConsumeFloatList(self, count):
206+
count = max(0, int(count))
207+
return [self.ConsumeFloat() for _ in range(count)]
208+
209+
def ConsumeFloatListInRange(self, count, lo, hi):
210+
count = max(0, int(count))
211+
return [self.ConsumeFloatInRange(lo, hi) for _ in range(count)]
212+
213+
def ConsumeProbabilityList(self, count):
214+
count = max(0, int(count))
215+
return [self.ConsumeProbability() for _ in range(count)]
216+
217+
def ConsumeRegularFloatList(self, count):
218+
count = max(0, int(count))
219+
return [self.ConsumeRegularFloat() for _ in range(count)]
220+
221+
# -- Arbitrary value --
222+
223+
_ANY_TYPE_INT = 0
224+
_ANY_TYPE_FLOAT = 1
225+
_ANY_TYPE_BOOL = 2
226+
_ANY_TYPE_BYTES = 3
227+
_ANY_TYPE_STRING = 4
228+
_ANY_TYPE_NONE = 5
229+
230+
def ConsumeRandomValue(self):
231+
"""Return a value of a randomly chosen primitive type."""
232+
t = self.ConsumeIntInRange(self._ANY_TYPE_INT, self._ANY_TYPE_NONE)
233+
if t == self._ANY_TYPE_INT:
234+
return self.ConsumeInt(4)
235+
elif t == self._ANY_TYPE_FLOAT:
236+
return self.ConsumeFloat()
237+
elif t == self._ANY_TYPE_BOOL:
238+
return self.ConsumeBool()
239+
elif t == self._ANY_TYPE_BYTES:
240+
return self.ConsumeBytes(self.ConsumeIntInRange(0, 64))
241+
elif t == self._ANY_TYPE_STRING:
242+
return self.ConsumeUnicode(self.ConsumeIntInRange(0, 64))
243+
else:
244+
return None

0 commit comments

Comments
 (0)