Skip to content

Commit e1e5be4

Browse files
committed
test: add coverage for dedup, collections, identities
Table-driven tests for dedup engine, collections CRUD, identity discovery, and source filter helpers. Incorporates review findings.
1 parent 93684d9 commit e1e5be4

13 files changed

Lines changed: 516 additions & 47 deletions

File tree

cmd/msgvault/cmd/collections.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ func runCollectionsDelete(_ *cobra.Command, args []string) error {
191191
return nil
192192
}
193193

194-
195194
func resolveAccountList(st *store.Store, accounts string) ([]int64, error) {
196195
if accounts == "" {
197196
return nil, fmt.Errorf("--accounts is required")

cmd/msgvault/cmd/deduplicate.go

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,15 @@ func runDeduplicate(cmd *cobra.Command, _ []string) error {
6262
preference := dedup.DefaultSourcePreference
6363
if dedupPrefer != "" {
6464
preference = strings.Split(dedupPrefer, ",")
65+
known := make(map[string]bool, len(dedup.DefaultSourcePreference))
66+
for _, t := range dedup.DefaultSourcePreference {
67+
known[t] = true
68+
}
6569
for i := range preference {
6670
preference[i] = strings.TrimSpace(preference[i])
71+
if !known[preference[i]] {
72+
fmt.Fprintf(os.Stderr, "Warning: unknown source type in --prefer: %q\n", preference[i])
73+
}
6774
}
6875
}
6976

@@ -119,7 +126,7 @@ func runDeduplicate(cmd *cobra.Command, _ []string) error {
119126
}
120127

121128
if len(accountSourceIDs) == 0 {
122-
return runDeduplicatePerSource(cmd, st, config)
129+
return runDeduplicatePerSource(cmd, st, dbPath, config)
123130
}
124131

125132
return runDeduplicateOnce(cmd, dbPath, config, engine)
@@ -128,6 +135,7 @@ func runDeduplicate(cmd *cobra.Command, _ []string) error {
128135
func runDeduplicatePerSource(
129136
cmd *cobra.Command,
130137
st *store.Store,
138+
dbPath string,
131139
cfgBase dedup.Config,
132140
) error {
133141
sources, err := st.ListSources("")
@@ -144,6 +152,7 @@ func runDeduplicatePerSource(
144152
)
145153
fmt.Println()
146154

155+
backedUp := false
147156
anyRan := false
148157
for _, src := range sources {
149158
cfgScoped := cfgBase
@@ -186,8 +195,21 @@ func runDeduplicatePerSource(
186195
}
187196
}
188197

198+
if !backedUp && !dedupNoBackup {
199+
backedUp = true
200+
backupPath := fmt.Sprintf(
201+
"%s.dedup-backup-%s", dbPath,
202+
time.Now().Format("20060102-150405"),
203+
)
204+
fmt.Printf("Backing up database to %s...\n",
205+
filepath.Base(backupPath))
206+
if err := copyFileForBackup(dbPath, backupPath); err != nil {
207+
return fmt.Errorf("backup database: %w", err)
208+
}
209+
}
210+
189211
batchID := fmt.Sprintf(
190-
"dedup-%s", time.Now().Format("20060102-150405"),
212+
"dedup-%s-%d-%s", time.Now().Format("20060102-150405"), src.ID, src.Identifier,
191213
)
192214
summary, err := engineScoped.Execute(
193215
cmd.Context(), report, batchID,
@@ -308,21 +330,21 @@ func readDedupYesNo(cmd *cobra.Command) (bool, error) {
308330
}
309331

310332
func copyFileForBackup(src, dst string) error {
311-
in, err := os.Open(src)
312-
if err != nil {
333+
// Copy the main database file.
334+
if err := copyFile(src, dst); err != nil {
313335
return err
314336
}
315-
defer func() { _ = in.Close() }()
316-
317-
out, err := os.Create(dst)
318-
if err != nil {
319-
return err
320-
}
321-
if _, err := io.Copy(out, in); err != nil {
322-
_ = out.Close()
323-
return err
337+
// Also copy WAL and SHM files if they exist, so the backup is
338+
// consistent even when SQLite has uncheckpointed WAL pages.
339+
for _, suffix := range []string{"-wal", "-shm"} {
340+
extra := src + suffix
341+
if _, err := os.Stat(extra); err == nil {
342+
if err := copyFile(extra, dst+suffix); err != nil {
343+
return fmt.Errorf("copy %s: %w", suffix, err)
344+
}
345+
}
324346
}
325-
return out.Close()
347+
return nil
326348
}
327349

328350
func init() {

internal/dedup/dedup.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ func (e *Engine) Scan(ctx context.Context) (*Report, error) {
206206
if count > 0 {
207207
e.logger.Info("backfilling rfc822_message_id from stored MIME",
208208
"count", count)
209-
backfilledCount, err = e.store.BackfillRFC822IDs(
209+
var backfillFailed int64
210+
backfilledCount, backfillFailed, err = e.store.BackfillRFC822IDs(
210211
func(done, total int64) {
211212
e.logger.Info("backfill progress",
212213
"done", done, "total", total)
@@ -219,6 +220,10 @@ func (e *Engine) Scan(ctx context.Context) (*Report, error) {
219220
e.logger.Info("backfilled rfc822_message_id",
220221
"count", backfilledCount)
221222
}
223+
if backfillFailed > 0 {
224+
e.logger.Warn("backfill: some messages could not be parsed",
225+
"failed", backfillFailed)
226+
}
222227
}
223228

224229
totalMessages, err := e.store.CountActiveMessages(
@@ -517,7 +522,14 @@ func normalizeRawMIME(raw []byte) []byte {
517522
}
518523

519524
headerSection := raw[:headerEnd]
520-
body := raw[headerEnd:]
525+
// Find the start of the actual body after the blank line.
526+
bodyStart := headerEnd
527+
if bytes.HasPrefix(raw[headerEnd:], []byte("\r\n\r\n")) {
528+
bodyStart = headerEnd + 4
529+
} else {
530+
bodyStart = headerEnd + 2 // "\n\n"
531+
}
532+
body := raw[bodyStart:]
521533

522534
// Copy headerSection before appending to avoid mutating the
523535
// underlying raw buffer (headerSection is a sub-slice of raw).
@@ -544,6 +556,7 @@ func normalizeRawMIME(raw []byte) []byte {
544556
fmt.Fprintf(&buf, "%s: %s\n", key, val)
545557
}
546558
}
559+
buf.WriteString("\n") // canonical header/body separator
547560
buf.Write(body)
548561
return buf.Bytes()
549562
}

internal/dedup/dedup_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,84 @@ func TestEngine_FormatMethodology_MentionsSentPolicy(t *testing.T) {
364364
t.Errorf("methodology missing cross-account guarantee")
365365
}
366366
}
367+
368+
func TestEngine_SurvivorTiebreakers(t *testing.T) {
369+
t.Run("raw MIME wins over no raw MIME", func(t *testing.T) {
370+
f := storetest.New(t)
371+
st := f.Store
372+
373+
idNoRaw := addMessage(t, st, f.Source, "no-raw", "rfc-raw-tie", false)
374+
idHasRaw := addMessage(t, st, f.Source, "has-raw", "rfc-raw-tie", false)
375+
testutil.MustNoErr(t,
376+
st.UpsertMessageRaw(idHasRaw, []byte("Subject: test\r\n\r\nBody")),
377+
"UpsertMessageRaw",
378+
)
379+
380+
eng := dedup.NewEngine(st, dedup.Config{
381+
AccountSourceIDs: []int64{f.Source.ID},
382+
Account: "test",
383+
}, nil)
384+
report, err := eng.Scan(context.Background())
385+
testutil.MustNoErr(t, err, "Scan")
386+
if report.DuplicateGroups != 1 {
387+
t.Fatalf("groups = %d, want 1", report.DuplicateGroups)
388+
}
389+
survivor := report.Groups[0].Messages[report.Groups[0].Survivor]
390+
if survivor.ID != idHasRaw {
391+
t.Errorf("survivor = %d, want %d (has raw)", survivor.ID, idHasRaw)
392+
}
393+
_ = idNoRaw
394+
})
395+
396+
t.Run("more labels wins when raw MIME is equal", func(t *testing.T) {
397+
f := storetest.New(t)
398+
st := f.Store
399+
400+
idFew := addMessage(t, st, f.Source, "few", "rfc-label-tie", false)
401+
idMany := addMessage(t, st, f.Source, "many", "rfc-label-tie", false)
402+
403+
lid1, _ := st.EnsureLabel(f.Source.ID, "L1", "Label1", "user")
404+
lid2, _ := st.EnsureLabel(f.Source.ID, "L2", "Label2", "user")
405+
lid3, _ := st.EnsureLabel(f.Source.ID, "L3", "Label3", "user")
406+
_ = st.LinkMessageLabel(idFew, lid1)
407+
_ = st.LinkMessageLabel(idMany, lid1)
408+
_ = st.LinkMessageLabel(idMany, lid2)
409+
_ = st.LinkMessageLabel(idMany, lid3)
410+
411+
eng := dedup.NewEngine(st, dedup.Config{
412+
AccountSourceIDs: []int64{f.Source.ID},
413+
Account: "test",
414+
}, nil)
415+
report, err := eng.Scan(context.Background())
416+
testutil.MustNoErr(t, err, "Scan")
417+
if report.DuplicateGroups != 1 {
418+
t.Fatalf("groups = %d, want 1", report.DuplicateGroups)
419+
}
420+
survivor := report.Groups[0].Messages[report.Groups[0].Survivor]
421+
if survivor.ID != idMany {
422+
t.Errorf("survivor = %d, want %d (more labels)", survivor.ID, idMany)
423+
}
424+
})
425+
426+
t.Run("lower ID wins as final tiebreaker", func(t *testing.T) {
427+
f := storetest.New(t)
428+
st := f.Store
429+
430+
idFirst := addMessage(t, st, f.Source, "first", "rfc-id-tie", false)
431+
_ = addMessage(t, st, f.Source, "second", "rfc-id-tie", false)
432+
433+
eng := dedup.NewEngine(st, dedup.Config{
434+
AccountSourceIDs: []int64{f.Source.ID},
435+
Account: "test",
436+
}, nil)
437+
report, err := eng.Scan(context.Background())
438+
testutil.MustNoErr(t, err, "Scan")
439+
if report.DuplicateGroups != 1 {
440+
t.Fatalf("groups = %d, want 1", report.DuplicateGroups)
441+
}
442+
survivor := report.Groups[0].Messages[report.Groups[0].Survivor]
443+
if survivor.ID != idFirst {
444+
t.Errorf("survivor = %d, want %d (lower ID)", survivor.ID, idFirst)
445+
}
446+
})
447+
}

internal/dedup/normalize_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package dedup
2+
3+
import (
4+
"bytes"
5+
"testing"
6+
)
7+
8+
func TestNormalizeRawMIME(t *testing.T) {
9+
tests := []struct {
10+
name string
11+
input []byte
12+
wantSame bool // true if output should equal input
13+
contains string // substring the output must contain
14+
excludes string // substring the output must NOT contain
15+
}{
16+
{
17+
name: "strips Received header (CRLF)",
18+
input: []byte("Received: from mx1.google.com\r\nFrom: alice@example.com\r\nSubject: Hi\r\n\r\nBody"),
19+
contains: "From: alice@example.com",
20+
excludes: "Received",
21+
},
22+
{
23+
name: "strips multiple transport headers",
24+
input: []byte("Delivered-To: bob@example.com\r\nX-Gmail-Labels: INBOX\r\nAuthentication-Results: spf=pass\r\nFrom: alice@example.com\r\nSubject: Test\r\n\r\nBody"),
25+
contains: "From: alice@example.com",
26+
excludes: "Delivered-To",
27+
},
28+
{
29+
name: "preserves non-transport headers",
30+
input: []byte("From: alice@example.com\r\nTo: bob@example.com\r\nSubject: Meeting\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nBody text"),
31+
contains: "Subject: Meeting",
32+
},
33+
{
34+
name: "handles LF-only line endings",
35+
input: []byte("Received: from mx1\nFrom: alice@example.com\nSubject: Test\n\nBody with LF"),
36+
contains: "From: alice@example.com",
37+
excludes: "Received",
38+
},
39+
{
40+
name: "no header/body separator returns raw unchanged",
41+
input: []byte("This is just a blob of text with no headers"),
42+
wantSame: true,
43+
},
44+
{
45+
name: "empty body preserved",
46+
input: []byte("From: alice@example.com\r\nSubject: Empty\r\n\r\n"),
47+
contains: "Subject: Empty",
48+
},
49+
{
50+
name: "preserves body content exactly",
51+
input: []byte("Received: from mx1\r\nFrom: a@b.com\r\n\r\nExact body content here."),
52+
contains: "Exact body content here.",
53+
},
54+
}
55+
56+
for _, tt := range tests {
57+
t.Run(tt.name, func(t *testing.T) {
58+
inputCopy := make([]byte, len(tt.input))
59+
copy(inputCopy, tt.input)
60+
61+
result := normalizeRawMIME(tt.input)
62+
63+
if !bytes.Equal(tt.input, inputCopy) {
64+
t.Error("normalizeRawMIME mutated its input buffer")
65+
}
66+
67+
if tt.wantSame {
68+
if !bytes.Equal(result, tt.input) {
69+
t.Errorf("expected unchanged output, got:\n%s", result)
70+
}
71+
return
72+
}
73+
if tt.contains != "" && !bytes.Contains(result, []byte(tt.contains)) {
74+
t.Errorf("output missing %q:\n%s", tt.contains, result)
75+
}
76+
if tt.excludes != "" && bytes.Contains(result, []byte(tt.excludes)) {
77+
t.Errorf("output should not contain %q:\n%s", tt.excludes, result)
78+
}
79+
})
80+
}
81+
}
82+
83+
func TestNormalizeRawMIME_DeterministicOutput(t *testing.T) {
84+
raw1 := []byte("Received: from mx1.google.com\r\nFrom: sender@example.com\r\nSubject: Meeting\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nLet's meet at 3pm.")
85+
raw2 := []byte("Received: from mx2.google.com\r\nDelivered-To: other@example.com\r\nFrom: sender@example.com\r\nSubject: Meeting\r\nDate: Mon, 1 Jan 2024 12:00:00 +0000\r\n\r\nLet's meet at 3pm.")
86+
87+
hash1 := sha256Hex(normalizeRawMIME(raw1))
88+
hash2 := sha256Hex(normalizeRawMIME(raw2))
89+
if hash1 != hash2 {
90+
t.Errorf("same message with different transport headers produced different hashes")
91+
}
92+
}

internal/query/duckdb.go

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -682,10 +682,7 @@ func (e *DuckDBEngine) buildWhereClause(opts AggregateOptions, keyColumns ...str
682682
// message_type IS NULL and '' handle old data without the column.
683683
conditions = append(conditions, "(msg.message_type = 'email' OR msg.message_type IS NULL OR msg.message_type = '')")
684684

685-
if opts.SourceID != nil {
686-
conditions = append(conditions, "msg.source_id = ?")
687-
args = append(args, *opts.SourceID)
688-
}
685+
conditions, args = appendSourceFilter(conditions, args, "msg.", opts.SourceID, opts.SourceIDs)
689686

690687
if opts.After != nil {
691688
conditions = append(conditions, "msg.sent_at >= CAST(? AS TIMESTAMP)")
@@ -888,10 +885,7 @@ func (e *DuckDBEngine) buildFilterConditions(filter MessageFilter) (string, []in
888885
// message_type IS NULL and '' handle old data without the column.
889886
conditions = append(conditions, "(msg.message_type = 'email' OR msg.message_type IS NULL OR msg.message_type = '')")
890887

891-
if filter.SourceID != nil {
892-
conditions = append(conditions, "msg.source_id = ?")
893-
args = append(args, *filter.SourceID)
894-
}
888+
conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs)
895889

896890
if filter.ConversationID != nil {
897891
conditions = append(conditions, "msg.conversation_id = ?")
@@ -1151,10 +1145,7 @@ func (e *DuckDBEngine) GetTotalStats(ctx context.Context, opts StatsOptions) (*T
11511145
// Restrict to email messages only; NULL and '' handle pre-message_type data.
11521146
conditions = append(conditions, emailOnlyFilterMsg)
11531147

1154-
if opts.SourceID != nil {
1155-
conditions = append(conditions, "msg.source_id = ?")
1156-
args = append(args, *opts.SourceID)
1157-
}
1148+
conditions, args = appendSourceFilter(conditions, args, "msg.", opts.SourceID, opts.SourceIDs)
11581149

11591150
if opts.WithAttachmentsOnly {
11601151
conditions = append(conditions, "msg.has_attachments = 1")
@@ -2343,10 +2334,7 @@ func (e *DuckDBEngine) buildSearchConditions(q *search.Query, filter MessageFilt
23432334
conditions = append(conditions, emailOnlyFilterMsg)
23442335

23452336
// Apply basic filter conditions (ignoring join flags for search - we handle those differently)
2346-
if filter.SourceID != nil {
2347-
conditions = append(conditions, "msg.source_id = ?")
2348-
args = append(args, *filter.SourceID)
2349-
}
2337+
conditions, args = appendSourceFilter(conditions, args, "msg.", filter.SourceID, filter.SourceIDs)
23502338
if filter.After != nil {
23512339
conditions = append(conditions, "msg.sent_at >= CAST(? AS TIMESTAMP)")
23522340
args = append(args, filter.After.Format("2006-01-02 15:04:05"))

0 commit comments

Comments
 (0)