Skip to content

Commit

Permalink
Do not write export files if not columns are selected
Browse files Browse the repository at this point in the history
Sometimes zero columns are selected. Only open data is exported,
so when columns are protected by a scope, those columns are not
exported.

Furthermore, the temporal aspect was not added to the jsonlines export,
so the whole history was exported.
  • Loading branch information
jjmurre committed Feb 6, 2024
1 parent 31a461c commit b6507a6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 23 deletions.
38 changes: 26 additions & 12 deletions src/schematools/exports/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

from datetime import date
from pathlib import Path
from typing import IO
from typing import IO, Iterable

import psycopg2
from sqlalchemy import MetaData, Table
from sqlalchemy import Column, MetaData, Table
from sqlalchemy.engine import Connection
from sqlalchemy.sql.elements import ClauseElement

from schematools.factories import tables_factory
from schematools.types import _PUBLIC_SCOPE, DatasetFieldSchema, DatasetSchema, DatasetTableSchema
Expand Down Expand Up @@ -78,25 +79,24 @@ def __init__(
)
self.sa_tables = tables_factory(dataset_schema, metadata)

def _get_column(self, sa_table: Table, field: DatasetFieldSchema):
def _get_column(self, sa_table: Table, field: DatasetFieldSchema) -> Column:
column = getattr(sa_table.c, field.db_name)
# apply all processors
for processor in self.processors:
column = processor(field, column)

return column

# processor = self.geo_modifier if field.is_geo else lambda col, _fn: col
# return processor(column, field.db_name)

def _get_columns(self, sa_table: Table, table: DatasetTableSchema):
def _get_columns(self, sa_table: Table, table: DatasetTableSchema) -> Iterable[Column]:
for field in _get_fields(self.dataset_schema, table, self.scopes):
try:
yield self._get_column(sa_table, field)
except AttributeError:
pass # skip unavailable columns

def _get_temporal_clause(self, sa_table: Table, table: DatasetTableSchema):
def _get_temporal_clause(
self, sa_table: Table, table: DatasetTableSchema
) -> ClauseElement | None:
if not table.is_temporal:
return None
temporal = table.temporal
Expand All @@ -112,13 +112,27 @@ def export_tables(self):
if table.has_geometry_fields and srid is None:
raise ValueError("Table has geo fields, but srid is None.")
sa_table = self.sa_tables[table.id]
columns = list(self._get_columns(sa_table, table))
if not columns:
return
with open(
self.base_dir / f"{table.db_name}.{self.extension}", "w", encoding="utf8"
) as file_handle:
self.write_rows(file_handle, table, sa_table, srid)

def write_rows(
self, file_handle: IO[str], table: DatasetTableSchema, sa_table: Table, srid: str
self.write_rows(
file_handle,
table,
columns,
self._get_temporal_clause(sa_table, table),
srid,
)

def write_rows( # noqa: D102
self,
file_handle: IO[str],
table: DatasetTableSchema,
columns: Iterable[Column],
temporal_clause: ClauseElement | None,
srid: str,
):
raise NotImplementedError

Expand Down
22 changes: 15 additions & 7 deletions src/schematools/exports/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import csv
from datetime import date
from typing import IO
from typing import IO, Iterable

from geoalchemy2 import functions as gfunc # ST_AsEWKT
from sqlalchemy import MetaData, Table, func, select
from sqlalchemy import Column, MetaData, Table, func, select
from sqlalchemy.engine import Connection
from sqlalchemy.sql.elements import ClauseElement

from schematools.exports import BaseExporter, enable_datetime_cast
from schematools.exports import BaseExporter, NoContentException, enable_datetime_cast
from schematools.naming import toCamelCase
from schematools.types import DatasetFieldSchema, DatasetSchema

Expand Down Expand Up @@ -43,15 +44,22 @@ def datetime_modifier(field: DatasetFieldSchema, column):
processors = (geo_modifier, id_modifier, datetime_modifier)

def write_rows( # noqa: D102
self, file_handle: IO[str], table: DatasetSchema, sa_table: Table, srid: str
self,
file_handle: IO[str],
table: DatasetTableSchema,
columns: Iterable[Column],
temporal_clause: ClauseElement | None,
srid: str,
):
columns = list(self._get_columns(sa_table, table))
if not columns:
raise NoContentException()

field_names = [c.name for c in columns]
writer = csv.DictWriter(file_handle, field_names, extrasaction="ignore")
# Use capitalize() on headers, because csv export does the same
writer.writerow({fn: toCamelCase(fn).capitalize() for fn in field_names})
query = select(self._get_columns(sa_table, table))
temporal_clause = self._get_temporal_clause(sa_table, table)
query = select(columns)
# temporal_clause = self._get_temporal_clause(sa_table, table)
if temporal_clause is not None:
query = query.where(temporal_clause)
if self.size is not None:
Expand Down
2 changes: 2 additions & 0 deletions src/schematools/exports/geopackage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def export_geopackages(
for field in _get_fields(dataset_schema, table, scopes)
if field.db_name != "schema"
)
if not field_names.seq:
return
table_name = sql.Identifier(table.db_name)
query = sql.SQL("SELECT {field_names} from {table_name}").format(
field_names=field_names, table_name=table_name
Expand Down
14 changes: 10 additions & 4 deletions src/schematools/exports/jsonlines.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from decimal import Decimal
from typing import IO, Any
from typing import IO, Any, Iterable

import jsonlines
import orjson
from geoalchemy2 import functions as func
from sqlalchemy import MetaData, Table, select
from sqlalchemy import Column, MetaData, Table, select
from sqlalchemy.engine import Connection
from sqlalchemy.sql.elements import ClauseElement

from schematools.exports import BaseExporter
from schematools.exports.csv import DatasetFieldSchema, enable_datetime_cast
Expand Down Expand Up @@ -60,11 +61,16 @@ def _get_row_modifier(self, table: DatasetTableSchema):
return lookup

def write_rows( # noqa: D102
self, file_handle: IO[str], table: DatasetTableSchema, sa_table: Table, srid: str
self,
file_handle: IO[str],
table: DatasetTableSchema,
columns: Iterable[Column],
temporal_clause: ClauseElement | None,
srid: str,
):
writer = jsonlines.Writer(file_handle, dumps=_dumps)
row_modifier = self._get_row_modifier(table)
query = select(self._get_columns(sa_table, table))
query = select(columns)
if self.size is not None:
query = query.limit(self.size)
result = self.connection.execution_options(yield_per=1000).execute(query)
Expand Down

0 comments on commit b6507a6

Please sign in to comment.