diff --git a/src/croc/croc.go b/src/croc/croc.go index 47066ea4d..8675431fa 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -2,6 +2,7 @@ package croc import ( "bytes" + "context" "crypto/rand" "crypto/sha256" "encoding/binary" @@ -147,6 +148,9 @@ type Client struct { quit chan bool finishedNum int numberOfTransferredFiles int + + // ctx.go for graceful shutdown + *stop } // Chunk contains information about the @@ -260,6 +264,7 @@ func New(ops Options) (c *Client, err error) { } c.mutex = &sync.Mutex{} + c.stop = newStop(context.Background()) return } @@ -530,7 +535,7 @@ func (c *Client) sendCollectFiles(filesInfo []FileInfo) (err error) { c.Options.HashAlgorithm = "xxhash" } - c.FilesToTransfer[i].Hash, err = utils.HashFile(fullPath, c.Options.HashAlgorithm, fileInfo.Size > 1e7) + c.FilesToTransfer[i].Hash, err = c.stop.hash(fullPath, c.Options.HashAlgorithm, fileInfo.Size > 1e7) log.Debugf("hashed %s to %x using %s", fullPath, c.FilesToTransfer[i].Hash, c.Options.HashAlgorithm) totalFilesSize += fileInfo.Size if err != nil { @@ -578,7 +583,12 @@ func (c *Client) setupLocalRelay() { if c.Options.Debug { debugString = "debug" } - err := tcp.Run(debugString, "127.0.0.1", portStr, c.Options.RelayPassword, strings.Join(c.Options.RelayPorts[1:], ",")) + err := c.stop.run( + debugString, + "127.0.0.1", + portStr, + c.Options.RelayPassword, + strings.Join(c.Options.RelayPorts[1:], ",")) if err != nil { panic(err) } @@ -600,6 +610,7 @@ func (c *Client) broadcastOnLocalNetwork(useipv6 bool) { Payload: []byte("croc" + c.Options.RelayPorts[0]), Delay: 20 * time.Millisecond, TimeLimit: timeLimit, + StopChan: c.stop.stopChan, } if useipv6 { settings.IPVersion = peerdiscovery.IPv6 @@ -629,11 +640,15 @@ func (c *Client) transferOverLocalRelay(errchan chan<- error) { } log.Debugf("local connection established: %+v", conn) for { + if err := c.ctxErr(); err != nil { + errchan <- err + return + } data, _ := conn.Receive() if bytes.Equal(data, handshakeRequest) { break } else if bytes.Equal(data, []byte{1}) { - log.Debug("got ping") + log.Trace("got ping") } else { log.Debugf("instead of handshake got: %s", data) } @@ -652,6 +667,8 @@ func (c *Client) transferOverLocalRelay(errchan chan<- error) { // Send will send the specified file func (c *Client) Send(filesInfo []FileInfo, emptyFoldersToTransfer []FileInfo, totalNumberFolders int) (err error) { + go c.stop.done() + defer c.stop.Cancel() c.EmptyFoldersToTransfer = emptyFoldersToTransfer c.TotalNumberFolders = totalNumberFolders c.TotalNumberOfContents = len(filesInfo) @@ -745,6 +762,10 @@ On the other computer run: var kB []byte B, _ := pake.InitCurve([]byte(c.Options.SharedSecret[5:]), 1, c.Options.Curve) for { + if err := c.ctxErr(); err != nil { + errchan <- err + return + } var dataMessage SimpleMessage log.Trace("waiting for bytes") data, errConn := conn.Receive() @@ -870,6 +891,8 @@ func showReceiveCommandQrCode(command string) { // Receive will receive a file func (c *Client) Receive() (err error) { + go c.stop.done() + defer c.stop.Cancel() fmt.Fprintf(os.Stderr, "connecting...") // recipient will look for peers first // and continue if it doesn't find any within 100 ms @@ -908,6 +931,7 @@ func (c *Client) Receive() (err error) { Delay: 20 * time.Millisecond, TimeLimit: 200 * time.Millisecond, MulticastAddress: c.Options.MulticastAddress, + StopChan: c.stop.stopChan, }) if err1 == nil && len(ipv4discoveries) > 0 { dmux.Lock() @@ -924,6 +948,7 @@ func (c *Client) Receive() (err error) { Delay: 20 * time.Millisecond, TimeLimit: 200 * time.Millisecond, IPVersion: peerdiscovery.IPv6, + StopChan: c.stop.stopChan, }) if err1 == nil && len(ipv6discoveries) > 0 { dmux.Lock() @@ -1128,6 +1153,8 @@ func (c *Client) Receive() (err error) { if c.numberOfTransferredFiles+len(c.EmptyFoldersToTransfer) == 0 { fmt.Fprintf(os.Stderr, "\rNo files transferred.\n") } + } else { + c.SendError() } return } @@ -1153,6 +1180,11 @@ func (c *Client) transfer() (err error) { // listen for incoming messages and process them for { + if e := c.ctxErr(); e != nil { + log.Tracef("transfer: %v", e) + err = e + break + } var data []byte var done bool data, err = c.conn[0].Receive() @@ -1173,6 +1205,10 @@ func (c *Client) transfer() (err error) { break } } + if err := c.ctxErr(); err != nil && c.SuccessfulTransfer { + c.SuccessfulTransfer = false + log.Tracef("SuccessfulTransfer: %v", err) + } // purge errors that come from successful transfer if c.SuccessfulTransfer { if err != nil { @@ -1223,6 +1259,9 @@ func (c *Client) transfer() (err error) { log.Debugf("error: %s", err.Error()) err = fmt.Errorf("room (secure channel) not ready, maybe peer disconnected") } + if err != nil { + c.SendError() + } return } @@ -1392,6 +1431,16 @@ func (c *Client) processMessageFileInfo(m message.Message) (done bool, err error } func (c *Client) processMessagePake(m message.Message) (err error) { + defer func() { + if r := recover(); r != nil { + if c.stop.gui { + log.Errorf("panic: %v", r) + c.stop.Cancel() + } else { + panic(r) + } + } + }() log.Debug("received pake payload") var salt []byte @@ -1541,6 +1590,8 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) { done, err = c.processExternalIP(m) case message.TypeError: // c.spinner.Stop() + log.Trace("Peer initiates interruption of my loops and goroutines") + c.stop.Cancel() fmt.Print("\r") err = fmt.Errorf("peer error: %s", m.Message) return true, err @@ -1603,6 +1654,15 @@ func (c *Client) processMessage(payload []byte) (done bool, err error) { func (c *Client) updateIfSenderChannelSecured() (err error) { if c.Options.IsSender && c.Step1ChannelSecured && !c.Step2FileInfoTransferred { + + if len(c.FilesToTransfer) == 1 && + c.FilesToTransfer[0].Name == filepath.Base(os.DevNull) && + c.FilesToTransfer[0].FolderSource == filepath.Dir(os.DevNull) { + log.Debug(os.DevNull) + c.Step2FileInfoTransferred = true + return + } + var b []byte machID, _ := machineid.ID() b, err = json.Marshal(SenderInfo{ @@ -2007,6 +2067,16 @@ func (c *Client) setBar() { } func (c *Client) receiveData(i int) { + defer func() { + if r := recover(); r != nil { + if c.stop.gui { + log.Errorf("panic: %v", r) + c.stop.Cancel() + } else { + panic(r) + } + } + }() log.Tracef("%d receiving data", i) for { data, err := c.conn[i+1].Receive() @@ -2036,6 +2106,28 @@ func (c *Client) receiveData(i int) { positionInt64 := int64(position) c.mutex.Lock() + if c.CurrentFileIsClosed || c.CurrentFile == nil { + c.mutex.Unlock() + log.Tracef("was closed %d", i) + return + } + if err := c.ctxErr(); err != nil { + c.CurrentFileIsClosed = true + defer c.mutex.Unlock() + log.Tracef("stopping: %v", err) + if err := c.CurrentFile.Close(); err != nil { + log.Tracef("closing %s: %v", c.CurrentFile.Name(), err) + } else { + log.Tracef("Successful closing %s", c.CurrentFile.Name()) + } + log.Tracef("sending close-sender") + if sendErr := message.Send(c.conn[0], c.Key, message.Message{ + Type: message.TypeCloseSender, + }); sendErr != nil { + log.Tracef("sending close-sender: %v", sendErr) + } + return + } _, err = c.CurrentFile.WriteAt(data[8:], positionInt64) if err != nil { panic(err) @@ -2075,6 +2167,14 @@ func (c *Client) receiveData(i int) { func (c *Client) sendData(i int) { defer func() { + if r := recover(); r != nil { + if c.stop.gui { + log.Errorf("panic: %v", r) + c.stop.Cancel() + } else { + panic(r) + } + } log.Debugf("finished with %d", i) c.numfinished++ if c.numfinished == len(c.Options.RelayPorts) { @@ -2089,6 +2189,10 @@ func (c *Client) sendData(i int) { pos := uint64(0) curi := float64(0) for { + if err := c.ctxErr(); err != nil { + log.Tracef("stopping send %d: %v", i, err) + return + } // Read file var n int var errRead error diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go index ef84567df..762d4d6e9 100644 --- a/src/croc/croc_test.go +++ b/src/croc/croc_test.go @@ -1,6 +1,9 @@ package croc import ( + "context" + "fmt" + "math/rand" "os" "path" "path/filepath" @@ -426,3 +429,793 @@ func TestCleanUp(t *testing.T) { } } } + +func hashed(c *Client) bool { + if len(c.FilesToTransfer) == 0 { + return false + } + for _, file := range c.FilesToTransfer { + if len(file.Hash) == 0 { + return false + } + } + return true +} + +func waitHashed(sender *Client) (err error) { + err = fmt.Errorf("not hashed") + for i := 0; i < 300; i++ { // Max 3 seconds + if hashed(sender) { + time.Sleep(100 * time.Millisecond) + return nil + } + time.Sleep(10 * time.Millisecond) + } + return +} + +func createTestFile(t *testing.T, size int) (string, func()) { + tempFile, err := os.CreateTemp("", "test-*.dat") + if err != nil { + t.Fatal(err) + } + + data := make([]byte, size) + for i := 0; i < size; i++ { + data[i] = byte(i % 256) + } + + if _, err := tempFile.Write(data); err != nil { + tempFile.Close() + os.Remove(tempFile.Name()) + t.Fatal(err) + } + + if err := tempFile.Close(); err != nil { + os.Remove(tempFile.Name()) + t.Fatal(err) + } + + return tempFile.Name(), func() { + os.Remove(tempFile.Name()) + } +} + +func TestBase(t *testing.T) { + tempFile, cleanup := createTestFile(t, 1024*1024) // 1 МБ + defer cleanup() + receivedFile := filepath.Base(tempFile) + defer os.Remove(receivedFile) + + go tcp.Run("debug", "127.0.0.1", "8286", "pass123", "8287") + time.Sleep(200 * time.Millisecond) + go tcp.Run("debug", "127.0.0.1", "8287", "pass123") + time.Sleep(200 * time.Millisecond) + + uniqueSecret := fmt.Sprintf("test-%d-%d", time.Now().UnixNano(), rand.Intn(10000)) + + sender, err := New(Options{ + IsSender: true, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8286", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + GitIgnore: false, + }) + if err != nil { + t.Fatalf("Create sender failed: %v", err) + } + + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tempFile}, false, false, []string{}) + if errGet != nil { + t.Fatalf("Get file info failed: %v", errGet) + } + + receiver, err := New(Options{ + IsSender: false, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8286", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + }) + if err != nil { + t.Fatalf("Create receiver failed: %v", err) + } + + fatalErr := make(chan error, 1) + + failTest := func(err error) { + select { + case fatalErr <- err: + default: + } + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + log.Warn("Send") + if err := sender.Send(filesInfo, emptyFolders, totalNumberFolders); err != nil { + failTest(fmt.Errorf("Send failed: %w", err)) + } + }() + + go func() { + defer wg.Done() + + if err := waitHashed(sender); err != nil { + failTest(fmt.Errorf("waitHashed failed: %w", err)) + return + } + + log.Warn("Receive") + if err := receiver.Receive(); err != nil { + failTest(fmt.Errorf("Receive failed: %w", err)) + } + }() + + go func() { + for i := 0; i < 3000; i++ { + if sender.Step1ChannelSecured && receiver.Step1ChannelSecured { + time.Sleep(time.Millisecond) + if sender.Step2FileInfoTransferred && receiver.Step2FileInfoTransferred { + log.Warn("Step2FileInfoTransferred reached") + return + } + log.Warn("Step1ChannelSecured reached") + } + time.Sleep(time.Millisecond) + } + }() + + done := make(chan bool, 1) + go func() { + wg.Wait() + done <- true + }() + + select { + case err := <-fatalErr: + t.Fatal(err) + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Test timeout after 5 seconds") + } +} + +func TestCtx(t *testing.T) { + tempFile, cleanup := createTestFile(t, 1024*1024) // 1 МБ + defer cleanup() + receivedFile := filepath.Base(tempFile) + defer os.Remove(receivedFile) + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8288", "pass123", "8289") + time.Sleep(200 * time.Millisecond) + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8289", "pass123") + time.Sleep(200 * time.Millisecond) + + uniqueSecret := fmt.Sprintf("test-%d-%d", time.Now().UnixNano(), rand.Intn(10000)) + + sender, err := NewCtx(ctx, Options{ + IsSender: true, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8288", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + GitIgnore: false, + }) + if err != nil { + t.Fatalf("Create sender failed: %v", err) + } + + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tempFile}, false, false, []string{}) + if errGet != nil { + t.Fatalf("Get file info failed: %v", errGet) + } + + receiver, err := NewCtx(ctx, Options{ + IsSender: false, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8288", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + }) + if err != nil { + t.Fatalf("Create receiver failed: %v", err) + } + + fatalErr := make(chan error, 1) + + failTest := func(err error) { + select { + case fatalErr <- err: + default: + } + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + log.Warn("Send") + if err := sender.Send(filesInfo, emptyFolders, totalNumberFolders); err != nil { + failTest(fmt.Errorf("Send failed: %w", err)) + } + }() + + go func() { + defer wg.Done() + + if err := waitHashed(sender); err != nil { + failTest(fmt.Errorf("waitHashed failed: %w", err)) + return + } + + log.Warn("Receive") + if err := receiver.Receive(); err != nil { + failTest(fmt.Errorf("Receive failed: %w", err)) + } + }() + + go func() { + for i := 0; i < 3000; i++ { + if sender.Step1ChannelSecured && receiver.Step1ChannelSecured { + time.Sleep(time.Millisecond) + if sender.Step2FileInfoTransferred && receiver.Step2FileInfoTransferred { + log.Warn("Step2FileInfoTransferred reached") + return + } + log.Warn("Step1ChannelSecured reached") + } + time.Sleep(time.Millisecond) + } + }() + + done := make(chan bool, 1) + go func() { + wg.Wait() + done <- true + }() + + select { + case err := <-fatalErr: + t.Fatal(err) + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Test timeout after 5 seconds") + } +} + +func validErrors(err error) bool { + s := err.Error() + return strings.Contains(s, "cancel") || + strings.Contains(s, "context") || + strings.Contains(s, "reset") || + strings.Contains(s, "broken") || + strings.Contains(s, "refusing") || + strings.Contains(s, "EOF") || + strings.Contains(s, "closed") +} + +func result(t *testing.T, err error) { + if err != nil { + if validErrors(err) { + t.Logf("Expected error during context cancellation: %v", err) + } else { + t.Errorf("Unexpected error during cancellation: %v", err) + } + return + } + t.Error("Transfer should have been interrupted by context cancellation") +} + +func TestAllCtx(t *testing.T) { + tempFile, cleanup := createTestFile(t, 1024*1024) // 1 МБ + defer cleanup() + receivedFile := filepath.Base(tempFile) + defer os.Remove(receivedFile) + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8290", "pass123", "8291") + time.Sleep(200 * time.Millisecond) + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8291", "pass123") + time.Sleep(200 * time.Millisecond) + + uniqueSecret := fmt.Sprintf("test-%d-%d", time.Now().UnixNano(), rand.Intn(10000)) + + sender, err := NewCtx(ctx, Options{ + IsSender: true, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8290", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + GitIgnore: false, + }) + if err != nil { + t.Fatalf("Create sender failed: %v", err) + } + + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tempFile}, false, false, []string{}) + if errGet != nil { + t.Fatalf("Get file info failed: %v", errGet) + } + + receiver, err := NewCtx(ctx, Options{ + IsSender: false, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8290", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + }) + if err != nil { + t.Fatalf("Create receiver failed: %v", err) + } + + fatalErr := make(chan error, 1) + + failTest := func(err error) { + select { + case fatalErr <- err: + default: + } + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + log.Warn("Send") + if err := sender.Send(filesInfo, emptyFolders, totalNumberFolders); err != nil { + failTest(fmt.Errorf("Send failed: %w", err)) + } + }() + + go func() { + defer wg.Done() + + if err := waitHashed(sender); err != nil { + failTest(fmt.Errorf("waitHashed failed: %w", err)) + return + } + + log.Warn("Receive") + if err := receiver.Receive(); err != nil { + failTest(fmt.Errorf("Receive failed: %w", err)) + } + }() + + go func() { + for i := 0; i < 3000; i++ { + if sender.Step1ChannelSecured && receiver.Step1ChannelSecured { + time.Sleep(time.Millisecond) + if sender.Step2FileInfoTransferred && receiver.Step2FileInfoTransferred { + log.Warn("Step2FileInfoTransferred reached") + cancel() + return + } + log.Warn("Step1ChannelSecured reached") + } + time.Sleep(time.Millisecond) + } + }() + + done := make(chan bool, 1) + go func() { + wg.Wait() + done <- true + }() + + select { + case err := <-fatalErr: + result(t, err) + case <-done: + t.Error("Transfer should have been interrupted by context cancellation") + case <-time.After(5 * time.Second): + t.Fatal("Test timeout after 5 seconds") + } +} + +func TestSendCtx(t *testing.T) { + tempFile, cleanup := createTestFile(t, 1024*1024) // 1 МБ + defer cleanup() + receivedFile := filepath.Base(tempFile) + defer os.Remove(receivedFile) + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8292", "pass123", "8293") + time.Sleep(200 * time.Millisecond) + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8293", "pass123") + time.Sleep(200 * time.Millisecond) + + uniqueSecret := fmt.Sprintf("test-%d-%d", time.Now().UnixNano(), rand.Intn(10000)) + + sender, err := NewCtx(ctx2, Options{ + IsSender: true, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8292", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + GitIgnore: false, + }) + if err != nil { + t.Fatalf("Create sender failed: %v", err) + } + + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tempFile}, false, false, []string{}) + if errGet != nil { + t.Fatalf("Get file info failed: %v", errGet) + } + + receiver, err := NewCtx(ctx, Options{ + IsSender: false, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8292", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + }) + if err != nil { + t.Fatalf("Create receiver failed: %v", err) + } + + fatalErr := make(chan error, 1) + + failTest := func(err error) { + select { + case fatalErr <- err: + default: + } + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + log.Warn("Send") + if err := sender.Send(filesInfo, emptyFolders, totalNumberFolders); err != nil { + failTest(fmt.Errorf("Send failed: %w", err)) + } + }() + + go func() { + defer wg.Done() + + if err := waitHashed(sender); err != nil { + failTest(fmt.Errorf("waitHashed failed: %w", err)) + return + } + + log.Warn("Receive") + if err := receiver.Receive(); err != nil { + failTest(fmt.Errorf("Receive failed: %w", err)) + } + }() + + go func() { + for i := 0; i < 3000; i++ { + if sender.Step1ChannelSecured && receiver.Step1ChannelSecured { + time.Sleep(time.Millisecond) + if sender.Step2FileInfoTransferred && receiver.Step2FileInfoTransferred { + log.Warn("Step2FileInfoTransferred reached") + cancel2() + return + } + log.Warn("Step1ChannelSecured reached") + } + time.Sleep(time.Millisecond) + } + }() + + done := make(chan bool, 1) + go func() { + wg.Wait() + done <- true + }() + + select { + case err := <-fatalErr: + result(t, err) + case <-done: + t.Error("Transfer should have been interrupted by context cancellation") + case <-time.After(5 * time.Second): + t.Fatal("Test timeout after 5 seconds") + } +} + +func TestReceiveCtx(t *testing.T) { + tempFile, cleanup := createTestFile(t, 1024*1024) // 1 МБ + defer cleanup() + receivedFile := filepath.Base(tempFile) + defer os.Remove(receivedFile) + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8294", "pass123", "8295") + time.Sleep(200 * time.Millisecond) + go tcp.RunCtx(ctx, "debug", "127.0.0.1", "8295", "pass123") + time.Sleep(200 * time.Millisecond) + + uniqueSecret := fmt.Sprintf("test-%d-%d", time.Now().UnixNano(), rand.Intn(10000)) + + sender, err := NewCtx(ctx, Options{ + IsSender: true, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8294", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + GitIgnore: false, + }) + if err != nil { + t.Fatalf("Create sender failed: %v", err) + } + + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tempFile}, false, false, []string{}) + if errGet != nil { + t.Fatalf("Get file info failed: %v", errGet) + } + + receiver, err := NewCtx(ctx2, Options{ + IsSender: false, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8294", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + }) + if err != nil { + t.Fatalf("Create receiver failed: %v", err) + } + + fatalErr := make(chan error, 1) + + failTest := func(err error) { + select { + case fatalErr <- err: + default: + } + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + log.Warn("Send") + if err := sender.Send(filesInfo, emptyFolders, totalNumberFolders); err != nil { + failTest(fmt.Errorf("Send failed: %w", err)) + } + }() + + go func() { + defer wg.Done() + + if err := waitHashed(sender); err != nil { + failTest(fmt.Errorf("waitHashed failed: %w", err)) + return + } + + log.Warn("Receive") + if err := receiver.Receive(); err != nil { + failTest(fmt.Errorf("Receive failed: %w", err)) + } + }() + + go func() { + for i := 0; i < 3000; i++ { + if sender.Step1ChannelSecured && receiver.Step1ChannelSecured { + time.Sleep(time.Millisecond) + if sender.Step2FileInfoTransferred && receiver.Step2FileInfoTransferred { + log.Warn("Step2FileInfoTransferred reached") + cancel2() + return + } + log.Warn("Step1ChannelSecured reached") + } + time.Sleep(time.Millisecond) + } + }() + + done := make(chan bool, 1) + go func() { + wg.Wait() + done <- true + }() + + select { + case err := <-fatalErr: + result(t, err) + case <-done: + t.Error("Transfer should have been interrupted by context cancellation") + case <-time.After(5 * time.Second): + t.Fatal("Test timeout after 5 seconds") + } +} + +func TestRunCtx(t *testing.T) { + tempFile, cleanup := createTestFile(t, 1024*1024) // 1 МБ + defer cleanup() + receivedFile := filepath.Base(tempFile) + defer os.Remove(receivedFile) + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + go tcp.RunCtx(ctx2, "debug", "127.0.0.1", "8296", "pass123", "8297") + time.Sleep(200 * time.Millisecond) + go tcp.RunCtx(ctx2, "debug", "127.0.0.1", "8297", "pass123") + time.Sleep(200 * time.Millisecond) + + uniqueSecret := fmt.Sprintf("test-%d-%d", time.Now().UnixNano(), rand.Intn(10000)) + + sender, err := NewCtx(ctx, Options{ + IsSender: true, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8296", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + GitIgnore: false, + }) + if err != nil { + t.Fatalf("Create sender failed: %v", err) + } + + filesInfo, emptyFolders, totalNumberFolders, errGet := GetFilesInfo([]string{tempFile}, false, false, []string{}) + if errGet != nil { + t.Fatalf("Get file info failed: %v", errGet) + } + + receiver, err := NewCtx(ctx, Options{ + IsSender: false, + SharedSecret: uniqueSecret, + Debug: true, + RelayAddress: "127.0.0.1:8296", + RelayPassword: "pass123", + Stdout: false, + NoPrompt: true, + DisableLocal: true, + Curve: "siec", + Overwrite: true, + }) + if err != nil { + t.Fatalf("Create receiver failed: %v", err) + } + + fatalErr := make(chan error, 1) + + failTest := func(err error) { + select { + case fatalErr <- err: + default: + } + } + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + log.Warn("Send") + if err := sender.Send(filesInfo, emptyFolders, totalNumberFolders); err != nil { + failTest(fmt.Errorf("Send failed: %w", err)) + } + }() + + go func() { + defer wg.Done() + + if err := waitHashed(sender); err != nil { + failTest(fmt.Errorf("waitHashed failed: %w", err)) + return + } + + log.Warn("Receive") + if err := receiver.Receive(); err != nil { + failTest(fmt.Errorf("Receive failed: %w", err)) + } + }() + + go func() { + for i := 0; i < 3000; i++ { + if sender.Step1ChannelSecured && receiver.Step1ChannelSecured { + time.Sleep(time.Millisecond) + if sender.Step2FileInfoTransferred && receiver.Step2FileInfoTransferred { + log.Warn("Step2FileInfoTransferred reached") + cancel2() + return + } + log.Warn("Step1ChannelSecured reached") + } + time.Sleep(time.Millisecond) + } + }() + + done := make(chan bool, 1) + go func() { + wg.Wait() + done <- true + }() + + select { + case err := <-fatalErr: + result(t, err) + case <-done: + t.Error("Transfer should have been interrupted by context cancellation") + case <-time.After(5 * time.Second): + t.Fatal("Test timeout after 5 seconds") + } +} diff --git a/src/croc/ctx.go b/src/croc/ctx.go new file mode 100644 index 000000000..1f9a4171e --- /dev/null +++ b/src/croc/ctx.go @@ -0,0 +1,104 @@ +// ctx.go +package croc + +import ( + "context" + "time" + + "github.com/schollz/croc/v10/src/message" + "github.com/schollz/croc/v10/src/tcp" + "github.com/schollz/croc/v10/src/utils" + log "github.com/schollz/logger" +) + +// stop manages graceful shutdown +type stop struct { + ctx context.Context + cancel context.CancelFunc + stopChan chan struct{} //peerdiscovery + run func(debugLevel string, host string, port string, password string, banner ...string) (err error) + hash func(fname string, algorithm string, showProgress ...bool) (hash256 []byte, err error) + gui bool +} + +// newStop creates a new stop manager instance +func newStop(ctx context.Context) *stop { + s := &stop{ + stopChan: make(chan struct{}), + run: tcp.Run, + hash: utils.HashFile, + } + if ctx == nil { + ctx = context.Background() + } + s.ctx, s.cancel = context.WithCancel(ctx) + + return s +} + +func (s *stop) done() { + <-s.ctx.Done() + time.Sleep(time.Millisecond) + close(s.stopChan) + log.Trace("croc done") +} + +// NewCtx creates a client with context support +func NewCtx(ctx context.Context, ops Options) (*Client, error) { + // Create a regular c + c, err := New(ops) + if err != nil { + return nil, err + } + c.stop = newStop(ctx) + c.stop.gui = true + c.stop.run = func(debugLevel string, host string, port string, password string, banner ...string) (err error) { + return tcp.RunCtx(c.stop.ctx, debugLevel, host, port, password, banner...) + } + c.stop.hash = func(fname string, algorithm string, showProgress ...bool) (hash256 []byte, err error) { + return utils.HashFileCtx(c.stop.ctx, fname, algorithm, showProgress...) + } + + go func() { + select { + case <-ctx.Done(): + log.Trace("parent context canceled") + c.SendError() + case <-c.stopChan: + // for stop goroutine + } + log.Trace("croc NewCtx done") + }() + + return c, nil +} + +// ctxErr checks whether it is necessary to interrupt my loops and goroutines +func (s *stop) ctxErr() error { + select { + case <-s.ctx.Done(): + return s.ctx.Err() + default: + return nil + } +} + +// Cancel initiates interruption of my loops and goroutines +func (s *stop) Cancel() { + log.Trace("croc Cancel") + if s.cancel != nil { + s.cancel() + s.cancel = nil + } +} + +// SendError tells the peer to interrupt their loops and goroutines +func (c *Client) SendError() { + if c.Key != nil && len(c.conn) > 0 && c.conn[0] != nil { + message.Send(c.conn[0], c.Key, message.Message{ + Type: message.TypeError, + Message: "refusing files", + }) + time.Sleep(time.Millisecond) + } +} diff --git a/src/tcp/ctx.go b/src/tcp/ctx.go new file mode 100644 index 000000000..30661e680 --- /dev/null +++ b/src/tcp/ctx.go @@ -0,0 +1,69 @@ +// ctx.go +package tcp + +import ( + "context" + "errors" + "net" + "sync" + + log "github.com/schollz/logger" +) + +// stop manages graceful shutdown of the TCP server +type stop struct { + ctx context.Context + cancel context.CancelFunc + // Track connections + server net.Listener + wg sync.WaitGroup + gui bool +} + +// newStop creates a new stop manager +func newStop(ctx context.Context) *stop { + s := &stop{} + if ctx == nil { + ctx = context.Background() + } + s.ctx, s.cancel = context.WithCancel(ctx) + + return s +} + +// Cancel initiate graceful shutdown +func (s *stop) Cancel() { + log.Trace("tcp Cancel") + if s.cancel != nil { + s.cancel() + s.cancel = nil + } +} + +func RunCtx(ctx context.Context, debugLevel, host, port, password string, banner ...string) error { + return RunWithOptionsAsync(host, port, password, WithBanner(banner...), WithLogLevel(debugLevel), WithCtx(ctx)) +} + +func WithCtx(ctx context.Context) serverOptsFunc { + return func(s *server) error { + if s.stop.cancel != nil { + s.stop.cancel() + } + s.stop = newStop(ctx) + s.stop.gui = true + return nil + } +} + +// Ignore context cancellation error +func Ignore(err error) error { + if err != nil && (errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) || + // ignore Listener closed during cancellation + // strings.Contains(err.Error(), "use of closed network connection") || + errors.Is(err, net.ErrClosed)) { + log.Tracef("ignored: %v", err) + return nil + } + return err +} diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 177bf20be..5cd090773 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -2,6 +2,7 @@ package tcp import ( "bytes" + "context" "fmt" "net" "strings" @@ -27,7 +28,9 @@ type server struct { roomCleanupInterval time.Duration roomTTL time.Duration - stopRoomCleanup chan struct{} + // stopRoomCleanup chan struct{} + // replaced by stop ctx.go + *stop } type roomInfo struct { @@ -50,7 +53,8 @@ func newDefaultServer() *server { s.roomCleanupInterval = DEFAULT_ROOM_CLEANUP_INTERVAL s.roomTTL = DEFAULT_ROOM_TTL s.debugLevel = DEFAULT_LOG_LEVEL - s.stopRoomCleanup = make(chan struct{}) + // s.stopRoomCleanup = make(chan struct{}) replaced by stop + s.stop = newStop(context.Background()) return s } @@ -74,27 +78,38 @@ func Run(debugLevel, host, port, password string, banner ...string) (err error) return RunWithOptionsAsync(host, port, password, WithBanner(banner...), WithLogLevel(debugLevel)) } -func (s *server) start() (err error) { - log.SetLevel(s.debugLevel) - - // Mask our password in logs - maskedPassword := "" - if len(s.password) > 2 { - maskedPassword = fmt.Sprintf("%c***%c", s.password[0], s.password[len(s.password)-1]) +// Mask our password in logs +func maskedPassword(password string) (s string) { + if len(password) > 2 { + s = fmt.Sprintf("%c***%c", password[0], password[len(password)-1]) } else { - maskedPassword = s.password + s = password } + return +} - log.Debugf("starting with password '%s'", maskedPassword) +func (s *server) start() (err error) { + log.SetLevel(s.debugLevel) + + log.Debugf("starting with password '%s'", maskedPassword(s.password)) s.rooms.Lock() s.rooms.rooms = make(map[string]roomInfo) s.rooms.Unlock() - go s.deleteOldRooms() - defer s.stopRoomDeletion() + s.stop.wg.Add(1) + go func() { + defer s.stop.wg.Done() + s.deleteOldRooms() + }() + // defer s.stopRoomDeletion() + defer s.stop.Cancel() + if s.stop.gui { + defer s.stop.wg.Wait() + } err = s.run() + err = Ignore(err) if err != nil { log.Error(err) } @@ -124,19 +139,36 @@ func (s *server) run() (err error) { } addr = strings.Replace(addr, "127.0.0.1", "0.0.0.0", 1) log.Infof("starting TCP server on %s", addr) - server, err := net.Listen(network, addr) + lc := net.ListenConfig{} + s.stop.server, err = lc.Listen(s.stop.ctx, network, addr) if err != nil { return fmt.Errorf("error listening on %s: %w", addr, err) } - defer server.Close() + defer s.stop.server.Close() + + go func() { + dc := &net.Dialer{ + Timeout: 100 * time.Millisecond, + } + if conn, err := dc.DialContext(s.stop.ctx, network, addr); err == nil { + log.Debugf("started TCP server on %s", addr) + conn.Close() + } else { + log.Errorf("started TCP server on %s : %v", addr, err) + s.stop.Cancel() + } + }() + // spawn a new goroutine whenever a client connects for { - connection, err := server.Accept() + connection, err := s.stop.server.Accept() if err != nil { return fmt.Errorf("problem accepting connection: %w", err) } log.Debugf("client %s connected", connection.RemoteAddr().String()) + s.stop.wg.Add(1) go func(connection net.Conn) { + defer s.stop.wg.Done() c := comm.New(connection) room, errCommunication := s.clientCommunication(c) log.Debugf("room: %+v", room) @@ -151,9 +183,11 @@ func (s *server) run() (err error) { connection.Close() return } + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() for { // check connection - log.Debugf("checking connection of room %s for %+v", room, c) + log.Tracef("checking connection of room %s for %+v", room, c) deleteIt := false s.rooms.Lock() roomData, ok := s.rooms.rooms[room] @@ -162,14 +196,14 @@ func (s *server) run() (err error) { s.rooms.Unlock() return } - log.Debugf("room: %+v", roomData) + log.Tracef("room: %+v", roomData) if roomData.first != nil && roomData.second != nil { log.Debug("rooms ready") s.rooms.Unlock() break } - if s.rooms.rooms[room].first != nil { - errSend := s.rooms.rooms[room].first.Send([]byte{1}) + if roomData.first != nil { + errSend := roomData.first.Send([]byte{1}) if errSend != nil { log.Debug(errSend) deleteIt = true @@ -180,7 +214,14 @@ func (s *server) run() (err error) { s.deleteRoom(room) break } - time.Sleep(1 * time.Second) + select { + case <-s.stop.ctx.Done(): + log.Tracef("check: %v", s.stop.ctx.Err()) + s.deleteRoom(room) + return + case <-ticker.C: + // time.Sleep(1 * time.Second) + } } }(connection) } @@ -190,34 +231,47 @@ func (s *server) run() (err error) { // have exceeded their allocated TTL. func (s *server) deleteOldRooms() { ticker := time.NewTicker(s.roomCleanupInterval) - for { + defer func() { + ticker.Stop() + log.Debug("room cleanup stopped") + }() + for next := true; next; { + roomsToDelete := []string{} select { case <-ticker.C: - var roomsToDelete []string s.rooms.Lock() - for room := range s.rooms.rooms { - if time.Since(s.rooms.rooms[room].opened) > s.roomTTL { + for room, roomData := range s.rooms.rooms { + if time.Since(roomData.opened) > s.roomTTL { roomsToDelete = append(roomsToDelete, room) } } s.rooms.Unlock() - - for _, room := range roomsToDelete { - s.deleteRoom(room) - log.Debugf("room cleaned up: %s", room) + case <-s.stop.ctx.Done(): + if s.server != nil { + log.Debugf("stop TCP server on %s", s.server.Addr()) + s.server.Close() + time.Sleep(time.Millisecond) } - case <-s.stopRoomCleanup: - ticker.Stop() - log.Debug("room cleanup stopped") - return + log.Debug("stop room cleanup fired") + s.rooms.Lock() + for room := range s.rooms.rooms { + roomsToDelete = append(roomsToDelete, room) + } + s.rooms.Unlock() + next = false + } + for _, room := range roomsToDelete { + s.deleteRoom(room) + log.Debugf("room cleaned up: %s", room) } } } -func (s *server) stopRoomDeletion() { - log.Debug("stop room cleanup fired") - s.stopRoomCleanup <- struct{}{} -} +// replaced by stop +// func (s *server) stopRoomDeletion() { +// log.Debug("stop room cleanup fired") +// s.stopRoomCleanup <- struct{}{} +// } var weakKey = []byte{1, 2, 3} @@ -533,7 +587,7 @@ func ConnectToTCPServer(address, password, room string, timelimit ...time.Durati return } - log.Debug("sending password") + log.Debugf("sending password '%s'", maskedPassword(password)) bSend, err := crypt.Encrypt([]byte(password), strongKeyForEncryption) if err != nil { log.Debug(err) diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go index 9c9859940..5ebbd820c 100644 --- a/src/tcp/tcp_test.go +++ b/src/tcp/tcp_test.go @@ -2,6 +2,7 @@ package tcp import ( "bytes" + "context" "fmt" "testing" "time" @@ -17,14 +18,20 @@ func BenchmarkConnection(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { c, _, _, _ := ConnectToTCPServer("127.0.0.1:8283", "pass123", fmt.Sprintf("testroom%d", i), 1*time.Minute) - c.Close() + if c != nil { + c.Close() + } } } func TestTCP(t *testing.T) { log.SetLevel("error") timeToRoomDeletion := 100 * time.Millisecond - go RunWithOptionsAsync("127.0.0.1", "8381", "pass123", WithBanner("8382"), WithLogLevel("debug"), WithRoomTTL(timeToRoomDeletion)) + go RunWithOptionsAsync("127.0.0.1", "8381", "pass123", + WithBanner("8382"), + WithLogLevel("debug"), + WithRoomTTL(timeToRoomDeletion)) + time.Sleep(timeToRoomDeletion) err := PingServer("127.0.0.1:8381") assert.Nil(t, err) @@ -69,3 +76,233 @@ func TestTCP(t *testing.T) { c1.Close() time.Sleep(300 * time.Millisecond) } + +func TestTCPctx(t *testing.T) { + log.SetLevel("error") + // Set short room TTL for testing cleanup + timeToRoomDeletion := 100 * time.Millisecond + + // Create cancelable context + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Start server with custom options + go RunWithOptionsAsync("127.0.0.1", "8381", "pass123", + WithBanner("8382"), + WithLogLevel("debug"), + WithRoomTTL(timeToRoomDeletion), + WithCtx(ctx), + ) + + time.Sleep(timeToRoomDeletion) + + // Test ping to running server + err := PingServer("127.0.0.1:8381") + assert.Nil(t, err) + + // Test ping to non-existent server + err = PingServer("127.0.0.1:8333") + assert.NotNil(t, err) + + time.Sleep(timeToRoomDeletion) + + // Connect first client to room + c1, banner, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", 1*time.Minute) + assert.Equal(t, banner, "8382") + assert.Nil(t, err) + + // Connect second client to same room + c2, _, _, err := ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom") + assert.Nil(t, err) + + // Third client should fail - room is full + _, _, _, err = ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom") + assert.NotNil(t, err) + + // Connection with very short timeout should fail + _, _, _, err = ConnectToTCPServer("127.0.0.1:8381", "pass123", "testRoom", 1*time.Nanosecond) + assert.NotNil(t, err) + + // Test data exchange between clients + // Send from c1 to c2 + assert.Nil(t, c1.Send([]byte("hello, c2"))) + var data []byte + for { + data, err = c2.Receive() + if bytes.Equal(data, []byte{1}) { + continue // Skip heartbeat + } + break + } + assert.Nil(t, err) + assert.Equal(t, []byte("hello, c2"), data) + + // Send from c2 to c1 + assert.Nil(t, c2.Send([]byte("hello, c1"))) + for { + data, err = c1.Receive() + if bytes.Equal(data, []byte{1}) { + continue // Skip heartbeat + } + break + } + assert.Nil(t, err) + assert.Equal(t, []byte("hello, c1"), data) + + // Close server + cancel() + + // Test ping to non-existent server + err = PingServer("127.0.0.1:8331") + assert.NotNil(t, err) + + time.Sleep(300 * time.Millisecond) +} + +func TestWrongPassword(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8385", "pass123", "8386") + time.Sleep(100 * time.Millisecond) + + // Attempt to connect with wrong password + _, _, _, err := ConnectToTCPServer("127.0.0.1:8385", "wrongpass", "testRoom") + assert.NotNil(t, err) + assert.Contains(t, err.Error(), "bad password") +} + +func TestRoomIsolation(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8387", "pass123", "8388") + time.Sleep(100 * time.Millisecond) + + // Room 1 + c1, _, _, _ := ConnectToTCPServer("127.0.0.1:8387", "pass123", "room1") + c2, _, _, _ := ConnectToTCPServer("127.0.0.1:8387", "pass123", "room1") + + // Room 2 + c3, _, _, _ := ConnectToTCPServer("127.0.0.1:8387", "pass123", "room2") + c4, _, _, _ := ConnectToTCPServer("127.0.0.1:8387", "pass123", "room2") + + // Send data in different rooms + c1.Send([]byte("to_room_1")) + c3.Send([]byte("to_room_2")) + + // Verify reception + var data []byte + + // c2 should receive message from room1 + for { + data, _ = c2.Receive() + if bytes.Equal(data, []byte{1}) { + continue + } + break + } + assert.Equal(t, []byte("to_room_1"), data) + + // c4 should receive message from room2 + for { + data, _ = c4.Receive() + if bytes.Equal(data, []byte{1}) { + continue + } + break + } + assert.Equal(t, []byte("to_room_2"), data) + + c1.Close() + c2.Close() + c3.Close() + c4.Close() +} + +func TestRoomRecreationAfterTTL(t *testing.T) { + log.SetLevel("error") + shortTTL := 50 * time.Millisecond + + go RunWithOptionsAsync("127.0.0.1", "8389", "pass123", + WithRoomTTL(shortTTL), + WithLogLevel("error")) + time.Sleep(100 * time.Millisecond) + + roomName := "testRoomRecreate" + + // 1. Create a room + c1, _, _, _ := ConnectToTCPServer("127.0.0.1:8389", "pass123", roomName) + assert.NotNil(t, c1) + + // 2. Close first client, room becomes empty + c1.Close() + + // 3. Wait for room cleanup (TTL + buffer) + time.Sleep(shortTTL + 50*time.Millisecond) + + // 4. Try to connect to the same room again. + // If room wasn't deleted, we might get "room full" or weird behavior. + // If deleted — connection should succeed as the first client. + c3, _, _, err := ConnectToTCPServer("127.0.0.1:8389", "pass123", roomName) + assert.Nil(t, err) + assert.NotNil(t, c3) + + if c3 != nil { + c3.Close() + } +} + +func TestLargeDataTransfer(t *testing.T) { + log.SetLevel("error") + go Run("debug", "127.0.0.1", "8391", "pass123", "8392") + time.Sleep(100 * time.Millisecond) + + c1, _, _, _ := ConnectToTCPServer("127.0.0.1:8391", "pass123", "bigRoom") + c2, _, _, _ := ConnectToTCPServer("127.0.0.1:8391", "pass123", "bigRoom") + + // Generate data larger than standard buffer (e.g., 1 MB) + largeData := make([]byte, 1024*1024) + for i := range largeData { + largeData[i] = byte(i % 256) + } + + err := c1.Send(largeData) + assert.Nil(t, err) + + var received []byte + // Receive data, as it might arrive in chunks (though chanFromConn buffers it) + // In this case pipe passes full Read packets, but for safety let's verify tail + for { + data, err := c2.Receive() + if bytes.Equal(data, []byte{1}) { + continue + } + assert.Nil(t, err) + received = data + break + } + + assert.True(t, bytes.Equal(largeData, received), "Large data mismatch") + + c1.Close() + c2.Close() +} + +func TestServerReleasesPort(t *testing.T) { + log.SetLevel("trace") + host := "127.0.0.1" + port := "8394" + + // 1. Start and automatically stop first server using timeout + // RunCtx blocks the execution, so we don't need 'go' or channels + ctx1, cancel1 := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel1() + + err := RunCtx(ctx1, "trace", host, port, "pass123") + assert.Nil(t, err, "First server should stop gracefully") + + // 2. Try to start second server on the same port immediately + // If port is not released, this will fail with "address already in use" + ctx2, cancel2 := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel2() + + err = RunCtx(ctx2, "trace", host, port, "pass123") + assert.Nil(t, err, "Second server should start (port was released)") +} diff --git a/src/utils/ctx.go b/src/utils/ctx.go new file mode 100644 index 000000000..07b560ce0 --- /dev/null +++ b/src/utils/ctx.go @@ -0,0 +1,281 @@ +// ctx.go +package utils + +import ( + "context" + "crypto/md5" + "encoding/hex" + "fmt" + "io" + "os" + "path" + "time" + + "github.com/cespare/xxhash/v2" + "github.com/minio/highwayhash" + "github.com/schollz/progressbar/v3" +) + +// ctxFile wraps os.File with context cancellation support. +type ctxFile struct { + ctx context.Context + f *os.File +} + +// NewCtxFile creates a new context-aware file wrapper. +func NewCtxFile(ctx context.Context, f *os.File) *ctxFile { + return &ctxFile{ctx: ctx, f: f} +} + +// Read implements io.Reader interface with context cancellation. +func (c *ctxFile) Read(p []byte) (n int, err error) { + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + n, err = c.f.Read(p) + if c.ctx.Err() != nil { + return 0, c.ctx.Err() + } + return n, err + } +} + +// ReadAt implements io.ReaderAt interface with context cancellation. +func (c *ctxFile) ReadAt(p []byte, off int64) (n int, err error) { + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + n, err = c.f.ReadAt(p, off) + if c.ctx.Err() != nil { + return 0, c.ctx.Err() + } + return n, err + } +} + +// Seek implements io.Seeker interface with context cancellation. +func (c *ctxFile) Seek(offset int64, whence int) (n int64, err error) { + select { + case <-c.ctx.Done(): + return 0, c.ctx.Err() + default: + n, err = c.f.Seek(offset, whence) + if c.ctx.Err() != nil { + return 0, c.ctx.Err() + } + return n, err + } +} + +// HashFileCtx returns the hash of a file with context cancellation support. +func HashFileCtx(ctx context.Context, fname string, algorithm string, showProgress ...bool) ([]byte, error) { + // Quick context check before starting + if err := ctx.Err(); err != nil { + return nil, err + } + + fstats, err := os.Lstat(fname) + if err != nil { + return nil, err + } + + // Handle symlinks - quick operation, no context needed + if fstats.Mode()&os.ModeSymlink != 0 { + target, err := os.Readlink(fname) + if err != nil { + return nil, err + } + return []byte(SHA256(target)), nil + } + + f, err := os.Open(fname) + if err != nil { + return nil, err + } + defer f.Close() + + // Get file info for size (now file is opened, following symlinks if any) + fi, err := f.Stat() + if err != nil { + return nil, err + } + + // Wrap the file with context support + cf := NewCtxFile(ctx, f) + sr := io.NewSectionReader(cf, 0, fi.Size()) + + // Parse showProgress parameter + doShowProgress := false + if len(showProgress) > 0 { + doShowProgress = showProgress[0] + } + + // Create progress bar based on algorithm + var bar *progressbar.ProgressBar + if doShowProgress { + fnameShort := path.Base(fname) + if len(fnameShort) > 20 { + fnameShort = fnameShort[:20] + "..." + } + + if algorithm == "imohash" { + // Spinner for imohash (indeterminate progress, max = -1) + bar = progressbar.NewOptions64(-1, + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionShowBytes(false), + progressbar.OptionSetDescription(fmt.Sprintf("Sampling %s", fnameShort)), + progressbar.OptionClearOnFinish(), + progressbar.OptionFullWidth(), + progressbar.OptionShowElapsedTimeOnFinish(), + progressbar.OptionSpinnerType(14), + progressbar.OptionSetSpinnerChangeInterval(100*time.Millisecond), + ) + } else { + // Regular progress bar for other algorithms + bar = progressbar.NewOptions64(fi.Size(), + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionShowBytes(true), + progressbar.OptionSetDescription(fmt.Sprintf("Hashing %s", fnameShort)), + progressbar.OptionClearOnFinish(), + progressbar.OptionFullWidth(), + ) + } + } + + // Dispatch to appropriate hash function + switch algorithm { + case "imohash": + return IMOHashReader(sr, bar) + case "md5": + return MD5HashReader(sr, bar) + case "xxhash": + return XXHashReader(sr, bar) + case "highway": + return HighwayHashReader(sr, bar) + default: + return nil, fmt.Errorf("unsupported algorithm: %s", algorithm) + } +} + +// IMOHashReader returns imohash for a SectionReader. +// Uses spinner progress bar for indeterminate progress. +func IMOHashReader(sr *io.SectionReader, bar *progressbar.ProgressBar) ([]byte, error) { + // Start spinner if provided + if bar != nil { + // Add(0) triggers initial render for spinner + bar.Add(0) + } + + b, err := imopartial.SumSectionReader(sr) + if err != nil { + // If there's an error, finish the bar to clean up display + if bar != nil { + bar.Exit() + } + return nil, err + } + + // Finish the progress bar + if bar != nil { + bar.Finish() + } + + return b[:], nil +} + +// IMOHashReaderFull returns full imohash (no sampling) for a SectionReader. +func IMOHashReaderFull(sr *io.SectionReader, bar *progressbar.ProgressBar) ([]byte, error) { + // For full imohash (which reads entire file), use regular progress bar logic + if bar != nil { + bar.Add(0) // Start the spinner + } + + b, err := imofull.SumSectionReader(sr) + if err != nil { + if bar != nil { + bar.Exit() + } + return nil, err + } + + if bar != nil { + bar.Finish() + } + + return b[:], nil +} + +// MD5HashReader returns MD5 hash for a SectionReader. +func MD5HashReader(sr *io.SectionReader, bar *progressbar.ProgressBar) ([]byte, error) { + // Reset to beginning + if _, err := sr.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + h := md5.New() + if bar != nil { + // Copy with progress tracking (like original code) + if _, err := io.Copy(io.MultiWriter(h, bar), sr); err != nil { + return nil, err + } + } else { + if _, err := io.Copy(h, sr); err != nil { + return nil, err + } + } + return h.Sum(nil), nil +} + +// XXHashReader returns xxhash for a SectionReader. +func XXHashReader(sr *io.SectionReader, bar *progressbar.ProgressBar) ([]byte, error) { + if _, err := sr.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + h := xxhash.New() + if bar != nil { + if _, err := io.Copy(io.MultiWriter(h, bar), sr); err != nil { + return nil, err + } + } else { + if _, err := io.Copy(h, sr); err != nil { + return nil, err + } + } + return h.Sum(nil), nil +} + +// HighwayHashReader returns highwayhash for a SectionReader. +func HighwayHashReader(sr *io.SectionReader, bar *progressbar.ProgressBar) ([]byte, error) { + if _, err := sr.Seek(0, io.SeekStart); err != nil { + return nil, err + } + + key, err := hex.DecodeString("1553c5383fb0b86578c3310da665b4f6e0521acf22eb58a99532ffed02a6b115") + if err != nil { + return nil, err + } + + h, err := highwayhash.New(key) + if err != nil { + return nil, fmt.Errorf("could not create highwayhash: %w", err) + } + + if bar != nil { + if _, err := io.Copy(io.MultiWriter(h, bar), sr); err != nil { + return nil, err + } + } else { + if _, err := io.Copy(h, sr); err != nil { + return nil, err + } + } + return h.Sum(nil), nil +} + +// Helper function to update existing HashFile to use HashFileCtx +// func HashFile(fname string, algorithm string, showProgress ...bool) ([]byte, error) { +// return HashFileCtx(context.Background(), fname, algorithm, showProgress...) +// } diff --git a/src/utils/utils.go b/src/utils/utils.go index a3870cf35..22d30dba3 100644 --- a/src/utils/utils.go +++ b/src/utils/utils.go @@ -415,6 +415,9 @@ func ChunkRangesToChunks(chunkRanges []int64) (chunks []int64) { func GetLocalIPs() (ips []string, err error) { addrs, err := net.InterfaceAddrs() if err != nil { + if ip := LocalIP(); ip != "" { + return []string{ip}, nil + } return } ips = []string{} diff --git a/src/utils/utils_test.go b/src/utils/utils_test.go index 5d51dc763..6f9cc7fe3 100644 --- a/src/utils/utils_test.go +++ b/src/utils/utils_test.go @@ -3,6 +3,7 @@ package utils import ( "archive/zip" "bytes" + "context" "fmt" "log" "math/rand" @@ -658,3 +659,316 @@ func verifyFileModTime(t *testing.T, filePath string, expectedTime time.Time) { filePath, expected, actual) } } + +// TestHashFileCtxNoCancellation tests HashFileCtx without cancellation +func TestHashFileCtxNoCancellation(t *testing.T) { + // Use the same bigFile() function as other tests + bigFile() + defer os.Remove("bigfile.test") + + ctx := context.Background() + + // Test each algorithm - using the same expected values from existing tests + tests := []struct { + name string + algorithm string + wantHash string + }{ + { + name: "MD5 hash", + algorithm: "md5", + wantHash: "8304ff018e02baad0e3555bade29a405", // From TestMD5HashFile + }, + { + name: "XXHash", + algorithm: "xxhash", + wantHash: "4918740eb5ccb6f7", // From TestXXHashFile + }, + { + name: "imohash", + algorithm: "imohash", + wantHash: "c0d1e12301e6c635f6d4a8ea5c897437", // From TestIMOHashFile + }, + { + name: "highway", + algorithm: "highway", + wantHash: "3c32999529323ed66a67aeac5720c7bf1301dcc5dca87d8d46595e85ff990329", // From TestHighwayHashFile + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test without progress bar + hash, err := HashFileCtx(ctx, "bigfile.test", tt.algorithm) + assert.NoError(t, err, "HashFileCtx should not return error") + assert.Equal(t, tt.wantHash, fmt.Sprintf("%x", hash), + "Hash should match for algorithm %s", tt.algorithm) + + // Test with progress bar (false) + hash, err = HashFileCtx(ctx, "bigfile.test", tt.algorithm, false) + assert.NoError(t, err, "HashFileCtx with showProgress=false should not return error") + assert.Equal(t, tt.wantHash, fmt.Sprintf("%x", hash), + "Hash should match for algorithm %s with showProgress=false", tt.algorithm) + + // Test with progress bar (true) - only for non-imohash to avoid spinner issues in tests + if tt.algorithm != "imohash" { + hash, err = HashFileCtx(ctx, "bigfile.test", tt.algorithm, true) + assert.NoError(t, err, "HashFileCtx with showProgress=true should not return error") + assert.Equal(t, tt.wantHash, fmt.Sprintf("%x", hash), + "Hash should match for algorithm %s with showProgress=true", tt.algorithm) + } + }) + } + + // Test symlink handling - match original behavior + t.Run("Symlink handling", func(t *testing.T) { + // Create symlink to bigfile.test + symlinkPath := "bigfile.test.symlink" + defer os.Remove(symlinkPath) + + err := os.Symlink("bigfile.test", symlinkPath) + if err != nil && strings.Contains(err.Error(), "privilege") { + t.Skip("Skipping symlink test - requires privilege") + } + assert.NoError(t, err, "Should create symlink") + + // Hash the symlink + hash, err := HashFileCtx(ctx, symlinkPath, "md5") + assert.NoError(t, err, "Should hash symlink target path") + assert.NotNil(t, hash, "Should return hash for symlink") + + // The original HashFile returns []byte(SHA256(target)) + // SHA256("bigfile.test") = "3ae29e98bba80ccefc79289c59cc34cb7223954310bb61c6a26147bb9b08c4e4" + // []byte("3ae29e98...") = ASCII bytes of hex string + + // When converted back with fmt.Sprintf("%x", hash): + // ASCII '3' = 0x33, 'a' = 0x61, 'e' = 0x65, '2' = 0x32, etc. + // So fmt.Sprintf("%x", []byte("3ae2...")) = "33616532..." + + actualHex := fmt.Sprintf("%x", hash) + + // Let's compute what we SHOULD get: + targetPath := "bigfile.test" + expectedSHA256Hex := SHA256(targetPath) // "3ae29e98..." + expectedBytes := []byte(expectedSHA256Hex) + expectedResultHex := fmt.Sprintf("%x", expectedBytes) // hex of ASCII bytes + + // Debug + t.Logf("Target path: '%s'", targetPath) + t.Logf("SHA256(target) hex: %s", expectedSHA256Hex) + t.Logf("Expected result (hex of ASCII bytes): %s", expectedResultHex) + t.Logf("Actual result: %s", actualHex) + + // They should match! + assert.Equal(t, expectedResultHex, actualHex, + "HashFileCtx should behave exactly like HashFile for symlinks") + + // Also test with original HashFile to ensure consistency + originalHash, err := HashFile(symlinkPath, "md5") + assert.NoError(t, err) + originalHex := fmt.Sprintf("%x", originalHash) + + assert.Equal(t, originalHex, actualHex, + "HashFileCtx should return same result as HashFile for symlinks") + }) + // Test error cases + t.Run("Error cases", func(t *testing.T) { + // Non-existent file + hash, err := HashFileCtx(ctx, "non_existent_file_12345.test", "md5") + assert.Error(t, err, "Should return error for non-existent file") + assert.Nil(t, hash, "Hash should be nil on error") + + // Unsupported algorithm + hash, err = HashFileCtx(ctx, "bigfile.test", "unsupported_algo") + assert.Error(t, err, "Should return error for unsupported algorithm") + assert.Contains(t, err.Error(), "unsupported algorithm") + assert.Nil(t, hash, "Hash should be nil on error") + }) +} + +// TestHashFileCtxWithCancellation tests HashFileCtx with context cancellation +func TestHashFileCtxWithCancellation(t *testing.T) { + // Use the same bigFile() function + bigFile() + defer os.Remove("bigfile.test") + + // Test 1: Cancel before starting + t.Run("Cancel before start", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + hash, err := HashFileCtx(ctx, "bigfile.test", "md5") + assert.Error(t, err, "Should return error when context cancelled before start") + assert.Equal(t, context.Canceled, err, "Error should be context.Canceled") + assert.Nil(t, hash, "Hash should be nil when cancelled") + }) + + // Test 2: Cancel during operation + t.Run("Cancel during operation", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + // Start hash operation in goroutine + errCh := make(chan error, 1) + hashCh := make(chan []byte, 1) + + go func() { + hash, err := HashFileCtx(ctx, "bigfile.test", "md5", false) + if err != nil { + errCh <- err + hashCh <- nil + } else { + errCh <- nil + hashCh <- hash + } + }() + + // Cancel after a short delay + time.Sleep(10 * time.Millisecond) + cancel() + + // Wait for result + select { + case err := <-errCh: + hash := <-hashCh + // Either we got an error (cancelled) or a hash (completed before cancellation) + if err != nil { + // Check if it's a context error + if err == context.Canceled || err == context.DeadlineExceeded { + assert.Error(t, err, "Should return context error when cancelled") + } + assert.Nil(t, hash, "Hash should be nil when cancelled") + } else { + // Completed successfully before cancellation + assert.NotNil(t, hash, "If not cancelled, should return hash") + assert.Equal(t, 16, len(hash), "MD5 hash should be 16 bytes") + // Verify it's the correct hash + assert.Equal(t, "8304ff018e02baad0e3555bade29a405", fmt.Sprintf("%x", hash)) + } + case <-time.After(5 * time.Second): + t.Fatal("Test timed out") + } + }) + + // Test 3: Cancel with deadline + t.Run("Cancel with deadline", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + // For a 75MB file, MD5 should take more than 1ms + hash, err := HashFileCtx(ctx, "bigfile.test", "md5", false) + assert.Error(t, err, "Should timeout for 75MB file with 1ms deadline") + assert.Equal(t, context.DeadlineExceeded, err, "Error should be context.DeadlineExceeded") + assert.Nil(t, hash, "Hash should be nil when deadline exceeded") + }) + + // Test 4: Imohash should be fast enough to complete before cancellation + t.Run("Imohash fast completion", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Imohash samples the file, so it should complete quickly + hash, err := HashFileCtx(ctx, "bigfile.test", "imohash", false) + assert.NoError(t, err, "Imohash should complete before any cancellation") + assert.NotNil(t, hash, "Should return hash for imohash") + assert.Equal(t, 16, len(hash), "Imohash should be 16 bytes") + // Verify it's the correct hash + assert.Equal(t, "c0d1e12301e6c635f6d4a8ea5c897437", fmt.Sprintf("%x", hash)) + }) +} + +// TestHashFileCtxEquivalence tests that HashFileCtx produces same results as original HashFile +func TestHashFileCtxEquivalence(t *testing.T) { + // Use bigFile() for consistency + bigFile() + defer os.Remove("bigfile.test") + + algorithms := []string{"md5", "xxhash", "imohash", "highway"} + + for _, algorithm := range algorithms { + t.Run(algorithm, func(t *testing.T) { + // Get hash using original HashFile + originalHash, err1 := HashFile("bigfile.test", algorithm) + + // Get hash using HashFileCtx with background context + ctxHash, err2 := HashFileCtx(context.Background(), "bigfile.test", algorithm) + + // Both should succeed or fail together + if err1 != nil { + assert.Error(t, err2, "HashFileCtx should also fail if HashFile fails") + t.Logf("Both failed as expected: %v", err1) + } else { + assert.NoError(t, err2, "HashFileCtx should not fail if HashFile succeeds") + assert.NotNil(t, originalHash, "Original hash should not be nil") + assert.NotNil(t, ctxHash, "Context hash should not be nil") + + // Compare hex representations + originalHex := fmt.Sprintf("%x", originalHash) + ctxHex := fmt.Sprintf("%x", ctxHash) + assert.Equal(t, originalHex, ctxHex, + "HashFile and HashFileCtx should produce same hash for algorithm %s. Got %s vs %s", + algorithm, originalHex, ctxHex) + + // Also verify against known values from existing tests + switch algorithm { + case "md5": + assert.Equal(t, "8304ff018e02baad0e3555bade29a405", originalHex) + case "xxhash": + assert.Equal(t, "4918740eb5ccb6f7", originalHex) + case "imohash": + assert.Equal(t, "c0d1e12301e6c635f6d4a8ea5c897437", originalHex) + case "highway": + assert.Equal(t, "3c32999529323ed66a67aeac5720c7bf1301dcc5dca87d8d46595e85ff990329", originalHex) + } + } + }) + } +} + +// TestHashFileCtxLargeFile tests with larger files (already using bigfile.test) +func TestHashFileCtxLargeFile(t *testing.T) { + // Skip in short mode + if testing.Short() { + t.Skip("Skipping large file test in short mode") + } + + // Use bigFile() + bigFile() + defer os.Remove("bigfile.test") + + ctx := context.Background() + + // Test each algorithm with large file + algorithms := []string{"md5", "xxhash", "imohash", "highway"} + + for _, algorithm := range algorithms { + t.Run(algorithm, func(t *testing.T) { + hash, err := HashFileCtx(ctx, "bigfile.test", algorithm, false) + assert.NoError(t, err, "Should hash large file with algorithm %s", algorithm) + assert.NotNil(t, hash, "Should return hash for large file") + + // Verify hash size + switch algorithm { + case "md5": + assert.Equal(t, 16, len(hash), "MD5 should be 16 bytes") + case "xxhash": + assert.Equal(t, 8, len(hash), "XXHash should be 8 bytes") + case "imohash": + assert.Equal(t, 16, len(hash), "Imohash should be 16 bytes") + case "highway": + assert.Equal(t, 32, len(hash), "HighwayHash should be 32 bytes") + } + + // Verify against known values + switch algorithm { + case "md5": + assert.Equal(t, "8304ff018e02baad0e3555bade29a405", fmt.Sprintf("%x", hash)) + case "xxhash": + assert.Equal(t, "4918740eb5ccb6f7", fmt.Sprintf("%x", hash)) + case "imohash": + assert.Equal(t, "c0d1e12301e6c635f6d4a8ea5c897437", fmt.Sprintf("%x", hash)) + case "highway": + assert.Equal(t, "3c32999529323ed66a67aeac5720c7bf1301dcc5dca87d8d46595e85ff990329", fmt.Sprintf("%x", hash)) + } + }) + } +}