Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 118 additions & 14 deletions huma.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func findParams(registry Registry, op *Operation, t reflect.Type) *findResult[*p
}

return pfi
}, false, "Body")
}, true, "Body")
}

// findResolvers searches a given type for resolvers matching a specified resolverType.
Expand Down Expand Up @@ -306,6 +306,27 @@ type headerInfo struct {
TimeFormat string
}

// isPromotedField reports whether the field identified by the index path is
// reachable from the root type purely through embedded (anonymous) structs, so
// Go promotes it as if it were declared at the top level. Such fields keep the
// legacy behavior of being auto-named as a header from the field name; named
// (non-embedded) nested struct fields must opt in with an explicit `header`
// tag instead.
func isPromotedField(root reflect.Type, path []int) bool {
t := baseType(root)
for _, idx := range path[:len(path)-1] {
if t.Kind() != reflect.Struct || idx >= t.NumField() {
return false
}
f := t.Field(idx)
if !f.Anonymous {
return false
}
t = baseType(f.Type)
}
return true
}

// findHeaders extracts header-related metadata from a given struct type using reflection.
// It returns a findResult containing headerInfo for fields tagged with "header" or
// defaulting to field names. Embedded fields or fields named "Status" and "Body" are
Expand All @@ -319,6 +340,20 @@ func findHeaders(t reflect.Type) *findResult[*headerInfo] {

header := sf.Tag.Get("header")
if header == "" {
// Only auto-name a header from the field name for "surface level"
// fields: literal top-level fields and fields promoted via embedded
// structs. Named nested struct fields must use an explicit `header`
// tag, which is handled by recursion above.
if !isPromotedField(t, i) {
return nil
}

// Never name a header after a struct we recurse into.
fieldType := baseType(sf.Type)
if fieldType.Kind() == reflect.Struct && fieldType != timeType {
return nil
}

header = sf.Name
}

Expand All @@ -331,7 +366,7 @@ func findHeaders(t reflect.Type) *findResult[*headerInfo] {
}

return &headerInfo{sf, header, timeFormat}
}, false, "Status", "Body")
}, true, "Status", "Body")
}

type findResultPath[T comparable] struct {
Expand Down Expand Up @@ -382,6 +417,65 @@ func (r *findResult[T]) Every(v reflect.Value, f func(reflect.Value, T)) {
}
}

// everyAlloc behaves like every, but allocates nil pointers encountered along
// the path so nested input fields can be populated. A pointer this call
// allocated is reset to nil when nothing below it was set, so an optional nested
// group that received no values stays nil rather than becoming an empty struct.
// The callback reports whether it set a value; everyAlloc returns whether
// anything below the current node was set.
func (r *findResult[T]) everyAlloc(current reflect.Value, path []int, v T, f func(reflect.Value, T) bool) bool {
if len(path) == 0 {
return f(current, v)
}

var allocated reflect.Value
if current.Kind() == reflect.Pointer {
if current.IsNil() {
if !current.CanSet() {
return false
}
current.Set(reflect.New(current.Type().Elem()))
allocated = current
}
current = current.Elem()
}

if current.Kind() == reflect.Invalid {
return false
}

set := false
switch current.Kind() {
case reflect.Struct:
set = r.everyAlloc(current.Field(path[0]), path[1:], v, f)
case reflect.Slice:
for j := 0; j < current.Len(); j++ {
if r.everyAlloc(current.Index(j), path, v, f) {
set = true
}
}
case reflect.Map:
for _, k := range current.MapKeys() {
if r.everyAlloc(current.MapIndex(k), path, v, f) {
set = true
}
}
default:
panic("unsupported")
}

if allocated.IsValid() && !set {
allocated.Set(reflect.Zero(allocated.Type()))
}
return set
}

func (r *findResult[T]) EveryAlloc(v reflect.Value, f func(reflect.Value, T) bool) {
for i := range r.Paths {
r.everyAlloc(v, r.Paths[i].Path, r.Paths[i].Value, f)
}
}

