diff --git a/backend/src/services/dataset_review/repositories/__tests__/test_session_repository.py b/backend/src/services/dataset_review/repositories/__tests__/test_session_repository.py index 9ac49b13..77d968ea 100644 --- a/backend/src/services/dataset_review/repositories/__tests__/test_session_repository.py +++ b/backend/src/services/dataset_review/repositories/__tests__/test_session_repository.py @@ -14,7 +14,9 @@ from src.models.dataset_review import ( FindingArea, FindingSeverity, ReadinessState, - RecommendedAction + RecommendedAction, + SessionCollaborator, + SessionCollaboratorRole ) from src.services.dataset_review.repositories.session_repository import DatasetReviewSessionRepository @@ -75,6 +77,32 @@ def test_load_session_detail_ownership(db_session): loaded_wrong = repo.load_session_detail(session.session_id, "wrong_user") assert loaded_wrong is None +def test_load_session_detail_collaborator(db_session): + # @PURPOSE: Verify collaborator access in detail loading. + repo = DatasetReviewSessionRepository(db_session) + session = DatasetReviewSession( + user_id="user1", environment_id="env1", source_kind="superset_link", + source_input="http://link", dataset_ref="dataset1" + ) + repo.create_session(session) + + # Add collaborator + collab_user = User(id="collab1", username="collab", email="c@e.com", password_hash="p") + db_session.add(collab_user) + + collaborator = SessionCollaborator( + session_id=session.session_id, + user_id="collab1", + role=SessionCollaboratorRole.REVIEWER + ) + db_session.add(collaborator) + db_session.commit() + + # Collaborator access + loaded = repo.load_session_detail(session.session_id, "collab1") + assert loaded is not None + assert loaded.session_id == session.session_id + def test_save_preview_marks_stale(db_session): # @PURPOSE: Verify that saving a new preview marks old ones as stale. repo = DatasetReviewSessionRepository(db_session) @@ -127,6 +155,23 @@ def test_save_profile_and_findings(db_session): assert updated_session.profile.dataset_name == "Test DS" assert len(updated_session.findings) == 1 assert updated_session.findings[0].code == "ERR1" + + # Verify removal of old findings + new_finding = ValidationFinding( + session_id=session.session_id, + area=FindingArea.DATASET_PROFILE, + severity=FindingSeverity.WARNING, + code="WARN1", + title="Warning", + message="Something" + ) + + repo.save_profile_and_findings(session.session_id, "user1", profile, [new_finding]) + + db_session.expire_all() + final_session = repo.load_session_detail(session.session_id, "user1") + assert len(final_session.findings) == 1 + assert final_session.findings[0].code == "WARN1" def test_save_run_context(db_session): # @PURPOSE: Verify saving of run context. diff --git a/backend/src/services/dataset_review/repositories/session_repository.py b/backend/src/services/dataset_review/repositories/session_repository.py index 35116c56..74b7af14 100644 --- a/backend/src/services/dataset_review/repositories/session_repository.py +++ b/backend/src/services/dataset_review/repositories/session_repository.py @@ -10,13 +10,15 @@ # @POST: session aggregate reads are structurally consistent and writes preserve ownership and version semantics. from typing import Optional, List +from sqlalchemy import or_ from sqlalchemy.orm import Session, joinedload from src.models.dataset_review import ( DatasetReviewSession, DatasetProfile, ValidationFinding, CompiledPreview, - DatasetRunContext + DatasetRunContext, + SessionCollaborator ) from src.core.logger import belief_scope @@ -44,8 +46,9 @@ class DatasetReviewSessionRepository: @PRE: user_id must match session owner or authorized collaborator. """ with belief_scope("DatasetReviewSessionRepository.load_session_detail"): - # Note: We check user_id to enforce the ownership_scope invariant. + # Check if user is owner or collaborator return self.db.query(DatasetReviewSession)\ + .outerjoin(SessionCollaborator, DatasetReviewSession.session_id == SessionCollaborator.session_id)\ .options( joinedload(DatasetReviewSession.profile), joinedload(DatasetReviewSession.findings), @@ -60,7 +63,12 @@ class DatasetReviewSessionRepository: joinedload(DatasetReviewSession.run_contexts) )\ .filter(DatasetReviewSession.session_id == session_id)\ - .filter(DatasetReviewSession.user_id == user_id)\ + .filter( + or_( + DatasetReviewSession.user_id == user_id, + SessionCollaborator.user_id == user_id + ) + )\ .first() def save_profile_and_findings(self, session_id: str, user_id: str, profile: DatasetProfile, findings: List[ValidationFinding]) -> DatasetReviewSession: @@ -77,12 +85,21 @@ class DatasetReviewSessionRepository: raise ValueError("Session not found or access denied") if profile: + # Ensure we update existing profile by session_id if it exists + existing_profile = self.db.query(DatasetProfile).filter_by(session_id=session_id).first() + if existing_profile: + profile.profile_id = existing_profile.profile_id self.db.merge(profile) - # For findings, we might want to sync them (remove old ones if not in new list, or update) - # Simplest for now: add/merge findings + # Remove old findings for this session to avoid stale data + self.db.query(ValidationFinding).filter( + ValidationFinding.session_id == session_id + ).delete() + + # Add new findings for finding in findings: - self.db.merge(finding) + finding.session_id = session_id + self.db.add(finding) self.db.commit() return self.load_session_detail(session_id, user_id)