Coverage for src / rhiza / models.py: 99%

119 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-12 20:13 +0000

1"""Data models for Rhiza configuration. 

2 

3This module defines dataclasses that represent the structure of Rhiza 

4configuration files, making it easier to work with them without frequent 

5YAML parsing. 

6""" 

7 

8from dataclasses import dataclass, field 

9from pathlib import Path 

10from typing import Any 

11 

12import yaml # type: ignore[import-untyped] 

13 

14__all__ = [ 

15 "BundleDefinition", 

16 "RhizaBundles", 

17 "RhizaTemplate", 

18] 

19 

20 

21def _normalize_to_list(value: str | list[str] | None) -> list[str]: 

22 r"""Convert a value to a list of strings. 

23 

24 Handles the case where YAML multi-line strings (using |) are parsed as 

25 a single string instead of a list. Splits the string by newlines and 

26 strips whitespace from each item. 

27 

28 Args: 

29 value: A string, list of strings, or None. 

30 

31 Returns: 

32 A list of strings. Empty list if value is None or empty. 

33 

34 Examples: 

35 >>> _normalize_to_list(None) 

36 [] 

37 >>> _normalize_to_list([]) 

38 [] 

39 >>> _normalize_to_list(['a', 'b', 'c']) 

40 ['a', 'b', 'c'] 

41 >>> _normalize_to_list('single line') 

42 ['single line'] 

43 >>> _normalize_to_list('line1\\n' + 'line2\\n' + 'line3') 

44 ['line1', 'line2', 'line3'] 

45 >>> _normalize_to_list(' item1 \\n' + ' item2 ') 

46 ['item1', 'item2'] 

47 """ 

48 if value is None: 

49 return [] 

50 if isinstance(value, list): 

51 return value 

52 if isinstance(value, str): 

53 # Split by newlines and filter out empty strings 

54 # Handle both actual newlines (\n) and literal backslash-n (\\n) 

55 if "\\n" in value and "\n" not in value: 

56 # Contains literal \n but not actual newlines 

57 items = value.split("\\n") 

58 else: 

59 # Contains actual newlines or neither 

60 items = value.split("\n") 

61 return [item.strip() for item in items if item.strip()] 

62 return [] 

63 

64 

65@dataclass 

66class BundleDefinition: 

67 """Represents a single bundle from template-bundles.yml. 

68 

69 Attributes: 

70 name: The bundle identifier (e.g., "core", "tests", "github"). 

71 description: Human-readable description of the bundle. 

72 files: List of file paths included in this bundle. 

73 workflows: List of workflow file paths included in this bundle. 

74 depends_on: List of bundle names that this bundle depends on. 

75 """ 

76 

77 name: str 

78 description: str 

79 files: list[str] = field(default_factory=list) 

80 workflows: list[str] = field(default_factory=list) 

81 depends_on: list[str] = field(default_factory=list) 

82 

83 def all_paths(self) -> list[str]: 

84 """Return combined files and workflows.""" 

85 return self.files + self.workflows 

86 

87 

88@dataclass 

89class RhizaBundles: 

90 """Represents the structure of template-bundles.yml. 

91 

92 Attributes: 

93 version: Version string of the bundles configuration format. 

94 bundles: Dictionary mapping bundle names to their definitions. 

95 """ 

96 

97 version: str 

98 bundles: dict[str, BundleDefinition] = field(default_factory=dict) 

99 

100 @classmethod 

101 def from_yaml(cls, file_path: Path) -> "RhizaBundles": 

102 """Load RhizaBundles from a YAML file. 

103 

104 Args: 

105 file_path: Path to the template-bundles.yml file. 

106 

107 Returns: 

108 The loaded bundles configuration. 

109 

110 Raises: 

111 FileNotFoundError: If the file does not exist. 

112 yaml.YAMLError: If the YAML is malformed. 

113 ValueError: If the file is invalid or missing required fields. 

114 TypeError: If bundle data has invalid types. 

115 """ 

116 with open(file_path) as f: 

117 config = yaml.safe_load(f) 

118 

119 if not config: 

120 raise ValueError("Bundles file is empty") # noqa: TRY003 

121 

122 version = config.get("version") 

123 if not version: 

124 raise ValueError("Bundles file missing required field: version") # noqa: TRY003 

125 

126 bundles_config = config.get("bundles", {}) 

127 if not isinstance(bundles_config, dict): 

128 msg = "Bundles must be a dictionary" 

129 raise TypeError(msg) 

130 

131 bundles: dict[str, BundleDefinition] = {} 

132 for bundle_name, bundle_data in bundles_config.items(): 

133 if not isinstance(bundle_data, dict): 

134 msg = f"Bundle '{bundle_name}' must be a dictionary" 

135 raise TypeError(msg) 

136 

137 files = _normalize_to_list(bundle_data.get("files")) 

138 workflows = _normalize_to_list(bundle_data.get("workflows")) 

139 depends_on = _normalize_to_list(bundle_data.get("depends-on")) 

140 

141 bundles[bundle_name] = BundleDefinition( 

142 name=bundle_name, 

143 description=bundle_data.get("description", ""), 

144 files=files, 

145 workflows=workflows, 

146 depends_on=depends_on, 

147 ) 

148 

149 return cls(version=version, bundles=bundles) 

150 

151 def resolve_dependencies(self, bundle_names: list[str]) -> list[str]: 

