|
| 1 | +"""Pure-Python FuzzedDataProvider matching the atheris API. |
| 2 | +
|
| 3 | +This is a drop-in replacement for atheris.FuzzedDataProvider that requires |
| 4 | +no native compilation. It matches atheris's consumption semantics: |
| 5 | + - ConsumeBytes/ConsumeInt/ConsumeFloat/ConsumeUnicode consume from the FRONT |
| 6 | + - ConsumeIntInRange/ConsumeBool/PickValueInList consume from the BACK |
| 7 | +
|
| 8 | +Reference: https://github.com/google/atheris |
| 9 | +""" |
| 10 | + |
| 11 | +import struct |
| 12 | + |
| 13 | + |
| 14 | +class FuzzedDataProvider: |
| 15 | + def __init__(self, data): |
| 16 | + if not isinstance(data, (bytes, bytearray)): |
| 17 | + raise TypeError("data must be bytes or bytearray") |
| 18 | + self._data = bytes(data) |
| 19 | + self._front = 0 |
| 20 | + self._back = len(self._data) |
| 21 | + |
| 22 | + def remaining_bytes(self): |
| 23 | + return max(0, self._back - self._front) |
| 24 | + |
| 25 | + def buffer(self): |
| 26 | + return self._data[self._front : self._back] |
| 27 | + |
| 28 | + # -- Front-consuming methods (ConsumeBytes, ConsumeInt, etc.) -- |
| 29 | + |
| 30 | + def _consume_front(self, n): |
| 31 | + n = min(n, self.remaining_bytes()) |
| 32 | + result = self._data[self._front : self._front + n] |
| 33 | + self._front += n |
| 34 | + return result |
| 35 | + |
| 36 | + def ConsumeBytes(self, count): |
| 37 | + count = max(0, int(count)) |
| 38 | + return self._consume_front(count) |
| 39 | + |
| 40 | + def ConsumeInt(self, byte_count): |
| 41 | + byte_count = max(0, int(byte_count)) |
| 42 | + raw = self._consume_front(byte_count) |
| 43 | + if not raw: |
| 44 | + return 0 |
| 45 | + val = int.from_bytes(raw, "little") |
| 46 | + bits = len(raw) * 8 |
| 47 | + if val >= (1 << (bits - 1)): |
| 48 | + val -= 1 << bits |
| 49 | + return val |
| 50 | + |
| 51 | + def ConsumeUInt(self, byte_count): |
| 52 | + byte_count = max(0, int(byte_count)) |
| 53 | + raw = self._consume_front(byte_count) |
| 54 | + if not raw: |
| 55 | + return 0 |
| 56 | + return int.from_bytes(raw, "little") |
| 57 | + |
| 58 | + def ConsumeFloat(self): |
| 59 | + raw = self._consume_front(8) |
| 60 | + if len(raw) < 8: |
| 61 | + raw = raw + b"\x00" * (8 - len(raw)) |
| 62 | + return struct.unpack("<d", raw)[0] |
| 63 | + |
| 64 | + def ConsumeRegularFloat(self): |
| 65 | + val = self.ConsumeFloat() |
| 66 | + if val != val or val == float("inf") or val == float("-inf"): |
| 67 | + return 0.0 |
| 68 | + return val |
| 69 | + |
| 70 | + def ConsumeUnicode(self, count): |
| 71 | + count = max(0, int(count)) |
| 72 | + if count == 0 or self.remaining_bytes() == 0: |
| 73 | + return "" |
| 74 | + # First byte selects encoding mode (matching atheris behavior) |
| 75 | + mode_byte = self._consume_front(1) |
| 76 | + mode = mode_byte[0] if mode_byte else 0 |
| 77 | + if mode == 1: |
| 78 | + # ASCII mode: one byte per character, masked to 0-127 |
| 79 | + raw = self._consume_front(count) |
| 80 | + return "".join(chr(b & 0x7F) for b in raw) |
| 81 | + elif mode == 2: |
| 82 | + # UTF-16 mode: two bytes per character |
| 83 | + raw = self._consume_front(count * 2) |
| 84 | + chars = [] |
| 85 | + for i in range(0, len(raw) - 1, 2): |
| 86 | + cp = int.from_bytes(raw[i : i + 2], "little") |
| 87 | + chars.append(chr(cp)) |
| 88 | + return "".join(chars[:count]) |
| 89 | + else: |
| 90 | + # UTF-32 mode: four bytes per character, clamped to valid range |
| 91 | + raw = self._consume_front(count * 4) |
| 92 | + chars = [] |
| 93 | + for i in range(0, len(raw) - 3, 4): |
| 94 | + cp = int.from_bytes(raw[i : i + 4], "little") & 0x10FFFF |
| 95 | + try: |
| 96 | + chars.append(chr(cp)) |
| 97 | + except (ValueError, OverflowError): |
| 98 | + chars.append(" ") |
| 99 | + return "".join(chars[:count]) |
| 100 | + |
| 101 | + def ConsumeUnicodeNoSurrogates(self, count): |
| 102 | + count = max(0, int(count)) |
| 103 | + if count == 0 or self.remaining_bytes() == 0: |
| 104 | + return "" |
| 105 | + mode_byte = self._consume_front(1) |
| 106 | + mode = mode_byte[0] if mode_byte else 0 |
| 107 | + if mode == 1: |
| 108 | + raw = self._consume_front(count) |
| 109 | + return "".join(chr(b & 0x7F) for b in raw) |
| 110 | + elif mode == 2: |
| 111 | + raw = self._consume_front(count * 2) |
| 112 | + chars = [] |
| 113 | + for i in range(0, len(raw) - 1, 2): |
| 114 | + cp = int.from_bytes(raw[i : i + 2], "little") |
| 115 | + if 0xD800 <= cp <= 0xDFFF: |
| 116 | + cp -= 0xD800 |
| 117 | + chars.append(chr(cp)) |
| 118 | + return "".join(chars[:count]) |
| 119 | + else: |
| 120 | + raw = self._consume_front(count * 4) |
| 121 | + chars = [] |
| 122 | + for i in range(0, len(raw) - 3, 4): |
| 123 | + cp = int.from_bytes(raw[i : i + 4], "little") & 0x10FFFF |
| 124 | + if 0xD800 <= cp <= 0xDFFF: |
| 125 | + cp -= 0xD800 |
| 126 | + try: |
| 127 | + chars.append(chr(cp)) |
| 128 | + except (ValueError, OverflowError): |
| 129 | + chars.append(" ") |
| 130 | + return "".join(chars[:count]) |
| 131 | + |
| 132 | + def ConsumeString(self, count): |
| 133 | + return self.ConsumeUnicode(count) |
| 134 | + |
| 135 | + # -- Back-consuming methods (ConsumeIntInRange, ConsumeBool, etc.) -- |
| 136 | + |
| 137 | + def _consume_back(self, n): |
| 138 | + n = min(n, self.remaining_bytes()) |
| 139 | + result = self._data[self._back - n : self._back] |
| 140 | + self._back -= n |
| 141 | + return result |
| 142 | + |
| 143 | + def ConsumeIntInRange(self, lo, hi): |
| 144 | + lo, hi = int(lo), int(hi) |
| 145 | + if lo > hi: |
| 146 | + lo, hi = hi, lo |
| 147 | + if lo == hi: |
| 148 | + return lo |
| 149 | + rng = hi - lo |
| 150 | + # Match LLVM: consume ceil(bits_needed/8) bytes from back |
| 151 | + # LLVM loops while offset < sizeof(T)*8 && (range >> offset) > 0 |
| 152 | + nbytes = (rng.bit_length() + 7) // 8 |
| 153 | + raw = self._consume_back(nbytes) |
| 154 | + if not raw: |
| 155 | + return lo |
| 156 | + # LLVM reads bytes from back as big-endian accumulation: |
| 157 | + # result = (result << 8) | next_byte_from_back |
| 158 | + # which equals int.from_bytes(reversed_bytes, 'big') |
| 159 | + # But since _consume_back returns bytes in memory order and |
| 160 | + # int.from_bytes(raw, 'little') produces the same value, we use that. |
| 161 | + val = int.from_bytes(raw, "little") |
| 162 | + return lo + (val % (rng + 1)) |
| 163 | + |
| 164 | + # Alias for LLVM naming compatibility |
| 165 | + ConsumeIntegralInRange = ConsumeIntInRange |
| 166 | + |
| 167 | + def ConsumeBool(self): |
| 168 | + # Matches LLVM: 1 & ConsumeIntegral<uint8_t>() |
| 169 | + # ConsumeIntegral<uint8_t>() = ConsumeIntegralInRange(0, 255) |
| 170 | + return (self.ConsumeIntInRange(0, 255) & 1) == 1 |
| 171 | + |
| 172 | + def ConsumeProbability(self): |
| 173 | + # Matches LLVM: ConsumeIntegral<uint64_t>() / UINT64_MAX |
| 174 | + # ConsumeIntegral<uint64_t>() = ConsumeIntegralInRange(0, 2^64-1) |
| 175 | + # When range == UINT64_MAX, no modulo is applied (special case) |
| 176 | + raw = self._consume_back(8) |
| 177 | + if not raw: |
| 178 | + return 0.0 |
| 179 | + val = int.from_bytes(raw, "little") |
| 180 | + return val / float((1 << 64) - 1) |
| 181 | + |
| 182 | + def ConsumeFloatInRange(self, lo, hi): |
| 183 | + lo, hi = float(lo), float(hi) |
| 184 | + if lo > hi: |
| 185 | + lo, hi = hi, lo |
| 186 | + p = self.ConsumeProbability() |
| 187 | + return lo + (hi - lo) * p |
| 188 | + |
| 189 | + def PickValueInList(self, lst): |
| 190 | + if not lst: |
| 191 | + raise ValueError("list must not be empty") |
| 192 | + idx = self.ConsumeIntInRange(0, len(lst) - 1) |
| 193 | + return lst[idx] |
| 194 | + |
| 195 | + # -- List methods -- |
| 196 | + |
| 197 | + def ConsumeIntList(self, count, byte_count): |
| 198 | + count = max(0, int(count)) |
| 199 | + return [self.ConsumeInt(byte_count) for _ in range(count)] |
| 200 | + |
| 201 | + def ConsumeIntListInRange(self, count, lo, hi): |
| 202 | + count = max(0, int(count)) |
| 203 | + return [self.ConsumeIntInRange(lo, hi) for _ in range(count)] |
| 204 | + |
| 205 | + def ConsumeFloatList(self, count): |
| 206 | + count = max(0, int(count)) |
| 207 | + return [self.ConsumeFloat() for _ in range(count)] |
| 208 | + |
| 209 | + def ConsumeFloatListInRange(self, count, lo, hi): |
| 210 | + count = max(0, int(count)) |
| 211 | + return [self.ConsumeFloatInRange(lo, hi) for _ in range(count)] |
| 212 | + |
| 213 | + def ConsumeProbabilityList(self, count): |
| 214 | + count = max(0, int(count)) |
| 215 | + return [self.ConsumeProbability() for _ in range(count)] |
| 216 | + |
| 217 | + def ConsumeRegularFloatList(self, count): |
| 218 | + count = max(0, int(count)) |
| 219 | + return [self.ConsumeRegularFloat() for _ in range(count)] |
| 220 | + |
| 221 | + # -- Arbitrary value -- |
| 222 | + |
| 223 | + _ANY_TYPE_INT = 0 |
| 224 | + _ANY_TYPE_FLOAT = 1 |
| 225 | + _ANY_TYPE_BOOL = 2 |
| 226 | + _ANY_TYPE_BYTES = 3 |
| 227 | + _ANY_TYPE_STRING = 4 |
| 228 | + _ANY_TYPE_NONE = 5 |
| 229 | + |
| 230 | + def ConsumeRandomValue(self): |
| 231 | + """Return a value of a randomly chosen primitive type.""" |
| 232 | + t = self.ConsumeIntInRange(self._ANY_TYPE_INT, self._ANY_TYPE_NONE) |
| 233 | + if t == self._ANY_TYPE_INT: |
| 234 | + return self.ConsumeInt(4) |
| 235 | + elif t == self._ANY_TYPE_FLOAT: |
| 236 | + return self.ConsumeFloat() |
| 237 | + elif t == self._ANY_TYPE_BOOL: |
| 238 | + return self.ConsumeBool() |
| 239 | + elif t == self._ANY_TYPE_BYTES: |
| 240 | + return self.ConsumeBytes(self.ConsumeIntInRange(0, 64)) |
| 241 | + elif t == self._ANY_TYPE_STRING: |
| 242 | + return self.ConsumeUnicode(self.ConsumeIntInRange(0, 64)) |
| 243 | + else: |
| 244 | + return None |
0 commit comments