diff --git a/conn.go b/conn.go index 2cee852..123081e 100644 --- a/conn.go +++ b/conn.go @@ -32,6 +32,12 @@ type Conn struct { // transaction within Execute. mu sync.RWMutex + // receiveMu serializes concurrent Receive and ReceiveIter calls to prevent + // races in multi-part message handling and the peek/allocate logic in + // conn_linux.go. It is separate from mu so that Send can proceed + // concurrently with Receive. + receiveMu sync.Mutex + // sock is the operating system-specific implementation of // a netlink sockets connection. sock Socket @@ -226,10 +232,13 @@ func (c *Conn) lockedSend(m Message) (Message, error) { // // If any of the messages indicate a netlink error, that error will be returned. func (c *Conn) Receive() ([]Message, error) { - // Wait for any concurrent calls to Execute and Receive to finish before - // proceeding. - c.mu.Lock() - defer c.mu.Unlock() + // Wait for any concurrent calls to Execute to finish before proceeding. + c.mu.RLock() + defer c.mu.RUnlock() + + // Serialize concurrent Receive calls. See receiveMu for details. + c.receiveMu.Lock() + defer c.receiveMu.Unlock() return c.lockedReceive() } @@ -242,8 +251,13 @@ func (c *Conn) Receive() ([]Message, error) { // response is multi-part, the remaining messages will be discarded. func (c *Conn) ReceiveIter() iter.Seq2[Message, error] { return func(yield func(Message, error) bool) { - c.mu.Lock() - defer c.mu.Unlock() + // Wait for any concurrent calls to Execute to finish before proceeding. + c.mu.RLock() + defer c.mu.RUnlock() + + // Serialize concurrent ReceiveIter calls. See receiveMu for details. + c.receiveMu.Lock() + defer c.receiveMu.Unlock() for msg, err := range c.lockedReceiveIter() { if err != nil { diff --git a/conn_linux_integration_test.go b/conn_linux_integration_test.go index 172edf7..017c8e8 100644 --- a/conn_linux_integration_test.go +++ b/conn_linux_integration_test.go @@ -6,6 +6,7 @@ import ( "encoding/binary" "errors" "fmt" + "iter" "math/rand" "net" "os" @@ -18,6 +19,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/mdlayher/netlink" + "github.com/mdlayher/netlink/nltest" "golang.org/x/net/bpf" "golang.org/x/sys/unix" ) @@ -491,6 +493,73 @@ func TestIntegrationConnConcurrentSerializeReceiveIter(t *testing.T) { } } +// TestConnSendConcurrentWithReceive verifies that Send can proceed concurrently +// with a blocking Receive, as required by patterns like NFQUEUE where one +// goroutine reads packets and another sends verdicts. +func TestConnSendConcurrentWithReceive(t *testing.T) { + t.Parallel() + + sock := &blockingSocket{ + receivingC: make(chan struct{}), + doneC: make(chan struct{}), + } + + c := netlink.NewConn(sock, nltest.PID) + + var wg sync.WaitGroup + wg.Go(func() { c.Receive() }) + + // Block until ReceiveIter has entered its blocking state + select { + case <-sock.receivingC: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for Receive to enter blocking state") + } + + // Send must not block while Receive is holding Conn.receiveMu + sendDone := make(chan error, 1) + go func() { + _, err := c.Send(netlink.Message{}) + sendDone <- err + }() + + select { + case err := <-sendDone: + if err != nil { + t.Fatalf("Send failed: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("deadlock: Send blocked while Receive was running") + } + + c.Close() + wg.Wait() +} + +// blockingSocket is a netlink.Socket whose ReceiveIter blocks until Close. +// receivingC is closed when ReceiveIter first enters its blocking state, +// providing a deterministic synchronisation point for tests. +type blockingSocket struct { + receivingOnce sync.Once + receivingC chan struct{} + doneC chan struct{} + closeOnce sync.Once +} + +func (s *blockingSocket) Close() error { + s.closeOnce.Do(func() { close(s.doneC) }) + return nil +} +func (s *blockingSocket) Send(_ netlink.Message) error { return nil } +func (s *blockingSocket) SendMessages(_ []netlink.Message) error { return nil } +func (s *blockingSocket) Receive() ([]netlink.Message, error) { return nil, nil } +func (s *blockingSocket) ReceiveIter() iter.Seq2[netlink.Message, error] { + return func(_ func(netlink.Message, error) bool) { + s.receivingOnce.Do(func() { close(s.receivingC) }) + <-s.doneC + } +} + func TestReceiveIter(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil)