From 35168b78d395a161ca9e1ba4053108a5ac30f45e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 22 Apr 2026 11:54:41 +0100 Subject: [PATCH] Batch SCCs for parallel processing --- mypy/build.py | 104 +++++++++++++++++++++++++++--------- mypy/build_worker/worker.py | 44 ++++++++------- 2 files changed, 103 insertions(+), 45 deletions(-) diff --git a/mypy/build.py b/mypy/build.py index 204451e2fa4d..275131d9db31 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -74,14 +74,12 @@ read_bytes, read_int, read_int_list, - read_int_opt, read_str, read_str_list, read_str_opt, write_bytes, write_int, write_int_list, - write_int_opt, write_json_value, write_str, write_str_list, @@ -212,6 +210,10 @@ "https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules" ) +# Padding when estimating how much time it will take to process a file. This is to avoid +# situations where 100 empty __init__.py files cost less than 1 trivial module. +MIN_SIZE_HINT: Final = 256 + class SCC: """A simple class that represents a strongly connected component (import cycle).""" @@ -240,7 +242,7 @@ def __init__( self.direct_dependents: list[int] = [] # Rough estimate of how much time processing this SCC will take, this # is used for more efficient scheduling across multiple build workers. - self.size_hint: int = 0 + self.size_hint: int = MIN_SIZE_HINT # TODO: Get rid of BuildResult. We might as well return a BuildManager. @@ -450,7 +452,7 @@ def connect(wc: WorkerClient, data: bytes) -> None: if not worker.connected: continue try: - send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={})) + send(worker.conn, SccRequestMessage(scc_ids=[], import_errors={}, mod_data={})) except (OSError, IPCException): pass for worker in workers: @@ -968,6 +970,8 @@ def __init__( # Stale SCCs that are queued for processing. Each tuple contains SCC size hint, # SCC adding order (tie-breaker), and the SCC itself. self.scc_queue: list[tuple[int, int, SCC]] = [] + # Total size hint for SCCs currently in queue. + self.size_in_queue: int = 0 # SCCs that have been fully processed. self.done_sccs: set[int] = set() # Parallel build workers, list is empty for in-process type-checking. @@ -1097,6 +1101,9 @@ def parse_parallel(self, sequential_states: list[State], parallel_states: list[S state.semantic_analysis_pass1() self.ast_cache[state.id] = (state.tree, state.early_errors) self.modules[state.id] = state.tree + assert state.tree is not None + if state.tree.raw_data is not None: + state.size_hint = len(state.tree.raw_data.defs) + MIN_SIZE_HINT state.check_blockers() state.setup_errors() @@ -1333,7 +1340,11 @@ def receive_worker_message(self, idx: int) -> ReadBuffer: try: return receive(self.workers[idx].conn) except OSError as exc: - exit_code = self.workers[idx].proc.poll() + try: + # Give worker process a chance to actually terminate before reporting. + exit_code = self.workers[idx].proc.wait(timeout=WORKER_SHUTDOWN_TIMEOUT) + except TimeoutError: + exit_code = None exit_status = f"exit code {exit_code}" if exit_code is not None else "still running" raise OSError( f"Worker {idx} disconnected before sending data ({exit_status})" @@ -1346,16 +1357,49 @@ def submit(self, graph: Graph, sccs: list[SCC]) -> None: else: self.scc_queue.extend([(0, 0, scc) for scc in sccs]) + def get_scc_batch(self, max_size_in_batch: int) -> list[SCC]: + """Get a batch of SCCs from queue to submit to a worker. + + We batch SCCs to avoid communication overhead, but to avoid + long poles, we limit fraction of work per worker. + """ + batch: list[SCC] = [] + size_in_batch = 0 + while self.scc_queue and ( + # Three notes keep in mind here: + # * Heap key is *negative* size (so that larger SCCs appear first). + # * Each batch must have at least one item. + # * Adding another SCC to batch should not exceed maximum allowed size. + size_in_batch - self.scc_queue[0][0] <= max_size_in_batch + or not batch + ): + size_key, _, scc = heappop(self.scc_queue) + size_in_batch -= size_key + self.size_in_queue += size_key + batch.append(scc) + return batch + + def max_batch_size(self) -> int: + batch_frac = 1 / len(self.workers) + if sys.platform == "linux": + # Linux is good with socket roundtrip latency, so we can use + # more fine-grained batches. + batch_frac /= 2 + return int(self.size_in_queue * batch_frac) + def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None: if sccs is not None: for scc in sccs: heappush(self.scc_queue, (-scc.size_hint, self.queue_order, scc)) + self.size_in_queue += scc.size_hint self.queue_order += 1 + max_size_in_batch = self.max_batch_size() while self.scc_queue and self.free_workers: idx = self.free_workers.pop() - _, _, scc = heappop(self.scc_queue) + scc_batch = self.get_scc_batch(max_size_in_batch) import_errors = { mod_id: self.errors.recorded[path] + for scc in scc_batch for mod_id in scc.mod_ids if (path := graph[mod_id].xpath) in self.errors.recorded } @@ -1363,7 +1407,7 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None send( self.workers[idx].conn, SccRequestMessage( - scc_id=scc.id, + scc_ids=[scc.id for scc in scc_batch], import_errors=import_errors, mod_data={ mod_id: ( @@ -1373,11 +1417,12 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None graph[mod_id].suppressed_deps_opts(), tree.raw_data if (tree := graph[mod_id].tree) else None, ) + for scc in scc_batch for mod_id in scc.mod_ids }, ), ) - self.add_stats(scc_send_time=time.time() - t0) + self.add_stats(scc_requests_sent=1, scc_send_time=time.time() - t0) def wait_for_done(self, graph: Graph) -> tuple[list[SCC], bool, dict[str, ModuleResult]]: """Wait for a stale SCC processing to finish. @@ -1414,13 +1459,12 @@ def wait_for_done_workers( if not data.is_interface: # Mark worker as free after it finished checking implementation. self.free_workers.add(idx) - scc_id = data.scc_id if data.blocker is not None: raise data.blocker assert data.result is not None results.update(data.result) if data.is_interface: - done_sccs.append(self.scc_by_id[scc_id]) + done_sccs.extend([self.scc_by_id[scc_id] for scc_id in data.scc_ids]) self.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1) self.submit_to_workers(graph) # advance after some workers are free. return ( @@ -2727,7 +2771,7 @@ def new_state( # import pkg.mod if exist_removed_submodules(dependencies, manager): state.needs_parse = True # Same as above, the current state is stale anyway. - state.size_hint = meta.size + state.size_hint = meta.size + MIN_SIZE_HINT else: # When doing a fine-grained cache load, pretend we only # know about modules that have cache information and defer @@ -3092,7 +3136,7 @@ def get_source(self) -> str: self.parse_inline_configuration(source) self.check_for_invalid_options() - self.size_hint = len(source) + self.size_hint = len(source) + MIN_SIZE_HINT self.time_spent_us += time_spent_us(t0) return source @@ -3157,6 +3201,14 @@ def parse_file(self, *, temporary: bool = False, raw_data: FileRawData | None = self.check_blockers() manager.ast_cache[self.id] = (self.tree, self.early_errors) + assert self.tree is not None + if self.tree.raw_data is not None: + # Size of serialized tree is a better proxy for file complexity than + # file size, so we use that when possible. Note that we rely on lucky + # coincidence that serialized tree size has same order of magnitude as + # file size, so we don't need any normalization factor in situations + # where parsed and cached files are mixed. + self.size_hint = len(self.tree.raw_data.defs) + MIN_SIZE_HINT self.setup_errors() def setup_errors(self) -> None: @@ -5054,26 +5106,26 @@ def write(self, buf: WriteBuffer) -> None: class SccRequestMessage(IPCMessage): """ - A message representing a request to type check an SCC. + A message representing a request to type check a batch of SCCs. - If scc_id is None, then it means that the coordinator requested a shutdown. + If scc_ids is empty, then it means that the coordinator requested a shutdown. """ def __init__( self, *, - scc_id: int | None, + scc_ids: list[int], import_errors: dict[str, list[ErrorInfo]], mod_data: dict[str, tuple[bytes, FileRawData | None]], ) -> None: - self.scc_id = scc_id + self.scc_ids = scc_ids self.import_errors = import_errors self.mod_data = mod_data @classmethod def read(cls, buf: ReadBuffer) -> SccRequestMessage: return SccRequestMessage( - scc_id=read_int_opt(buf), + scc_ids=read_int_list(buf), import_errors={ read_str(buf): [ErrorInfo.read(buf) for _ in range(read_int_bare(buf))] for _ in range(read_int_bare(buf)) @@ -5089,7 +5141,7 @@ def read(cls, buf: ReadBuffer) -> SccRequestMessage: def write(self, buf: WriteBuffer) -> None: write_tag(buf, SCC_REQUEST_MESSAGE) - write_int_opt(buf, self.scc_id) + write_int_list(buf, self.scc_ids) write_int_bare(buf, len(self.import_errors)) for path, errors in self.import_errors.items(): write_str(buf, path) @@ -5130,17 +5182,17 @@ def write(self, buf: WriteBuffer) -> None: class SccResponseMessage(IPCMessage): """ - A message representing a result of type checking an SCC. + A message representing a result of type checking a batch of SCCs. Only one of `result` or `blocker` can be non-None. The latter means there was - a blocking error while type checking the SCC. The `is_interface` flag indicates + a blocking error while type checking the SCCs. The `is_interface` flag indicates whether this is a result for interface or implementation phase of type-checking. """ def __init__( self, *, - scc_id: int, + scc_ids: list[int], is_interface: bool, result: dict[str, ModuleResult] | None = None, blocker: CompileError | None = None, @@ -5149,26 +5201,26 @@ def __init__( assert blocker is None if blocker is not None: assert result is None - self.scc_id = scc_id + self.scc_ids = scc_ids self.is_interface = is_interface self.result = result self.blocker = blocker @classmethod def read(cls, buf: ReadBuffer) -> SccResponseMessage: - scc_id = read_int(buf) + scc_ids = read_int_list(buf) is_interface = read_bool(buf) tag = read_tag(buf) if tag == LITERAL_NONE: return SccResponseMessage( - scc_id=scc_id, + scc_ids=scc_ids, is_interface=is_interface, blocker=CompileError(read_str_list(buf), read_bool(buf), read_str_opt(buf)), ) else: assert tag == DICT_STR_GEN return SccResponseMessage( - scc_id=scc_id, + scc_ids=scc_ids, is_interface=is_interface, result={ read_str_bare(buf): ModuleResult.read(buf) for _ in range(read_int_bare(buf)) @@ -5177,7 +5229,7 @@ def read(cls, buf: ReadBuffer) -> SccResponseMessage: def write(self, buf: WriteBuffer) -> None: write_tag(buf, SCC_RESPONSE_MESSAGE) - write_int(buf, self.scc_id) + write_int_list(buf, self.scc_ids) write_bool(buf, self.is_interface) if self.result is None: assert self.blocker is not None diff --git a/mypy/build_worker/worker.py b/mypy/build_worker/worker.py index bd71ffc68357..529fd67516e6 100644 --- a/mypy/build_worker/worker.py +++ b/mypy/build_worker/worker.py @@ -32,7 +32,6 @@ from mypy import util from mypy.build import ( GRAPH_MESSAGE, - SCC, SCC_REQUEST_MESSAGE, SCCS_DATA_MESSAGE, SOURCES_DATA_MESSAGE, @@ -48,7 +47,7 @@ process_stale_scc_implementation, process_stale_scc_interface, ) -from mypy.cache import Tag, read_int_opt +from mypy.cache import Tag, read_int_list from mypy.defaults import RECURSION_LIMIT, WORKER_CONNECTION_TIMEOUT, WORKER_IDLE_TIMEOUT from mypy.errors import CompileError, ErrorInfo, Errors, report_internal_error from mypy.fscache import FileSystemCache @@ -127,7 +126,7 @@ def should_shutdown(buf: ReadBuffer, expected_tag: Tag) -> bool: """Check if the message is a shutdown request.""" tag = read_tag(buf) if tag == SCC_REQUEST_MESSAGE: - assert read_int_opt(buf) is None + assert not read_int_list(buf) return True assert tag == expected_tag, f"Unexpected tag: {tag}" return False @@ -181,8 +180,8 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: assert read_tag(buf) == SCC_REQUEST_MESSAGE scc_message = SccRequestMessage.read(buf) manager.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1) - scc_id = scc_message.scc_id - if scc_id is None: + scc_ids = scc_message.scc_ids + if not scc_ids: # This indicates a shutdown request. Add GC stats before exiting. gc_stats = gc.get_stats() manager.add_stats( @@ -191,17 +190,23 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: ) manager.dump_stats() break - scc = manager.scc_by_id[scc_id] + sccs = [manager.scc_by_id[scc_id] for scc_id in scc_ids] + mod_ids: list[str] = [] + for scc in sccs: + mod_ids.extend(scc.mod_ids) t0 = time.time() try: - load_states(scc, graph, manager, scc_message.import_errors, scc_message.mod_data) - result = process_stale_scc_interface( - graph, scc, manager, from_cache=graph_data.from_cache - ) - # We must commit after each SCC, otherwise we break --sqlite-cache. - manager.commit() + load_states(mod_ids, graph, manager, scc_message.import_errors, scc_message.mod_data) + result = [] + for scc in sccs: + scc_result = process_stale_scc_interface( + graph, scc, manager, from_cache=graph_data.from_cache + ) + result.extend(scc_result) + # We must commit after each SCC, otherwise we break --sqlite-cache. + manager.commit() except CompileError as blocker: - message = SccResponseMessage(scc_id=scc_id, is_interface=True, blocker=blocker) + message = SccResponseMessage(scc_ids=scc_ids, is_interface=True, blocker=blocker) timed_send(manager, server, message) else: mod_results = {} @@ -211,16 +216,17 @@ def serve(server: IPCServer, ctx: ServerContext) -> None: stale.append(id) mod_results[id] = mod_result meta_files.append(meta_file) - message = SccResponseMessage(scc_id=scc_id, is_interface=True, result=mod_results) + message = SccResponseMessage(scc_ids=scc_ids, is_interface=True, result=mod_results) timed_send(manager, server, message) try: + # We process implementations from all SCCs in the batch together. result = process_stale_scc_implementation(graph, stale, manager, meta_files) # Both phases write cache, so we should commit here as well. manager.commit() except CompileError as blocker: - message = SccResponseMessage(scc_id=scc_id, is_interface=False, blocker=blocker) + message = SccResponseMessage(scc_ids=scc_ids, is_interface=False, blocker=blocker) else: - message = SccResponseMessage(scc_id=scc_id, is_interface=False, result=result) + message = SccResponseMessage(scc_ids=scc_ids, is_interface=False, result=result) timed_send(manager, server, message) manager.add_stats(total_process_stale_time=time.time() - t0, stale_sccs_processed=1) @@ -232,7 +238,7 @@ def timed_send(manager: BuildManager, server: IPCServer, message: SccResponseMes def load_states( - scc: SCC, + mod_ids: list[str], graph: Graph, manager: BuildManager, import_errors: dict[str, list[ErrorInfo]], @@ -240,7 +246,7 @@ def load_states( ) -> None: """Re-create full state of an SCC as it would have been in coordinator.""" needs_parse = [] - for id in scc.mod_ids: + for id in mod_ids: state = graph[id] # Re-clone options since we don't send them, it is usually faster than deserializing. state.options = state.options.clone_for_module(state.id) @@ -254,7 +260,7 @@ def load_states( # Perform actual parsing in parallel (but we don't need to compute dependencies). if needs_parse: manager.parse_all(needs_parse, post_parse=False) - for id in scc.mod_ids: + for id in mod_ids: state = graph[id] assert state.tree is not None import_lines = {imp.line for imp in state.tree.imports}