Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/upload file max size config #2798

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions docs/requests.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ async with request.form(max_files=1000, max_fields=1000):
...
```

You can configure maximum size per file uploaded with the parameter `max_file_size`:

```python
async with request.form(max_file_size=100*1024*1024): # 100 MB limit per file
...
```

!!! info
These limits are for security reasons, allowing an unlimited number of fields or files could lead to a denial of service attack by consuming a lot of CPU and memory parsing too many empty fields.

Expand Down
19 changes: 14 additions & 5 deletions starlette/formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ async def parse(self) -> FormData:


class MultiPartParser:
max_file_size = 1024 * 1024 # 1MB
max_part_size = 1024 * 1024 # 1MB
default_max_field_size = 1024 * 1024 # 1MB
default_max_file_mem_size = 1024 * 1024 # 1MB
default_max_file_disk_size = 1024 * 1024 * 1024 # 1GB

def __init__(
self,
Expand All @@ -132,12 +133,18 @@ def __init__(
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_field_size: int | float | None = None,
max_file_mem_size: int | float | None = None,
max_file_disk_size: int | float | None = None,
) -> None:
assert multipart is not None, "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.max_files = max_files
self.max_fields = max_fields
self.max_field_size: int | float = max_field_size or self.default_max_field_size
self.max_file_mem_size: int | float = max_file_mem_size or self.default_max_file_mem_size
self.max_file_disk_size: int | float = max_file_disk_size or self.default_max_file_disk_size
self.items: list[tuple[str, str | UploadFile]] = []
self._current_files = 0
self._current_fields = 0
Expand All @@ -155,8 +162,8 @@ def on_part_begin(self) -> None:
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message_bytes = data[start:end]
if self._current_part.file is None:
if len(self._current_part.data) + len(message_bytes) > self.max_part_size:
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_part_size / 1024)}KB.")
if len(self._current_part.data) + len(message_bytes) > self.max_field_size:
raise MultiPartException(f"Part exceeded maximum size of {int(self.max_field_size / 1024)}KB.")
self._current_part.data.extend(message_bytes)
else:
self._file_parts_to_write.append((self._current_part, message_bytes))
Expand Down Expand Up @@ -201,7 +208,7 @@ def on_headers_finished(self) -> None:
if self._current_files > self.max_files:
raise MultiPartException(f"Too many files. Maximum number of files is {self.max_files}.")
filename = _user_safe_decode(options[b"filename"], self._charset)
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
tempfile = SpooledTemporaryFile(max_size=self.max_file_mem_size)
self._files_to_close_on_error.append(tempfile)
self._current_part.file = UploadFile(
file=tempfile, # type: ignore[arg-type]
Expand Down Expand Up @@ -255,6 +262,8 @@ async def parse(self) -> FormData:
# the main thread.
for part, data in self._file_parts_to_write:
assert part.file # for type checkers
if part.file.size is not None and part.file.size + len(data) > self.max_file_disk_size:
raise MultiPartException(f"File exceeds maximum size of {self.max_file_disk_size} bytes.")
await part.file.write(data)
for part in self._file_parts_to_finish:
assert part.file # for type checkers
Expand Down
41 changes: 38 additions & 3 deletions starlette/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,15 @@ async def json(self) -> typing.Any:
self._json = json.loads(body)
return self._json

async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | float = 1000) -> FormData:
async def _get_form(
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_field_size: int | float | None,
max_file_mem_size: int | float | None,
max_file_disk_size: int | float | None,
) -> FormData:
if self._form is None:
assert (
parse_options_header is not None
Expand All @@ -264,6 +272,9 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
self.stream(),
max_files=max_files,
max_fields=max_fields,
max_field_size=max_field_size,
max_file_mem_size=max_file_mem_size,
max_file_disk_size=max_file_disk_size,
)
self._form = await multipart_parser.parse()
except MultiPartException as exc:
Expand All @@ -278,9 +289,33 @@ async def _get_form(self, *, max_files: int | float = 1000, max_fields: int | fl
return self._form

def form(
self, *, max_files: int | float = 1000, max_fields: int | float = 1000
self,
*,
max_files: int | float = 1000,
max_fields: int | float = 1000,
max_field_size: int | float | None = None,
max_file_mem_size: int | float | None = None,
max_file_disk_size: int | float | None = None,
) -> AwaitableOrContextManager[FormData]:
return AwaitableOrContextManagerWrapper(self._get_form(max_files=max_files, max_fields=max_fields))
"""
Return a FormData instance, representing the form data in the request.

:param max_files: The maximum number of files that can be parsed.
:param max_fields: The maximum number of fields that can be parsed.
:param max_field_size: The maximum size of each field part in bytes.
:param max_file_mem_size: The maximum memory size for each file part in bytes.
:param max_file_disk_size: The maximum disk size for each file part in bytes.
https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile
"""
return AwaitableOrContextManagerWrapper(
self._get_form(
max_files=max_files,
max_fields=max_fields,
max_field_size=max_field_size,
max_file_mem_size=max_file_mem_size,
max_file_disk_size=max_file_disk_size,
)
)

async def close(self) -> None:
if self._form is not None: # pragma: no branch
Expand Down
33 changes: 33 additions & 0 deletions tests/test_formparsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing
from contextlib import nullcontext as does_not_raise
from pathlib import Path
from secrets import token_bytes

import pytest

Expand Down Expand Up @@ -127,6 +128,14 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
return app


def make_app_max_file_size(max_file_size: int) -> ASGIApp:
async def app(scope: Scope, receive: Receive, send: Send) -> None:
request = Request(scope, receive)
await request.form(max_file_size=max_file_size)

return app


def test_multipart_request_data(tmpdir: Path, test_client_factory: TestClientFactory) -> None:
client = test_client_factory(app)
response = client.post("/", data={"some": "data"}, files=FORCE_MULTIPART)
Expand Down Expand Up @@ -580,6 +589,30 @@ def test_too_many_files_and_fields_raise(
assert res.text == "Too many files. Maximum number of files is 1000."


@pytest.mark.parametrize(
"app,expectation",
[
(make_app_max_file_size(1024), pytest.raises(MultiPartException)),
(Starlette(routes=[Mount("/", app=make_app_max_file_size(1024))]), does_not_raise()),
],
)
def test_max_part_file_size_raise(
tmpdir: Path,
app: ASGIApp,
expectation: typing.ContextManager[Exception],
test_client_factory: TestClientFactory,
) -> None:
path = os.path.join(tmpdir, "test.txt")
with open(path, "wb") as file:
file.write(token_bytes(1024 + 1))

client = test_client_factory(app)
with open(path, "rb") as f, expectation:
response = client.post("/", files={"test": f})
assert response.status_code == 400
assert response.text == "File exceeds maximum size of 1024 bytes."


@pytest.mark.parametrize(
"app,expectation",
[
Expand Down
Loading