Coverage for src/rhiza_tools/commands/_shared.py: 100%

57 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-30 13:37 +0000

1"""Shared utilities for rhiza-tools commands. 

2 

3This module provides common helpers used across multiple command modules 

4(bump, release, rollback) to avoid duplication and ensure consistency. 

5 

6Utilities: 

7 - COOL_STYLE: Shared questionary styling for interactive prompts 

8 - run_git_command: Execute git commands with standard error handling 

9 - get_current_version: Read the project version from pyproject.toml 

10 - get_current_git_branch: Safely determine the current git branch 

11 - get_latest_remote_version: Highest semver tag published on the remote 

12 - validate_pyproject_exists: Guard against missing pyproject.toml 

13""" 

14 

15import subprocess # nosec B404 - subprocess needed for git operations 

16import tomllib 

17from pathlib import Path 

18 

19import questionary as qs 

20import semver 

21import typer 

22 

23from rhiza_tools import console 

24 

25# The win32 import only succeeds on Windows, so exactly one platform exercises 

26# each side of this branch; the fallback can never be covered on Windows. 

27try: 

28 from prompt_toolkit.output.win32 import ( # type: ignore[attr-defined] 

29 NoConsoleScreenBufferError as _WinConsoleError, 

30 ) 

31except (ImportError, AssertionError): # pragma: no cover 

32 

33 class _WinConsoleError(Exception): # type: ignore[no-redef] 

34 """Sentinel: never raised outside of Windows environments.""" 

35 

36 

37# Tuple of exceptions indicating a non-interactive environment (no TTY). 

38# Use this in except clauses instead of bare ``EOFError`` so that Windows CI 

39# (which raises ``NoConsoleScreenBufferError`` instead of ``EOFError``) is 

40# handled consistently. 

41NON_INTERACTIVE_ERRORS: tuple[type[BaseException], ...] = (EOFError, _WinConsoleError) 

42 

43COOL_STYLE = qs.Style( 

44 [ 

45 ("separator", "fg:#cc5454"), 

46 ("qmark", "fg:#2FA4A9 bold"), 

47 ("question", ""), 

48 ("selected", "fg:#2FA4A9 bold"), 

49 ("pointer", "fg:#2FA4A9 bold"), 

50 ("highlighted", "fg:#2FA4A9 bold"), 

51 ("answer", "fg:#2FA4A9 bold"), 

52 ("text", "fg:#ffffff"), 

53 ("disabled", "fg:#858585 italic"), 

54 ] 

55) 

56 

57 

58def run_git_command(command: list[str], check: bool = True) -> subprocess.CompletedProcess[str]: 

59 """Run a git command and return the result. 

60 

61 Args: 

62 command: The git command to run as a list of arguments. 

63 check: If True, raise an exception on non-zero exit code. 

64 

65 Returns: 

66 CompletedProcess instance with stdout, stderr, and returncode. 

67 

68 Raises: 

69 subprocess.CalledProcessError: If check=True and command fails. 

70 

71 Example: 

72 >>> result = run_git_command(["git", "status", "--porcelain"]) # doctest: +SKIP 

73 >>> print(result.stdout) # doctest: +SKIP 

74 """ 

75 result = subprocess.run(command, capture_output=True, text=True, check=False) # nosec B603 - git commands are trusted # noqa: S603 

76 if check and result.returncode != 0: 

77 console.error(f"Git command failed: {' '.join(command)}") 

78 console.error(f"Error: {result.stderr}") 

79 raise subprocess.CalledProcessError(result.returncode, command, result.stdout, result.stderr) 

80 return result 

81 

82 

83def get_current_version() -> str: 

84 """Read current version from pyproject.toml. 

85 

86 Returns: 

87 The current version string from the project.version field. 

88 

89 Raises: 

90 typer.Exit: If pyproject.toml cannot be read or parsed. 

91 

92 Example: 

93 >>> version = get_current_version() # doctest: +SKIP 

94 >>> print(version) # doctest: +SKIP 

95 0.1.0 

96 """ 

97 try: 

98 with open("pyproject.toml", "rb") as f: 

99 data = tomllib.load(f) 

100 return str(data["project"]["version"]) 

101 except (OSError, tomllib.TOMLDecodeError, KeyError, TypeError) as e: 

