diff --git a/src/wei/resources_interface.py b/src/wei/resources_interface.py index 4ee9ada1..824c68a6 100644 --- a/src/wei/resources_interface.py +++ b/src/wei/resources_interface.py @@ -12,6 +12,7 @@ Plate, Pool, Queue, + ResourceContainerBase, Stack, ) @@ -39,48 +40,19 @@ def __init__( SQLModel.metadata.create_all(self.engine) print(f"Resources Database started on: {database_url}") - def add_resource(self, resource: SQLModel) -> SQLModel: + def add_resource(self, resource: ResourceContainerBase): """ - Add a resource to the database if it doesn't already exist, - and link it to the Asset. + Add a resource to the database using the add_resource method + in ResourceContainerBase. Args: - resource (SQLModel): The resource to add. + resource (ResourceContainerBase): The resource to add. Returns: - SQLModel: The existing or newly added resource. + ResourceContainerBase: The saved or existing resource. """ with self.session as session: - # Check if the resource already exists by name and module_name - resource_class = type(resource) - existing_resource = session.exec( - select(resource_class).where( - resource_class.name == resource.name, - resource_class.module_name == resource.module_name, - ) - ).one_or_none() - - if existing_resource: - print(f"Using existing resource: {existing_resource.name}") - return existing_resource # Return the existing resource if found - - # Add the new resource since it doesn't exist - session.add(resource) - session.commit() - session.refresh(resource) - - # Automatically create and link an Asset entry - asset = Asset( - name=f"{resource.name}", - id=resource.id, - module_name=resource.module_name, - ) - session.add(asset) - session.commit() - session.refresh(asset) - - print(f"Added new resource: {resource.name}") - return resource + return resource.add_resource(session) def get_resource(self, resource_name: str, module_name: str) -> Optional[SQLModel]: """ @@ -237,15 +209,14 @@ def get_all_resources(self, resource_type: Type[SQLModel]) -> List[SQLModel]: """ with self.session as session: statement = select(resource_type) - # If the resource type is Asset, load all relationships if resource_type is Asset: statement = statement.options( selectinload(Asset.stack), selectinload(Asset.queue), selectinload(Asset.pool), selectinload(Asset.collection), - selectinload(Asset.plate), ) + resources = session.exec(statement).all() return resources @@ -383,14 +354,14 @@ def pop_from_queue(self, queue: Queue) -> Asset: return queue.pop(session) def insert_into_collection( - self, collection: Collection, location: int, asset: Asset + self, collection: Collection, location: str, asset: Asset ) -> None: """ Insert an asset into a collection resource. Args: collection (Collection): The collection resource to update. - location (int): The location within the collection to insert the asset. + location (str): The location within the collection to insert the asset. asset (Asset): The asset to insert. """ with self.session as session: @@ -404,14 +375,14 @@ def insert_into_collection( session.refresh(collection) def retrieve_from_collection( - self, collection: Collection, location: int + self, collection: Collection, location: str ) -> Optional[Asset]: """ Retrieve an asset from a collection resource. Args: collection (Collection): The collection resource to update. - location (int): The location within the collection to retrieve the asset from. + location (str): The location within the collection to retrieve the asset from. Returns: Asset: The retrieved asset. @@ -434,10 +405,18 @@ def update_plate_well(self, plate: Plate, well_id: str, quantity: float) -> None quantity (float): The new quantity for the well. """ with self.session as session: - session.add(plate) # Re-attach plate to the session - plate.set_wells({well_id: quantity}, session) # Use the updated set_wells - session.commit() # Commit the changes - session.refresh(plate) # Refresh the plate object to reflect changes + # Step 1: Find the corresponding collection (plate) in the database + collection = session.query(Collection).filter_by(name=plate.name).first() + + if not collection: + raise ValueError(f"Collection for plate {plate.name} not found.") + + # Step 2: Use the set_wells function to update the well quantity + plate.set_wells({well_id: quantity}, session) + + # Step 3: Commit the changes and refresh the collection object if needed + session.commit() + session.refresh(collection) def update_plate_contents( self, plate: Plate, new_contents: Dict[str, float] @@ -450,10 +429,24 @@ def update_plate_contents( new_contents (Dict[str, float]): A dictionary with well IDs as keys and quantities as values. """ with self.session as session: - session.add(plate) # Re-attach plate to the session - plate.set_wells(new_contents, session) # Use the updated set_wells - session.commit() # Commit the changes - session.refresh(plate) # Refresh the plate object to reflect changes + # First, retrieve the corresponding Collection from the database + collection = ( + session.query(Collection) + .filter_by( + name=plate.name, + module_name=plate.module_name, + ) + .first() + ) + + if not collection: + raise ValueError(f"Collection for Plate {plate.name} not found.") + + # Now, use the plate object (in memory) to update the wells + plate.set_wells(new_contents, session) # Use the Plate's set_wells logic + + # Make sure to update the Collection's quantity as well + session.commit() def get_well_quantity(self, plate: Plate, well_id: str) -> Optional[float]: """ @@ -490,10 +483,17 @@ def increase_well(self, plate: Plate, well_id: str, quantity: float) -> None: quantity (float): The amount to increase the well quantity by. """ with self.session as session: - session.add(plate) # Re-attach plate to the session - plate.increase(well_id, quantity, session) # Use the increase method - session.commit() # Commit the changes - session.refresh(plate) # Refresh the plate object to reflect changes + # Find the corresponding collection (plate) + collection = session.query(Collection).filter_by(name=plate.name).first() + + if not collection: + raise ValueError(f"Collection for plate {plate.name} not found.") + + # Delegate the task of increasing the well's quantity to the Plate class + plate.increase_well(well_id, quantity, session) + + session.commit() + session.refresh(collection) # Refresh the collection object if needed def decrease_well(self, plate: Plate, well_id: str, quantity: float) -> None: """ @@ -505,10 +505,17 @@ def decrease_well(self, plate: Plate, well_id: str, quantity: float) -> None: quantity (float): The amount to decrease the well quantity by. """ with self.session as session: - session.add(plate) # Re-attach plate to the session - plate.decrease(well_id, quantity, session) # Use the decrease method - session.commit() # Commit the changes - session.refresh(plate) # Refresh the plate object to reflect changes + # Find the corresponding collection (plate) in the database + collection = session.query(Collection).filter_by(name=plate.name).first() + + if not collection: + raise ValueError(f"Collection for plate {plate.name} not found.") + + # Delegate the task of decreasing the well's quantity to the Plate class + plate.decrease_well(well_id, quantity, session) + + session.commit() + session.refresh(collection) def get_wells(self, plate: Plate) -> Dict[str, Pool]: """ @@ -522,7 +529,7 @@ def get_wells(self, plate: Plate) -> Dict[str, Pool]: """ with self.session as session: # Ensure the plate is attached to the current session - plate = session.merge(plate) + # plate = session.merge(plate) # Use the get_wells method from the PlateBase class to retrieve all wells wells = plate.get_wells(session) @@ -612,10 +619,10 @@ def get_wells(self, plate: Plate) -> Dict[str, Pool]: resource_interface.insert_into_collection(collection, location="1", asset=asset3) # Retrieve an asset from the Collection - # retrieved_asset = resource_interface.retrieve_from_collection( - # collection, location=1 - # ) - # print("\nRetrieved Asset from Collection:", retrieved_asset) + retrieved_asset = resource_interface.retrieve_from_collection( + collection, location=1 + ) + print("\nRetrieved Asset from Collection:", retrieved_asset) # Create a Plate resource plate = Plate( @@ -651,7 +658,7 @@ def get_wells(self, plate: Plate) -> Dict[str, Pool]: print(f"Updated wells: {updated_wells}") resource_interface.update_plate_well(plate, well_id="A1", quantity=80.0) - all_plates = resource_interface.get_all_resources(Plate) + all_plates = resource_interface.get_all_resources(Collection) print("\nAll Plates after modification:", all_plates) # all_asset = resource_interface.get_all_resources(Asset) diff --git a/src/wei/types/resource_types.py b/src/wei/types/resource_types.py index 3f217601..910890af 100644 --- a/src/wei/types/resource_types.py +++ b/src/wei/types/resource_types.py @@ -1,5 +1,6 @@ """Resources Data Classes""" +import warnings from datetime import datetime, timezone from typing import Any, Dict, List, Optional @@ -8,6 +9,9 @@ from sqlmodel import Field as SQLField from sqlmodel import Session, SQLModel +# Suppress all DeprecationWarnings +warnings.filterwarnings("ignore", category=DeprecationWarning) + class AssetBase(SQLModel): """ @@ -169,17 +173,147 @@ def save(self, session: Session): Args: session (Session): SQLAlchemy session to use for saving. """ + if isinstance(self, Plate): + # Handling the Plate as a Collection for database operations + collection = ( + session.query(Collection) + .filter_by( + name=self.name, + module_name=self.module_name, + ) + .first() + ) + + if not collection: + # Create a new Collection if it doesn't exist + collection = Collection( + id=self.id, + name=self.name, + description=self.description, + capacity=self.capacity, + quantity=self.quantity, + module_name=self.module_name, + ) + session.add(collection) + else: + # Update the existing Collection attributes + collection.description = self.description + collection.capacity = self.capacity + collection.quantity = self.quantity + collection.module_name = self.module_name + + session.commit() + session.refresh(collection) + + # Update the corresponding Asset (if it exists) + asset = session.get(Asset, collection.id) + if asset: + asset.time_updated = datetime.now(timezone.utc) + session.commit() + + return collection # Return the updated Collection + + # Handle normal resources session.add(self) session.commit() + # Update the corresponding Asset (if it exists) asset = session.get(Asset, self.id) if asset: - asset.time_updated = datetime.now( - timezone.utc - ) # Set time_updated to current time + asset.time_updated = datetime.now(timezone.utc) session.commit() session.refresh(self) + return self # Return the updated resource + + def add_resource(self, session: Session) -> "ResourceContainerBase": + """ + Check if a resource with the same name and module_name exists. + If it exists, return the existing resource. Otherwise, create the resource + and link it with an Asset. + + Args: + session (Session): SQLAlchemy session to use for database operations. + + Returns: + ResourceContainerBase: The saved or existing resource. + """ + # If this is a Plate, treat it as a Collection for database operations + if isinstance(self, Plate): + print("Handling Plate as a Collection") + # Create a new Collection object from the Plate data + collection_resource = Collection( + id=self.id, + name=self.name, + description=self.description, + capacity=self.capacity, + module_name=self.module_name, + quantity=self.quantity, + ) + + # Check if the collection already exists + existing_resource = ( + session.query(Collection) + .filter_by( + name=self.name, + module_name=self.module_name, + ) + .first() + ) + + if existing_resource: + print(f"Using existing collection resource: {existing_resource.name}") + return self # Returning the Plate object itself + + # If the resource doesn't exist, create and save a new collection + session.add(collection_resource) + session.commit() + session.refresh(collection_resource) + + # Create and link an Asset entry for the resource + asset = Asset( + name=collection_resource.name, + id=collection_resource.id, + module_name=collection_resource.module_name, + ) + session.add(asset) + session.commit() + session.refresh(asset) + + print(f"Added new collection resource: {collection_resource.name}") + return self # Returning the Plate object itself + + # Handle normal resources (non-Plate) + existing_resource = ( + session.query(type(self)) + .filter_by( + name=self.name, + module_name=self.module_name, + ) + .first() + ) + + if existing_resource: + print(f"Using existing resource: {existing_resource.name}") + return existing_resource + + # If the resource doesn't exist, create and save a new one + session.add(self) + session.commit() + session.refresh(self) + + # Automatically create and link an Asset entry for the resource + asset = Asset( + name=self.name, + id=self.id, + module_name=self.module_name, + ) + session.add(asset) + session.commit() + session.refresh(asset) + + print(f"Added new resource: {self.name}") + return self class PoolBase(ResourceContainerBase): @@ -579,7 +713,7 @@ def insert(self, location: str, asset: Asset, session: Session) -> None: self.quantity = len(contents) # Update the quantity self.save(session) - def retrieve(self, location: int, session: Session) -> Optional[Dict[str, Any]]: + def retrieve(self, location: str, session: Session) -> Optional[Dict[str, Any]]: """ Retrieve an asset from the collection at the specified location. @@ -593,9 +727,13 @@ def retrieve(self, location: int, session: Session) -> Optional[Dict[str, Any]]: Raises: ValueError: If the location is invalid or the asset is not found. """ + location_str = str(location) + # Perform the query with string comparison allocation = ( session.query(AssetAllocation) - .filter_by(resource_id=self.id, resource_type="collection", index=location) + .filter_by( + resource_id=self.id, resource_type="collection", index=location_str + ) .first() ) @@ -605,9 +743,9 @@ def retrieve(self, location: int, session: Session) -> Optional[Dict[str, Any]]: if asset: # Deallocate the asset from the collection asset.deallocate(session) - # Update the quantity after removing the asset, ensuring it does not drop below 0 + # Update the quantity after removing the asset contents = self.get_contents(session) - self.quantity = max(len(contents) - 1, 0) # Prevent negative quantity + self.quantity = max(len(contents) - 1, 0) self.save(session) return { @@ -615,7 +753,7 @@ def retrieve(self, location: int, session: Session) -> Optional[Dict[str, Any]]: "name": asset.name, "module_name": module_name, "resource_type": "collection", - "location": location, + "location": location_str, } else: raise ValueError( @@ -623,7 +761,7 @@ def retrieve(self, location: int, session: Session) -> Optional[Dict[str, Any]]: ) else: raise ValueError( - f"Location {location} not found in collection {self.name}." + f"Location {location_str} not found in collection {self.name}." ) @@ -661,7 +799,13 @@ def get_wells(self, session: Session) -> Dict[str, Pool]: Returns: Dict[str, Pool]: A dictionary of wells keyed by their location. """ - return self.get_contents(session) # Since Plate is a collection, we reuse this + # Query the Pool table for all wells associated with this plate + wells = session.query(Pool).filter_by(module_name=self.name).all() + + # Create a dictionary of wells, using the well ID (name) as the key + wells_dict = {well.name: well for well in wells} + + return wells_dict def set_wells(self, wells_dict: Dict[str, float], session: Session): """ @@ -672,37 +816,53 @@ def set_wells(self, wells_dict: Dict[str, float], session: Session): wells_dict (Dict[str, float]): A dictionary of well IDs and quantities. session (Session): SQLAlchemy session passed from the interface layer. """ - current_wells = self.get_wells(session) + # Find the corresponding collection for this plate + collection = session.query(Collection).filter_by(name=self.name).first() + # Iterate over wells to add or update them for well_id, quantity in wells_dict.items(): - if well_id in current_wells: - # Update existing well - current_wells[well_id].quantity = quantity + # Check if the well already exists in the Collection + existing_well = ( + session.query(Pool) + .filter_by(name=well_id, module_name=self.name) + .first() + ) + if existing_well: + # If the well exists, update its quantity + existing_well.quantity = quantity + session.commit() else: - # Create a new well + # Create a new Pool (well) if it doesn't exist new_well = Pool( + name=well_id, description=f"Well {well_id}", - name=f"{well_id}", capacity=self.well_capacity, quantity=quantity, - module_name=self.name, # Bug with self.name & self.module_name + module_name=self.name, # Plate's name as module_name ) + session.add(new_well) + session.commit() + + # Create an Asset entry for the new well asset = Asset( - name=f"{new_well.name}", + name=new_well.name, id=new_well.id, module_name=new_well.module_name, ) - session.add(new_well) # Add the new well to the session session.add(asset) - session.commit() # Commit to generate the new_well ID + session.commit() - # Insert the well into the plate's collection resource (indexed by well_id) - self.insert(location=str(well_id), asset=asset, session=session) + # Insert the well into the Collection at the specified location + collection.insert(location=str(well_id), asset=asset, session=session) - # Update the plate's total quantity (number of wells) - self.quantity = len(self.get_wells(session)) + # Update the Plate's total quantity and commit + total_quantity = ( + session.query(func.sum(Pool.quantity)) + .filter(Pool.module_name == self.name) + .scalar() + ) + self.quantity = total_quantity session.commit() - session.refresh(self) def increase_well(self, well_id: str, quantity: float, session: Session) -> None: """ @@ -716,19 +876,15 @@ def increase_well(self, well_id: str, quantity: float, session: Session) -> None Raises: ValueError: If the addition exceeds the well's capacity. """ - wells = self.get_wells(session) + existing_well = ( + session.query(Pool).filter_by(name=well_id, module_name=self.name).first() + ) - # Check if the well exists - if well_id in wells: - well = wells[well_id] - well.increase(quantity, session) # Use Pool's increase method - else: - raise ValueError(f"Well {well_id} does not exist in plate {self.name}.") + if not existing_well: + raise ValueError(f"Well {well_id} not found in plate {self.name}.") - # Update the total quantity of the plate (number of wells) - self.quantity = len(self.get_wells(session)) - session.commit() - session.refresh(self) + # Step 2: Increase the quantity of the existing well + existing_well.quantity += quantity def decrease_well(self, well_id: str, quantity: float, session: Session) -> None: """ @@ -742,19 +898,22 @@ def decrease_well(self, well_id: str, quantity: float, session: Session) -> None Raises: ValueError: If the decrease would result in a negative quantity. """ - wells = self.get_wells(session) + # Step 1: Find the corresponding well (Pool) in the Collection + existing_well = ( + session.query(Pool).filter_by(name=well_id, module_name=self.name).first() + ) - # Check if the well exists - if well_id in wells: - well = wells[well_id] - well.decrease(quantity, session) # Use Pool's decrease method - else: - raise ValueError(f"Well {well_id} does not exist in plate {self.name}.") + if not existing_well: + raise ValueError(f"Well {well_id} not found in plate {self.name}.") - # Update the total quantity of the plate (number of wells) - self.quantity = len(self.get_wells(session)) - session.commit() - session.refresh(self) + # Step 2: Check if the well has sufficient quantity to decrease + if existing_well.quantity < quantity: + raise ValueError( + f"Well {well_id} does not have enough quantity to decrease by {quantity}." + ) + + # Step 3: Decrease the quantity of the existing well + existing_well.quantity -= quantity # class PlateTable(PlateBase, table=True):