Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ type Conn struct {
// transaction within Execute.
mu sync.RWMutex

// receiveMu serializes concurrent Receive and ReceiveIter calls to prevent
// races in multi-part message handling and the peek/allocate logic in
// conn_linux.go. It is separate from mu so that Send can proceed
// concurrently with Receive.
receiveMu sync.Mutex

// sock is the operating system-specific implementation of
// a netlink sockets connection.
sock Socket
Expand Down Expand Up @@ -226,10 +232,13 @@ func (c *Conn) lockedSend(m Message) (Message, error) {
//
// If any of the messages indicate a netlink error, that error will be returned.
func (c *Conn) Receive() ([]Message, error) {
// Wait for any concurrent calls to Execute and Receive to finish before
// proceeding.
c.mu.Lock()
defer c.mu.Unlock()
// Wait for any concurrent calls to Execute to finish before proceeding.
c.mu.RLock()
defer c.mu.RUnlock()

// Serialize concurrent Receive calls. See receiveMu for details.
c.receiveMu.Lock()
defer c.receiveMu.Unlock()

return c.lockedReceive()
}
Expand All @@ -242,8 +251,13 @@ func (c *Conn) Receive() ([]Message, error) {
// response is multi-part, the remaining messages will be discarded.
func (c *Conn) ReceiveIter() iter.Seq2[Message, error] {
return func(yield func(Message, error) bool) {
c.mu.Lock()
defer c.mu.Unlock()
// Wait for any concurrent calls to Execute to finish before proceeding.
c.mu.RLock()
defer c.mu.RUnlock()

// Serialize concurrent ReceiveIter calls. See receiveMu for details.
c.receiveMu.Lock()
defer c.receiveMu.Unlock()

for msg, err := range c.lockedReceiveIter() {
if err != nil {
Expand Down
49 changes: 49 additions & 0 deletions conn_linux_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,55 @@ func TestIntegrationConnConcurrentSerializeReceiveIter(t *testing.T) {
}
}

// TestIntegrationConnConcurrentSendWhileReceiving verifies that Send can
// proceed concurrently with a blocking Receive, as required by patterns like
// NFQUEUE where one goroutine reads packets and another sends verdicts.
func TestIntegrationConnConcurrentSendWhileReceiving(t *testing.T) {
t.Parallel()

c, err := netlink.Dial(unix.NETLINK_GENERIC, nil)
if err != nil {
t.Fatalf("failed to dial: %v", err)
}

// Detect deadlock: if Receive holds an exclusive lock, Send blocks until
// Receive returns, hanging the test indefinitely.
timer := time.AfterFunc(10*time.Second, func() {
panic("deadlock: Send blocked waiting for Receive to release its lock")
})
defer timer.Stop()

var wg sync.WaitGroup
wg.Add(1)
defer wg.Wait()
defer c.Close()

req := netlink.Message{
Header: netlink.Header{
Flags: netlink.Request | netlink.Acknowledge,
},
}

sigC := make(chan struct{})
go func() {
defer wg.Done()

// Wait for the main goroutine to enter Receive, then give it time to
// block in recvmsg before attempting Send.
<-sigC
time.Sleep(50 * time.Millisecond)

if _, err := c.Send(req); err != nil {
panicf("Send failed while Receive was blocking: %v", err)
}

c.Close()
}()

close(sigC)
c.Receive()
}

func TestReceiveIter(t *testing.T) {
t.Parallel()
c, err := netlink.Dial(unix.NETLINK_GENERIC, nil)
Expand Down
Loading