diff --git a/lms/tasks/roster.py b/lms/tasks/roster.py index 808d87c76c..f531017084 100644 --- a/lms/tasks/roster.py +++ b/lms/tasks/roster.py @@ -237,6 +237,15 @@ def fetch_course_roster(*, lms_course_id) -> None: lms_course = request.db.get(LMSCourse, lms_course_id) roster_service.fetch_course_roster(lms_course) + # Check the if course has any sections, if it does, schedule fetching its rosters + if request.db.scalars( + select(LMSSegment.id).where( + LMSSegment.lms_course_id == lms_course_id, + LMSSegment.type == "canvas_section", + ) + ).first(): + fetch_canvas_sections_roster.delay(lms_course_id=lms_course_id) + @app.task( acks_late=True, @@ -268,3 +277,19 @@ def fetch_segment_roster(*, lms_segment_id) -> None: with request.tm: assignment = request.db.get(LMSSegment, lms_segment_id) roster_service.fetch_canvas_group_roster(assignment) + + +@app.task( + acks_late=True, + autoretry_for=(Exception,), + max_retries=2, + retry_backoff=3600, + retry_backoff_max=7200, +) +def fetch_canvas_sections_roster(*, lms_course_id) -> None: + """Fetch the roster for all sections of a given course.""" + with app.request_context() as request: + roster_service: RosterService = request.find_service(RosterService) + with request.tm: + lms_course = request.db.get(LMSCourse, lms_course_id) + roster_service.fetch_canvas_sections_roster(lms_course) diff --git a/tests/unit/lms/tasks/roster_test.py b/tests/unit/lms/tasks/roster_test.py index 5ed7776e6f..7e6fe3fc09 100644 --- a/tests/unit/lms/tasks/roster_test.py +++ b/tests/unit/lms/tasks/roster_test.py @@ -8,6 +8,7 @@ fetch_assignment_roster, fetch_course_roster, fetch_segment_roster, + fetch_canvas_sections_roster, schedule_fetching_assignment_rosters, schedule_fetching_course_rosters, schedule_fetching_rosters, @@ -25,6 +26,20 @@ def test_fetch_course_roster(self, roster_service, db_session): roster_service.fetch_course_roster.assert_called_once_with(lms_course) + def test_fetch_course_roster_with_sections( + self, roster_service, db_session, fetch_canvas_sections_roster + ): + lms_course = factories.LMSCourse() + factories.LMSSegment(lms_course=lms_course, type="canvas_section") + db_session.flush() + + fetch_course_roster(lms_course_id=lms_course.id) + + roster_service.fetch_course_roster.assert_called_once_with(lms_course) + fetch_canvas_sections_roster.delay.assert_called_once_with( + lms_course_id=lms_course.id + ) + def test_fetch_assignment_roster(self, roster_service, db_session): assignment = factories.Assignment() db_session.flush() @@ -41,6 +56,14 @@ def test_fetch_segment_roster(self, roster_service, db_session): roster_service.fetch_canvas_group_roster.assert_called_once_with(lms_segment) + def test_fetch_canvas_sections_roster(self, roster_service, db_session): + lms_course = factories.LMSCourse() + db_session.flush() + + fetch_canvas_sections_roster(lms_course_id=lms_course.id) + + roster_service.fetch_canvas_sections_roster.assert_called_once_with(lms_course) + def test_schedule_fetching_rosters( self, schedule_fetching_assignment_rosters, @@ -302,6 +325,10 @@ def fetch_assignment_roster(self, patch): def fetch_segment_roster(self, patch): return patch("lms.tasks.roster.fetch_segment_roster") + @pytest.fixture + def fetch_canvas_sections_roster(self, patch): + return patch("lms.tasks.roster.fetch_canvas_sections_roster") + @pytest.fixture def schedule_fetching_segment_rosters(self, patch): return patch("lms.tasks.roster.schedule_fetching_segment_rosters")