152 """Resolve bundle dependencies using topological sort. 

153 

154 Args: 

155 bundle_names: List of bundle names to resolve. 

156 

157 Returns: 

158 Ordered list of bundle names with dependencies first, no duplicates. 

159 

160 Raises: 

161 ValueError: If a bundle doesn't exist or circular dependency detected. 

162 """ 

163 # Validate all bundles exist 

164 for name in bundle_names: 

165 if name not in self.bundles: 

166 raise ValueError(f"Bundle '{name}' not found in template-bundles.yml") # noqa: TRY003 

167 

168 resolved: list[str] = [] 

169 visiting: set[str] = set() 

170 visited: set[str] = set() 

171 

172 def visit(bundle_name: str) -> None: 

173 if bundle_name in visited: 

174 return 

175 if bundle_name in visiting: 

176 raise ValueError(f"Circular dependency detected involving '{bundle_name}'") # noqa: TRY003 

177 

178 visiting.add(bundle_name) 

179 bundle = self.bundles[bundle_name] 

180 

181 for dep in bundle.depends_on: 

182 if dep not in self.bundles: 

183 raise ValueError(f"Bundle '{bundle_name}' depends on unknown bundle '{dep}'") # noqa: TRY003 

184 visit(dep) 

185 

186 visiting.remove(bundle_name) 

187 visited.add(bundle_name) 

188 resolved.append(bundle_name) 

189 

190 for name in bundle_names: 

191 visit(name) 

192 

193 return resolved 

194 

195 def resolve_to_paths(self, bundle_names: list[str]) -> list[str]: 

196 """Convert bundle names to deduplicated file paths. 

197 

198 Args: 

199 bundle_names: List of bundle names to resolve. 

200 

201 Returns: 

202 Deduplicated list of file paths from all bundles and their dependencies. 

203 

204 Raises: 

205 ValueError: If a bundle doesn't exist or circular dependency detected. 

206 """ 

207 resolved_bundles = self.resolve_dependencies(bundle_names) 

208 paths: list[str] = [] 

209 seen: set[str] = set() 

210 

211 for bundle_name in resolved_bundles: 

212 bundle = self.bundles[bundle_name] 

213 for path in bundle.all_paths(): 

214 if path not in seen: 

215 paths.append(path) 

216 seen.add(path) 

217 

218 return paths 

219 

220 

221@dataclass 

222class RhizaTemplate: 

223 """Represents the structure of .rhiza/template.yml. 

224 

225 Attributes: 

226 template_repository: The GitHub or GitLab repository containing templates (e.g., "jebel-quant/rhiza"). 

227 Can be None if not specified in the template file. 

228 template_branch: The branch to use from the template repository. 

229 Can be None if not specified in the template file (defaults to "main" when creating). 

230 template_host: The git hosting platform ("github" or "gitlab"). 

231 Defaults to "github" if not specified in the template file. 

232 include: List of paths to include from the template repository (path-based mode). 

233 exclude: List of paths to exclude from the template repository (default: empty list). 

234 templates: List of template names to include (template-based mode). 

235 Can be used together with include to merge paths. 

236 """ 

237 

238 template_repository: str | None = None 

239 template_branch: str | None = None 

240 template_host: str = "github" 

241 include: list[str] = field(default_factory=list) 

242 exclude: list[str] = field(default_factory=list) 

243 templates: list[str] = field(default_factory=list) 

244 

245 @classmethod 

246 def from_yaml(cls, file_path: Path) -> "RhizaTemplate": 

247 """Load RhizaTemplate from a YAML file. 

248 

249 Args: 

250 file_path: Path to the template.yml file. 

251 

252 Returns: 

253 The loaded template configuration. 

254 

255 Raises: 

256 FileNotFoundError: If the file does not exist. 

257 yaml.YAMLError: If the YAML is malformed. 

258 ValueError: If the file is empty. 

259 """ 

260 with open(file_path) as f: 

261 config = yaml.safe_load(f) 

262 

263 if not config: 

264 raise ValueError("Template file is empty") # noqa: TRY003 

265 

266 return cls( 

267 template_repository=config.get("template-repository"), 

268 template_branch=config.get("template-branch"), 

269 template_host=config.get("template-host", "github"), 

270 include=_normalize_to_list(config.get("include")), 

271 exclude=_normalize_to_list(config.get("exclude")), 

272 templates=_normalize_to_list(config.get("templates")), 

273 ) 

274 

275 def to_yaml(self, file_path: Path) -> None: 

276 """Save RhizaTemplate to a YAML file. 

277 

278 Args: 

279 file_path: Path where the template.yml file should be saved. 

280 """ 

281 # Ensure parent directory exists 

282 file_path.parent.mkdir(parents=True, exist_ok=True) 

283 

284 # Convert to dictionary with YAML-compatible keys 

285 config: dict[str, Any] = {} 

286 

287 # Only include template-repository if it's not None 

288 if self.template_repository: 

289 config["template-repository"] = self.template_repository 

290 

291 # Only include template-branch if it's not None 

292 if self.template_branch: 

293 config["template-branch"] = self.template_branch 

294 

295 # Only include template-host if it's not the default "github" 

296 if self.template_host and self.template_host != "github": 

297 config["template-host"] = self.template_host 

298 

299 # Write templates if present 

300 if self.templates: 

301 config["templates"] = self.templates 

302 

303 # Write include if present (can coexist with templates) 

304 if self.include: 

305 config["include"] = self.include 

306 

307 # Only include exclude if it's not empty 

308 if self.exclude: 

309 config["exclude"] = self.exclude 

310 

311 with open(file_path, "w") as f: 

312 yaml.dump(config, f, default_flow_style=False, sort_keys=False)