Coverage for src/rhiza_tools/commands/_shared.py: 100%
57 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-30 13:37 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-30 13:37 +0000
1"""Shared utilities for rhiza-tools commands.
3This module provides common helpers used across multiple command modules
4(bump, release, rollback) to avoid duplication and ensure consistency.
6Utilities:
7 - COOL_STYLE: Shared questionary styling for interactive prompts
8 - run_git_command: Execute git commands with standard error handling
9 - get_current_version: Read the project version from pyproject.toml
10 - get_current_git_branch: Safely determine the current git branch
11 - get_latest_remote_version: Highest semver tag published on the remote
12 - validate_pyproject_exists: Guard against missing pyproject.toml
13"""
15import subprocess # nosec B404 - subprocess needed for git operations
16import tomllib
17from pathlib import Path
19import questionary as qs
20import semver
21import typer
23from rhiza_tools import console
25# The win32 import only succeeds on Windows, so exactly one platform exercises
26# each side of this branch; the fallback can never be covered on Windows.
27try:
28 from prompt_toolkit.output.win32 import ( # type: ignore[attr-defined]
29 NoConsoleScreenBufferError as _WinConsoleError,
30 )
31except (ImportError, AssertionError): # pragma: no cover
33 class _WinConsoleError(Exception): # type: ignore[no-redef]
34 """Sentinel: never raised outside of Windows environments."""
37# Tuple of exceptions indicating a non-interactive environment (no TTY).
38# Use this in except clauses instead of bare ``EOFError`` so that Windows CI
39# (which raises ``NoConsoleScreenBufferError`` instead of ``EOFError``) is
40# handled consistently.
41NON_INTERACTIVE_ERRORS: tuple[type[BaseException], ...] = (EOFError, _WinConsoleError)
43COOL_STYLE = qs.Style(
44 [
45 ("separator", "fg:#cc5454"),
46 ("qmark", "fg:#2FA4A9 bold"),
47 ("question", ""),
48 ("selected", "fg:#2FA4A9 bold"),
49 ("pointer", "fg:#2FA4A9 bold"),
50 ("highlighted", "fg:#2FA4A9 bold"),
51 ("answer", "fg:#2FA4A9 bold"),
52 ("text", "fg:#ffffff"),
53 ("disabled", "fg:#858585 italic"),
54 ]
55)
58def run_git_command(command: list[str], check: bool = True) -> subprocess.CompletedProcess[str]:
59 """Run a git command and return the result.
61 Args:
62 command: The git command to run as a list of arguments.
63 check: If True, raise an exception on non-zero exit code.
65 Returns:
66 CompletedProcess instance with stdout, stderr, and returncode.
68 Raises:
69 subprocess.CalledProcessError: If check=True and command fails.
71 Example:
72 >>> result = run_git_command(["git", "status", "--porcelain"]) # doctest: +SKIP
73 >>> print(result.stdout) # doctest: +SKIP
74 """
75 result = subprocess.run(command, capture_output=True, text=True, check=False) # nosec B603 - git commands are trusted # noqa: S603
76 if check and result.returncode != 0:
77 console.error(f"Git command failed: {' '.join(command)}")
78 console.error(f"Error: {result.stderr}")
79 raise subprocess.CalledProcessError(result.returncode, command, result.stdout, result.stderr)
80 return result
83def get_current_version() -> str:
84 """Read current version from pyproject.toml.
86 Returns:
87 The current version string from the project.version field.
89 Raises:
90 typer.Exit: If pyproject.toml cannot be read or parsed.
92 Example:
93 >>> version = get_current_version() # doctest: +SKIP
94 >>> print(version) # doctest: +SKIP
95 0.1.0
96 """
97 try:
98 with open("pyproject.toml", "rb") as f:
99 data = tomllib.load(f)
100 return str(data["project"]["version"])
101 except (OSError, tomllib.TOMLDecodeError, KeyError, TypeError) as e:
102 console.error(f"Failed to read version from pyproject.toml: {e}")
103 raise typer.Exit(code=1) from None
106def get_current_git_branch() -> str:
107 """Get the current git branch name.
109 This is the safe variant that returns ``"unknown"`` on failure,
110 suitable for display purposes. For strict validation use
111 :func:`run_git_command` directly.
113 Returns:
114 Current branch name or "unknown" if unable to determine.
115 """
116 result = run_git_command(["git", "rev-parse", "--abbrev-ref", "HEAD"], check=False)
117 return result.stdout.strip() if result.returncode == 0 else "unknown"
120def get_latest_remote_version(remote: str = "origin") -> semver.Version | None:
121 """Return the highest semantic-version tag published on *remote*.
123 This reads ``v*``-style tags directly from the remote with ``git ls-remote``
124 (no local fetch, no working-tree changes) and returns the greatest valid
125 semver among them. It is the source of truth for "what is the latest
126 released version" and exists to stop the release tooling from trusting a
127 potentially stale local ``pyproject.toml`` (see issue #1126).
129 Tags that are not valid semantic versions are ignored. The function never
130 raises for the common failure modes (no remote configured, no tags, network
131 unavailable); it returns ``None`` so callers can degrade gracefully.
133 Args:
134 remote: The git remote to query. Defaults to ``"origin"``.
136 Returns:
137 The highest :class:`semver.Version` found on the remote, or ``None`` if
138 the remote cannot be reached or has no valid version tags.
140 Example:
141 >>> latest = get_latest_remote_version() # doctest: +SKIP
142 >>> print(latest) # doctest: +SKIP
143 0.4.0
144 """
145 result = run_git_command(["git", "ls-remote", "--tags", remote], check=False)
146 if result.returncode != 0 or not result.stdout.strip():
147 return None
149 versions: list[semver.Version] = []
150 for line in result.stdout.splitlines():
151 # Each line is "<sha>\trefs/tags/<tag>"; annotated tags also appear
152 # peeled as "refs/tags/<tag>^{}" which we normalise away.
153 ref = line.split("\t")[-1].strip()
154 if not ref.startswith("refs/tags/"):
155 continue
156 tag = ref[len("refs/tags/") :].removesuffix("^{}")
157 if tag.startswith("v"):
158 tag = tag[1:]
159 try:
160 versions.append(semver.Version.parse(tag))
161 except ValueError:
162 # Non-semver tags (e.g. "latest", date stamps) are not releases.
163 continue
165 return max(versions) if versions else None
168def parse_semver_or_exit(version_str: str, *, strip_v_prefix: bool = False) -> semver.Version:
169 """Parse a semantic version string, exiting with a consistent error on failure.
171 Centralises the parse-and-exit pattern that several commands previously
172 duplicated, so an unparseable version is always reported the same way.
174 Args:
175 version_str: The version to parse (e.g. ``"1.2.3"`` or ``"v1.2.3"``).
176 strip_v_prefix: If True, drop a leading ``"v"`` before parsing.
178 Returns:
179 The parsed :class:`semver.Version`.
181 Raises:
182 typer.Exit: If ``version_str`` is not a valid semantic version.
184 Example:
185 >>> parse_semver_or_exit("1.2.3") # doctest: +SKIP
186 Version(major=1, minor=2, patch=3, ...)
187 """
188 candidate = version_str[1:] if strip_v_prefix and version_str.startswith("v") else version_str
189 try:
190 return semver.Version.parse(candidate)
191 except ValueError:
192 console.error(f"Invalid semantic version: {version_str}")
193 raise typer.Exit(code=1) from None
196def validate_pyproject_exists() -> None:
197 """Validate that pyproject.toml exists in the current directory.
199 Raises:
200 typer.Exit: If pyproject.toml is not found.
201 """
202 if not Path("pyproject.toml").exists():
203 console.error("pyproject.toml not found in current directory")
204 raise typer.Exit(code=1)