102 console.error(f"Failed to read version from pyproject.toml: {e}") 

103 raise typer.Exit(code=1) from None 

104 

105 

106def get_current_git_branch() -> str: 

107 """Get the current git branch name. 

108 

109 This is the safe variant that returns ``"unknown"`` on failure, 

110 suitable for display purposes. For strict validation use 

111 :func:`run_git_command` directly. 

112 

113 Returns: 

114 Current branch name or "unknown" if unable to determine. 

115 """ 

116 result = run_git_command(["git", "rev-parse", "--abbrev-ref", "HEAD"], check=False) 

117 return result.stdout.strip() if result.returncode == 0 else "unknown" 

118 

119 

120def get_latest_remote_version(remote: str = "origin") -> semver.Version | None: 

121 """Return the highest semantic-version tag published on *remote*. 

122 

123 This reads ``v*``-style tags directly from the remote with ``git ls-remote`` 

124 (no local fetch, no working-tree changes) and returns the greatest valid 

125 semver among them. It is the source of truth for "what is the latest 

126 released version" and exists to stop the release tooling from trusting a 

127 potentially stale local ``pyproject.toml`` (see issue #1126). 

128 

129 Tags that are not valid semantic versions are ignored. The function never 

130 raises for the common failure modes (no remote configured, no tags, network 

131 unavailable); it returns ``None`` so callers can degrade gracefully. 

132 

133 Args: 

134 remote: The git remote to query. Defaults to ``"origin"``. 

135 

136 Returns: 

137 The highest :class:`semver.Version` found on the remote, or ``None`` if 

138 the remote cannot be reached or has no valid version tags. 

139 

140 Example: 

141 >>> latest = get_latest_remote_version() # doctest: +SKIP 

142 >>> print(latest) # doctest: +SKIP 

143 0.4.0 

144 """ 

145 result = run_git_command(["git", "ls-remote", "--tags", remote], check=False) 

146 if result.returncode != 0 or not result.stdout.strip(): 

147 return None 

148 

149 versions: list[semver.Version] = [] 

150 for line in result.stdout.splitlines(): 

151 # Each line is "<sha>\trefs/tags/<tag>"; annotated tags also appear 

152 # peeled as "refs/tags/<tag>^{}" which we normalise away. 

153 ref = line.split("\t")[-1].strip() 

154 if not ref.startswith("refs/tags/"): 

155 continue 

156 tag = ref[len("refs/tags/") :].removesuffix("^{}") 

157 if tag.startswith("v"): 

158 tag = tag[1:] 

159 try: 

160 versions.append(semver.Version.parse(tag)) 

161 except ValueError: 

162 # Non-semver tags (e.g. "latest", date stamps) are not releases. 

163 continue 

164 

165 return max(versions) if versions else None 

166 

167 

168def parse_semver_or_exit(version_str: str, *, strip_v_prefix: bool = False) -> semver.Version: 

169 """Parse a semantic version string, exiting with a consistent error on failure. 

170 

171 Centralises the parse-and-exit pattern that several commands previously 

172 duplicated, so an unparseable version is always reported the same way. 

173 

174 Args: 

175 version_str: The version to parse (e.g. ``"1.2.3"`` or ``"v1.2.3"``). 

176 strip_v_prefix: If True, drop a leading ``"v"`` before parsing. 

177 

178 Returns: 

179 The parsed :class:`semver.Version`. 

180 

181 Raises: 

182 typer.Exit: If ``version_str`` is not a valid semantic version. 

183 

184 Example: 

185 >>> parse_semver_or_exit("1.2.3") # doctest: +SKIP 

186 Version(major=1, minor=2, patch=3, ...) 

187 """ 

188 candidate = version_str[1:] if strip_v_prefix and version_str.startswith("v") else version_str 

189 try: 

190 return semver.Version.parse(candidate) 

191 except ValueError: 

192 console.error(f"Invalid semantic version: {version_str}") 

193 raise typer.Exit(code=1) from None 

194 

195 

196def validate_pyproject_exists() -> None: 

197 """Validate that pyproject.toml exists in the current directory. 

198 

199 Raises: 

200 typer.Exit: If pyproject.toml is not found. 

201 """ 

202 if not Path("pyproject.toml").exists(): 

203 console.error("pyproject.toml not found in current directory") 

204 raise typer.Exit(code=1)