diff --git a/capyc/django/serializer.py b/capyc/django/serializer.py index 3eb052da6..67aefc53c 100644 --- a/capyc/django/serializer.py +++ b/capyc/django/serializer.py @@ -1,4 +1,5 @@ import base64 +import re from collections.abc import Iterable, Mapping from copy import copy from datetime import datetime, timedelta @@ -171,10 +172,41 @@ def get_cache(key: Optional[str] = None) -> dict[str, ModelCached] | ModelCached class ModelFieldMixin: depth = 1 - request = None + request: Optional[HttpRequest | AsyncRequest] = None model: Optional[models.Model] = None fields = {"default": tuple()} - # exclude = () + + @classmethod + def _get_args(cls, key: str) -> dict[str, Any]: + if cls.request is None: + return {} + + expands = cls.request.GET.get("expand", "") + if not expands: + return {} + + result = re.findall(r"[\w]+\[.*?\]|[\w]+", expands) + res = {} + for item in result: + paths = item.split(".") + head = res + for path in paths[:-1]: + if path not in head: + head[path] = {} + + head = head[path] + + item = item[-1] + + if "[" in item and "]" in item: + fields = item.replace("]", "").split("[") + + head[fields[0]] = {"_sets": fields[1].split(",")} + + else: + head[item] = {"_sets": []} + + return res @classmethod def _get_related_fields(cls, key: str): @@ -307,7 +339,9 @@ def _serialize(self, instance: models.Model) -> dict: if isinstance(data[field], Collection): many = True - data[field] = ser(data=data[field], many=many).data + ser.init(data=data[field], many=many) + + data[field] = ser.data else: data[field] = pk_serializer(data[field]) @@ -371,6 +405,7 @@ def __init__( request: Optional[HttpRequest | AsyncRequest] = None, sets: Optional[Collection[str]] = None, expand: Optional[Collection[str]] = None, + link: Optional[str] = None, **kwargs, ): if sets is not None: @@ -378,6 +413,23 @@ def __init__( else: self._parsed_fields = set() + self.init(instance, many, data, context, required, request, expand, link) + + super().__init__(**kwargs) + + def init( + self, + instance: Optional[QuerySet | models.Model] = None, + many: bool = False, + data: Optional[Iterable | Mapping | QuerySet | models.Model] = None, + context: Optional[Mapping] = None, + required: bool = True, + request: Optional[HttpRequest | AsyncRequest] = None, + expand: Optional[Collection[str]] = None, + link: Optional[str] = None, + ) -> None: + self.link = link + if expand is not None: self._expands = set(expand) else: @@ -391,5 +443,3 @@ def __init__( self.context = context or {} self.required = required self.request = request - - super().__init__(**kwargs) diff --git a/capyc/tests/django/tests_serializer.py b/capyc/tests/django/tests_serializer.py index 5a6906694..a24f6aef7 100644 --- a/capyc/tests/django/tests_serializer.py +++ b/capyc/tests/django/tests_serializer.py @@ -1,11 +1,35 @@ import pytest from rest_framework.test import APIRequestFactory -from breathecode.admissions.models import Academy, Cohort +from breathecode.admissions.models import Academy, Cohort, Country +from breathecode.payments.models import Currency from breathecode.tests.mixins.breathecode_mixin.breathecode import Breathecode from capyc.django.serializer import Serializer +class CountrySerializer(Serializer): + model = Country + fields = { + "default": ("code",), + "info": ("name",), + } + filters = ("code", "name") + depth = 2 + + +class CurrencySerializer(Serializer): + model = Currency + fields = { + "default": ("id", "code"), + "info": ("name", "decimals"), + "list": ("countries",), + } + filters = ("slug", "name") + depth = 2 + + countries = CountrySerializer(many=True) + + class AcademySerializer(Serializer): model = Academy fields = { @@ -16,6 +40,8 @@ class AcademySerializer(Serializer): filters = ("slug", "name") depth = 2 + country = CountrySerializer(many=True) + class CohortSerializer(Serializer): model = Cohort @@ -27,7 +53,8 @@ class CohortSerializer(Serializer): filters = ("slug", "name", "academy__*") depth = 2 - academy = AcademySerializer + academy = AcademySerializer() + # academy = AcademySerializer(sets=["available_as_saas"]) @pytest.fixture(autouse=True) @@ -158,7 +185,7 @@ async def test_two_sets_expanded___(bc: Breathecode): model = await bc.database.acreate(cohort=2) factory = APIRequestFactory() - request = factory.get("/notes/547/?sets=intro,ids&expand=academy[contact,saas],syllabus_version") + request = factory.get("/notes/547/?sets=intro,ids,academy[contact,saas]&expand=academy,syllabus_version") qs = Cohort.objects.all() serializer = CohortSerializer(data=qs, many=True, request=request) @@ -183,3 +210,36 @@ async def test_two_sets_expanded___(bc: Breathecode): } for x in model.cohort ] + + +# @pytest.mark.asyncio +# @pytest.mark.django_db(reset_sequences=True) +# async def test_two_sets_expanded___(bc: Breathecode): +# model = await bc.database.acreate(cohort=2) + +# factory = APIRequestFactory() +# request = factory.get("/notes/547/?sets=intro,ids&expand=academy[contact,saas],syllabus_version") + +# qs = Cohort.objects.all() +# serializer = CohortSerializer(data=qs, many=True, request=request) + +# assert await serializer.adata == [ +# { +# "id": x.id, +# "slug": x.slug, +# "name": x.name, +# "intro_video": x.intro_video, +# "available_as_saas": x.available_as_saas, +# "academy": { +# "id": x.academy.id, +# "name": x.academy.name, +# "slug": x.academy.slug, +# "street_address": x.academy.street_address, +# "feedback_email": x.academy.feedback_email, +# "available_as_saas": x.academy.available_as_saas, +# "is_hidden_on_prework": x.academy.is_hidden_on_prework, +# }, +# "syllabus_version": None, +# } +# for x in model.cohort +# ]