diff --git a/dbcat/api.py b/dbcat/api.py index 42e0eb7..49cfdf8 100644 --- a/dbcat/api.py +++ b/dbcat/api.py @@ -170,6 +170,26 @@ def add_postgresql_source( ) +def add_oracle_source( + catalog: Catalog, + name: str, + username: str, + password: str, + database: str, + uri: str, + port: Optional[int] = None, +) -> CatSource: + with catalog.commit_context: + return catalog.add_source( + name=name, + username=username, + password=password, + database=database, + uri=uri, + port=port, + source_type="oracle", + ) + def add_mysql_source( catalog: Catalog, name: str, diff --git a/dbcat/catalog/db.py b/dbcat/catalog/db.py index 8869b18..dd4e576 100644 --- a/dbcat/catalog/db.py +++ b/dbcat/catalog/db.py @@ -17,6 +17,7 @@ SnowflakeMetadataExtractor, ) from databuilder.extractor.sql_alchemy_extractor import SQLAlchemyExtractor +from databuilder.extractor.oracle_metadata_extractor import OracleMetadataExtractor from databuilder.models.table_metadata import TableMetadata from pyhocon import ConfigFactory, ConfigTree from sqlalchemy.orm.exc import NoResultFound @@ -57,6 +58,8 @@ def __init__( self._extractor, self._conf = DbScanner._create_sqlite_extractor(source) elif source.source_type == "athena": self._extractor, self._conf = DbScanner._create_athena_extractor(source) + elif source.source_type == 'oracle': + self._extractor, self._conf = DbScanner._create_oracle_extractor(source) else: raise ValueError("{} is not supported".format(source.source_type)) @@ -262,6 +265,29 @@ def _create_mysql_extractor( return extractor, conf + @staticmethod + def _create_oracle_extractor(source: CatSource) -> Tuple[Extractor, Any]: + where_clause_suffix = """ + WHERE c.owner NOT IN ('AUDSYS','CTXSYS','DBSFWUSER','APPQOSSYS','DBSNMP','DVSYS','GSMADMIN_INTERNAL','LBACSYS', + 'ALL_SA_AUDIT_OPTIONS','MDSYS','OJVMSYS','OLAPSYS','ORDDATA','ORDSYS','OUTLN','SYS','SYSTEM', + 'WMSYS','XDB') + """ + + extractor = OracleMetadataExtractor() + scope = extractor.get_scope() + conn_string_key = f"{scope}.{SQLAlchemyExtractor().get_scope()}.{SQLAlchemyExtractor.CONN_STRING}" + conf = ConfigFactory.from_dict( + { + conn_string_key: source.conn_string, + f"{scope}.{OracleMetadataExtractor.CLUSTER_KEY}": source.cluster, + f"{scope}.{OracleMetadataExtractor.DATABASE_KEY}": source.database, + f"{scope}.{OracleMetadataExtractor.WHERE_CLAUSE_SUFFIX_KEY}": where_clause_suffix, + } + ) + + return extractor, conf + + @staticmethod def _create_postgres_extractor(source: CatSource) -> Tuple[Extractor, Any]: where_clause_suffix = """ diff --git a/test/connections.yaml b/test/connections.yaml index a48a85d..eb501c4 100644 --- a/test/connections.yaml +++ b/test/connections.yaml @@ -33,4 +33,11 @@ connections: aws_access_key_id: dummy_key aws_secret_access_key: dummy_secret region_name: us-east-1 - s3_staging_dir: s3://dummy \ No newline at end of file + s3_staging_dir: s3://dummy + - name: oracle + source_type: oracle + database: db_database + username: db_user + password: db_password + port: db_port + uri: db_uri \ No newline at end of file diff --git a/test/test_catalog.py b/test/test_catalog.py index 4e67baa..70cf708 100644 --- a/test/test_catalog.py +++ b/test/test_catalog.py @@ -445,7 +445,7 @@ def test_add_sources(open_catalog_connection): catalog.add_source(**c) connections = catalog.search_sources(source_like="%") - assert len(connections) == 7 + assert len(connections) == 8 # pg pg_connection = connections[1] @@ -502,6 +502,17 @@ def test_add_sources(open_catalog_connection): assert athena_conn.s3_staging_dir == "s3://dummy" + # oracle + oracle_conn = connections[7] + assert oracle_conn.name == "oracle" + assert oracle_conn.source_type == "oracle" + assert oracle_conn.database == "db_database" + assert oracle_conn.username == "db_user" + assert oracle_conn.password == "db_password" + assert oracle_conn.port == "db_port" + assert oracle_conn.uri == "db_uri" + + @pytest.fixture(scope="module") def load_job_and_executions(save_catalog): catalog = save_catalog