-
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
106 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
.git | ||
.env | ||
|
||
env/ | ||
|
||
dist/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
from injector import inject | ||
from sqlalchemy import Pool, QueuePool, StaticPool, create_engine | ||
from sqlalchemy.orm import sessionmaker, Session | ||
|
||
from ai_assistant_core.infrastructure.migrator import run_database_migration | ||
|
||
from .sqlalchemy import Base | ||
from urllib.parse import urlparse | ||
|
||
|
||
@inject | ||
class SqlAlchemySessionFactory: | ||
|
||
def __init__(self, database_url: str) -> None: | ||
self.database_url = database_url | ||
|
||
def create( | ||
self, | ||
) -> Session: | ||
database_url = self.database_url | ||
self._create_sqlite_path() | ||
engine = create_engine( | ||
database_url, | ||
poolclass=self._get_pool_class(database_url), | ||
) | ||
|
||
Base.metadata.bind = engine | ||
Base.metadata.create_all(engine) | ||
|
||
with engine.begin() as connection: | ||
run_database_migration(connection=connection) | ||
|
||
return sessionmaker(autocommit=False, bind=engine)() | ||
|
||
def _get_pool_class(self, database_url: str) -> Pool: | ||
if database_url.startswith("sqlite"): | ||
return StaticPool | ||
else: | ||
return QueuePool | ||
|
||
def _create_sqlite_path(self) -> str: | ||
parsed_url = urlparse(self.database_url) | ||
|
||
if parsed_url.scheme in ["sqlite"]: | ||
path = parsed_url.path.removeprefix("/") | ||
database_path = os.path.abspath(path) | ||
directory = os.path.dirname(database_path) | ||
|
||
if not os.path.exists(directory): | ||
os.makedirs(directory) |
40 changes: 4 additions & 36 deletions
40
core/ai_assistant_core/infrastructure/sqlalchemy_module.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,48 +1,16 @@ | ||
import os | ||
from injector import Module, provider, singleton | ||
from sqlalchemy import Pool, QueuePool, StaticPool, create_engine | ||
from sqlalchemy.orm import sessionmaker, Session | ||
from sqlalchemy.orm import Session | ||
|
||
|
||
from ai_assistant_core.app_configuration import AppConfiguration | ||
from ai_assistant_core.infrastructure.migrator import run_database_migration | ||
|
||
from .sqlalchemy import Base | ||
from urllib.parse import urlparse | ||
from ai_assistant_core.infrastructure.session_factory import SqlAlchemySessionFactory | ||
|
||
|
||
class SqlAlchemyModule(Module): | ||
|
||
@singleton | ||
@provider | ||
def provide_sqlalchemy_session(self, configuration: AppConfiguration) -> Session: | ||
database_url = configuration.database_url | ||
self.create_sqlite_path(database_url) | ||
engine = create_engine( | ||
database_url, | ||
poolclass=self.get_pool_class(database_url), | ||
) | ||
|
||
Base.metadata.bind = engine | ||
Base.metadata.create_all(engine) | ||
|
||
with engine.begin() as connection: | ||
run_database_migration(connection=connection) | ||
|
||
return sessionmaker(autocommit=False, bind=engine)() | ||
|
||
def get_pool_class(self, database_url: str) -> Pool: | ||
if database_url.startswith("sqlite"): | ||
return StaticPool | ||
else: | ||
return QueuePool | ||
|
||
def create_sqlite_path(self, database_url: str) -> str: | ||
parsed_url = urlparse(database_url) | ||
|
||
if parsed_url.scheme in ["sqlite"]: | ||
database_path = os.path.abspath(parsed_url.path) | ||
directory = os.path.dirname(database_path) | ||
factory = SqlAlchemySessionFactory(database_url) | ||
|
||
if not os.path.exists(directory): | ||
os.makedirs(directory) | ||
return factory.create() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import os | ||
import pytest | ||
from unittest.mock import patch | ||
from ai_assistant_core.infrastructure.session_factory import SqlAlchemySessionFactory | ||
from platformdirs import user_data_dir | ||
|
||
|
||
@pytest.mark.skipif(os.name != "posix", reason="Test runs only on Unix-based systems") | ||
class TestUnix: | ||
local_db_path = "/tmp/some/sub-dir/test.db" | ||
|
||
def test_local_database_directory_is_created(self): | ||
instance = SqlAlchemySessionFactory( | ||
database_url=f"sqlite:///{self.local_db_path}" | ||
) | ||
with patch("os.makedirs") as mock_makedirs, patch( | ||
"os.path.exists", side_effect=lambda path: path != "/tmp/some/sub-dir" | ||
): | ||
instance._create_sqlite_path() | ||
|
||
mock_makedirs.assert_called_once_with("/tmp/some/sub-dir") | ||
|
||
|
||
@pytest.mark.skipif(os.name != "nt", reason="Test runs only on Windows") | ||
class TestWindows: | ||
local_db_path = f"{user_data_dir()}\\some\\test.db" | ||
|
||
def test_local_database_directory_is_created(self): | ||
instance = SqlAlchemySessionFactory( | ||
database_url=f"sqlite:///{self.local_db_path}" | ||
) | ||
with patch("os.makedirs") as mock_makedirs, patch( | ||
"os.path.exists", | ||
side_effect=lambda path: path != f"{user_data_dir()}\\some", | ||
): | ||
instance._create_sqlite_path() | ||
|
||
mock_makedirs.assert_called_once_with(f"{user_data_dir()}\\some") |