]> git.scottworley.com Git - pinch/blame - pinch.py
Appease new stricter minidom typechecks
[pinch] / pinch.py
CommitLineData
05a3d87b
SW
1# pinch: PIN CHannels - a replacement for `nix-channel --update`
2#
3# This program is free software: you can redistribute it and/or modify it
4# under the terms of the GNU General Public License as published by the
5# Free Software Foundation, version 3.
6
7
0e5e611d 8import argparse
f15e458d 9import configparser
2f96f32a
SW
10import filecmp
11import functools
736c25eb 12import getpass
2f96f32a
SW
13import hashlib
14import operator
15import os
16import os.path
9a78329e 17import shlex
2f96f32a 18import shutil
73bec7e8 19import subprocess
9d7844bb 20import sys
0afcdb2a 21import tarfile
2f96f32a 22import tempfile
89e79125 23import types
2f96f32a
SW
24import urllib.parse
25import urllib.request
26import xml.dom.minidom
27
28from typing import (
9d2c406b 29 Callable,
2f96f32a
SW
30 Dict,
31 Iterable,
32 List,
7f4c3ace 33 Mapping,
0a4ff8dd 34 NamedTuple,
73bec7e8 35 NewType,
d7cfdb22 36 Optional,
567a6783 37 Set,
2f96f32a 38 Tuple,
7f4c3ace 39 Type,
567a6783 40 TypeVar,
0a4ff8dd 41 Union,
2f96f32a
SW
42)
43
d06918bc
SW
44import git_cache
45
3603dde2
SW
46# Use xdg module when it's less painful to have as a dependency
47
48
9f936f16 49class XDG(NamedTuple):
3603dde2
SW
50 XDG_CACHE_HOME: str
51
52
53xdg = XDG(
54 XDG_CACHE_HOME=os.getenv(
55 'XDG_CACHE_HOME',
56 os.path.expanduser('~/.cache')))
26125a28
SW
57
58
2f96f32a
SW
59class VerificationError(Exception):
60 pass
61
62
63class Verification:
64
65 def __init__(self) -> None:
66 self.line_length = 0
67
68 def status(self, s: str) -> None:
9d7844bb 69 print(s, end=' ', file=sys.stderr, flush=True)
2f96f32a
SW
70 self.line_length += 1 + len(s) # Unicode??
71
72 @staticmethod
73 def _color(s: str, c: int) -> str:
cb28d8e5 74 return f'\033[{c:2d}m{s}\033[00m'
2f96f32a
SW
75
76 def result(self, r: bool) -> None:
77 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
78 length = len(message)
dd1026fe 79 cols = shutil.get_terminal_size().columns or 80
2f96f32a 80 pad = (cols - (self.line_length + length)) % cols
9d7844bb 81 print(' ' * pad + self._color(message, color), file=sys.stderr)
2f96f32a
SW
82 self.line_length = 0
83 if not r:
84 raise VerificationError()
85
86 def check(self, s: str, r: bool) -> None:
87 self.status(s)
88 self.result(r)
89
90 def ok(self) -> None:
91 self.result(True)
92
93
faff8642
SW
94Digest16 = NewType('Digest16', str)
95Digest32 = NewType('Digest32', str)
96
97
98class ChannelTableEntry(types.SimpleNamespace):
99 absolute_url: str
100 digest: Digest16
101 file: str
102 size: int
103 url: str
104
105
0a4ff8dd
SW
106class AliasPin(NamedTuple):
107 pass
108
109
0afcdb2a
SW
110class SymlinkPin(NamedTuple):
111 @property
112 def release_name(self) -> str:
113 return 'link'
114
115
0a4ff8dd
SW
116class GitPin(NamedTuple):
117 git_revision: str
faff8642
SW
118 release_name: str
119
0a4ff8dd
SW
120
121class ChannelPin(NamedTuple):
122 git_revision: str
123 release_name: str
124 tarball_url: str
125 tarball_sha256: str
126
127
0afcdb2a
SW
128Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
129
130
131def copy_to_nix_store(v: Verification, filename: str) -> str:
132 v.status('Putting tarball in Nix store')
133 process = subprocess.run(
134 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
135 v.result(process.returncode == 0)
530104d7 136 return process.stdout.decode().strip() # type: ignore # (for old mypy)
0a4ff8dd
SW
137
138
96063a51
SW
139def symlink_archive(v: Verification, path: str) -> str:
140 with tempfile.TemporaryDirectory() as td:
141 archive_filename = os.path.join(td, 'link.tar.gz')
142 os.symlink(path, os.path.join(td, 'link'))
143 with tarfile.open(archive_filename, mode='x:gz') as t:
144 t.add(os.path.join(td, 'link'), arcname='link')
145 return copy_to_nix_store(v, archive_filename)
146
147
567a6783 148class AliasSearchPath(NamedTuple):
faff8642 149 alias_of: str
2fa9cbea 150
567a6783 151 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
0a4ff8dd 152 return AliasPin()
2fa9cbea
SW
153
154
0afcdb2a
SW
155class SymlinkSearchPath(NamedTuple):
156 path: str
157
0afcdb2a
SW
158 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
159 return SymlinkPin()
160
161 def fetch(self, v: Verification, _: Pin) -> str:
96063a51 162 return symlink_archive(v, self.path)
0afcdb2a
SW
163
164
567a6783 165class GitSearchPath(NamedTuple):
faff8642
SW
166 git_ref: str
167 git_repo: str
faff8642 168
567a6783 169 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
d06918bc 170 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
567a6783
SW
171 if old_pin is not None:
172 assert isinstance(old_pin, GitPin)
d06918bc 173 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
d7cfdb22
SW
174 return GitPin(release_name=git_revision_name(v, self, new_revision),
175 git_revision=new_revision)
4a82be40 176
567a6783
SW
177 def fetch(self, v: Verification, pin: Pin) -> str:
178 assert isinstance(pin, GitPin)
d06918bc
SW
179 git_cache.ensure_rev_available(
180 self.git_repo, self.git_ref, pin.git_revision)
567a6783 181 return git_get_tarball(v, self, pin)
b3ea39b8 182
b3ea39b8 183
567a6783
SW
184class ChannelSearchPath(NamedTuple):
185 channel_url: str
186 git_ref: str
187 git_repo: str
4a82be40 188
567a6783
SW
189 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
190 if old_pin is not None:
191 assert isinstance(old_pin, ChannelPin)
a67cfec9 192
567a6783
SW
193 channel_html, forwarded_url = fetch_channel(v, self)
194 table, new_gitpin = parse_channel(v, channel_html)
53a27350
SW
195 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
196 return old_pin
567a6783 197 fetch_resources(v, new_gitpin, forwarded_url, table)
d06918bc
SW
198 git_cache.ensure_rev_available(
199 self.git_repo, self.git_ref, new_gitpin.git_revision)
200 if old_pin is not None:
201 verify_git_ancestry(
202 v, self, old_pin.git_revision, new_gitpin.git_revision)
567a6783 203 check_channel_contents(v, self, table, new_gitpin)
0a4ff8dd 204 return ChannelPin(
b17278e3 205 release_name=new_gitpin.release_name,
567a6783
SW
206 tarball_url=table['nixexprs.tar.xz'].absolute_url,
207 tarball_sha256=table['nixexprs.tar.xz'].digest,
3258ff2c 208 git_revision=new_gitpin.git_revision)
4a82be40 209
567a6783
SW
210 def fetch(self, v: Verification, pin: Pin) -> str:
211 assert isinstance(pin, ChannelPin)
b3ea39b8
SW
212
213 return fetch_with_nix_prefetch_url(
567a6783
SW
214 v, pin.tarball_url, Digest16(pin.tarball_sha256))
215
216
0afcdb2a
SW
217SearchPath = Union[AliasSearchPath,
218 SymlinkSearchPath,
219 GitSearchPath,
220 ChannelSearchPath]
567a6783 221TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
b3ea39b8 222
4a82be40 223
dc038df0 224def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
2f96f32a
SW
225
226 def throw(error: OSError) -> None:
227 raise error
228
229 def join(x: str, y: str) -> str:
230 return y if x == '.' else os.path.join(x, y)
231
232 def recursive_files(d: str) -> Iterable[str]:
233 all_files: List[str] = []
234 for path, dirs, files in os.walk(d, onerror=throw):
235 rel = os.path.relpath(path, start=d)
236 all_files.extend(join(rel, f) for f in files)
237 for dir_or_link in dirs:
238 if os.path.islink(join(path, dir_or_link)):
239 all_files.append(join(rel, dir_or_link))
240 return all_files
241
242 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
243 return (f for f in files if not f.startswith('.git/'))
244
245 files = functools.reduce(
246 operator.or_, (set(
247 exclude_dot_git(
248 recursive_files(x))) for x in [a, b]))
249 return filecmp.cmpfiles(a, b, files, shallow=False)
250
251
567a6783
SW
252def fetch_channel(
253 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
cb28d8e5 254 v.status(f'Fetching channel from {channel.channel_url}')
fb27ccc7
SW
255 with urllib.request.urlopen(channel.channel_url, timeout=10) as request:
256 channel_html = request.read().decode()
257 forwarded_url = request.geturl()
258 v.result(request.status == 200)
567a6783
SW
259 v.check('Got forwarded', channel.channel_url != forwarded_url)
260 return channel_html, forwarded_url
2f96f32a
SW
261
262
567a6783
SW
263def parse_channel(v: Verification, channel_html: str) \
264 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
2f96f32a 265 v.status('Parsing channel description as XML')
567a6783 266 d = xml.dom.minidom.parseString(channel_html)
2f96f32a
SW
267 v.ok()
268
7408d43d
SW
269 v.status('Finding release name (1)')
270 title = d.getElementsByTagName('title')[0].firstChild
271 v.result(isinstance(title, xml.dom.minidom.CharacterData))
272 assert isinstance(title, xml.dom.minidom.CharacterData)
273 release_name = title.nodeValue.split()[2]
274 v.status('Finding release name (2)')
275 h1 = d.getElementsByTagName('h1')[0].firstChild
276 v.result(isinstance(h1, xml.dom.minidom.CharacterData))
277 assert isinstance(h1, xml.dom.minidom.CharacterData)
278 v.status('Verifying release name:')
279 v.status(release_name)
280 v.result(release_name == h1.nodeValue.split()[2])
281
282 v.status('Finding git commit')
2f96f32a 283 git_commit_node = d.getElementsByTagName('tt')[0]
7408d43d
SW
284 v.result(
285 isinstance(
286 git_commit_node.firstChild,
287 xml.dom.minidom.CharacterData))
288 assert isinstance(
289 git_commit_node.firstChild,
290 xml.dom.minidom.CharacterData)
291 v.status('Extracting git commit:')
3258ff2c
SW
292 git_revision = git_commit_node.firstChild.nodeValue
293 v.status(git_revision)
2f96f32a
SW
294 v.ok()
295 v.status('Verifying git commit label')
296 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
297
298 v.status('Parsing table')
567a6783 299 table: Dict[str, ChannelTableEntry] = {}
2f96f32a
SW
300 for row in d.getElementsByTagName('tr')[1:]:
301 name = row.childNodes[0].firstChild.firstChild.nodeValue
302 url = row.childNodes[0].firstChild.getAttribute('href')
303 size = int(row.childNodes[1].firstChild.nodeValue)
73bec7e8 304 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
567a6783 305 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
2f96f32a 306 v.ok()
7408d43d 307 return table, GitPin(release_name=release_name, git_revision=git_revision)
2f96f32a
SW
308
309
dc038df0
SW
310def digest_string(s: bytes) -> Digest16:
311 return Digest16(hashlib.sha256(s).hexdigest())
312
313
73bec7e8
SW
314def digest_file(filename: str) -> Digest16:
315 hasher = hashlib.sha256()
316 with open(filename, 'rb') as f:
317 # pylint: disable=cell-var-from-loop
318 for block in iter(lambda: f.read(4096), b''):
319 hasher.update(block)
320 return Digest16(hasher.hexdigest())
321
322
b1e8c1b0
SW
323@functools.lru_cache
324def _experimental_flag_needed(v: Verification) -> bool:
325 v.status('Checking Nix version')
326 process = subprocess.run(['nix', '--help'], stdout=subprocess.PIPE)
327 v.result(process.returncode == 0)
328 return b'--experimental-features' in process.stdout
1d48f551
SW
329
330
331def _nix_command(v: Verification) -> List[str]:
1d48f551 332 return ['nix', '--experimental-features',
b1e8c1b0 333 'nix-command'] if _experimental_flag_needed(v) else ['nix']
1d48f551
SW
334
335
73bec7e8
SW
336def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
337 v.status('Converting digest to base16')
1d48f551
SW
338 process = subprocess.run(_nix_command(v) + [
339 'to-base16',
340 '--type',
341 'sha256',
342 digest32],
343 stdout=subprocess.PIPE)
73bec7e8
SW
344 v.result(process.returncode == 0)
345 return Digest16(process.stdout.decode().strip())
346
347
348def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
349 v.status('Converting digest to base32')
1d48f551
SW
350 process = subprocess.run(_nix_command(v) + [
351 'to-base32',
352 '--type',
353 'sha256',
354 digest16],
355 stdout=subprocess.PIPE)
73bec7e8
SW
356 v.result(process.returncode == 0)
357 return Digest32(process.stdout.decode().strip())
358
359
360def fetch_with_nix_prefetch_url(
361 v: Verification,
362 url: str,
363 digest: Digest16) -> str:
cb28d8e5 364 v.status(f'Fetching {url}')
73bec7e8 365 process = subprocess.run(
ba596fc0 366 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
73bec7e8
SW
367 v.result(process.returncode == 0)
368 prefetch_digest, path, empty = process.stdout.decode().split('\n')
369 assert empty == ''
370 v.check("Verifying nix-prefetch-url's digest",
371 to_Digest16(v, Digest32(prefetch_digest)) == digest)
cb28d8e5 372 v.status(f"Verifying digest of {path}")
73bec7e8
SW
373 file_digest = digest_file(path)
374 v.result(file_digest == digest)
7c4de64c 375 return path # type: ignore # (for old mypy)
2f96f32a 376
73bec7e8 377
9343cf48
SW
378def fetch_resources(
379 v: Verification,
567a6783
SW
380 pin: GitPin,
381 forwarded_url: str,
382 table: Dict[str, ChannelTableEntry]) -> None:
2f96f32a 383 for resource in ['git-revision', 'nixexprs.tar.xz']:
567a6783
SW
384 fields = table[resource]
385 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
e434d96d
SW
386 fields.file = fetch_with_nix_prefetch_url(
387 v, fields.absolute_url, fields.digest)
73bec7e8 388 v.status('Verifying git commit on main page matches git commit in table')
fb27ccc7
SW
389 with open(table['git-revision'].file, encoding='utf-8') as rev_file:
390 v.result(rev_file.read(999) == pin.git_revision)
2f96f32a 391
971d3659 392
b17278e3 393def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
eb0c6f1b
SW
394 return os.path.join(
395 xdg.XDG_CACHE_HOME,
396 'pinch/git-tarball',
cb28d8e5 397 f'{digest_string(channel.git_repo.encode())}-{pin.git_revision}-{pin.release_name}')
971d3659
SW
398
399
d7cfdb22
SW
400def verify_git_ancestry(
401 v: Verification,
402 channel: TarrableSearchPath,
d06918bc
SW
403 old_revision: str,
404 new_revision: str) -> None:
405 cachedir = git_cache.git_cachedir(channel.git_repo)
cb28d8e5 406 v.status(f'Verifying rev is an ancestor of previous rev {old_revision}')
971d3659
SW
407 process = subprocess.run(['git',
408 '-C',
409 cachedir,
410 'merge-base',
411 '--is-ancestor',
d06918bc
SW
412 old_revision,
413 new_revision])
971d3659
SW
414 v.result(process.returncode == 0)
415
dc038df0 416
925c801b
SW
417def compare_tarball_and_git(
418 v: Verification,
a72fdca9 419 pin: GitPin,
925c801b
SW
420 channel_contents: str,
421 git_contents: str) -> None:
422 v.status('Comparing channel tarball with git checkout')
423 match, mismatch, errors = compare(os.path.join(
a72fdca9 424 channel_contents, pin.release_name), git_contents)
925c801b 425 v.ok()
cb28d8e5
SW
426 v.check(f'{len(match)} files match', len(match) > 0)
427 v.check(f'{len(mismatch)} files differ', len(mismatch) == 0)
925c801b
SW
428 expected_errors = [
429 '.git-revision',
430 '.version-suffix',
431 'nixpkgs',
432 'programs.sqlite',
433 'svn-revision']
5d0f42f7
SW
434 permitted_errors = [
435 'pkgs/test/nixpkgs-check-by-name/tests/symlink-invalid/pkgs/by-name/fo/foo/foo.nix',
436 ]
437 benign_expected_errors = []
438 benign_permitted_errors = []
925c801b
SW
439 for ee in expected_errors:
440 if ee in errors:
441 errors.remove(ee)
5d0f42f7
SW
442 benign_expected_errors.append(ee)
443 for pe in permitted_errors:
444 if pe in errors:
445 errors.remove(pe)
446 benign_permitted_errors.append(ee)
ee11f936
SW
447 v.check(
448 f'{len(errors)} unexpected incomparable files: {errors}',
449 len(errors) == 0)
925c801b 450 v.check(
5d0f42f7
SW
451 f'({len(benign_expected_errors)} of {len(expected_errors)} expected incomparable files)',
452 len(benign_expected_errors) == len(expected_errors))
453 v.check(
454 f'({len(benign_permitted_errors)} of {len(permitted_errors)} permitted incomparable files)',
455 len(benign_permitted_errors) <= len(permitted_errors))
925c801b
SW
456
457
7d2c278f
SW
458def extract_tarball(
459 v: Verification,
567a6783 460 table: Dict[str, ChannelTableEntry],
7d2c278f 461 dest: str) -> None:
cb28d8e5 462 v.status(f"Extracting tarball {table['nixexprs.tar.xz'].file}")
567a6783 463 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
925c801b
SW
464 v.ok()
465
466
7d2c278f
SW
467def git_checkout(
468 v: Verification,
469 channel: TarrableSearchPath,
3258ff2c 470 pin: GitPin,
7d2c278f 471 dest: str) -> None:
925c801b 472 v.status('Checking out corresponding git revision')
fb27ccc7
SW
473 with subprocess.Popen(
474 ['git', '-C', git_cache.git_cachedir(channel.git_repo), 'archive', pin.git_revision],
475 stdout=subprocess.PIPE) as git:
476 with subprocess.Popen(['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout) as tar:
477 if git.stdout:
478 git.stdout.close()
479 tar.wait()
480 git.wait()
481 v.result(git.returncode == 0 and tar.returncode == 0)
925c801b
SW
482
483
9343cf48
SW
484def git_get_tarball(
485 v: Verification,
486 channel: TarrableSearchPath,
487 pin: GitPin) -> str:
b17278e3 488 cache_file = tarball_cache_file(channel, pin)
eb0c6f1b 489 if os.path.exists(cache_file):
fb27ccc7
SW
490 with open(cache_file, encoding='utf-8') as f:
491 cached_tarball = f.read(9999)
492 if os.path.exists(cached_tarball):
493 return cached_tarball
eb0c6f1b 494
736c25eb
SW
495 with tempfile.TemporaryDirectory() as output_dir:
496 output_filename = os.path.join(
9343cf48 497 output_dir, pin.release_name + '.tar.xz')
fb27ccc7 498 with open(output_filename, 'w', encoding='utf-8') as output_file:
cb28d8e5 499 v.status(f'Generating tarball for git revision {pin.git_revision}')
fb27ccc7
SW
500 with subprocess.Popen(
501 ['git', '-C', git_cache.git_cachedir(channel.git_repo),
cb28d8e5 502 'archive', f'--prefix={pin.release_name}/', pin.git_revision],
fb27ccc7
SW
503 stdout=subprocess.PIPE) as git:
504 with subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file) as xz:
505 xz.wait()
506 git.wait()
507 v.result(git.returncode == 0 and xz.returncode == 0)
736c25eb 508
0afcdb2a 509 store_tarball = copy_to_nix_store(v, output_filename)
eb0c6f1b
SW
510
511 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
fb27ccc7
SW
512 with open(cache_file, 'w', encoding='utf-8') as f:
513 f.write(store_tarball)
7c4de64c 514 return store_tarball # type: ignore # (for old mypy)
736c25eb
SW
515
516
f9cd7bdc
SW
517def check_channel_metadata(
518 v: Verification,
a72fdca9 519 pin: GitPin,
f9cd7bdc
SW
520 channel_contents: str) -> None:
521 v.status('Verifying git commit in channel tarball')
fb27ccc7
SW
522 with open(os.path.join(channel_contents, pin.release_name, '.git-revision'),
523 encoding='utf-8') as f:
524 v.result(f.read(999) == pin.git_revision)
f9cd7bdc
SW
525
526 v.status(
cb28d8e5 527 f'Verifying version-suffix is a suffix of release name {pin.release_name}:')
fb27ccc7
SW
528 with open(os.path.join(channel_contents, pin.release_name, '.version-suffix'),
529 encoding='utf-8') as f:
530 version_suffix = f.read(999)
f9cd7bdc 531 v.status(version_suffix)
a72fdca9 532 v.result(pin.release_name.endswith(version_suffix))
f9cd7bdc
SW
533
534
7d2c278f
SW
535def check_channel_contents(
536 v: Verification,
a72fdca9 537 channel: TarrableSearchPath,
567a6783 538 table: Dict[str, ChannelTableEntry],
a72fdca9 539 pin: GitPin) -> None:
dc038df0
SW
540 with tempfile.TemporaryDirectory() as channel_contents, \
541 tempfile.TemporaryDirectory() as git_contents:
925c801b 542
567a6783 543 extract_tarball(v, table, channel_contents)
a72fdca9 544 check_channel_metadata(v, pin, channel_contents)
f9cd7bdc 545
3258ff2c 546 git_checkout(v, channel, pin, git_contents)
925c801b 547
a72fdca9 548 compare_tarball_and_git(v, pin, channel_contents, git_contents)
925c801b 549
dc038df0 550 v.status('Removing temporary directories')
2f96f32a
SW
551 v.ok()
552
553
d7cfdb22
SW
554def git_revision_name(
555 v: Verification,
556 channel: TarrableSearchPath,
557 git_revision: str) -> str:
e3cae769
SW
558 v.status('Getting commit date')
559 process = subprocess.run(['git',
560 '-C',
d06918bc 561 git_cache.git_cachedir(channel.git_repo),
bed32182 562 'log',
e3cae769
SW
563 '-n1',
564 '--format=%ct-%h',
565 '--abbrev=11',
88af5903 566 '--no-show-signature',
d7cfdb22 567 git_revision],
ba596fc0 568 stdout=subprocess.PIPE)
de68382a 569 v.result(process.returncode == 0 and process.stdout != b'')
cb28d8e5 570 return f'{os.path.basename(channel.git_repo)}-{process.stdout.decode().strip()}'
e3cae769
SW
571
572
567a6783
SW
573K = TypeVar('K')
574V = TypeVar('V')
575
576
9d2c406b
SW
577def partition_dict(pred: Callable[[K, V], bool],
578 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
567a6783
SW
579 selected: Dict[K, V] = {}
580 remaining: Dict[K, V] = {}
581 for k, v in d.items():
9d2c406b 582 if pred(k, v):
567a6783
SW
583 selected[k] = v
584 else:
585 remaining[k] = v
586 return selected, remaining
587
588
9d2c406b
SW
589def filter_dict(d: Dict[K, V], fields: Set[K]
590 ) -> Tuple[Dict[K, V], Dict[K, V]]:
591 return partition_dict(lambda k, v: k in fields, d)
592
593
d815b199 594def read_config_section(
567a6783
SW
595 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
596 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
597 'alias': (AliasSearchPath, AliasPin),
598 'channel': (ChannelSearchPath, ChannelPin),
599 'git': (GitSearchPath, GitPin),
0afcdb2a 600 'symlink': (SymlinkSearchPath, SymlinkPin),
7f4c3ace 601 }
567a6783
SW
602 SP, P = mapping[conf['type']]
603 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
604 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
605 # Error suppression works around https://github.com/python/mypy/issues/9007
9e4ad890 606 pin_present = pin_fields or P._fields == ()
530104d7 607 pin = P(**pin_fields) if pin_present else None # type: ignore
567a6783 608 return SP(**remaining_fields), pin
f8f5b125
SW
609
610
e8bd4979
SW
611def read_pinned_config_section(
612 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
613 sp, pin = read_config_section(conf)
614 if pin is None:
352cba96 615 raise RuntimeError(
cb28d8e5 616 f'Cannot update unpinned channel "{section}" (Run "pin" before "update")')
e8bd4979
SW
617 return sp, pin
618
619
01ba0eb2
SW
620def read_config(filename: str) -> configparser.ConfigParser:
621 config = configparser.ConfigParser()
fb27ccc7
SW
622 with open(filename, encoding='utf-8') as f:
623 config.read_file(f, filename)
01ba0eb2
SW
624 return config
625
626
4603b1a7
SW
627def read_config_files(
628 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
629 merged_config: Dict[str, configparser.SectionProxy] = {}
630 for file in filenames:
631 config = read_config(file)
632 for section in config.sections():
633 if section in merged_config:
352cba96 634 raise RuntimeError('Duplicate channel "{section}"')
4603b1a7
SW
635 merged_config[section] = config[section]
636 return merged_config
637
638
41b87c9c 639def pinCommand(args: argparse.Namespace) -> None:
2f96f32a 640 v = Verification()
01ba0eb2 641 config = read_config(args.channels_file)
5cfa8e11 642 for section in config.sections():
98853153
SW
643 if args.channels and section not in args.channels:
644 continue
736c25eb 645
d815b199 646 sp, old_pin = read_config_section(config[section])
17906b27 647
567a6783 648 config[section].update(sp.pin(v, old_pin)._asdict())
8fca6c28 649
fb27ccc7 650 with open(args.channels_file, 'w', encoding='utf-8') as configfile:
e434d96d 651 config.write(configfile)
2f96f32a
SW
652
653
41b87c9c 654def updateCommand(args: argparse.Namespace) -> None:
736c25eb 655 v = Verification()
da135b07 656 exprs: Dict[str, str] = {}
3b2117a3
SW
657 profile_manifest = os.path.join(args.profile, "manifest.nix")
658 search_paths: List[str] = [
659 "-I", "pinch_profile=" + args.profile,
660 "-I", "pinch_profile_manifest=" + os.readlink(profile_manifest)
661 ] if os.path.exists(profile_manifest) else []
9d2c406b
SW
662 config = {
663 section: read_pinned_config_section(section, conf) for section,
664 conf in read_config_files(
665 args.channels_file).items()}
666 alias, nonalias = partition_dict(
667 lambda k, v: isinstance(v[0], AliasSearchPath), config)
668
436195f0 669 for section, (sp, pin) in sorted(nonalias.items()):
9d2c406b
SW
670 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
671 assert not isinstance(pin, AliasPin) # partition_dict()
567a6783 672 tarball = sp.fetch(v, pin)
cb28d8e5
SW
673 search_paths.extend(
674 ["-I", f"pinch_tarball_for_{pin.release_name}={tarball}"])
4603b1a7 675 exprs[section] = (
cb28d8e5
SW
676 f'f: f {{ name = "{pin.release_name}"; channelName = "%s"; '
677 f'src = builtins.storePath "{tarball}"; }}')
4603b1a7 678
9d2c406b
SW
679 for section, (sp, pin) in alias.items():
680 assert isinstance(sp, AliasSearchPath) # For mypy
681 exprs[section] = exprs[sp.alias_of]
17906b27 682
9a78329e
SW
683 command = [
684 'nix-env',
685 '--profile',
9e8ed0ed 686 args.profile,
9a78329e
SW
687 '--show-trace',
688 '--file',
689 '<nix/unpack-channel.nix>',
690 '--install',
fc168c34 691 '--remove-all',
436195f0
SW
692 ] + search_paths + ['--from-expression'] + [
693 exprs[name] % name for name in sorted(exprs.keys())]
9a78329e
SW
694 if args.dry_run:
695 print(' '.join(map(shlex.quote, command)))
696 else:
697 v.status('Installing channels with nix-env')
698 process = subprocess.run(command)
699 v.result(process.returncode == 0)
736c25eb
SW
700
701
0e5e611d
SW
702def main() -> None:
703 parser = argparse.ArgumentParser(prog='pinch')
704 subparsers = parser.add_subparsers(dest='mode', required=True)
705 parser_pin = subparsers.add_parser('pin')
706 parser_pin.add_argument('channels_file', type=str)
98853153 707 parser_pin.add_argument('channels', type=str, nargs='*')
41b87c9c 708 parser_pin.set_defaults(func=pinCommand)
736c25eb 709 parser_update = subparsers.add_parser('update')
9a78329e 710 parser_update.add_argument('--dry-run', action='store_true')
9e8ed0ed 711 parser_update.add_argument('--profile', default=(
cb28d8e5 712 f'/nix/var/nix/profiles/per-user/{getpass.getuser()}/channels'))
01ba0eb2 713 parser_update.add_argument('channels_file', type=str, nargs='+')
41b87c9c 714 parser_update.set_defaults(func=updateCommand)
0e5e611d
SW
715 args = parser.parse_args()
716 args.func(args)
717
718
b5964ec3
SW
719if __name__ == '__main__':
720 main()