diff --git a/edgy/contrib/multi_tenancy/metaclasses.py b/edgy/contrib/multi_tenancy/metaclasses.py index f40e8ad4..068f4e22 100644 --- a/edgy/contrib/multi_tenancy/metaclasses.py +++ b/edgy/contrib/multi_tenancy/metaclasses.py @@ -51,6 +51,7 @@ def __new__( bases: tuple[type, ...], attrs: Any, on_conflict: Literal["error", "replace", "keep"] = "error", + skip_registry: bool = False, **kwargs: Any, ) -> Any: database: Union[Literal["keep"], None, Database, bool] = attrs.get("database", "keep") @@ -61,7 +62,8 @@ def __new__( new_model.meta.is_tenant = _check_model_inherited_tenancy(bases) if ( - new_model.meta.registry + not skip_registry + and new_model.meta.registry and not new_model.meta.abstract and not new_model.__is_proxy_model__ ): diff --git a/edgy/core/db/models/mixins/db.py b/edgy/core/db/models/mixins/db.py index b0c369f8..a9e07085 100644 --- a/edgy/core/db/models/mixins/db.py +++ b/edgy/core/db/models/mixins/db.py @@ -296,7 +296,6 @@ def copy_edgy_model( registry: Optional[Registry] = None, name: str = "", unlink_same_registry: bool = True, - meta_info: MetaInfo | None = None, on_conflict: Literal["keep", "replace", "error"] = "error", **kwargs: Any, ) -> type[Model]: @@ -318,7 +317,7 @@ def copy_edgy_model( __name__=name or cls.__name__, __module__=cls.__module__, __definitions__=attrs, - __metadata__=meta_info, + __metadata__=cls.meta, __bases__=cls.__bases__, skip_registry=True, **kwargs, @@ -357,9 +356,9 @@ def copy_edgy_model( # unreference _copy.meta.fields[field_name].through = through_model = _copy.meta.fields[ field_name - ].through.copy_edgy_model( - meta_info=MetaInfo(registry=False), - ) + ].through.copy_edgy_model() + # we want to set the registry explicit + through_model.meta.registry = False if src_field.from_foreign_key in through_model.meta.fields: # explicit set through_model.meta.fields[src_field.from_foreign_key].target = _copy diff --git a/tests/contrib/autoreflection/test_reflecting_models.py b/tests/contrib/autoreflection/test_reflecting_models.py index aa6cc7af..f983eb43 100644 --- a/tests/contrib/autoreflection/test_reflecting_models.py +++ b/tests/contrib/autoreflection/test_reflecting_models.py @@ -101,6 +101,69 @@ class Meta: ) +async def test_basic_reflection_after_copy(): + reflected = edgy.Registry(DATABASE_URL) + + class AutoAll(AutoReflectModel): + class Meta: + registry = reflected + + class AutoNever(AutoReflectModel): + non_matching = edgy.CharField(max_length=40) + + class Meta: + registry = reflected + template = r"AutoNever" + + class AutoNever2(AutoReflectModel): + id = edgy.CharField(max_length=40, primary_key=True) + + class Meta: + registry = reflected + template = r"AutoNever2" + + class AutoNever3(AutoReflectModel): + class Meta: + registry = reflected + template = r"AutoNever3" + exclude_pattern = r".*" + + class AutoFoo(AutoReflectModel): + class Meta: + registry = reflected + include_pattern = r"^foos$" + + class AutoBar(AutoReflectModel): + class Meta: + registry = reflected + include_pattern = r"^bars" + template = r"{tablename}_{tablename}" + + assert AutoBar.meta.template + + reflected = reflected.__copy__() + + assert len(reflected.reflected) == 0 + async with reflected: + assert ( + sum( + 1 for model in reflected.reflected.values() if model.__name__.startswith("AutoAll") + ) + == 3 + ) + assert "bars_bars" in reflected.reflected + assert "AutoNever" not in reflected.reflected + assert "AutoNever2" not in reflected.reflected + assert "AutoNever3" not in reflected.reflected + + assert ( + sum( + 1 for model in reflected.reflected.values() if model.__name__.startswith("AutoFoo") + ) + == 1 + ) + + async def test_extra_reflection(): reflected = edgy.Registry(DATABASE_ALTERNATIVE_URL, extra={"another": DATABASE_URL}) diff --git a/tests/contrib/multi_tenancy/test_migrate.py b/tests/contrib/multi_tenancy/test_migrate.py index 1d65f6bb..d205947a 100644 --- a/tests/contrib/multi_tenancy/test_migrate.py +++ b/tests/contrib/multi_tenancy/test_migrate.py @@ -56,6 +56,20 @@ async def test_migrate_objs_main_only(): assert len(registry.metadata_by_name[None].tables.keys()) == 2 +async def test_migrate_objs_main_only_after_copy(): + tenant = await Tenant.query.create( + schema_name="migrate_edgy", + domain_url="https://edgy.dymmond.com", + tenant_name="migrate_edgy", + ) + + assert tenant.schema_name == "migrate_edgy" + assert tenant.tenant_name == "migrate_edgy" + + registry = edgy.get_migration_prepared_registry(models.__copy__()) + assert len(registry.metadata_by_name[None].tables.keys()) == 2 + + async def test_migrate_objs_all(): tenant = await Tenant.query.create( schema_name="migrate_edgy", @@ -79,6 +93,27 @@ async def test_migrate_objs_all(): } +async def test_migrate_objs_all_after_copy(): + tenant = await Tenant.query.create( + schema_name="migrate_edgy", + domain_url="https://edgy.dymmond.com", + tenant_name="migrate_edgy", + ) + + assert tenant.schema_name == "migrate_edgy" + assert tenant.tenant_name == "migrate_edgy" + + edgy.monkay.set_instance(Instance(registry=models.__copy__())) + with edgy.monkay.with_settings(edgy.monkay.settings.model_copy(update={"multi_schema": True})): + registry = edgy.get_migration_prepared_registry() + + assert set(registry.metadata_by_name[None].tables.keys()) == { + "tenants", + "migrate_edgy.products", + "products", + } + + async def test_migrate_objs_namespace_only(): tenant = await Tenant.query.create( schema_name="migrate_edgy", diff --git a/tests/contrib/multi_tenancy/test_mt_models.py b/tests/contrib/multi_tenancy/test_mt_models.py index 387f3bee..4487d491 100644 --- a/tests/contrib/multi_tenancy/test_mt_models.py +++ b/tests/contrib/multi_tenancy/test_mt_models.py @@ -71,6 +71,14 @@ class Meta: is_tenant = True +class Cart(TenantModel): + products = fields.ManyToMany(Product) + + class Meta: + registry = models + is_tenant = True + + async def test_create_a_tenant_schema(): tenant = await Tenant.query.create( schema_name="edgy", domain_url="https://edgy.dymmond.com", tenant_name="edgy" @@ -80,6 +88,47 @@ async def test_create_a_tenant_schema(): assert tenant.tenant_name == "edgy" +async def test_create_a_tenant_schema_copy(): + copied = models.__copy__() + tenant = await copied.get_model("Tenant").query.create( + schema_name="edgy", domain_url="https://edgy.dymmond.com", tenant_name="edgy" + ) + + assert tenant.schema_name == "edgy" + assert tenant.tenant_name == "edgy" + NewProduct = copied.get_model("Product") + NewCart = copied.get_model("Cart") + assert NewCart.meta.fields["products"].target is NewProduct + assert NewCart.meta.fields["products"].through is not Cart.meta.fields["products"].through + cart = await NewCart.query.using(schema=tenant.schema_name).create() + for i in range(5): + await cart.products.add( + await NewProduct.query.using(schema=tenant.schema_name).create(name=f"product-{i}") + ) + + products = await cart.products.using(schema=tenant.schema_name).all() + assert len(products) == 5 + + total = await NewProduct.query.using(schema=tenant.schema_name).all() + + assert len(total) == 5 + + total = await NewProduct.query.all() + + assert len(total) == 0 + + for i in range(15): + await NewProduct.query.create(name=f"product-{i}") + + total = await NewProduct.query.all() + + assert len(total) == 15 + + total = await NewProduct.query.using(schema=tenant.schema_name).all() + + assert len(total) == 5 + + async def test_raises_ModelSchemaError_on_public_schema(): with pytest.raises(ModelSchemaError) as raised: await Tenant.query.create(