// jsonName extracts the JSON name from a struct field or converts the field name
// to lowercase if no JSON tag is present.
func jsonName(field reflect.StructField) string {
Expand Down Expand Up @@ -740,6 +834,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if outputType.Kind() != reflect.Struct {
panic("output must be a struct")
}

outHeaders, outStatusIndex, outBodyIndex, outBodyFunc := processOutputType(outputType, &op, registry)

if len(op.Errors) > 0 {
Expand Down Expand Up @@ -838,10 +933,10 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
}
}

inputParams.Every(v, func(f reflect.Value, p *paramFieldInfo) {
inputParams.EveryAlloc(v, func(f reflect.Value, p *paramFieldInfo) bool {
f = reflect.Indirect(f)
if f.Kind() == reflect.Invalid {
return
return false
}

pb.Reset()
Expand All @@ -860,7 +955,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
// Special case: http.Cookie type, meaning we want the entire parsed
// cookie struct, not just the value.
f.Set(reflect.ValueOf(c).Elem())
return
return true
}
}

Expand All @@ -883,7 +978,7 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if !op.SkipValidateParams && p.Required {
res.Add(pb, "", "required "+p.Loc+" parameter is missing")
}
return
return false
}
pv = setDeepObjectValue(pb, res, receiver, value)
} else {
Expand All @@ -894,13 +989,13 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
// Path params are always required.
res.Add(pb, "", "required "+p.Loc+" parameter is missing")
}
return
return false
}
var err error
pv, err = parseInto(ctx, receiver, value, nil, *p)
if err != nil {
res.Add(pb, value, err.Error())
return
return false
}
}

Expand All @@ -911,6 +1006,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if !op.SkipValidateParams {
Validate(oapi.Components.Schemas, p.Schema, pb, ModeWriteToServer, pv, res)
}

return true
})

