Coverage for src / rhiza / models.py: 99%
119 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-12 20:13 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-12 20:13 +0000
1"""Data models for Rhiza configuration.
3This module defines dataclasses that represent the structure of Rhiza
4configuration files, making it easier to work with them without frequent
5YAML parsing.
6"""
8from dataclasses import dataclass, field
9from pathlib import Path
10from typing import Any
12import yaml # type: ignore[import-untyped]
14__all__ = [
15 "BundleDefinition",
16 "RhizaBundles",
17 "RhizaTemplate",
18]
21def _normalize_to_list(value: str | list[str] | None) -> list[str]:
22 r"""Convert a value to a list of strings.
24 Handles the case where YAML multi-line strings (using |) are parsed as
25 a single string instead of a list. Splits the string by newlines and
26 strips whitespace from each item.
28 Args:
29 value: A string, list of strings, or None.
31 Returns:
32 A list of strings. Empty list if value is None or empty.
34 Examples:
35 >>> _normalize_to_list(None)
36 []
37 >>> _normalize_to_list([])
38 []
39 >>> _normalize_to_list(['a', 'b', 'c'])
40 ['a', 'b', 'c']
41 >>> _normalize_to_list('single line')
42 ['single line']
43 >>> _normalize_to_list('line1\\n' + 'line2\\n' + 'line3')
44 ['line1', 'line2', 'line3']
45 >>> _normalize_to_list(' item1 \\n' + ' item2 ')
46 ['item1', 'item2']
47 """
48 if value is None:
49 return []
50 if isinstance(value, list):
51 return value
52 if isinstance(value, str):
53 # Split by newlines and filter out empty strings
54 # Handle both actual newlines (\n) and literal backslash-n (\\n)
55 if "\\n" in value and "\n" not in value:
56 # Contains literal \n but not actual newlines
57 items = value.split("\\n")
58 else:
59 # Contains actual newlines or neither
60 items = value.split("\n")
61 return [item.strip() for item in items if item.strip()]
62 return []
65@dataclass
66class BundleDefinition:
67 """Represents a single bundle from template-bundles.yml.
69 Attributes:
70 name: The bundle identifier (e.g., "core", "tests", "github").
71 description: Human-readable description of the bundle.
72 files: List of file paths included in this bundle.
73 workflows: List of workflow file paths included in this bundle.
74 depends_on: List of bundle names that this bundle depends on.
75 """
77 name: str
78 description: str
79 files: list[str] = field(default_factory=list)
80 workflows: list[str] = field(default_factory=list)
81 depends_on: list[str] = field(default_factory=list)
83 def all_paths(self) -> list[str]:
84 """Return combined files and workflows."""
85 return self.files + self.workflows
88@dataclass
89class RhizaBundles:
90 """Represents the structure of template-bundles.yml.
92 Attributes:
93 version: Version string of the bundles configuration format.
94 bundles: Dictionary mapping bundle names to their definitions.
95 """
97 version: str
98 bundles: dict[str, BundleDefinition] = field(default_factory=dict)
100 @classmethod
101 def from_yaml(cls, file_path: Path) -> "RhizaBundles":
102 """Load RhizaBundles from a YAML file.
104 Args:
105 file_path: Path to the template-bundles.yml file.
107 Returns:
108 The loaded bundles configuration.
110 Raises:
111 FileNotFoundError: If the file does not exist.
112 yaml.YAMLError: If the YAML is malformed.
113 ValueError: If the file is invalid or missing required fields.
114 TypeError: If bundle data has invalid types.
115 """
116 with open(file_path) as f:
117 config = yaml.safe_load(f)
119 if not config:
120 raise ValueError("Bundles file is empty") # noqa: TRY003
122 version = config.get("version")
123 if not version:
124 raise ValueError("Bundles file missing required field: version") # noqa: TRY003
126 bundles_config = config.get("bundles", {})
127 if not isinstance(bundles_config, dict):
128 msg = "Bundles must be a dictionary"
129 raise TypeError(msg)
131 bundles: dict[str, BundleDefinition] = {}
132 for bundle_name, bundle_data in bundles_config.items():
133 if not isinstance(bundle_data, dict):
134 msg = f"Bundle '{bundle_name}' must be a dictionary"
135 raise TypeError(msg)
137 files = _normalize_to_list(bundle_data.get("files"))
138 workflows = _normalize_to_list(bundle_data.get("workflows"))
139 depends_on = _normalize_to_list(bundle_data.get("depends-on"))
141 bundles[bundle_name] = BundleDefinition(
142 name=bundle_name,
143 description=bundle_data.get("description", ""),
144 files=files,
145 workflows=workflows,
146 depends_on=depends_on,
147 )
149 return cls(version=version, bundles=bundles)
151 def resolve_dependencies(self, bundle_names: list[str]) -> list[str]:
152 """Resolve bundle dependencies using topological sort.
154 Args:
155 bundle_names: List of bundle names to resolve.
157 Returns:
158 Ordered list of bundle names with dependencies first, no duplicates.
160 Raises:
161 ValueError: If a bundle doesn't exist or circular dependency detected.
162 """
163 # Validate all bundles exist
164 for name in bundle_names:
165 if name not in self.bundles:
166 raise ValueError(f"Bundle '{name}' not found in template-bundles.yml") # noqa: TRY003
168 resolved: list[str] = []
169 visiting: set[str] = set()
170 visited: set[str] = set()
172 def visit(bundle_name: str) -> None:
173 if bundle_name in visited:
174 return
175 if bundle_name in visiting:
176 raise ValueError(f"Circular dependency detected involving '{bundle_name}'") # noqa: TRY003
178 visiting.add(bundle_name)
179 bundle = self.bundles[bundle_name]
181 for dep in bundle.depends_on:
182 if dep not in self.bundles:
183 raise ValueError(f"Bundle '{bundle_name}' depends on unknown bundle '{dep}'") # noqa: TRY003
184 visit(dep)
186 visiting.remove(bundle_name)
187 visited.add(bundle_name)
188 resolved.append(bundle_name)
190 for name in bundle_names:
191 visit(name)
193 return resolved
195 def resolve_to_paths(self, bundle_names: list[str]) -> list[str]:
196 """Convert bundle names to deduplicated file paths.
198 Args:
199 bundle_names: List of bundle names to resolve.
201 Returns:
202 Deduplicated list of file paths from all bundles and their dependencies.
204 Raises:
205 ValueError: If a bundle doesn't exist or circular dependency detected.
206 """
207 resolved_bundles = self.resolve_dependencies(bundle_names)
208 paths: list[str] = []
209 seen: set[str] = set()
211 for bundle_name in resolved_bundles:
212 bundle = self.bundles[bundle_name]
213 for path in bundle.all_paths():
214 if path not in seen:
215 paths.append(path)
216 seen.add(path)
218 return paths
221@dataclass
222class RhizaTemplate:
223 """Represents the structure of .rhiza/template.yml.
225 Attributes:
226 template_repository: The GitHub or GitLab repository containing templates (e.g., "jebel-quant/rhiza").
227 Can be None if not specified in the template file.
228 template_branch: The branch to use from the template repository.
229 Can be None if not specified in the template file (defaults to "main" when creating).
230 template_host: The git hosting platform ("github" or "gitlab").
231 Defaults to "github" if not specified in the template file.
232 include: List of paths to include from the template repository (path-based mode).
233 exclude: List of paths to exclude from the template repository (default: empty list).
234 templates: List of template names to include (template-based mode).
235 Can be used together with include to merge paths.
236 """
238 template_repository: str | None = None
239 template_branch: str | None = None
240 template_host: str = "github"
241 include: list[str] = field(default_factory=list)
242 exclude: list[str] = field(default_factory=list)
243 templates: list[str] = field(default_factory=list)
245 @classmethod
246 def from_yaml(cls, file_path: Path) -> "RhizaTemplate":
247 """Load RhizaTemplate from a YAML file.
249 Args:
250 file_path: Path to the template.yml file.
252 Returns:
253 The loaded template configuration.
255 Raises:
256 FileNotFoundError: If the file does not exist.
257 yaml.YAMLError: If the YAML is malformed.
258 ValueError: If the file is empty.
259 """
260 with open(file_path) as f:
261 config = yaml.safe_load(f)
263 if not config:
264 raise ValueError("Template file is empty") # noqa: TRY003
266 return cls(
267 template_repository=config.get("template-repository"),
268 template_branch=config.get("template-branch"),
269 template_host=config.get("template-host", "github"),
270 include=_normalize_to_list(config.get("include")),
271 exclude=_normalize_to_list(config.get("exclude")),
272 templates=_normalize_to_list(config.get("templates")),
273 )
275 def to_yaml(self, file_path: Path) -> None:
276 """Save RhizaTemplate to a YAML file.
278 Args:
279 file_path: Path where the template.yml file should be saved.
280 """
281 # Ensure parent directory exists
282 file_path.parent.mkdir(parents=True, exist_ok=True)
284 # Convert to dictionary with YAML-compatible keys
285 config: dict[str, Any] = {}
287 # Only include template-repository if it's not None
288 if self.template_repository:
289 config["template-repository"] = self.template_repository
291 # Only include template-branch if it's not None
292 if self.template_branch:
293 config["template-branch"] = self.template_branch
295 # Only include template-host if it's not the default "github"
296 if self.template_host and self.template_host != "github":
297 config["template-host"] = self.template_host
299 # Write templates if present
300 if self.templates:
301 config["templates"] = self.templates
303 # Write include if present (can coexist with templates)
304 if self.include:
305 config["include"] = self.include
307 # Only include exclude if it's not empty
308 if self.exclude:
309 config["exclude"] = self.exclude
311 with open(file_path, "w") as f:
312 yaml.dump(config, f, default_flow_style=False, sort_keys=False)