diff --git a/synapseclient/models/mixins/storable_container.py b/synapseclient/models/mixins/storable_container.py index e58e813dd..4f30d7ee5 100644 --- a/synapseclient/models/mixins/storable_container.py +++ b/synapseclient/models/mixins/storable_container.py @@ -54,6 +54,30 @@ class StorableContainer(StorableContainerSynchronousProtocol): async def get_async(self, *, synapse_client: Optional[Synapse] = None) -> None: """Used to satisfy the usage in this mixin from the parent class.""" + async def worker( + self, + name: str, + queue: asyncio.Queue, + failure_strategy: FailureStrategy, + synapse_client: Synapse, + ): + while True: + # Get a "work item" out of the queue. + work_item = await queue.get() + + print(f"{name} working on {work_item}. File queue Size: {queue.qsize()}") + + result = await work_item + + self._resolve_sync_from_synapse_result( + result=result, + failure_strategy=failure_strategy, + synapse_client=synapse_client, + ) + + # Notify the queue that the "work item" has been processed. + queue.task_done() + @otel_trace_method( method_to_trace_name=lambda self, **kwargs: f"{self.__class__.__name__}_sync_from_synapse: {self.id}" ) @@ -67,6 +91,7 @@ async def sync_from_synapse_async( include_activity: bool = True, follow_link: bool = False, link_hops: int = 1, + queue: asyncio.Queue = None, *, synapse_client: Optional[Synapse] = None, ) -> Self: @@ -224,9 +249,10 @@ async def sync_from_synapse_async( ``` """ + syn = Synapse.get_client(synapse_client=synapse_client) if not self._last_persistent_instance: - await self.get_async(synapse_client=synapse_client) - Synapse.get_client(synapse_client=synapse_client).logger.info( + await self.get_async(synapse_client=syn) + syn.logger.info( f"Syncing {self.__class__.__name__} ({self.id}:{self.name}) from Synapse." ) path = os.path.expanduser(path) if path else None @@ -236,10 +262,13 @@ async def sync_from_synapse_async( None, lambda: self._retrieve_children( follow_link=follow_link, - synapse_client=synapse_client, + synapse_client=syn, ), ) + create_workers = not queue + + queue = queue or asyncio.Queue() pending_tasks = [] self.folders = [] self.files = [] @@ -253,10 +282,11 @@ async def sync_from_synapse_async( download_file=download_file, if_collision=if_collision, failure_strategy=failure_strategy, - synapse_client=synapse_client, + synapse_client=syn, include_activity=include_activity, follow_link=follow_link, link_hops=link_hops, + queue=queue, ) ) @@ -265,8 +295,30 @@ async def sync_from_synapse_async( self._resolve_sync_from_synapse_result( result=result, failure_strategy=failure_strategy, - synapse_client=synapse_client, + synapse_client=syn, ) + + # After all folders have been resolved start the file download process: + # Create three worker tasks to process the queue concurrently. + if create_workers: + worker_tasks = [] + for i in range(max(syn.max_threads * 2, 1)): + task = asyncio.create_task( + self.worker( + name=f"worker-{i}", + queue=queue, + failure_strategy=failure_strategy, + synapse_client=syn, + ) + ) + worker_tasks.append(task) + + # Wait until the queue is fully processed. + await queue.join() + + for task in worker_tasks: + task.cancel() + return self def flatten_file_list(self) -> List["File"]: @@ -381,6 +433,7 @@ def _retrieve_children( async def _wrap_recursive_get_children( self, folder: "Folder", + queue: asyncio.Queue, recursive: bool = False, path: Optional[str] = None, download_file: bool = False, @@ -413,11 +466,13 @@ async def _wrap_recursive_get_children( follow_link=follow_link, link_hops=link_hops, synapse_client=synapse_client, + queue=queue, ) def _create_task_for_child( self, child, + queue: asyncio.Queue, recursive: bool = False, path: Optional[str] = None, download_file: bool = False, @@ -487,6 +542,7 @@ def _create_task_for_child( follow_link=follow_link, link_hops=link_hops, synapse_client=synapse_client, + queue=queue, ) ) ) @@ -508,13 +564,11 @@ def _create_task_for_child( if if_collision: file.if_collision = if_collision - pending_tasks.append( - asyncio.create_task( - wrap_coroutine( - file.get_async( - include_activity=include_activity, - synapse_client=synapse_client, - ) + queue.put_nowait( + wrap_coroutine( + file.get_async( + include_activity=include_activity, + synapse_client=synapse_client, ) ) ) @@ -533,6 +587,7 @@ def _create_task_for_child( include_activity=include_activity, follow_link=follow_link, link_hops=link_hops - 1, + queue=queue, ) ) ) @@ -543,6 +598,7 @@ def _create_task_for_child( async def _follow_link( self, child, + queue: asyncio.Queue, recursive: bool = False, path: Optional[str] = None, download_file: bool = False, @@ -595,6 +651,7 @@ async def _follow_link( include_activity=include_activity, follow_link=follow_link, link_hops=link_hops, + queue=queue, synapse_client=synapse_client, ) for task in asyncio.as_completed(pending_tasks):