diff --git a/aicmt/git_operations.py b/aicmt/git_operations.py index 5a0440a..d36ddfc 100644 --- a/aicmt/git_operations.py +++ b/aicmt/git_operations.py @@ -1,15 +1,24 @@ -import os -from typing import List, Dict, NamedTuple, Optional, Tuple +from contextlib import contextmanager +from typing import List, NamedTuple, Optional, Tuple, Union, Any import git from git import Repo from rich.console import Console from pathlib import Path +from enum import Enum console = Console() class Change(NamedTuple): - """Represents a change in a Git repository""" + """Represents a change in a Git repository + + Attributes: + file: Path to the changed file + status: Current status of the file ('modified', 'deleted', 'new file', etc.) + diff: Actual changes or special messages ('[File deleted]', '[Binary file]') + insertions: Number of inserted lines + deletions: Number of deleted lines + """ file: str status: str @@ -18,7 +27,37 @@ class Change(NamedTuple): deletions: int +class FileStatus(str, Enum): + """Enum representing possible file statuses""" + + MODIFIED = "modified" + DELETED = "deleted" + NEW_FILE = "new file" + NEW_BINARY = "new file (binary)" + ERROR = "error" + + +# Constants +BINARY_MESSAGE = "[Binary file]" +DELETED_MESSAGE = "[File deleted]" +DEFAULT_REMOTE = "origin" +MAX_BINARY_CHECK_SIZE = 1024 * 1024 # 1MB + + +@contextmanager +def safe_file_operation(file_path: Union[str, Path]) -> Any: + """Context manager for safe file operations with proper error handling""" + try: + yield + except UnicodeDecodeError: + return FileStatus.NEW_BINARY, BINARY_MESSAGE + except IOError as e: + console.print(f"[red]Error reading file {file_path}: {str(e)}[/red]") + raise + + class GitOperations: + def __init__(self, repo_path: str = "."): """Initialize GitOperations with a repository path @@ -33,10 +72,46 @@ def __init__(self, repo_path: str = "."): self.repo = Repo(repo_path) self.git = self.repo.git except git.InvalidGitRepositoryError: - raise git.InvalidGitRepositoryError(f"'{repo_path}' is not a valid Git repository") + raise git.InvalidGitRepositoryError( + f"'{repo_path}' is not a valid Git repository") except git.NoSuchPathError: raise git.NoSuchPathError(f"Path '{repo_path}' does not exist") + def _is_binary_file(self, file_path: Path) -> bool: + """Check if a file is binary + + Args: + file_path: Path to the file to check + + Returns: + bool: True if the file is binary, False otherwise + """ + if not file_path.exists(): + return False + try: + with file_path.open("rb") as f: + chunk = f.read(MAX_BINARY_CHECK_SIZE) + return b"\0" in chunk or not chunk.decode("utf-8", + errors="ignore") + + except IOError: + return False + + def _get_file_content(self, file_path: Path) -> Tuple[str, str]: + """Get file content with proper status + + Args: + file_path: Path to the file + + Returns: + Tuple[str, str]: (status, content) + """ + with safe_file_operation(file_path): + if self._is_binary_file(file_path): + return FileStatus.NEW_BINARY, BINARY_MESSAGE + + return FileStatus.NEW_FILE, file_path.read_text(encoding="utf-8") + def get_unstaged_changes(self) -> List[Change]: """Get all unstaged changes in the repository @@ -51,41 +126,45 @@ def get_unstaged_changes(self) -> List[Change]: IOError: If there is an error reading a file git.GitCommandError: If there is an error executing a git command """ - changes = [] + changes: List[Change] = [] - # Get modified files - modified_files = {item.a_path for item in self.repo.index.diff(None)} - - # Get untracked files - untracked_files = set(self.repo.untracked_files) - - for file_path in modified_files.union(untracked_files): + # Handle modified and deleted files + for item in self.repo.index.diff(None): try: - file_path_obj = Path(file_path) - status = "" - diff = "" - - if file_path in modified_files: - status, diff = self._handle_modified_file(file_path, file_path_obj) - else: - status, diff = self._handle_untracked_file(file_path, file_path_obj) - - if diff and not diff.startswith("["): - insertions, deletions = self._calculate_diff_stats(diff) - else: - insertions = deletions = 0 - + path_obj = Path(item.a_path) + file_status, diff = self._handle_modified_file( + item.a_path, path_obj) + insertions, deletions = (0, 0) if diff.startswith( + "[") else self._calculate_diff_stats(diff) changes.append( - Change( - file=file_path, - status=status, - diff=diff, - insertions=insertions, - deletions=deletions, - ) + Change(file=item.a_path, + status=file_status, + diff=diff, + insertions=insertions, + deletions=deletions)) + except Exception as e: + console.print( + f"[yellow]Warning: Could not process {item.a_path}: {str(e)}[/yellow]" ) + + # Handle untracked files separately + for file_path in self.repo.untracked_files: + try: + path_obj = Path(file_path) + file_status, diff = self._handle_untracked_file( + file_path, path_obj) + insertions = len( + diff.splitlines()) if not diff.startswith("[") else 0 + changes.append( + Change(file=file_path, + status=file_status, + diff=diff, + insertions=insertions, + deletions=0)) except Exception as e: - console.print(f"[yellow]Warning: Could not process {file_path}: {str(e)}[/yellow]") + console.print( + f"[yellow]Warning: Could not process {file_path}: {str(e)}[/yellow]" + ) return changes @@ -108,18 +187,24 @@ def get_staged_changes(self) -> List[Change]: diff_index = self.repo.index.diff(None, staged=True) for diff in diff_index: - status = "error" + status = FileStatus.ERROR content = "" insertions = 0 deletions = 0 try: - status, content, insertions, deletions = self._process_file_diff(diff) + status, content, insertions, deletions = self._process_file_diff( + diff) except Exception as e: - status = "error" + status = FileStatus.ERROR content = f"[Unexpected error: {str(e)}]" - changes.append(Change(file=diff.b_path or diff.a_path, status=status, diff=content, insertions=insertions, deletions=deletions)) + changes.append( + Change(file=diff.b_path or diff.a_path, + status=status, + diff=content, + insertions=insertions, + deletions=deletions)) return changes @@ -137,50 +222,48 @@ def _process_file_diff(self, diff) -> Tuple: - insertions: number of lines inserted - deletions: number of lines deleted """ - status = "" - content = "" - insertions = 0 - deletions = 0 - if diff.deleted_file: - status = "deleted" - content = "[File deleted]" - insertions, deletions = 0, len(diff.a_blob.data_stream.read().decode("utf-8").splitlines()) - elif diff.new_file: - if diff.b_blob and diff.b_blob.mime_type != "text/plain": - status = "new file (binary)" - content = "[Binary file]" - else: - status = "new file" - file_path = os.path.join(self.repo.working_dir, diff.b_path) - try: - with open(file_path, "r", encoding="utf-8") as f: - content = f.read() - insertions = len(content.splitlines()) - deletions = 0 - except IOError as e: - content = f"[Error reading file: {str(e)}]" - else: - status = "modified" try: - # Check if the file is modified in the staging area - staged_diff = self.repo.git.diff("--cached", diff.a_path) - if staged_diff: - content = staged_diff - # stats = self.repo.git.diff("--cached", "--numstat",diff.a_path).split() - else: - # If the file is not modified in the staging area, compare with the parent commit - content = self.repo.git.diff("HEAD^", "HEAD", diff.a_path) - # stats = self.repo.git.diff("HEAD^", "HEAD", "--numstat",diff.a_path).split() + return (FileStatus.DELETED, DELETED_MESSAGE, 0, + len(diff.a_blob.data_stream.read().decode( + "utf-8").splitlines())) + except Exception: + return FileStatus.DELETED, DELETED_MESSAGE, 0, 0 - insertions, deletions = self._calculate_diff_stats(content) - except git.GitCommandError as e: - content = f"[Error getting diff: {str(e)}]" + if diff.new_file: + if diff.b_blob and diff.b_blob.mime_type != "text/plain": + return FileStatus.NEW_BINARY, BINARY_MESSAGE, 0, 0 + + file_path = Path(self.repo.working_dir) / diff.b_path + with safe_file_operation(file_path): + content = file_path.read_text(encoding="utf-8") + return FileStatus.NEW_FILE, content, len( + content.splitlines()), 0 + + # Handle modified files + try: + # Check if the file is modified in the staging area + staged_diff = self.repo.git.diff("--cached", diff.a_path) + if staged_diff: + content = staged_diff + else: + # If the file is not modified in the staging area, compare with the parent commit + content = self.repo.git.diff("HEAD^", "HEAD", diff.a_path) - return status, content, insertions, deletions + insertions, deletions = self._calculate_diff_stats(content) + return FileStatus.MODIFIED, content, insertions, deletions + except git.GitCommandError as e: + return FileStatus.ERROR, f"[Error getting diff: {str(e)}]", 0, 0 def _calculate_diff_stats(self, diff_content: str) -> Tuple[int, int]: - """Caculates the number of inserted and deleted lines in a diff content""" + """Calculates the number of inserted and deleted lines in a diff content + + Args: + diff_content: The diff content to analyze + + Returns: + Tuple[int, int]: Number of insertions and deletions + """ insertions = deletions = 0 for line in diff_content.split("\n"): if line.startswith("+") and not line.startswith("+++"): @@ -189,7 +272,8 @@ def _calculate_diff_stats(self, diff_content: str) -> Tuple[int, int]: deletions += 1 return insertions, deletions - def _handle_modified_file(self, file_path: str, file_path_obj: Path) -> Tuple[str, str]: + def _handle_modified_file(self, file_path: str, + file_path_obj: Path) -> Tuple[str, str]: """Handle modified file status and diff generation Args: @@ -202,15 +286,16 @@ def _handle_modified_file(self, file_path: str, file_path_obj: Path) -> Tuple[st try: # Try to get diff first diff = self.git.diff(file_path) - return "modified", diff + return FileStatus.MODIFIED, diff except git.exc.GitCommandError: # If file doesn't exist, treat it as deleted if not file_path_obj.exists(): - return "deleted", "[File deleted]" + return FileStatus.DELETED, DELETED_MESSAGE # If file exists but diff failed, something else is wrong raise IOError(f"Failed to get diff for {file_path}") - def _handle_untracked_file(self, file_path: str, file_path_obj: Path) -> Tuple[str, str]: + def _handle_untracked_file(self, file_path: str, + file_path_obj: Path) -> Tuple[str, str]: """Handle untracked file status and content reading Args: @@ -221,20 +306,21 @@ def _handle_untracked_file(self, file_path: str, file_path_obj: Path) -> Tuple[s Tuple[str, str]: (status, content) """ if not file_path_obj.is_file(): - return "deleted", "[File deleted]" + return FileStatus.DELETED, DELETED_MESSAGE try: # Check if file is binary with open(file_path, "rb") as f: - content = f.read(1024 * 1024) # Read first MB to check for binary content + content = f.read( + 1024 * 1024) # Read first MB to check for binary content if b"\0" in content: - return "new file (binary)", "[Binary file]" + return FileStatus.NEW_BINARY, BINARY_MESSAGE # File is not binary, try to read as text with open(file_path, "r", encoding="utf-8") as f: - return "new file", f.read() + return FileStatus.NEW_FILE, f.read() except UnicodeDecodeError: - return "new file (binary)", "[Binary file]" + return FileStatus.NEW_BINARY, BINARY_MESSAGE def stage_files(self, files: List[str]) -> None: """Stage specified files @@ -247,14 +333,15 @@ def stage_files(self, files: List[str]) -> None: """ if not files: raise ValueError("No files to stage!") - + try: # Get current status of files status = self.repo.git.status("--porcelain").splitlines() for file in files: # Find status for this file - file_status = next((s for s in status if s.split()[-1] == file), None) + file_status = next( + (s for s in status if s.split()[-1] == file), None) if file_status and file_status.startswith(" D"): # File is deleted, use remove self.repo.index.remove([file]) @@ -262,7 +349,8 @@ def stage_files(self, files: List[str]) -> None: # File is modified or new, use add self.repo.index.add([file]) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to stage files: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to stage files: {str(e)}", + e.status, e.stderr) from e def commit_changes(self, message: str) -> None: """Create a commit with the staged changes @@ -276,9 +364,12 @@ def commit_changes(self, message: str) -> None: try: self.repo.index.commit(message) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to commit changes: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to commit changes: {str(e)}", + e.status, e.stderr) - def push_changes(self, remote: str = "origin", branch: Optional[str] = None) -> None: + def push_changes(self, + remote: str = DEFAULT_REMOTE, + branch: Optional[str] = None) -> None: """Push commits to remote repository Args: @@ -294,7 +385,8 @@ def push_changes(self, remote: str = "origin", branch: Optional[str] = None) -> origin = self.repo.remote(remote) origin.push(branch) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to push changes: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to push changes: {str(e)}", + e.status, e.stderr) def get_current_branch(self) -> str: """Get the name of the current branch @@ -308,7 +400,8 @@ def get_current_branch(self) -> str: try: return self.repo.active_branch.name except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to get current branch: {str(e)}", e.status, e.stderr) + raise git.GitCommandError( + f"Failed to get current branch: {str(e)}", e.status, e.stderr) def checkout_branch(self, branch_name: str, create: bool = False) -> None: """Checkout a branch @@ -325,34 +418,8 @@ def checkout_branch(self, branch_name: str, create: bool = False) -> None: self.repo.create_head(branch_name) self.repo.git.checkout(branch_name) except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to checkout branch: {str(e)}", e.status, e.stderr) - - def get_commit_history(self, max_count: int = 10) -> List[Dict]: - """Get commit history - - Args: - max_count: Maximum number of commits to return - - Returns: - List[Dict]: List of commits with their details - - Raises: - git.GitCommandError: If there is an error getting the commit history - """ - try: - commits = [] - for commit in self.repo.iter_commits(max_count=max_count): - commits.append( - { - "hash": commit.hexsha, - "message": commit.message.strip(), - "author": str(commit.author), - "date": commit.committed_datetime.isoformat(), - } - ) - return commits - except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to get commit history: {str(e)}", e.status, e.stderr) + raise git.GitCommandError(f"Failed to checkout branch: {str(e)}", + e.status, e.stderr) def get_commit_changes(self, commit_hash: str) -> List[Change]: """Get changes from a specific commit @@ -368,24 +435,32 @@ def get_commit_changes(self, commit_hash: str) -> List[Change]: """ try: commit = self.repo.commit(commit_hash) - parent = commit.parents[0] if commit.parents else self.repo.tree("4b825dc642cb6eb9a060e54bf8d69288fbee4904") + parent = commit.parents[0] if commit.parents else self.repo.tree( + "4b825dc642cb6eb9a060e54bf8d69288fbee4904") changes = [] diff_index = parent.diff(commit) for diff in diff_index: - status = "error" + status = FileStatus.ERROR content = "" insertions = 0 deletions = 0 try: - status, content, insertions, deletions = self._process_file_diff(diff) + status, content, insertions, deletions = self._process_file_diff( + diff) except Exception as e: - status = "error" + status = FileStatus.ERROR content = f"[Unexpected error: {str(e)}]" - changes.append(Change(file=diff.b_path or diff.a_path, status=status, diff=content, insertions=insertions, deletions=deletions)) + changes.append( + Change(file=diff.b_path or diff.a_path, + status=status, + diff=content, + insertions=insertions, + deletions=deletions)) return changes except git.GitCommandError as e: - raise git.GitCommandError(f"Failed to get commit changes: {str(e)}", e.status, e.stderr) + raise git.GitCommandError( + f"Failed to get commit changes: {str(e)}", e.status, e.stderr) diff --git a/tests/test_cli.py b/tests/test_cli.py index 1bf914f..8b02188 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -130,12 +130,14 @@ def test_run_no_approved_groups(mock_confirm, assistant): with patch.multiple(assistant.cli, display_commit_groups=lambda x: []): assistant.run() + def test_create_new_commits(assistant, capsys): commit_groups = [{"files": [], "commit_message": "test commit", "description": "test description"}] assistant._create_new_commits(commit_groups) captured = capsys.readouterr() assert "No files to stage!" in captured.out + def test_commit_creation_failure(assistant, capsys): changes = [Change(file="test.py", status="modified", diff="test diff", insertions=1, deletions=0)] commit_groups = [{"files": ["test.py"], "commit_message": "test commit", "description": "test description"}] diff --git a/tests/test_git_operations.py b/tests/test_git_operations.py index 1eb7327..c194ed3 100644 --- a/tests/test_git_operations.py +++ b/tests/test_git_operations.py @@ -94,7 +94,7 @@ def test_stage_files(temp_git_repo): with pytest.raises(ValueError) as excinfo: git_ops.stage_files([]) assert str(excinfo.value) == "No files to stage!" - + finally: # Restore the original directory os.chdir(current_dir) @@ -137,29 +137,6 @@ def test_checkout_branch(temp_git_repo): assert git_ops.get_current_branch() == "test-branch" -def test_get_commit_history(temp_git_repo): - """Test getting commit history""" - git_ops = GitOperations(temp_git_repo) - - # Switch to the repository directory - current_dir = os.getcwd() - os.chdir(temp_git_repo) - - try: - # Create multiple commits - for i in range(3): - with open(f"history_test_{i}.txt", "w") as f: - f.write(f"test content {i}") - git_ops.stage_files([f"history_test_{i}.txt"]) - git_ops.commit_changes(f"Test commit {i}") - - history = git_ops.get_commit_history(max_count=4) # Modified to 4 to get all commits - assert len(history) == 4 # 3 new commits + 1 initial commit - finally: - # Restore the original directory - os.chdir(current_dir) - - def test_deleted_file_changes(temp_git_repo): """Test handling of deleted files""" git_ops = GitOperations(temp_git_repo) @@ -581,8 +558,8 @@ def test_file_reading(temp_git_repo): changes = git_ops.get_staged_changes() # Test that IOError is raised when trying to get diff new_file = next(c for c in changes if c.file == "new_file.txt") - assert new_file.status == "new file" - assert "Error reading file:" in new_file.diff + assert new_file.status == "error" + assert "No such file" in new_file.diff git_ops.commit_changes("Add file to modify")