Skip to content

Commit dbe3383

Browse files
refactor(prompt): fit builders through template context
1 parent 275f066 commit dbe3383

3 files changed

Lines changed: 141 additions & 68 deletions

File tree

internal/prompt/prompt.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func (b *Builder) BuildDirty(repoPath, diff string, repoID int64, contextCount i
216216
}
217217
}
218218

219-
body, err := fitDirtyPrompt(bodyLimit, view)
219+
body, err := fitDirtyPromptContext(bodyLimit, templateContextFromDirtyView(view))
220220
if err != nil {
221221
return "", err
222222
}
@@ -565,13 +565,13 @@ func (b *Builder) buildSinglePrompt(repoPath, sha string, repoID int64, contextC
565565
diffView.Body = inlineDiff
566566
}
567567

568-
body, err := fitSinglePrompt(
568+
body, err := fitSinglePromptContext(
569569
bodyLimit,
570-
singlePromptView{
570+
templateContextFromSingleView(singlePromptView{
571571
Optional: optional,
572572
Current: currentView,
573573
Diff: diffView,
574-
},
574+
}),
575575
)
576576
if err != nil {
577577
return "", err
@@ -683,13 +683,13 @@ func (b *Builder) buildRangePrompt(repoPath, rangeRef string, repoID int64, cont
683683
diffView.Body = inlineDiff
684684
}
685685

686-
body, err := fitRangePrompt(
686+
body, err := fitRangePromptContext(
687687
bodyLimit,
688-
rangePromptView{
688+
templateContextFromRangeView(rangePromptView{
689689
Optional: optional,
690690
Current: currentView,
691691
Diff: diffView,
692-
},
692+
}),
693693
)
694694
if err != nil {
695695
return "", err

internal/prompt/prompt_body_templates.go

Lines changed: 127 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -276,33 +276,49 @@ func templateContextFromSystemView(view systemPromptView) TemplateContext {
276276
return TemplateContext{System: &SystemTemplateContext{NoSkillsInstruction: view.NoSkillsInstruction, CurrentDate: view.CurrentDate}}
277277
}
278278

279+
func renderSinglePromptContext(ctx TemplateContext) (string, error) {
280+
return executePromptTemplate("assembled_single.md.gotmpl", ctx)
281+
}
282+
283+
func renderRangePromptContext(ctx TemplateContext) (string, error) {
284+
return executePromptTemplate("assembled_range.md.gotmpl", ctx)
285+
}
286+
287+
func renderDirtyPromptContext(ctx TemplateContext) (string, error) {
288+
return executePromptTemplate("assembled_dirty.md.gotmpl", ctx)
289+
}
290+
291+
func renderAddressPromptContext(ctx TemplateContext) (string, error) {
292+
return executePromptTemplate("assembled_address.md.gotmpl", ctx)
293+
}
294+
279295
func renderSinglePrompt(view singlePromptView) (string, error) {
280-
return executePromptTemplate("assembled_single.md.gotmpl", templateContextFromSingleView(view))
296+
return renderSinglePromptContext(templateContextFromSingleView(view))
281297
}
282298

283299
func renderRangePrompt(view rangePromptView) (string, error) {
284-
return executePromptTemplate("assembled_range.md.gotmpl", templateContextFromRangeView(view))
300+
return renderRangePromptContext(templateContextFromRangeView(view))
285301
}
286302

287303
func renderDirtyPrompt(view dirtyPromptView) (string, error) {
288-
return executePromptTemplate("assembled_dirty.md.gotmpl", templateContextFromDirtyView(view))
304+
return renderDirtyPromptContext(templateContextFromDirtyView(view))
289305
}
290306

291307
func renderAddressPrompt(view addressPromptView) (string, error) {
292-
return executePromptTemplate("assembled_address.md.gotmpl", templateContextFromAddressView(view))
308+
return renderAddressPromptContext(templateContextFromAddressView(view))
293309
}
294310

295-
func fitSinglePrompt(limit int, view singlePromptView) (string, error) {
296-
body, err := renderSinglePrompt(view)
311+
func fitSinglePromptContext(limit int, ctx TemplateContext) (string, error) {
312+
body, err := renderSinglePromptContext(ctx)
297313
if err != nil {
298314
return "", err
299315
}
300316
if len(body) <= limit {
301317
return body, nil
302318
}
303319

304-
for trimOptionalSections(&view.Optional) {
305-
body, err = renderSinglePrompt(view)
320+
for ctx.Review != nil && ctx.Review.Optional.TrimNext() {
321+
body, err = renderSinglePromptContext(ctx)
306322
if err != nil {
307323
return "", err
308324
}
@@ -311,9 +327,8 @@ func fitSinglePrompt(limit int, view singlePromptView) (string, error) {
311327
}
312328
}
313329

314-
if view.Current.Message != "" {
315-
view.Current.Message = ""
316-
body, err = renderSinglePrompt(view)
330+
if ctx.Review != nil && ctx.Review.Subject.TrimSingleMessage() {
331+
body, err = renderSinglePromptContext(ctx)
317332
if err != nil {
318333
return "", err
319334
}
@@ -322,9 +337,8 @@ func fitSinglePrompt(limit int, view singlePromptView) (string, error) {
322337
}
323338
}
324339

325-
if view.Current.Author != "" {
326-
view.Current.Author = ""
327-
body, err = renderSinglePrompt(view)
340+
if ctx.Review != nil && ctx.Review.Subject.TrimSingleAuthor() {
341+
body, err = renderSinglePromptContext(ctx)
328342
if err != nil {
329343
return "", err
330344
}
@@ -333,10 +347,10 @@ func fitSinglePrompt(limit int, view singlePromptView) (string, error) {
333347
}
334348
}
335349

336-
for len(body) > limit && view.Current.Subject != "" {
350+
for ctx.Review != nil && ctx.Review.Subject.Single != nil && len(body) > limit && ctx.Review.Subject.Single.Subject != "" {
337351
overflow := len(body) - limit
338-
view.Current.Subject = truncateUTF8(view.Current.Subject, max(0, len(view.Current.Subject)-overflow))
339-
body, err = renderSinglePrompt(view)
352+
ctx.Review.Subject.TrimSingleSubjectTo(max(0, len(ctx.Review.Subject.Single.Subject)-overflow))
353+
body, err = renderSinglePromptContext(ctx)
340354
if err != nil {
341355
return "", err
342356
}
@@ -345,76 +359,95 @@ func fitSinglePrompt(limit int, view singlePromptView) (string, error) {
345359
return hardCapPrompt(body, limit), nil
346360
}
347361

348-
func fitRangePrompt(limit int, view rangePromptView) (string, error) {
349-
_, body, err := trimRangePromptView(limit, view)
362+
func fitSinglePrompt(limit int, view singlePromptView) (string, error) {
363+
return fitSinglePromptContext(limit, templateContextFromSingleView(view))
364+
}
365+
366+
func fitRangePromptContext(limit int, ctx TemplateContext) (string, error) {
367+
_, body, err := trimRangePromptContext(limit, ctx)
350368
if err != nil {
351369
return "", err
352370
}
353371
return hardCapPrompt(body, limit), nil
354372
}
355373

356-
func cloneCommitRangeSectionView(view commitRangeSectionView) commitRangeSectionView {
357-
cloned := view
358-
if len(view.Entries) == 0 {
359-
return cloned
360-
}
361-
cloned.Entries = append([]commitRangeEntryView(nil), view.Entries...)
362-
return cloned
374+
func fitRangePrompt(limit int, view rangePromptView) (string, error) {
375+
return fitRangePromptContext(limit, templateContextFromRangeView(view))
363376
}
364377

365-
func trimRangePromptView(limit int, view rangePromptView) (rangePromptView, string, error) {
366-
view.Current = cloneCommitRangeSectionView(view.Current)
367-
body, err := renderRangePrompt(view)
378+
func trimRangePromptContext(limit int, ctx TemplateContext) (TemplateContext, string, error) {
379+
ctx = ctx.Clone()
380+
body, err := renderRangePromptContext(ctx)
368381
if err != nil {
369-
return rangePromptView{}, "", err
382+
return TemplateContext{}, "", err
370383
}
371384
if len(body) <= limit {
372-
return view, body, nil
385+
return ctx, body, nil
373386
}
374387

375-
for trimOptionalSections(&view.Optional) {
376-
body, err = renderRangePrompt(view)
388+
for ctx.Review != nil && ctx.Review.Optional.TrimNext() {
389+
body, err = renderRangePromptContext(ctx)
377390
if err != nil {
378-
return rangePromptView{}, "", err
391+
return TemplateContext{}, "", err
379392
}
380393
if len(body) <= limit {
381-
return view, body, nil
394+
return ctx, body, nil
382395
}
383396
}
384397

385-
for i := len(view.Current.Entries) - 1; i >= 0 && len(body) > limit; i-- {
386-
if view.Current.Entries[i].Subject == "" {
387-
continue
388-
}
389-
view.Current.Entries[i].Subject = ""
390-
body, err = renderRangePrompt(view)
398+
for ctx.Review != nil && len(body) > limit && ctx.Review.Subject.BlankNextRangeSubject() {
399+
body, err = renderRangePromptContext(ctx)
391400
if err != nil {
392-
return rangePromptView{}, "", err
401+
return TemplateContext{}, "", err
393402
}
394403
}
395404

396-
for len(view.Current.Entries) > 0 && len(body) > limit {
397-
view.Current.Entries = view.Current.Entries[:len(view.Current.Entries)-1]
398-
body, err = renderRangePrompt(view)
405+
for ctx.Review != nil && len(body) > limit && ctx.Review.Subject.DropLastRangeEntry() {
406+
body, err = renderRangePromptContext(ctx)
399407
if err != nil {
400-
return rangePromptView{}, "", err
408+
return TemplateContext{}, "", err
401409
}
402410
}
403411

404-
return view, body, nil
412+
return ctx, body, nil
405413
}
406414

407-
func fitDirtyPrompt(limit int, view dirtyPromptView) (string, error) {
408-
body, err := renderDirtyPrompt(view)
415+
func trimRangePromptView(limit int, view rangePromptView) (rangePromptView, string, error) {
416+
trimmed, body, err := trimRangePromptContext(limit, templateContextFromRangeView(view))
417+
if err != nil {
418+
return rangePromptView{}, "", err
419+
}
420+
rangeCtx := trimmed.Review.Subject.Range
421+
if trimmed.Review == nil || rangeCtx == nil {
422+
return rangePromptView{}, body, nil
423+
}
424+
entries := make([]commitRangeEntryView, 0, len(rangeCtx.Entries))
425+
for _, entry := range rangeCtx.Entries {
426+
entries = append(entries, commitRangeEntryView(entry))
427+
}
428+
return rangePromptView{
429+
Optional: optionalSectionsView{
430+
ProjectGuidelines: buildProjectGuidelinesSectionView(trimmed.Review.Optional.ProjectGuidelinesBody()),
431+
AdditionalContext: trimmed.Review.Optional.AdditionalContext,
432+
PreviousReviews: previousReviewViewsFromTemplateContext(trimmed.Review.Optional.PreviousReviews),
433+
PreviousAttempts: reviewAttemptViewsFromTemplateContext(trimmed.Review.Optional.PreviousAttempts),
434+
},
435+
Current: commitRangeSectionView{Count: rangeCtx.Count, Entries: entries},
436+
Diff: diffSectionView{Heading: trimmed.Review.Diff.Heading, Body: trimmed.Review.Diff.Body, Fallback: trimmed.Review.Fallback.Rendered()},
437+
}, body, nil
438+
}
439+
440+
func fitDirtyPromptContext(limit int, ctx TemplateContext) (string, error) {
441+
body, err := renderDirtyPromptContext(ctx)
409442
if err != nil {
410443
return "", err
411444
}
412445
if len(body) <= limit {
413446
return body, nil
414447
}
415448

416-
for trimOptionalSections(&view.Optional) {
417-
body, err = renderDirtyPrompt(view)
449+
for ctx.Review != nil && ctx.Review.Optional.TrimNext() {
450+
body, err = renderDirtyPromptContext(ctx)
418451
if err != nil {
419452
return "", err
420453
}
@@ -426,19 +459,22 @@ func fitDirtyPrompt(limit int, view dirtyPromptView) (string, error) {
426459
return hardCapPrompt(body, limit), nil
427460
}
428461

462+
func fitDirtyPrompt(limit int, view dirtyPromptView) (string, error) {
463+
return fitDirtyPromptContext(limit, templateContextFromDirtyView(view))
464+
}
465+
429466
func trimOptionalSections(view *optionalSectionsView) bool {
430-
switch {
431-
case len(view.PreviousAttempts) > 0:
432-
view.PreviousAttempts = nil
433-
case len(view.PreviousReviews) > 0:
434-
view.PreviousReviews = nil
435-
case view.AdditionalContext != "":
436-
view.AdditionalContext = ""
437-
case view.ProjectGuidelines != nil:
438-
view.ProjectGuidelines = nil
439-
default:
467+
if view == nil {
440468
return false
441469
}
470+
ctx := reviewOptionalContextFromView(*view)
471+
if !ctx.TrimNext() {
472+
return false
473+
}
474+
view.ProjectGuidelines = buildProjectGuidelinesSectionView(ctx.ProjectGuidelinesBody())
475+
view.AdditionalContext = ctx.AdditionalContext
476+
view.PreviousReviews = previousReviewViewsFromTemplateContext(ctx.PreviousReviews)
477+
view.PreviousAttempts = reviewAttemptViewsFromTemplateContext(ctx.PreviousAttempts)
442478
return true
443479
}
444480

@@ -524,6 +560,21 @@ func renderDirtyTruncatedDiffFallback(body string) (string, error) {
524560
return executePromptTemplate("dirty_truncated_diff_fallback", dirtyTruncatedDiffFallbackView{Body: body})
525561
}
526562

563+
func previousReviewViewsFromTemplateContext(contexts []PreviousReviewTemplateContext) []previousReviewView {
564+
views := make([]previousReviewView, 0, len(contexts))
565+
for _, ctx := range contexts {
566+
view := previousReviewView{Commit: ctx.Commit, Available: ctx.Available, Output: ctx.Output}
567+
if len(ctx.Comments) > 0 {
568+
view.Comments = make([]reviewCommentView, 0, len(ctx.Comments))
569+
for _, comment := range ctx.Comments {
570+
view.Comments = append(view.Comments, reviewCommentView(comment))
571+
}
572+
}
573+
views = append(views, view)
574+
}
575+
return views
576+
}
577+
527578
func previousReviewViews(contexts []HistoricalReviewContext) []previousReviewView {
528579
views := make([]previousReviewView, 0, len(contexts))
529580
for _, ctx := range contexts {
@@ -568,6 +619,21 @@ func renderPreviousAttemptsFromReviews(reviews []storage.Review) (string, error)
568619
return renderOptionalSectionsFromView(optionalSectionsView{PreviousAttempts: reviewAttemptViews(reviews)})
569620
}
570621

622+
func reviewAttemptViewsFromTemplateContext(attempts []ReviewAttemptTemplateContext) []reviewAttemptView {
623+
views := make([]reviewAttemptView, 0, len(attempts))
624+
for _, attempt := range attempts {
625+
view := reviewAttemptView{Label: attempt.Label, Agent: attempt.Agent, When: attempt.When, Output: attempt.Output}
626+
if len(attempt.Comments) > 0 {
627+
view.Comments = make([]reviewCommentView, 0, len(attempt.Comments))
628+
for _, comment := range attempt.Comments {
629+
view.Comments = append(view.Comments, reviewCommentView(comment))
630+
}
631+
}
632+
views = append(views, view)
633+
}
634+
return views
635+
}
636+
571637
func previousAttemptViewsFromContexts(attempts []reviewAttemptContext) []reviewAttemptView {
572638
views := make([]reviewAttemptView, 0, len(attempts))
573639
for i, attempt := range attempts {

internal/prompt/template_context.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,13 @@ func (o ReviewOptionalContext) IsEmpty() bool {
8888
len(o.PreviousAttempts) == 0
8989
}
9090

91+
func (o ReviewOptionalContext) ProjectGuidelinesBody() string {
92+
if o.ProjectGuidelines == nil {
93+
return ""
94+
}
95+
return o.ProjectGuidelines.Body
96+
}
97+
9198
func (o *ReviewOptionalContext) TrimNext() bool {
9299
if o == nil {
93100
return false

0 commit comments

Comments
 (0)