// Read input body if defined.
Expand Down Expand Up @@ -938,15 +1035,15 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
formValueParser = func(val reflect.Value) {}
} else {
formValueParser = func(val reflect.Value) {
rawBodyInputParams.Every(val, func(f reflect.Value, p *paramFieldInfo) {
rawBodyInputParams.EveryAlloc(val, func(f reflect.Value, p *paramFieldInfo) bool {
f = reflect.Indirect(f)
if f.Kind() == reflect.Invalid {
return
return false
}

// Skip FormFile and []FormFile fields as they are handled separately.
if p.Type == formFileType || p.Type == formFilesType {
return
return false
}

pb.Reset()
Expand All @@ -959,14 +1056,14 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if !op.SkipValidateParams && p.Required && !isFile {
res.Add(pb, "", "required "+p.Loc+" parameter is missing")
}
return
return false
}

// Validation should fail if multiple values are
// provided but the type of f is not a slice.
if len(value) > 1 && f.Type().Kind() != reflect.Slice {
res.Add(pb, value, "expected at most one value, but received multiple values")
return
return false
}
pv, err := parseInto(ctx, f, value[0], value, *p)
if err != nil {
Expand All @@ -976,6 +1073,8 @@ func Register[I, O any](api API, op Operation, handler func(context.Context, *I)
if !op.SkipValidateParams {
Validate(oapi.Components.Schemas, p.Schema, pb, ModeWriteToServer, pv, res)
}

return true
})
}
}
Expand Down Expand Up @@ -1517,7 +1616,7 @@ func setRequestBodyRequired(rb *RequestBody) {
rb.Required = true
}

// processOutputType validates the output type, extracts possible responses and
// processOutputType validates the output type, extracts possible responses, and
// defines them on the operation op.
func processOutputType(outputType reflect.Type, op *Operation, registry Registry) (*findResult[*headerInfo], int, int, bool) {
outStatusIndex := -1
Expand Down Expand Up @@ -1593,6 +1692,7 @@ func processOutputType(outputType reflect.Type, op *Operation, registry Registry
Description: http.StatusText(op.DefaultStatus),
}
}

outHeaders := findHeaders(outputType)
for _, entry := range outHeaders.Paths {
v := entry.Value
Expand All @@ -1619,21 +1719,25 @@ func processOutputType(outputType reflect.Type, op *Operation, registry Registry
if op.Responses[defaultStatusStr].Headers == nil {
op.Responses[defaultStatusStr].Headers = map[string]*Param{}
}

f := v.Field
if f.Type.Kind() == reflect.Slice {
f.Type = deref(f.Type.Elem())
}

if reflect.PointerTo(f.Type).Implements(fmtStringerType) {
// Special case: this field will be written as a string by calling
// `.String()` on the value.
f.Type = stringType
}

op.Responses[defaultStatusStr].Headers[v.Name] = &Header{
// We need to generate the schema from the field to get validation info
// like min/max and enums. Useful to let the client know possible values.
Schema: SchemaFromField(registry, f, getHint(outputType, f.Name, op.OperationID+defaultStatusStr+v.Name)),
}
}

return outHeaders, outStatusIndex, outBodyIndex, outBodyFunc
}

Expand Down
111 changes: 111 additions & 0 deletions huma_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package huma

import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"
)

// TestEveryAlloc exercises findResult.everyAlloc directly. The public input
// path always operates on a freshly-zeroed input value, so slice/map elements
// along a param path are empty and never iterated at runtime. These branches
// exist for parity with every and to gracefully handle param paths that pass
// through slices/maps, so they're covered here with populated containers.
func TestEveryAlloc(t *testing.T) {
type leaf struct {
Val string
}

t.Run("allocates pointer and keeps it when a value is set", func(t *testing.T) {
type container struct {
Ptr *leaf
}
var c container
r := &findResult[int]{}
set := r.everyAlloc(reflect.ValueOf(&c).Elem(), []int{0, 0}, 1, func(v reflect.Value, _ int) bool {
v.SetString("set")
return true
})
assert.True(t, set)
if assert.NotNil(t, c.Ptr) {
assert.Equal(t, "set", c.Ptr.Val)
}
})

t.Run("rolls an allocated pointer back to nil when nothing is set", func(t *testing.T) {
type container struct {
Ptr *leaf
}
var c container
r := &findResult[int]{}
set := r.everyAlloc(reflect.ValueOf(&c).Elem(), []int{0, 0}, 1, func(v reflect.Value, _ int) bool {
return false
})
assert.False(t, set)
assert.Nil(t, c.Ptr, "pointer the call allocated should be reset to nil")
})

t.Run("reuses an already-allocated pointer without rolling it back", func(t *testing.T) {
type container struct {
Ptr *leaf
}
c := container{Ptr: &leaf{Val: "existing"}}
r := &findResult[int]{}
set := r.everyAlloc(reflect.ValueOf(&c).Elem(), []int{0, 0}, 1, func(v reflect.Value, _ int) bool {
return false
})
assert.False(t, set)
// We did not allocate the pointer, so it must be left untouched.
if assert.NotNil(t, c.Ptr) {
assert.Equal(t, "existing", c.Ptr.Val)
}
})

t.Run("recurses into slice elements", func(t *testing.T) {
type container struct {
Items []leaf
}
c := container{Items: []leaf{{}, {}}}
r := &findResult[int]{}
count := 0
set := r.everyAlloc(reflect.ValueOf(&c).Elem(), []int{0, 0}, 1, func(v reflect.Value, _ int) bool {
v.SetString("x")
count++
return true
})
assert.True(t, set)
assert.Equal(t, 2, count)
assert.Equal(t, "x", c.Items[1].Val)
})

t.Run("recurses into map elements", func(t *testing.T) {
type container struct {
M map[string]leaf
}
c := container{M: map[string]leaf{"a": {}, "b": {}}}
r := &findResult[int]{}
count := 0
set := r.everyAlloc(reflect.ValueOf(&c).Elem(), []int{0, 0}, 1, func(v reflect.Value, _ int) bool {
count++
return true
})
assert.True(t, set)
assert.Equal(t, 2, count)
})

t.Run("panics on an unsupported kind in the path", func(t *testing.T) {
type container struct {
N int
}
var c container
r := &findResult[int]{}
assert.PanicsWithValue(t, "unsupported", func() {
// Path descends into the int field, which is neither a struct,
// slice, map, nor pointer.
r.everyAlloc(reflect.ValueOf(&c).Elem(), []int{0, 0}, 1, func(v reflect.Value, _ int) bool {
return true
})
})
})
}
Loading
Loading