Skip to content

Commit

Permalink
Refactored process wrapper, including using replica set.
Browse files Browse the repository at this point in the history
Updated `monogd` wrapper to start with replica set (if configured.)
Updated connection string to include replica set (if configured.)
Also updated typing for many functions.

Assuming this works as expected, should resolve kaizendorks#80
  • Loading branch information
infinityredux committed Jan 4, 2025
1 parent a60be8f commit 4b71ca0
Showing 1 changed file with 51 additions and 56 deletions.
107 changes: 51 additions & 56 deletions pymongo_inmemory/mongod.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,37 +49,40 @@ def clean_before_kill(signum, stack):

class MongodConfig:
def __init__(self, pim_context: Context):
self._pim_context = pim_context
self._context = pim_context
self.local_address = "127.0.0.1"
self.engine = pim_context.storage_engine

@property
def port(self):
set_port = self._pim_context.mongod_port
def port(self) -> str | None:
set_port = self._context.mongod_port
if set_port is None:
return str(find_open_port(range(27017, 28000)))
else:
return str(set_port)

@property
def connection_string(self):
if self._pim_context.mongo_client_host is not None:
if self._pim_context.mongo_client_host.startswith("mongodb://"):
return self._pim_context.mongo_client_host
else:
self.local_address = self._pim_context.mongo_client_host
def replica_set(self) -> str | None:
return self._context.replica_set

if self.local_address is not None and self.port is not None:
if self._pim_context.dbname is None:
return "mongodb://{host}:{port}".format(
host=self.local_address, port=self.port
)
@property
def connection_string(self) -> str | None:
host = self.local_address
if self._context.mongo_client_host is not None:
if self._context.mongo_client_host.startswith("mongodb://"):
return self._context.mongo_client_host
else:
return "mongodb://{host}:{port}/{dbname}".format(
host=self.local_address,
port=self.port,
dbname=self._pim_context.dbname,
)
host = self._context.mongo_client_host

port = self.port # Make sure it only runs once
if host is not None and port is not None:
url = f"mongodb://{host}:{port}"
if self._context.dbname is not None:
url += f"/{self._context.dbname}"
if self._context.replica_set is not None:
url += f"?replicaSet={self._context.replica_set}"

return url


class Mongod:
Expand All @@ -91,28 +94,31 @@ class Mongod:
with `atexit` module to ensure clean up.
"""

def __init__(self, pim_context: Context):
self._pim_context = Context() if pim_context is None else pim_context
def __init__(self, context: Context | None):
if context is None:
context = Context()
logger.info("Running MongoD in the following context")
logger.info(self._pim_context)
logger.info(context)

logger.info("Checking binary")
if self._pim_context.use_local_mongod:
logger.warn("Using local mongod instance")
if context.use_local_mongod:
logger.warning("Using local mongod instance")
self._bin_folder = ""
else:
self._bin_folder = download(self._pim_context)
self._bin_folder = download(context)

self._proc = None
self._connection_string = None

self.config = MongodConfig(self._pim_context)
self.config: MongodConfig = MongodConfig(context)
self.connection_string: str | None = self.config.connection_string

self._proc = None
self._connection_string: str | None = None
self._temp_data_folder = TemporaryDirectory(prefix="pymongoim")
self._using_tmp_folder = self._pim_context.mongod_data_folder is None

self._using_tmp_folder = context.mongod_data_folder is None
self._client = pymongo.MongoClient(self.connection_string)

self.data_folder: str = self._temp_data_folder.name if self._using_tmp_folder else context.mongod_data_folder
self.log_path = os.path.join(self.data_folder, "mongod.log")

def __enter__(self):
self.start()
return self
Expand All @@ -122,10 +128,10 @@ def __exit__(self, *args):

def start(self):
self._check_lock()
self.log_path = os.path.join(self.data_folder, "mongod.log")

logger.info("Starting mongod with {cs}...".format(cs=self.connection_string))
boot_command = [
# noinspection SpellCheckingInspection
boot_command: list[str] = [
os.path.join(self._bin_folder, "mongod"),
"--dbpath",
self.data_folder,
Expand All @@ -139,11 +145,20 @@ def start(self):
if self.config.engine is not None:
boot_command.append("--storageEngine")
boot_command.append(self.config.engine)
if self.config.replica_set is not None:
boot_command.append("--replSetName")
boot_command.append(self.config.replica_set)
logger.debug(boot_command)
self._proc = subprocess.Popen(boot_command)
_popen_objs.append(self._proc)

count = 0
while not self.is_healthy:
pass
time.sleep(0.1)
count += 1
if count >= 200:
raise RuntimeError("Mongo server failed to start within 20 seconds, please check logs.")

logger.info("Started mongod.")
logger.info("Connect with: {cs}".format(cs=self.connection_string))

Expand All @@ -155,26 +170,6 @@ def stop(self):
time.sleep(1)
self._clean_up()

@property
def data_folder(self):
if self._using_tmp_folder:
return self._temp_data_folder.name
else:
return self._pim_context.mongod_data_folder

@property
def connection_string(self):
if self._connection_string is not None:
return self._connection_string

self._connection_string = (
self.config.connection_string
if self.config.connection_string is not None
else None
)

return self._connection_string

@property
def is_locked(self):
return os.path.exists(os.path.join(self.data_folder, "mongod.lock"))
Expand Down Expand Up @@ -246,8 +241,8 @@ def _check_lock(self):
if __name__ == "__main__":
# This part is used for integrity tests too.
logging.basicConfig(level=logging.DEBUG)
context = Context()
with Mongod(context) as md:
main_context = Context()
with Mongod(main_context) as md:
try:
while True:
pass
Expand Down

0 comments on commit 4b71ca0

Please sign in to comment.