Skip to content

Commit 21af102

Browse files
Zac-HDclaude
andauthored
Permit either __aenter__ or __aexit__ without checkpoints (#444)
* Exempt async CM methods from ASYNC910/911 when partner checkpoints ASYNC910 and ASYNC911 no longer require every `__aenter__`/`__aexit__` to contain a checkpoint. Per Trio's documentation, an async context manager only needs one of entry/exit to act as a checkpoint. When a class defines both methods, the one without an `await` is exempt if its partner contains one. When a class defines only one of the two, the partner is charitably assumed to be inherited from a base class and to contain a checkpoint, so the defined method is also exempt. Closes #441 https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG * Tighten async CM exemption rules Two refinements in response to review feedback: - If an `__aenter__`/`__aexit__` method contains any checkpoint-like construct (`await`, `async with`, or `async for`), it must always checkpoint. We no longer exempt such methods even when the partner provides a checkpoint -- conditional checkpoints are still flagged. - Only charitably assume a missing partner is inherited (with a checkpoint) when the class actually inherits from something. Classes with no base classes are treated as flat, and methods that don't checkpoint are flagged. `metaclass=` and other keyword arguments do not count as inheriting, since they live in `ClassDef.keywords` rather than `ClassDef.bases`. https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG * Only flag __aenter__ when neither CM method checkpoints When both `__aenter__` and `__aexit__` are defined and neither contains a checkpoint, we used to flag (and autofix) both methods, which produced redundant `lowlevel.checkpoint()` calls -- only one is needed for the async context manager to checkpoint. Prefer to report and fix `__aenter__` in this case; `__aexit__` is exempted since adding a checkpoint to either satisfies the rule. https://claude.ai/code/session_014jAydKywq31Ew4fVYGJdiG --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 9a49703 commit 21af102

5 files changed

Lines changed: 364 additions & 0 deletions

File tree

docs/changelog.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Changelog
77
Unreleased
88
==========
99
- Autofix for :ref:`ASYNC910 <async910>` / :ref:`ASYNC911 <async911>` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 <async120>`); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) <https://github.com/python-trio/flake8-async/issues/403>`_
10+
- :ref:`ASYNC910 <async910>` and :ref:`ASYNC911 <async911>` now accept ``__aenter__`` / ``__aexit__`` methods when the partner method provides the checkpoint, or when only one of the two is defined on a class that inherits from another class (charitably assuming the partner is inherited and contains a checkpoint). `(issue #441) <https://github.com/python-trio/flake8-async/issues/441>`_
1011

1112
25.7.1
1213
======

flake8_async/visitors/visitor91x.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -465,6 +465,21 @@ def __init__(self, *args: Any, **kwargs: Any):
465465
# used to transfer new body between visit_FunctionDef and leave_FunctionDef
466466
self.new_body: cst.BaseSuite | None = None
467467

468+
# Tracks whether the current scope is a class body and, if so, which of
469+
# `__aenter__`/`__aexit__` are directly defined on it (values: True if
470+
# that method contains a checkpoint-like construct, False otherwise,
471+
# missing key if not defined). Used to exempt async context manager
472+
# methods from ASYNC910/911 when their partner method provides the
473+
# checkpoint, or when the partner is inherited from a base class.
474+
self.async_cm_class: dict[str, bool] | None = None
475+
# Whether the enclosing class has an explicit base class (other than
476+
# implicit `object`). We only assume a missing partner is inherited if
477+
# the class actually inherits from something.
478+
self.async_cm_class_has_bases = False
479+
# Set on entry to an exempt `__aenter__`/`__aexit__` so that
480+
# `error_91x` skips emitting ASYNC910/911.
481+
self.exempt_async_cm_method = False
482+
468483
def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
469484
if code is None:
470485
code = "ASYNC911" if self.has_yield else "ASYNC910"
@@ -532,6 +547,60 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
532547
self.suppress_imported_as.append("suppress")
533548
return
534549

550+
# Async context manager methods may legitimately skip checkpointing if the
551+
# partner method provides the checkpoint, or if the partner is inherited
552+
# from a base class (which we charitably assume contains a checkpoint).
553+
# See https://github.com/python-trio/flake8-async/issues/441.
554+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
555+
self.save_state(node, "async_cm_class", "async_cm_class_has_bases")
556+
defined: dict[str, bool] = {}
557+
checkpointy = (
558+
m.Await()
559+
| m.With(asynchronous=m.Asynchronous())
560+
| m.For(asynchronous=m.Asynchronous())
561+
)
562+
if isinstance(node.body, cst.IndentedBlock):
563+
for stmt in node.body.body:
564+
if (
565+
isinstance(stmt, cst.FunctionDef)
566+
and stmt.asynchronous is not None
567+
and stmt.name.value in ("__aenter__", "__aexit__")
568+
):
569+
defined[stmt.name.value] = bool(m.findall(stmt, checkpointy))
570+
self.async_cm_class = defined
571+
# Keyword args like `metaclass=` are in `node.keywords`, not `bases`.
572+
self.async_cm_class_has_bases = bool(node.bases)
573+
574+
def leave_ClassDef(
575+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
576+
) -> cst.ClassDef:
577+
self.restore_state(original_node)
578+
return updated_node
579+
580+
def _is_exempt_async_cm_method(self, node: cst.FunctionDef) -> bool:
581+
if self.async_cm_class is None:
582+
return False
583+
name = node.name.value
584+
if name not in ("__aenter__", "__aexit__"):
585+
return False
586+
if name not in self.async_cm_class:
587+
return False
588+
# A method that contains any checkpoint must always checkpoint: we
589+
# still check it normally so conditional checkpoints are flagged.
590+
if self.async_cm_class[name]:
591+
return False
592+
partner = "__aexit__" if name == "__aenter__" else "__aenter__"
593+
if partner in self.async_cm_class:
594+
# Partner is defined on the class; if it checkpoints, we're fine.
595+
if self.async_cm_class[partner]:
596+
return True
597+
# Neither method checkpoints -- to avoid double-flagging (and a
598+
# redundant autofix), we report and fix only `__aenter__`.
599+
return name == "__aexit__"
600+
# Partner is not defined on this class; only assume it is inherited
601+
# (and contains a checkpoint) if the class inherits from something.
602+
return self.async_cm_class_has_bases
603+
535604
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
536605
# `await` in default values happen in parent scope
537606
# we also know we don't ever modify parameters so we can ignore the return value
@@ -543,6 +612,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
543612
if func_has_decorator(node, "overload", "fixture") or func_empty_body(node):
544613
return False # subnodes can be ignored
545614

615+
is_exempt_cm = self._is_exempt_async_cm_method(node)
616+
546617
self.save_state(
547618
node,
548619
"has_yield",
@@ -557,6 +628,9 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
557628
"suppress_imported_as", # a copy is saved, but state is not reset
558629
"except_depth",
559630
"add_checkpoint_at_function_start",
631+
"async_cm_class",
632+
"async_cm_class_has_bases",
633+
"exempt_async_cm_method",
560634
copy=True,
561635
)
562636
self.uncheckpointed_statements = set()
@@ -567,6 +641,10 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
567641
self.taskgroup_has_start_soon = {}
568642
self.except_depth = 0
569643
self.add_checkpoint_at_function_start = False
644+
# Class-level context does not apply to nested scopes.
645+
self.async_cm_class = None
646+
self.async_cm_class_has_bases = False
647+
self.exempt_async_cm_method = is_exempt_cm
570648

571649
self.async_function = (
572650
node.asynchronous is not None
@@ -747,6 +825,12 @@ def error_91x(
747825
) -> bool:
748826
assert not isinstance(statement, ArtificialStatement), statement
749827

828+
# Exempt `__aenter__`/`__aexit__` when the partner method contains a
829+
# checkpoint, or when the partner is missing and charitably assumed
830+
# inherited.
831+
if self.exempt_async_cm_method:
832+
return False
833+
750834
if isinstance(node, cst.FunctionDef):
751835
msg = "exit"
752836
else:

tests/autofix_files/async910.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,123 @@ async def foo_nested_empty_async():
636636
async def bar(): ...
637637

638638
await foo()
639+
640+
641+
# Issue #441: async context manager methods may legitimately skip checkpointing
642+
# if the partner method provides the checkpoint, or if the partner is inherited.
643+
class ACM: # a dummy base to opt into the charitable-inheritance assumption
644+
pass
645+
646+
647+
class CtxWithSetup: # safe: __aenter__ checkpoints, __aexit__ can be fast
648+
async def __aenter__(self):
649+
await foo()
650+
651+
async def __aexit__(self, exc_type, exc, tb):
652+
print("fast exit")
653+
654+
655+
class CtxWithTeardown: # safe: __aexit__ checkpoints, __aenter__ can be fast
656+
async def __aenter__(self):
657+
print("fast setup")
658+
659+
async def __aexit__(self, exc_type, exc, tb):
660+
await foo()
661+
662+
663+
class CtxWithBothCheckpoint: # safe: both checkpoint
664+
async def __aenter__(self):
665+
await foo()
666+
667+
async def __aexit__(self, exc_type, exc, tb):
668+
await foo()
669+
670+
671+
# fmt: off
672+
class CtxNeitherCheckpoint:
673+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
674+
print("setup")
675+
await trio.lowlevel.checkpoint()
676+
677+
async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy
678+
print("teardown")
679+
# fmt: on
680+
681+
682+
# A method that contains any checkpoint is still required to always checkpoint.
683+
class CtxAenterConditionalAexitFast(ACM):
684+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
685+
if _:
686+
await foo()
687+
await trio.lowlevel.checkpoint()
688+
689+
async def __aexit__(self, *a):
690+
print("fast exit")
691+
692+
693+
# Only one method defined: charitably assume the other is inherited with a
694+
# checkpoint -- but only when the class inherits from something.
695+
class CtxOnlyAenterInherited(ACM): # safe: __aexit__ assumed inherited
696+
async def __aenter__(self):
697+
print("setup")
698+
699+
700+
class CtxOnlyAexitInherited(ACM): # safe: __aenter__ assumed inherited
701+
async def __aexit__(self, *a):
702+
print("teardown")
703+
704+
705+
# fmt: off
706+
class CtxOnlyAenter: # no base class -> don't assume inheritance
707+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
708+
print("setup")
709+
await trio.lowlevel.checkpoint()
710+
711+
712+
class CtxOnlyAexit: # no base class -> don't assume inheritance
713+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
714+
print("teardown")
715+
await trio.lowlevel.checkpoint()
716+
# fmt: on
717+
718+
719+
class CtxOnlyAenterWithCheckpoint: # safe
720+
async def __aenter__(self):
721+
await foo()
722+
723+
724+
class CtxOnlyAexitWithCheckpoint: # safe
725+
async def __aexit__(self, *a):
726+
await foo()
727+
728+
729+
# keyword-only bases (like `metaclass=`) don't count as inheriting.
730+
class Meta(type):
731+
pass
732+
733+
734+
class CtxMetaclassOnly(metaclass=Meta):
735+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
736+
print("setup")
737+
await trio.lowlevel.checkpoint()
738+
739+
740+
# a nested function named `__aenter__` inside another function is not a method
741+
def not_a_class():
742+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
743+
print("setup")
744+
await trio.lowlevel.checkpoint()
745+
746+
747+
# class nested inside a function still gets the exemption when it inherits
748+
def factory():
749+
class NestedCtx(ACM): # safe
750+
async def __aenter__(self):
751+
print("setup")
752+
753+
754+
# nested class; outer class has nothing relevant
755+
class Outer:
756+
class Inner(ACM): # safe: charitable inheritance for __aexit__
757+
async def __aenter__(self):
758+
print("setup")

tests/autofix_files/async910.py.diff

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,3 +223,48 @@
223223

224224

225225
async def foo_nested_empty_async():
226+
@@ x,6 x,7 @@
227+
class CtxNeitherCheckpoint:
228+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
229+
print("setup")
230+
+ await trio.lowlevel.checkpoint()
231+
232+
async def __aexit__(self, *a): # only __aenter__ is flagged to avoid redundancy
233+
print("teardown")
234+
@@ x,6 x,7 @@
235+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
236+
if _:
237+
await foo()
238+
+ await trio.lowlevel.checkpoint()
239+
240+
async def __aexit__(self, *a):
241+
print("fast exit")
242+
@@ x,11 x,13 @@
243+
class CtxOnlyAenter: # no base class -> don't assume inheritance
244+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
245+
print("setup")
246+
+ await trio.lowlevel.checkpoint()
247+
248+
249+
class CtxOnlyAexit: # no base class -> don't assume inheritance
250+
async def __aexit__(self, *a): # error: 4, "exit", Stmt("function definition", line)
251+
print("teardown")
252+
+ await trio.lowlevel.checkpoint()
253+
# fmt: on
254+
255+
256+
@@ x,12 x,14 @@
257+
class CtxMetaclassOnly(metaclass=Meta):
258+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
259+
print("setup")
260+
+ await trio.lowlevel.checkpoint()
261+
262+
263+
# a nested function named `__aenter__` inside another function is not a method
264+
def not_a_class():
265+
async def __aenter__(self): # error: 4, "exit", Stmt("function definition", line)
266+
print("setup")
267+
+ await trio.lowlevel.checkpoint()
268+
269+
270+
# class nested inside a function still gets the exemption when it inherits

0 commit comments

Comments
 (0)