Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: split up huma.Register #705

Merged
merged 16 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 66 additions & 67 deletions formdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,51 @@ func (v MimeTypeValidator) Validate(fh *multipart.FileHeader, location string) (
}
}

func (m *MultipartFormFiles[T]) readFile(
fh *multipart.FileHeader,
location string,
validator MimeTypeValidator,
) (FormFile, *ErrorDetail) {
f, err := fh.Open()
if err != nil {
return FormFile{}, &ErrorDetail{Message: "Failed to open file", Location: location}
}
contentType, validationErr := validator.Validate(fh, location)
if validationErr != nil {
return FormFile{}, validationErr
func (m *MultipartFormFiles[T]) Data() *T {
return m.data
}

// Decodes multipart.Form data into *T, returning []*ErrorDetail if any
// Schema is used to check for validation constraints
func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error {
var (
dataType = reflect.TypeOf(m.data).Elem()
value = reflect.New(dataType)
errors []error
)
for i := 0; i < dataType.NumField(); i++ {
field := value.Elem().Field(i)
structField := dataType.Field(i)
key := structField.Tag.Get("form")
if key == "" {
key = structField.Name
}
fileHeaders := m.Form.File[key]
switch {
case field.Type() == reflect.TypeOf(FormFile{}):
file, err := readSingleFile(fileHeaders, key, opMediaType)
if err != nil {
errors = append(errors, err)
continue
}
field.Set(reflect.ValueOf(file))
case field.Type() == reflect.TypeOf([]FormFile{}):
files, errs := readMultipleFiles(fileHeaders, key, opMediaType)
if errs != nil {
errors = append(errors, errs...)
continue
}
field.Set(reflect.ValueOf(files))

default:
continue
}
}
return FormFile{
File: f,
ContentType: contentType,
IsSet: true,
Size: fh.Size,
Filename: fh.Filename,
}, nil
m.data = value.Interface().(*T)
return errors
}

func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaType) (FormFile, *ErrorDetail) {
fileHeaders := m.Form.File[key]
func readSingleFile(fileHeaders []*multipart.FileHeader, key string, opMediaType *MediaType) (FormFile, *ErrorDetail) {
if len(fileHeaders) == 0 {
if opMediaType.Schema.requiredMap[key] {
return FormFile{}, &ErrorDetail{Message: "File required", Location: key}
Expand All @@ -117,16 +138,15 @@ func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaTyp
}
} else if len(fileHeaders) == 1 {
validator := NewMimeTypeValidator(opMediaType.Encoding[key])
return m.readFile(fileHeaders[0], key, validator)
return readFile(fileHeaders[0], key, validator)
}
return FormFile{}, &ErrorDetail{
Message: "Multiple files received but only one was expected",
Location: key,
}
}

func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *MediaType) ([]FormFile, []error) {
fileHeaders := m.Form.File[key]
func readMultipleFiles(fileHeaders []*multipart.FileHeader, key string, opMediaType *MediaType) ([]FormFile, []error) {
var (
files = make([]FormFile, len(fileHeaders))
errors []error
Expand All @@ -136,7 +156,7 @@ func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *Media
}
validator := NewMimeTypeValidator(opMediaType.Encoding[key])
for i, fh := range fileHeaders {
file, err := m.readFile(
file, err := readFile(
fh,
fmt.Sprintf("%s[%d]", key, i),
validator,
Expand All @@ -150,47 +170,26 @@ func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *Media
return files, errors
}

func (m *MultipartFormFiles[T]) Data() *T {
return m.data
}

// Decodes multipart.Form data into *T, returning []*ErrorDetail if any
// Schema is used to check for validation constraints
func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error {
var (
dataType = reflect.TypeOf(m.data).Elem()
value = reflect.New(dataType)
errors []error
)
for i := 0; i < dataType.NumField(); i++ {
field := value.Elem().Field(i)
structField := dataType.Field(i)
key := structField.Tag.Get("form")
if key == "" {
key = structField.Name
}
switch {
case field.Type() == reflect.TypeOf(FormFile{}):
file, err := m.readSingleFile(key, opMediaType)
if err != nil {
errors = append(errors, err)
continue
}
field.Set(reflect.ValueOf(file))
case field.Type() == reflect.TypeOf([]FormFile{}):
files, errs := m.readMultipleFiles(key, opMediaType)
if errs != nil {
errors = append(errors, errs...)
continue
}
field.Set(reflect.ValueOf(files))

default:
continue
}
func readFile(
fh *multipart.FileHeader,
location string,
validator MimeTypeValidator,
) (FormFile, *ErrorDetail) {
f, err := fh.Open()
if err != nil {
return FormFile{}, &ErrorDetail{Message: "Failed to open file", Location: location}
}
m.data = value.Interface().(*T)
return errors
contentType, validationErr := validator.Validate(fh, location)
if validationErr != nil {
return FormFile{}, validationErr
}
return FormFile{
File: f,
ContentType: contentType,
IsSet: true,
Size: fh.Size,
Filename: fh.Filename,
}, nil
}

func formDataFieldName(f reflect.StructField) string {
Expand All @@ -208,7 +207,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema {
Properties: make(map[string]*Schema, nFields),
requiredMap: make(map[string]bool, nFields),
}
requiredFields := make([]string, nFields)
requiredFields := make([]string, 0, nFields)
for i := 0; i < nFields; i++ {
f := t.Field(i)
name := formDataFieldName(f)
Expand All @@ -227,7 +226,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema {
}

if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required", false) {
requiredFields[i] = name
requiredFields = append(requiredFields, name)
schema.requiredMap[name] = true
}
}
Expand Down
Loading
Loading