Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 77 additions & 26 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,12 @@
read_bytes,
read_int,
read_int_list,
read_int_opt,
read_str,
read_str_list,
read_str_opt,
write_bytes,
write_int,
write_int_list,
write_int_opt,
write_json_value,
write_str,
write_str_list,
Expand Down Expand Up @@ -212,6 +210,10 @@
"https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules"
)

# Padding when estimating how much time it will take to process a file. This is to avoid
# situations where 100 empty __init__.py files cost less than 1 trivial module.
MIN_SIZE_HINT: Final = 256


class SCC:
"""A simple class that represents a strongly connected component (import cycle)."""
Expand Down Expand Up @@ -240,7 +242,7 @@ def __init__(
self.direct_dependents: list[int] = []
# Rough estimate of how much time processing this SCC will take, this
# is used for more efficient scheduling across multiple build workers.
self.size_hint: int = 0
self.size_hint: int = MIN_SIZE_HINT


# TODO: Get rid of BuildResult. We might as well return a BuildManager.
Expand Down Expand Up @@ -450,7 +452,7 @@ def connect(wc: WorkerClient, data: bytes) -> None:
if not worker.connected:
continue
try:
send(worker.conn, SccRequestMessage(scc_id=None, import_errors={}, mod_data={}))
send(worker.conn, SccRequestMessage(scc_ids=[], import_errors={}, mod_data={}))
except (OSError, IPCException):
pass
for worker in workers:
Expand Down Expand Up @@ -968,6 +970,8 @@ def __init__(
# Stale SCCs that are queued for processing. Each tuple contains SCC size hint,
# SCC adding order (tie-breaker), and the SCC itself.
self.scc_queue: list[tuple[int, int, SCC]] = []
# Total size hint for SCCs currently in queue.
self.size_in_queue: int = 0
# SCCs that have been fully processed.
self.done_sccs: set[int] = set()
# Parallel build workers, list is empty for in-process type-checking.
Expand Down Expand Up @@ -1081,6 +1085,8 @@ def parse_parallel(self, sequential_states: list[State], parallel_states: list[S
state.semantic_analysis_pass1()
self.ast_cache[state.id] = (state.tree, state.early_errors, state.source_hash)
self.modules[state.id] = state.tree
if state.tree.raw_data is not None:
state.size_hint = len(state.tree.raw_data.defs) + MIN_SIZE_HINT
state.check_blockers()
state.setup_errors()

Expand Down Expand Up @@ -1362,7 +1368,11 @@ def receive_worker_message(self, idx: int) -> ReadBuffer:
try:
return receive(self.workers[idx].conn)
except OSError as exc:
exit_code = self.workers[idx].proc.poll()
try:
# Give worker process a chance to actually terminate before reporting.
exit_code = self.workers[idx].proc.wait(timeout=WORKER_SHUTDOWN_TIMEOUT)
except TimeoutError:
exit_code = None
exit_status = f"exit code {exit_code}" if exit_code is not None else "still running"
raise OSError(
f"Worker {idx} disconnected before sending data ({exit_status})"
Expand All @@ -1375,24 +1385,57 @@ def submit(self, graph: Graph, sccs: list[SCC]) -> None:
else:
self.scc_queue.extend([(0, 0, scc) for scc in sccs])

def get_scc_batch(self, max_size_in_batch: int) -> list[SCC]:
"""Get a batch of SCCs from queue to submit to a worker.

We batch SCCs to avoid communication overhead, but to avoid
long poles, we limit fraction of work per worker.
"""
batch: list[SCC] = []
size_in_batch = 0
while self.scc_queue and (
# Three notes keep in mind here:
# * Heap key is *negative* size (so that larger SCCs appear first).
# * Each batch must have at least one item.
# * Adding another SCC to batch should not exceed maximum allowed size.
size_in_batch - self.scc_queue[0][0] <= max_size_in_batch
or not batch
):
size_key, _, scc = heappop(self.scc_queue)
size_in_batch -= size_key
self.size_in_queue += size_key
batch.append(scc)
return batch

def max_batch_size(self) -> int:
batch_frac = 1 / len(self.workers)
if sys.platform == "linux":
# Linux is good with socket roundtrip latency, so we can use
# more fine-grained batches.
batch_frac /= 2
return int(self.size_in_queue * batch_frac)

def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None:
if sccs is not None:
for scc in sccs:
heappush(self.scc_queue, (-scc.size_hint, self.queue_order, scc))
self.size_in_queue += scc.size_hint
self.queue_order += 1
max_size_in_batch = self.max_batch_size()
while self.scc_queue and self.free_workers:
idx = self.free_workers.pop()
_, _, scc = heappop(self.scc_queue)
scc_batch = self.get_scc_batch(max_size_in_batch)
import_errors = {
mod_id: self.errors.recorded[path]
for scc in scc_batch
for mod_id in scc.mod_ids
if (path := graph[mod_id].xpath) in self.errors.recorded
}
t0 = time.time()
send(
self.workers[idx].conn,
SccRequestMessage(
scc_id=scc.id,
scc_ids=[scc.id for scc in scc_batch],
import_errors=import_errors,
mod_data={
mod_id: (
Expand All @@ -1402,11 +1445,12 @@ def submit_to_workers(self, graph: Graph, sccs: list[SCC] | None = None) -> None
graph[mod_id].suppressed_deps_opts(),
tree.raw_data if (tree := graph[mod_id].tree) else None,
)
for scc in scc_batch
for mod_id in scc.mod_ids
},
),
)
self.add_stats(scc_send_time=time.time() - t0)
self.add_stats(scc_requests_sent=1, scc_send_time=time.time() - t0)

def wait_for_done(self, graph: Graph) -> tuple[list[SCC], bool, dict[str, ModuleResult]]:
"""Wait for a stale SCC processing to finish.
Expand Down Expand Up @@ -1443,13 +1487,12 @@ def wait_for_done_workers(
if not data.is_interface:
# Mark worker as free after it finished checking implementation.
self.free_workers.add(idx)
scc_id = data.scc_id
if data.blocker is not None:
raise data.blocker
assert data.result is not None
results.update(data.result)
if data.is_interface:
done_sccs.append(self.scc_by_id[scc_id])
done_sccs.extend([self.scc_by_id[scc_id] for scc_id in data.scc_ids])
self.add_stats(scc_wait_time=t1 - t0, scc_receive_time=time.time() - t1)
self.submit_to_workers(graph) # advance after some workers are free.
return (
Expand Down Expand Up @@ -2756,7 +2799,7 @@ def new_state(
# import pkg.mod
if exist_removed_submodules(dependencies, manager):
state.needs_parse = True # Same as above, the current state is stale anyway.
state.size_hint = meta.size
state.size_hint = meta.size + MIN_SIZE_HINT
else:
# When doing a fine-grained cache load, pretend we only
# know about modules that have cache information and defer
Expand Down Expand Up @@ -3120,7 +3163,7 @@ def get_source(self) -> str:

self.parse_inline_configuration(source)

self.size_hint = len(source)
self.size_hint = len(source) + MIN_SIZE_HINT
self.time_spent_us += time_spent_us(t0)
return source

Expand Down Expand Up @@ -3195,6 +3238,14 @@ def parse_file(self, *, temporary: bool = False, raw_data: FileRawData | None =
self.check_blockers()

manager.ast_cache[self.id] = (self.tree, self.early_errors, self.source_hash)
assert self.tree is not None
if self.tree.raw_data is not None:
# Size of serialized tree is a better proxy for file complexity than
# file size, so we use that when possible. Note that we rely on lucky
# coincidence that serialized tree size has same order of magnitude as
# file size, so we don't need any normalization factor in situations
# where parsed and cached files are mixed.
self.size_hint = len(self.tree.raw_data.defs) + MIN_SIZE_HINT
self.setup_errors()

def setup_errors(self) -> None:
Expand Down Expand Up @@ -5084,26 +5135,26 @@ def write(self, buf: WriteBuffer) -> None:

class SccRequestMessage(IPCMessage):
"""
A message representing a request to type check an SCC.
A message representing a request to type check a batch of SCCs.

If scc_id is None, then it means that the coordinator requested a shutdown.
If scc_ids is empty, then it means that the coordinator requested a shutdown.
"""

def __init__(
self,
*,
scc_id: int | None,
scc_ids: list[int],
import_errors: dict[str, list[ErrorInfo]],
mod_data: dict[str, tuple[bytes, FileRawData | None]],
) -> None:
self.scc_id = scc_id
self.scc_ids = scc_ids
self.import_errors = import_errors
self.mod_data = mod_data

@classmethod
def read(cls, buf: ReadBuffer) -> SccRequestMessage:
return SccRequestMessage(
scc_id=read_int_opt(buf),
scc_ids=read_int_list(buf),
import_errors={
read_str(buf): [ErrorInfo.read(buf) for _ in range(read_int_bare(buf))]
for _ in range(read_int_bare(buf))
Expand All @@ -5119,7 +5170,7 @@ def read(cls, buf: ReadBuffer) -> SccRequestMessage:

def write(self, buf: WriteBuffer) -> None:
write_tag(buf, SCC_REQUEST_MESSAGE)
write_int_opt(buf, self.scc_id)
write_int_list(buf, self.scc_ids)
write_int_bare(buf, len(self.import_errors))
for path, errors in self.import_errors.items():
write_str(buf, path)
Expand Down Expand Up @@ -5160,17 +5211,17 @@ def write(self, buf: WriteBuffer) -> None:

class SccResponseMessage(IPCMessage):
"""
A message representing a result of type checking an SCC.
A message representing a result of type checking a batch of SCCs.

Only one of `result` or `blocker` can be non-None. The latter means there was
a blocking error while type checking the SCC. The `is_interface` flag indicates
a blocking error while type checking the SCCs. The `is_interface` flag indicates
whether this is a result for interface or implementation phase of type-checking.
"""

def __init__(
self,
*,
scc_id: int,
scc_ids: list[int],
is_interface: bool,
result: dict[str, ModuleResult] | None = None,
blocker: CompileError | None = None,
Expand All @@ -5179,26 +5230,26 @@ def __init__(
assert blocker is None
if blocker is not None:
assert result is None
self.scc_id = scc_id
self.scc_ids = scc_ids
self.is_interface = is_interface
self.result = result
self.blocker = blocker

@classmethod
def read(cls, buf: ReadBuffer) -> SccResponseMessage:
scc_id = read_int(buf)
scc_ids = read_int_list(buf)
is_interface = read_bool(buf)
tag = read_tag(buf)
if tag == LITERAL_NONE:
return SccResponseMessage(
scc_id=scc_id,
scc_ids=scc_ids,
is_interface=is_interface,
blocker=CompileError(read_str_list(buf), read_bool(buf), read_str_opt(buf)),
)
else:
assert tag == DICT_STR_GEN
return SccResponseMessage(
scc_id=scc_id,
scc_ids=scc_ids,
is_interface=is_interface,
result={
read_str_bare(buf): ModuleResult.read(buf) for _ in range(read_int_bare(buf))
Expand All @@ -5207,7 +5258,7 @@ def read(cls, buf: ReadBuffer) -> SccResponseMessage:

def write(self, buf: WriteBuffer) -> None:
write_tag(buf, SCC_RESPONSE_MESSAGE)
write_int(buf, self.scc_id)
write_int_list(buf, self.scc_ids)
write_bool(buf, self.is_interface)
if self.result is None:
assert self.blocker is not None
Expand Down
Loading
Loading