]> git.scottworley.com Git - pinch/blame - pinch.py
Start on 3.1.0
[pinch] / pinch.py
CommitLineData
05a3d87b
SW
1# pinch: PIN CHannels - a replacement for `nix-channel --update`
2#
3# This program is free software: you can redistribute it and/or modify it
4# under the terms of the GNU General Public License as published by the
5# Free Software Foundation, version 3.
6
7
0e5e611d 8import argparse
f15e458d 9import configparser
2f96f32a
SW
10import filecmp
11import functools
736c25eb 12import getpass
2f96f32a
SW
13import hashlib
14import operator
15import os
16import os.path
9a78329e 17import shlex
2f96f32a 18import shutil
73bec7e8 19import subprocess
9d7844bb 20import sys
0afcdb2a 21import tarfile
2f96f32a 22import tempfile
89e79125 23import types
2f96f32a
SW
24import urllib.parse
25import urllib.request
26import xml.dom.minidom
27
28from typing import (
9d2c406b 29 Callable,
2f96f32a
SW
30 Dict,
31 Iterable,
32 List,
7f4c3ace 33 Mapping,
0a4ff8dd 34 NamedTuple,
73bec7e8 35 NewType,
d7cfdb22 36 Optional,
567a6783 37 Set,
2f96f32a 38 Tuple,
7f4c3ace 39 Type,
567a6783 40 TypeVar,
0a4ff8dd 41 Union,
2f96f32a
SW
42)
43
d06918bc
SW
44import git_cache
45
3603dde2
SW
46# Use xdg module when it's less painful to have as a dependency
47
48
9f936f16 49class XDG(NamedTuple):
3603dde2
SW
50 XDG_CACHE_HOME: str
51
52
53xdg = XDG(
54 XDG_CACHE_HOME=os.getenv(
55 'XDG_CACHE_HOME',
56 os.path.expanduser('~/.cache')))
26125a28
SW
57
58
2f96f32a
SW
59class VerificationError(Exception):
60 pass
61
62
63class Verification:
64
65 def __init__(self) -> None:
66 self.line_length = 0
67
68 def status(self, s: str) -> None:
9d7844bb 69 print(s, end=' ', file=sys.stderr, flush=True)
2f96f32a
SW
70 self.line_length += 1 + len(s) # Unicode??
71
72 @staticmethod
73 def _color(s: str, c: int) -> str:
cb28d8e5 74 return f'\033[{c:2d}m{s}\033[00m'
2f96f32a
SW
75
76 def result(self, r: bool) -> None:
77 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
78 length = len(message)
dd1026fe 79 cols = shutil.get_terminal_size().columns or 80
2f96f32a 80 pad = (cols - (self.line_length + length)) % cols
9d7844bb 81 print(' ' * pad + self._color(message, color), file=sys.stderr)
2f96f32a
SW
82 self.line_length = 0
83 if not r:
84 raise VerificationError()
85
86 def check(self, s: str, r: bool) -> None:
87 self.status(s)
88 self.result(r)
89
90 def ok(self) -> None:
91 self.result(True)
92
93
faff8642
SW
94Digest16 = NewType('Digest16', str)
95Digest32 = NewType('Digest32', str)
96
97
98class ChannelTableEntry(types.SimpleNamespace):
99 absolute_url: str
100 digest: Digest16
101 file: str
102 size: int
103 url: str
104
105
0a4ff8dd
SW
106class AliasPin(NamedTuple):
107 pass
108
109
0afcdb2a
SW
110class SymlinkPin(NamedTuple):
111 @property
112 def release_name(self) -> str:
113 return 'link'
114
115
0a4ff8dd
SW
116class GitPin(NamedTuple):
117 git_revision: str
faff8642
SW
118 release_name: str
119
0a4ff8dd
SW
120
121class ChannelPin(NamedTuple):
122 git_revision: str
123 release_name: str
124 tarball_url: str
125 tarball_sha256: str
126
127
0afcdb2a
SW
128Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
129
130
131def copy_to_nix_store(v: Verification, filename: str) -> str:
132 v.status('Putting tarball in Nix store')
133 process = subprocess.run(
134 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
135 v.result(process.returncode == 0)
530104d7 136 return process.stdout.decode().strip() # type: ignore # (for old mypy)
0a4ff8dd
SW
137
138
96063a51
SW
139def symlink_archive(v: Verification, path: str) -> str:
140 with tempfile.TemporaryDirectory() as td:
141 archive_filename = os.path.join(td, 'link.tar.gz')
142 os.symlink(path, os.path.join(td, 'link'))
143 with tarfile.open(archive_filename, mode='x:gz') as t:
144 t.add(os.path.join(td, 'link'), arcname='link')
145 return copy_to_nix_store(v, archive_filename)
146
147
567a6783 148class AliasSearchPath(NamedTuple):
faff8642 149 alias_of: str
2fa9cbea 150
567a6783 151 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
0a4ff8dd 152 return AliasPin()
2fa9cbea
SW
153
154
0afcdb2a
SW
155class SymlinkSearchPath(NamedTuple):
156 path: str
157
0afcdb2a
SW
158 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
159 return SymlinkPin()
160
161 def fetch(self, v: Verification, _: Pin) -> str:
96063a51 162 return symlink_archive(v, self.path)
0afcdb2a
SW
163
164
567a6783 165class GitSearchPath(NamedTuple):
faff8642
SW
166 git_ref: str
167 git_repo: str
faff8642 168
567a6783 169 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
d06918bc 170 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
567a6783
SW
171 if old_pin is not None:
172 assert isinstance(old_pin, GitPin)
d06918bc 173 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
d7cfdb22
SW
174 return GitPin(release_name=git_revision_name(v, self, new_revision),
175 git_revision=new_revision)
4a82be40 176
567a6783
SW
177 def fetch(self, v: Verification, pin: Pin) -> str:
178 assert isinstance(pin, GitPin)
d06918bc
SW
179 git_cache.ensure_rev_available(
180 self.git_repo, self.git_ref, pin.git_revision)
567a6783 181 return git_get_tarball(v, self, pin)
b3ea39b8 182
b3ea39b8 183
567a6783
SW
184class ChannelSearchPath(NamedTuple):
185 channel_url: str
186 git_ref: str
187 git_repo: str
4a82be40 188
567a6783
SW
189 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
190 if old_pin is not None:
191 assert isinstance(old_pin, ChannelPin)
a67cfec9 192
567a6783
SW
193 channel_html, forwarded_url = fetch_channel(v, self)
194 table, new_gitpin = parse_channel(v, channel_html)
53a27350
SW
195 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
196 return old_pin
567a6783 197 fetch_resources(v, new_gitpin, forwarded_url, table)
d06918bc
SW
198 git_cache.ensure_rev_available(
199 self.git_repo, self.git_ref, new_gitpin.git_revision)
200 if old_pin is not None:
201 verify_git_ancestry(
202 v, self, old_pin.git_revision, new_gitpin.git_revision)
567a6783 203 check_channel_contents(v, self, table, new_gitpin)
0a4ff8dd 204 return ChannelPin(
b17278e3 205 release_name=new_gitpin.release_name,
567a6783
SW
206 tarball_url=table['nixexprs.tar.xz'].absolute_url,
207 tarball_sha256=table['nixexprs.tar.xz'].digest,
3258ff2c 208 git_revision=new_gitpin.git_revision)
4a82be40 209
567a6783
SW
210 def fetch(self, v: Verification, pin: Pin) -> str:
211 assert isinstance(pin, ChannelPin)
b3ea39b8
SW
212
213 return fetch_with_nix_prefetch_url(
567a6783
SW
214 v, pin.tarball_url, Digest16(pin.tarball_sha256))
215
216
0afcdb2a
SW
217SearchPath = Union[AliasSearchPath,
218 SymlinkSearchPath,
219 GitSearchPath,
220 ChannelSearchPath]
567a6783 221TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
b3ea39b8 222
4a82be40 223
dc038df0 224def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
2f96f32a
SW
225
226 def throw(error: OSError) -> None:
227 raise error
228
229 def join(x: str, y: str) -> str:
230 return y if x == '.' else os.path.join(x, y)
231
232 def recursive_files(d: str) -> Iterable[str]:
233 all_files: List[str] = []
234 for path, dirs, files in os.walk(d, onerror=throw):
235 rel = os.path.relpath(path, start=d)
236 all_files.extend(join(rel, f) for f in files)
237 for dir_or_link in dirs:
238 if os.path.islink(join(path, dir_or_link)):
239 all_files.append(join(rel, dir_or_link))
240 return all_files
241
242 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
243 return (f for f in files if not f.startswith('.git/'))
244
245 files = functools.reduce(
246 operator.or_, (set(
247 exclude_dot_git(
248 recursive_files(x))) for x in [a, b]))
249 return filecmp.cmpfiles(a, b, files, shallow=False)
250
251
567a6783
SW
252def fetch_channel(
253 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
cb28d8e5 254 v.status(f'Fetching channel from {channel.channel_url}')
fb27ccc7
SW
255 with urllib.request.urlopen(channel.channel_url, timeout=10) as request:
256 channel_html = request.read().decode()
257 forwarded_url = request.geturl()
258 v.result(request.status == 200)
567a6783
SW
259 v.check('Got forwarded', channel.channel_url != forwarded_url)
260 return channel_html, forwarded_url
2f96f32a
SW
261
262
567a6783
SW
263def parse_channel(v: Verification, channel_html: str) \
264 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
2f96f32a 265 v.status('Parsing channel description as XML')
567a6783 266 d = xml.dom.minidom.parseString(channel_html)
2f96f32a
SW
267 v.ok()
268
7408d43d
SW
269 v.status('Finding release name (1)')
270 title = d.getElementsByTagName('title')[0].firstChild
271 v.result(isinstance(title, xml.dom.minidom.CharacterData))
272 assert isinstance(title, xml.dom.minidom.CharacterData)
273 release_name = title.nodeValue.split()[2]
274 v.status('Finding release name (2)')
275 h1 = d.getElementsByTagName('h1')[0].firstChild
276 v.result(isinstance(h1, xml.dom.minidom.CharacterData))
277 assert isinstance(h1, xml.dom.minidom.CharacterData)
278 v.status('Verifying release name:')
279 v.status(release_name)
280 v.result(release_name == h1.nodeValue.split()[2])
281
282 v.status('Finding git commit')
2f96f32a 283 git_commit_node = d.getElementsByTagName('tt')[0]
7408d43d
SW
284 v.result(
285 isinstance(
286 git_commit_node.firstChild,
287 xml.dom.minidom.CharacterData))
288 assert isinstance(
289 git_commit_node.firstChild,
290 xml.dom.minidom.CharacterData)
291 v.status('Extracting git commit:')
3258ff2c
SW
292 git_revision = git_commit_node.firstChild.nodeValue
293 v.status(git_revision)
2f96f32a
SW
294 v.ok()
295 v.status('Verifying git commit label')
296 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
297
298 v.status('Parsing table')
567a6783 299 table: Dict[str, ChannelTableEntry] = {}
2f96f32a
SW
300 for row in d.getElementsByTagName('tr')[1:]:
301 name = row.childNodes[0].firstChild.firstChild.nodeValue
302 url = row.childNodes[0].firstChild.getAttribute('href')
303 size = int(row.childNodes[1].firstChild.nodeValue)
73bec7e8 304 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
567a6783 305 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
2f96f32a 306 v.ok()
7408d43d 307 return table, GitPin(release_name=release_name, git_revision=git_revision)
2f96f32a
SW
308
309
dc038df0
SW
310def digest_string(s: bytes) -> Digest16:
311 return Digest16(hashlib.sha256(s).hexdigest())
312
313
73bec7e8
SW
314def digest_file(filename: str) -> Digest16:
315 hasher = hashlib.sha256()
316 with open(filename, 'rb') as f:
317 # pylint: disable=cell-var-from-loop
318 for block in iter(lambda: f.read(4096), b''):
319 hasher.update(block)
320 return Digest16(hasher.hexdigest())
321
322
b1e8c1b0
SW
323@functools.lru_cache
324def _experimental_flag_needed(v: Verification) -> bool:
325 v.status('Checking Nix version')
326 process = subprocess.run(['nix', '--help'], stdout=subprocess.PIPE)
327 v.result(process.returncode == 0)
328 return b'--experimental-features' in process.stdout
1d48f551
SW
329
330
331def _nix_command(v: Verification) -> List[str]:
1d48f551 332 return ['nix', '--experimental-features',
b1e8c1b0 333 'nix-command'] if _experimental_flag_needed(v) else ['nix']
1d48f551
SW
334
335
73bec7e8
SW
336def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
337 v.status('Converting digest to base16')
1d48f551
SW
338 process = subprocess.run(_nix_command(v) + [
339 'to-base16',
340 '--type',
341 'sha256',
342 digest32],
343 stdout=subprocess.PIPE)
73bec7e8
SW
344 v.result(process.returncode == 0)
345 return Digest16(process.stdout.decode().strip())
346
347
348def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
349 v.status('Converting digest to base32')
1d48f551
SW
350 process = subprocess.run(_nix_command(v) + [
351 'to-base32',
352 '--type',
353 'sha256',
354 digest16],
355 stdout=subprocess.PIPE)
73bec7e8
SW
356 v.result(process.returncode == 0)
357 return Digest32(process.stdout.decode().strip())
358
359
360def fetch_with_nix_prefetch_url(
361 v: Verification,
362 url: str,
363 digest: Digest16) -> str:
cb28d8e5 364 v.status(f'Fetching {url}')
73bec7e8 365 process = subprocess.run(
ba596fc0 366 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
73bec7e8
SW
367 v.result(process.returncode == 0)
368 prefetch_digest, path, empty = process.stdout.decode().split('\n')
369 assert empty == ''
370 v.check("Verifying nix-prefetch-url's digest",
371 to_Digest16(v, Digest32(prefetch_digest)) == digest)
cb28d8e5 372 v.status(f"Verifying digest of {path}")
73bec7e8
SW
373 file_digest = digest_file(path)
374 v.result(file_digest == digest)
7c4de64c 375 return path # type: ignore # (for old mypy)
2f96f32a 376
73bec7e8 377
9343cf48
SW
378def fetch_resources(
379 v: Verification,
567a6783
SW
380 pin: GitPin,
381 forwarded_url: str,
382 table: Dict[str, ChannelTableEntry]) -> None:
2f96f32a 383 for resource in ['git-revision', 'nixexprs.tar.xz']:
567a6783
SW
384 fields = table[resource]
385 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
e434d96d
SW
386 fields.file = fetch_with_nix_prefetch_url(
387 v, fields.absolute_url, fields.digest)
73bec7e8 388 v.status('Verifying git commit on main page matches git commit in table')
fb27ccc7
SW
389 with open(table['git-revision'].file, encoding='utf-8') as rev_file:
390 v.result(rev_file.read(999) == pin.git_revision)
2f96f32a 391
971d3659 392
b17278e3 393def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
eb0c6f1b
SW
394 return os.path.join(
395 xdg.XDG_CACHE_HOME,
396 'pinch/git-tarball',
cb28d8e5 397 f'{digest_string(channel.git_repo.encode())}-{pin.git_revision}-{pin.release_name}')
971d3659
SW
398
399
d7cfdb22
SW
400def verify_git_ancestry(
401 v: Verification,
402 channel: TarrableSearchPath,
d06918bc
SW
403 old_revision: str,
404 new_revision: str) -> None:
405 cachedir = git_cache.git_cachedir(channel.git_repo)
cb28d8e5 406 v.status(f'Verifying rev is an ancestor of previous rev {old_revision}')
971d3659
SW
407 process = subprocess.run(['git',
408 '-C',
409 cachedir,
410 'merge-base',
411 '--is-ancestor',
d06918bc
SW
412 old_revision,
413 new_revision])
971d3659
SW
414 v.result(process.returncode == 0)
415
dc038df0 416
ac38f5a2
SW
417def broken_symlinks_are_identical(root1: str, root2: str, path: str) -> bool:
418 a = os.path.join(root1, path)
419 b = os.path.join(root2, path)
420 return (os.path.islink(a)
421 and os.path.islink(b)
422 and not os.path.exists(a)
423 and not os.path.exists(b)
424 and os.readlink(a) == os.readlink(b))
425
426
925c801b
SW
427def compare_tarball_and_git(
428 v: Verification,
a72fdca9 429 pin: GitPin,
925c801b
SW
430 channel_contents: str,
431 git_contents: str) -> None:
432 v.status('Comparing channel tarball with git checkout')
ac38f5a2
SW
433 tarball_contents = os.path.join(channel_contents, pin.release_name)
434 match, mismatch, errors = compare(tarball_contents, git_contents)
925c801b 435 v.ok()
cb28d8e5
SW
436 v.check(f'{len(match)} files match', len(match) > 0)
437 v.check(f'{len(mismatch)} files differ', len(mismatch) == 0)
925c801b
SW
438 expected_errors = [
439 '.git-revision',
440 '.version-suffix',
441 'nixpkgs',
442 'programs.sqlite',
443 'svn-revision']
5d0f42f7 444 benign_expected_errors = []
925c801b
SW
445 for ee in expected_errors:
446 if ee in errors:
447 errors.remove(ee)
5d0f42f7 448 benign_expected_errors.append(ee)
ac38f5a2
SW
449 errors = [
450 e for e in errors
451 if not broken_symlinks_are_identical(tarball_contents, git_contents, e)
452 ]
ee11f936
SW
453 v.check(
454 f'{len(errors)} unexpected incomparable files: {errors}',
455 len(errors) == 0)
925c801b 456 v.check(
5d0f42f7
SW
457 f'({len(benign_expected_errors)} of {len(expected_errors)} expected incomparable files)',
458 len(benign_expected_errors) == len(expected_errors))
925c801b
SW
459
460
7d2c278f
SW
461def extract_tarball(
462 v: Verification,
567a6783 463 table: Dict[str, ChannelTableEntry],
7d2c278f 464 dest: str) -> None:
cb28d8e5 465 v.status(f"Extracting tarball {table['nixexprs.tar.xz'].file}")
567a6783 466 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
925c801b
SW
467 v.ok()
468
469
7d2c278f
SW
470def git_checkout(
471 v: Verification,
472 channel: TarrableSearchPath,
3258ff2c 473 pin: GitPin,
7d2c278f 474 dest: str) -> None:
925c801b 475 v.status('Checking out corresponding git revision')
fb27ccc7
SW
476 with subprocess.Popen(
477 ['git', '-C', git_cache.git_cachedir(channel.git_repo), 'archive', pin.git_revision],
478 stdout=subprocess.PIPE) as git:
479 with subprocess.Popen(['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout) as tar:
480 if git.stdout:
481 git.stdout.close()
482 tar.wait()
483 git.wait()
484 v.result(git.returncode == 0 and tar.returncode == 0)
925c801b
SW
485
486
9343cf48
SW
487def git_get_tarball(
488 v: Verification,
489 channel: TarrableSearchPath,
490 pin: GitPin) -> str:
b17278e3 491 cache_file = tarball_cache_file(channel, pin)
eb0c6f1b 492 if os.path.exists(cache_file):
fb27ccc7
SW
493 with open(cache_file, encoding='utf-8') as f:
494 cached_tarball = f.read(9999)
495 if os.path.exists(cached_tarball):
496 return cached_tarball
eb0c6f1b 497
736c25eb
SW
498 with tempfile.TemporaryDirectory() as output_dir:
499 output_filename = os.path.join(
9343cf48 500 output_dir, pin.release_name + '.tar.xz')
fb27ccc7 501 with open(output_filename, 'w', encoding='utf-8') as output_file:
cb28d8e5 502 v.status(f'Generating tarball for git revision {pin.git_revision}')
fb27ccc7
SW
503 with subprocess.Popen(
504 ['git', '-C', git_cache.git_cachedir(channel.git_repo),
cb28d8e5 505 'archive', f'--prefix={pin.release_name}/', pin.git_revision],
fb27ccc7
SW
506 stdout=subprocess.PIPE) as git:
507 with subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file) as xz:
508 xz.wait()
509 git.wait()
510 v.result(git.returncode == 0 and xz.returncode == 0)
736c25eb 511
0afcdb2a 512 store_tarball = copy_to_nix_store(v, output_filename)
eb0c6f1b
SW
513
514 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
fb27ccc7
SW
515 with open(cache_file, 'w', encoding='utf-8') as f:
516 f.write(store_tarball)
7c4de64c 517 return store_tarball # type: ignore # (for old mypy)
736c25eb
SW
518
519
f9cd7bdc
SW
520def check_channel_metadata(
521 v: Verification,
a72fdca9 522 pin: GitPin,
f9cd7bdc
SW
523 channel_contents: str) -> None:
524 v.status('Verifying git commit in channel tarball')
fb27ccc7
SW
525 with open(os.path.join(channel_contents, pin.release_name, '.git-revision'),
526 encoding='utf-8') as f:
527 v.result(f.read(999) == pin.git_revision)
f9cd7bdc
SW
528
529 v.status(
cb28d8e5 530 f'Verifying version-suffix is a suffix of release name {pin.release_name}:')
fb27ccc7
SW
531 with open(os.path.join(channel_contents, pin.release_name, '.version-suffix'),
532 encoding='utf-8') as f:
533 version_suffix = f.read(999)
f9cd7bdc 534 v.status(version_suffix)
a72fdca9 535 v.result(pin.release_name.endswith(version_suffix))
f9cd7bdc
SW
536
537
7d2c278f
SW
538def check_channel_contents(
539 v: Verification,
a72fdca9 540 channel: TarrableSearchPath,
567a6783 541 table: Dict[str, ChannelTableEntry],
a72fdca9 542 pin: GitPin) -> None:
dc038df0
SW
543 with tempfile.TemporaryDirectory() as channel_contents, \
544 tempfile.TemporaryDirectory() as git_contents:
925c801b 545
567a6783 546 extract_tarball(v, table, channel_contents)
a72fdca9 547 check_channel_metadata(v, pin, channel_contents)
f9cd7bdc 548
3258ff2c 549 git_checkout(v, channel, pin, git_contents)
925c801b 550
a72fdca9 551 compare_tarball_and_git(v, pin, channel_contents, git_contents)
925c801b 552
dc038df0 553 v.status('Removing temporary directories')
2f96f32a
SW
554 v.ok()
555
556
d7cfdb22
SW
557def git_revision_name(
558 v: Verification,
559 channel: TarrableSearchPath,
560 git_revision: str) -> str:
e3cae769
SW
561 v.status('Getting commit date')
562 process = subprocess.run(['git',
563 '-C',
d06918bc 564 git_cache.git_cachedir(channel.git_repo),
bed32182 565 'log',
e3cae769
SW
566 '-n1',
567 '--format=%ct-%h',
568 '--abbrev=11',
88af5903 569 '--no-show-signature',
d7cfdb22 570 git_revision],
ba596fc0 571 stdout=subprocess.PIPE)
de68382a 572 v.result(process.returncode == 0 and process.stdout != b'')
cb28d8e5 573 return f'{os.path.basename(channel.git_repo)}-{process.stdout.decode().strip()}'
e3cae769
SW
574
575
567a6783
SW
576K = TypeVar('K')
577V = TypeVar('V')
578
579
9d2c406b
SW
580def partition_dict(pred: Callable[[K, V], bool],
581 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
567a6783
SW
582 selected: Dict[K, V] = {}
583 remaining: Dict[K, V] = {}
584 for k, v in d.items():
9d2c406b 585 if pred(k, v):
567a6783
SW
586 selected[k] = v
587 else:
588 remaining[k] = v
589 return selected, remaining
590
591
9d2c406b
SW
592def filter_dict(d: Dict[K, V], fields: Set[K]
593 ) -> Tuple[Dict[K, V], Dict[K, V]]:
594 return partition_dict(lambda k, v: k in fields, d)
595
596
d815b199 597def read_config_section(
567a6783
SW
598 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
599 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
600 'alias': (AliasSearchPath, AliasPin),
601 'channel': (ChannelSearchPath, ChannelPin),
602 'git': (GitSearchPath, GitPin),
0afcdb2a 603 'symlink': (SymlinkSearchPath, SymlinkPin),
7f4c3ace 604 }
567a6783
SW
605 SP, P = mapping[conf['type']]
606 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
607 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
608 # Error suppression works around https://github.com/python/mypy/issues/9007
9e4ad890 609 pin_present = pin_fields or P._fields == ()
530104d7 610 pin = P(**pin_fields) if pin_present else None # type: ignore
567a6783 611 return SP(**remaining_fields), pin
f8f5b125
SW
612
613
e8bd4979
SW
614def read_pinned_config_section(
615 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
616 sp, pin = read_config_section(conf)
617 if pin is None:
352cba96 618 raise RuntimeError(
cb28d8e5 619 f'Cannot update unpinned channel "{section}" (Run "pin" before "update")')
e8bd4979
SW
620 return sp, pin
621
622
01ba0eb2
SW
623def read_config(filename: str) -> configparser.ConfigParser:
624 config = configparser.ConfigParser()
fb27ccc7
SW
625 with open(filename, encoding='utf-8') as f:
626 config.read_file(f, filename)
01ba0eb2
SW
627 return config
628
629
4603b1a7
SW
630def read_config_files(
631 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
632 merged_config: Dict[str, configparser.SectionProxy] = {}
633 for file in filenames:
634 config = read_config(file)
635 for section in config.sections():
636 if section in merged_config:
352cba96 637 raise RuntimeError('Duplicate channel "{section}"')
4603b1a7
SW
638 merged_config[section] = config[section]
639 return merged_config
640
641
41b87c9c 642def pinCommand(args: argparse.Namespace) -> None:
2f96f32a 643 v = Verification()
01ba0eb2 644 config = read_config(args.channels_file)
5cfa8e11 645 for section in config.sections():
98853153
SW
646 if args.channels and section not in args.channels:
647 continue
736c25eb 648
d815b199 649 sp, old_pin = read_config_section(config[section])
17906b27 650
567a6783 651 config[section].update(sp.pin(v, old_pin)._asdict())
8fca6c28 652
fb27ccc7 653 with open(args.channels_file, 'w', encoding='utf-8') as configfile:
e434d96d 654 config.write(configfile)
2f96f32a
SW
655
656
41b87c9c 657def updateCommand(args: argparse.Namespace) -> None:
736c25eb 658 v = Verification()
da135b07 659 exprs: Dict[str, str] = {}
3b2117a3
SW
660 profile_manifest = os.path.join(args.profile, "manifest.nix")
661 search_paths: List[str] = [
662 "-I", "pinch_profile=" + args.profile,
663 "-I", "pinch_profile_manifest=" + os.readlink(profile_manifest)
664 ] if os.path.exists(profile_manifest) else []
9d2c406b
SW
665 config = {
666 section: read_pinned_config_section(section, conf) for section,
667 conf in read_config_files(
668 args.channels_file).items()}
669 alias, nonalias = partition_dict(
670 lambda k, v: isinstance(v[0], AliasSearchPath), config)
671
436195f0 672 for section, (sp, pin) in sorted(nonalias.items()):
9d2c406b
SW
673 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
674 assert not isinstance(pin, AliasPin) # partition_dict()
567a6783 675 tarball = sp.fetch(v, pin)
cb28d8e5
SW
676 search_paths.extend(
677 ["-I", f"pinch_tarball_for_{pin.release_name}={tarball}"])
4603b1a7 678 exprs[section] = (
cb28d8e5
SW
679 f'f: f {{ name = "{pin.release_name}"; channelName = "%s"; '
680 f'src = builtins.storePath "{tarball}"; }}')
4603b1a7 681
9d2c406b
SW
682 for section, (sp, pin) in alias.items():
683 assert isinstance(sp, AliasSearchPath) # For mypy
684 exprs[section] = exprs[sp.alias_of]
17906b27 685
9a78329e
SW
686 command = [
687 'nix-env',
688 '--profile',
9e8ed0ed 689 args.profile,
9a78329e
SW
690 '--show-trace',
691 '--file',
692 '<nix/unpack-channel.nix>',
693 '--install',
fc168c34 694 '--remove-all',
436195f0
SW
695 ] + search_paths + ['--from-expression'] + [
696 exprs[name] % name for name in sorted(exprs.keys())]
9a78329e
SW
697 if args.dry_run:
698 print(' '.join(map(shlex.quote, command)))
699 else:
700 v.status('Installing channels with nix-env')
701 process = subprocess.run(command)
702 v.result(process.returncode == 0)
736c25eb
SW
703
704
0e5e611d
SW
705def main() -> None:
706 parser = argparse.ArgumentParser(prog='pinch')
707 subparsers = parser.add_subparsers(dest='mode', required=True)
708 parser_pin = subparsers.add_parser('pin')
709 parser_pin.add_argument('channels_file', type=str)
98853153 710 parser_pin.add_argument('channels', type=str, nargs='*')
41b87c9c 711 parser_pin.set_defaults(func=pinCommand)
736c25eb 712 parser_update = subparsers.add_parser('update')
9a78329e 713 parser_update.add_argument('--dry-run', action='store_true')
9e8ed0ed 714 parser_update.add_argument('--profile', default=(
cb28d8e5 715 f'/nix/var/nix/profiles/per-user/{getpass.getuser()}/channels'))
01ba0eb2 716 parser_update.add_argument('channels_file', type=str, nargs='+')
41b87c9c 717 parser_update.set_defaults(func=updateCommand)
0e5e611d
SW
718 args = parser.parse_args()
719 args.func(args)
720
721
b5964ec3
SW
722if __name__ == '__main__':
723 main()