]> git.scottworley.com Git - pinch/blame - pinch.py
Show which files are unexpectedly incomparable
[pinch] / pinch.py
CommitLineData
05a3d87b
SW
1# pinch: PIN CHannels - a replacement for `nix-channel --update`
2#
3# This program is free software: you can redistribute it and/or modify it
4# under the terms of the GNU General Public License as published by the
5# Free Software Foundation, version 3.
6
7
0e5e611d 8import argparse
f15e458d 9import configparser
2f96f32a
SW
10import filecmp
11import functools
736c25eb 12import getpass
2f96f32a
SW
13import hashlib
14import operator
15import os
16import os.path
9a78329e 17import shlex
2f96f32a 18import shutil
73bec7e8 19import subprocess
9d7844bb 20import sys
0afcdb2a 21import tarfile
2f96f32a 22import tempfile
89e79125 23import types
2f96f32a
SW
24import urllib.parse
25import urllib.request
26import xml.dom.minidom
27
28from typing import (
9d2c406b 29 Callable,
2f96f32a
SW
30 Dict,
31 Iterable,
32 List,
7f4c3ace 33 Mapping,
0a4ff8dd 34 NamedTuple,
73bec7e8 35 NewType,
d7cfdb22 36 Optional,
567a6783 37 Set,
2f96f32a 38 Tuple,
7f4c3ace 39 Type,
567a6783 40 TypeVar,
0a4ff8dd 41 Union,
2f96f32a
SW
42)
43
d06918bc
SW
44import git_cache
45
3603dde2
SW
46# Use xdg module when it's less painful to have as a dependency
47
48
9f936f16 49class XDG(NamedTuple):
3603dde2
SW
50 XDG_CACHE_HOME: str
51
52
53xdg = XDG(
54 XDG_CACHE_HOME=os.getenv(
55 'XDG_CACHE_HOME',
56 os.path.expanduser('~/.cache')))
26125a28
SW
57
58
2f96f32a
SW
59class VerificationError(Exception):
60 pass
61
62
63class Verification:
64
65 def __init__(self) -> None:
66 self.line_length = 0
67
68 def status(self, s: str) -> None:
9d7844bb 69 print(s, end=' ', file=sys.stderr, flush=True)
2f96f32a
SW
70 self.line_length += 1 + len(s) # Unicode??
71
72 @staticmethod
73 def _color(s: str, c: int) -> str:
cb28d8e5 74 return f'\033[{c:2d}m{s}\033[00m'
2f96f32a
SW
75
76 def result(self, r: bool) -> None:
77 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
78 length = len(message)
dd1026fe 79 cols = shutil.get_terminal_size().columns or 80
2f96f32a 80 pad = (cols - (self.line_length + length)) % cols
9d7844bb 81 print(' ' * pad + self._color(message, color), file=sys.stderr)
2f96f32a
SW
82 self.line_length = 0
83 if not r:
84 raise VerificationError()
85
86 def check(self, s: str, r: bool) -> None:
87 self.status(s)
88 self.result(r)
89
90 def ok(self) -> None:
91 self.result(True)
92
93
faff8642
SW
94Digest16 = NewType('Digest16', str)
95Digest32 = NewType('Digest32', str)
96
97
98class ChannelTableEntry(types.SimpleNamespace):
99 absolute_url: str
100 digest: Digest16
101 file: str
102 size: int
103 url: str
104
105
0a4ff8dd
SW
106class AliasPin(NamedTuple):
107 pass
108
109
0afcdb2a
SW
110class SymlinkPin(NamedTuple):
111 @property
112 def release_name(self) -> str:
113 return 'link'
114
115
0a4ff8dd
SW
116class GitPin(NamedTuple):
117 git_revision: str
faff8642
SW
118 release_name: str
119
0a4ff8dd
SW
120
121class ChannelPin(NamedTuple):
122 git_revision: str
123 release_name: str
124 tarball_url: str
125 tarball_sha256: str
126
127
0afcdb2a
SW
128Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
129
130
131def copy_to_nix_store(v: Verification, filename: str) -> str:
132 v.status('Putting tarball in Nix store')
133 process = subprocess.run(
134 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
135 v.result(process.returncode == 0)
530104d7 136 return process.stdout.decode().strip() # type: ignore # (for old mypy)
0a4ff8dd
SW
137
138
96063a51
SW
139def symlink_archive(v: Verification, path: str) -> str:
140 with tempfile.TemporaryDirectory() as td:
141 archive_filename = os.path.join(td, 'link.tar.gz')
142 os.symlink(path, os.path.join(td, 'link'))
143 with tarfile.open(archive_filename, mode='x:gz') as t:
144 t.add(os.path.join(td, 'link'), arcname='link')
145 return copy_to_nix_store(v, archive_filename)
146
147
567a6783 148class AliasSearchPath(NamedTuple):
faff8642 149 alias_of: str
2fa9cbea 150
567a6783 151 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
0a4ff8dd 152 return AliasPin()
2fa9cbea
SW
153
154
0afcdb2a
SW
155class SymlinkSearchPath(NamedTuple):
156 path: str
157
0afcdb2a
SW
158 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
159 return SymlinkPin()
160
161 def fetch(self, v: Verification, _: Pin) -> str:
96063a51 162 return symlink_archive(v, self.path)
0afcdb2a
SW
163
164
567a6783 165class GitSearchPath(NamedTuple):
faff8642
SW
166 git_ref: str
167 git_repo: str
faff8642 168
567a6783 169 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
d06918bc 170 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
567a6783
SW
171 if old_pin is not None:
172 assert isinstance(old_pin, GitPin)
d06918bc 173 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
d7cfdb22
SW
174 return GitPin(release_name=git_revision_name(v, self, new_revision),
175 git_revision=new_revision)
4a82be40 176
567a6783
SW
177 def fetch(self, v: Verification, pin: Pin) -> str:
178 assert isinstance(pin, GitPin)
d06918bc
SW
179 git_cache.ensure_rev_available(
180 self.git_repo, self.git_ref, pin.git_revision)
567a6783 181 return git_get_tarball(v, self, pin)
b3ea39b8 182
b3ea39b8 183
567a6783
SW
184class ChannelSearchPath(NamedTuple):
185 channel_url: str
186 git_ref: str
187 git_repo: str
4a82be40 188
567a6783
SW
189 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
190 if old_pin is not None:
191 assert isinstance(old_pin, ChannelPin)
a67cfec9 192
567a6783
SW
193 channel_html, forwarded_url = fetch_channel(v, self)
194 table, new_gitpin = parse_channel(v, channel_html)
53a27350
SW
195 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
196 return old_pin
567a6783 197 fetch_resources(v, new_gitpin, forwarded_url, table)
d06918bc
SW
198 git_cache.ensure_rev_available(
199 self.git_repo, self.git_ref, new_gitpin.git_revision)
200 if old_pin is not None:
201 verify_git_ancestry(
202 v, self, old_pin.git_revision, new_gitpin.git_revision)
567a6783 203 check_channel_contents(v, self, table, new_gitpin)
0a4ff8dd 204 return ChannelPin(
b17278e3 205 release_name=new_gitpin.release_name,
567a6783
SW
206 tarball_url=table['nixexprs.tar.xz'].absolute_url,
207 tarball_sha256=table['nixexprs.tar.xz'].digest,
3258ff2c 208 git_revision=new_gitpin.git_revision)
4a82be40 209
567a6783
SW
210 def fetch(self, v: Verification, pin: Pin) -> str:
211 assert isinstance(pin, ChannelPin)
b3ea39b8
SW
212
213 return fetch_with_nix_prefetch_url(
567a6783
SW
214 v, pin.tarball_url, Digest16(pin.tarball_sha256))
215
216
0afcdb2a
SW
217SearchPath = Union[AliasSearchPath,
218 SymlinkSearchPath,
219 GitSearchPath,
220 ChannelSearchPath]
567a6783 221TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
b3ea39b8 222
4a82be40 223
dc038df0 224def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
2f96f32a
SW
225
226 def throw(error: OSError) -> None:
227 raise error
228
229 def join(x: str, y: str) -> str:
230 return y if x == '.' else os.path.join(x, y)
231
232 def recursive_files(d: str) -> Iterable[str]:
233 all_files: List[str] = []
234 for path, dirs, files in os.walk(d, onerror=throw):
235 rel = os.path.relpath(path, start=d)
236 all_files.extend(join(rel, f) for f in files)
237 for dir_or_link in dirs:
238 if os.path.islink(join(path, dir_or_link)):
239 all_files.append(join(rel, dir_or_link))
240 return all_files
241
242 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
243 return (f for f in files if not f.startswith('.git/'))
244
245 files = functools.reduce(
246 operator.or_, (set(
247 exclude_dot_git(
248 recursive_files(x))) for x in [a, b]))
249 return filecmp.cmpfiles(a, b, files, shallow=False)
250
251
567a6783
SW
252def fetch_channel(
253 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
cb28d8e5 254 v.status(f'Fetching channel from {channel.channel_url}')
fb27ccc7
SW
255 with urllib.request.urlopen(channel.channel_url, timeout=10) as request:
256 channel_html = request.read().decode()
257 forwarded_url = request.geturl()
258 v.result(request.status == 200)
567a6783
SW
259 v.check('Got forwarded', channel.channel_url != forwarded_url)
260 return channel_html, forwarded_url
2f96f32a
SW
261
262
567a6783
SW
263def parse_channel(v: Verification, channel_html: str) \
264 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
2f96f32a 265 v.status('Parsing channel description as XML')
567a6783 266 d = xml.dom.minidom.parseString(channel_html)
2f96f32a
SW
267 v.ok()
268
3e6421c4
SW
269 v.status('Extracting release name:')
270 title_name = d.getElementsByTagName(
271 'title')[0].firstChild.nodeValue.split()[2]
272 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
273 v.status(title_name)
274 v.result(title_name == h1_name)
3e6421c4
SW
275
276 v.status('Extracting git commit:')
2f96f32a 277 git_commit_node = d.getElementsByTagName('tt')[0]
3258ff2c
SW
278 git_revision = git_commit_node.firstChild.nodeValue
279 v.status(git_revision)
2f96f32a
SW
280 v.ok()
281 v.status('Verifying git commit label')
282 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
283
284 v.status('Parsing table')
567a6783 285 table: Dict[str, ChannelTableEntry] = {}
2f96f32a
SW
286 for row in d.getElementsByTagName('tr')[1:]:
287 name = row.childNodes[0].firstChild.firstChild.nodeValue
288 url = row.childNodes[0].firstChild.getAttribute('href')
289 size = int(row.childNodes[1].firstChild.nodeValue)
73bec7e8 290 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
567a6783 291 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
2f96f32a 292 v.ok()
567a6783 293 return table, GitPin(release_name=title_name, git_revision=git_revision)
2f96f32a
SW
294
295
dc038df0
SW
296def digest_string(s: bytes) -> Digest16:
297 return Digest16(hashlib.sha256(s).hexdigest())
298
299
73bec7e8
SW
300def digest_file(filename: str) -> Digest16:
301 hasher = hashlib.sha256()
302 with open(filename, 'rb') as f:
303 # pylint: disable=cell-var-from-loop
304 for block in iter(lambda: f.read(4096), b''):
305 hasher.update(block)
306 return Digest16(hasher.hexdigest())
307
308
b1e8c1b0
SW
309@functools.lru_cache
310def _experimental_flag_needed(v: Verification) -> bool:
311 v.status('Checking Nix version')
312 process = subprocess.run(['nix', '--help'], stdout=subprocess.PIPE)
313 v.result(process.returncode == 0)
314 return b'--experimental-features' in process.stdout
1d48f551
SW
315
316
317def _nix_command(v: Verification) -> List[str]:
1d48f551 318 return ['nix', '--experimental-features',
b1e8c1b0 319 'nix-command'] if _experimental_flag_needed(v) else ['nix']
1d48f551
SW
320
321
73bec7e8
SW
322def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
323 v.status('Converting digest to base16')
1d48f551
SW
324 process = subprocess.run(_nix_command(v) + [
325 'to-base16',
326 '--type',
327 'sha256',
328 digest32],
329 stdout=subprocess.PIPE)
73bec7e8
SW
330 v.result(process.returncode == 0)
331 return Digest16(process.stdout.decode().strip())
332
333
334def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
335 v.status('Converting digest to base32')
1d48f551
SW
336 process = subprocess.run(_nix_command(v) + [
337 'to-base32',
338 '--type',
339 'sha256',
340 digest16],
341 stdout=subprocess.PIPE)
73bec7e8
SW
342 v.result(process.returncode == 0)
343 return Digest32(process.stdout.decode().strip())
344
345
346def fetch_with_nix_prefetch_url(
347 v: Verification,
348 url: str,
349 digest: Digest16) -> str:
cb28d8e5 350 v.status(f'Fetching {url}')
73bec7e8 351 process = subprocess.run(
ba596fc0 352 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
73bec7e8
SW
353 v.result(process.returncode == 0)
354 prefetch_digest, path, empty = process.stdout.decode().split('\n')
355 assert empty == ''
356 v.check("Verifying nix-prefetch-url's digest",
357 to_Digest16(v, Digest32(prefetch_digest)) == digest)
cb28d8e5 358 v.status(f"Verifying digest of {path}")
73bec7e8
SW
359 file_digest = digest_file(path)
360 v.result(file_digest == digest)
7c4de64c 361 return path # type: ignore # (for old mypy)
2f96f32a 362
73bec7e8 363
9343cf48
SW
364def fetch_resources(
365 v: Verification,
567a6783
SW
366 pin: GitPin,
367 forwarded_url: str,
368 table: Dict[str, ChannelTableEntry]) -> None:
2f96f32a 369 for resource in ['git-revision', 'nixexprs.tar.xz']:
567a6783
SW
370 fields = table[resource]
371 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
e434d96d
SW
372 fields.file = fetch_with_nix_prefetch_url(
373 v, fields.absolute_url, fields.digest)
73bec7e8 374 v.status('Verifying git commit on main page matches git commit in table')
fb27ccc7
SW
375 with open(table['git-revision'].file, encoding='utf-8') as rev_file:
376 v.result(rev_file.read(999) == pin.git_revision)
2f96f32a 377
971d3659 378
b17278e3 379def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
eb0c6f1b
SW
380 return os.path.join(
381 xdg.XDG_CACHE_HOME,
382 'pinch/git-tarball',
cb28d8e5 383 f'{digest_string(channel.git_repo.encode())}-{pin.git_revision}-{pin.release_name}')
971d3659
SW
384
385
d7cfdb22
SW
386def verify_git_ancestry(
387 v: Verification,
388 channel: TarrableSearchPath,
d06918bc
SW
389 old_revision: str,
390 new_revision: str) -> None:
391 cachedir = git_cache.git_cachedir(channel.git_repo)
cb28d8e5 392 v.status(f'Verifying rev is an ancestor of previous rev {old_revision}')
971d3659
SW
393 process = subprocess.run(['git',
394 '-C',
395 cachedir,
396 'merge-base',
397 '--is-ancestor',
d06918bc
SW
398 old_revision,
399 new_revision])
971d3659
SW
400 v.result(process.returncode == 0)
401
dc038df0 402
925c801b
SW
403def compare_tarball_and_git(
404 v: Verification,
a72fdca9 405 pin: GitPin,
925c801b
SW
406 channel_contents: str,
407 git_contents: str) -> None:
408 v.status('Comparing channel tarball with git checkout')
409 match, mismatch, errors = compare(os.path.join(
a72fdca9 410 channel_contents, pin.release_name), git_contents)
925c801b 411 v.ok()
cb28d8e5
SW
412 v.check(f'{len(match)} files match', len(match) > 0)
413 v.check(f'{len(mismatch)} files differ', len(mismatch) == 0)
925c801b
SW
414 expected_errors = [
415 '.git-revision',
416 '.version-suffix',
417 'nixpkgs',
418 'programs.sqlite',
419 'svn-revision']
420 benign_errors = []
421 for ee in expected_errors:
422 if ee in errors:
423 errors.remove(ee)
424 benign_errors.append(ee)
ee11f936
SW
425 v.check(
426 f'{len(errors)} unexpected incomparable files: {errors}',
427 len(errors) == 0)
925c801b 428 v.check(
cb28d8e5 429 f'({len(benign_errors)} of {len(expected_errors)} expected incomparable files)',
925c801b
SW
430 len(benign_errors) == len(expected_errors))
431
432
7d2c278f
SW
433def extract_tarball(
434 v: Verification,
567a6783 435 table: Dict[str, ChannelTableEntry],
7d2c278f 436 dest: str) -> None:
cb28d8e5 437 v.status(f"Extracting tarball {table['nixexprs.tar.xz'].file}")
567a6783 438 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
925c801b
SW
439 v.ok()
440
441
7d2c278f
SW
442def git_checkout(
443 v: Verification,
444 channel: TarrableSearchPath,
3258ff2c 445 pin: GitPin,
7d2c278f 446 dest: str) -> None:
925c801b 447 v.status('Checking out corresponding git revision')
fb27ccc7
SW
448 with subprocess.Popen(
449 ['git', '-C', git_cache.git_cachedir(channel.git_repo), 'archive', pin.git_revision],
450 stdout=subprocess.PIPE) as git:
451 with subprocess.Popen(['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout) as tar:
452 if git.stdout:
453 git.stdout.close()
454 tar.wait()
455 git.wait()
456 v.result(git.returncode == 0 and tar.returncode == 0)
925c801b
SW
457
458
9343cf48
SW
459def git_get_tarball(
460 v: Verification,
461 channel: TarrableSearchPath,
462 pin: GitPin) -> str:
b17278e3 463 cache_file = tarball_cache_file(channel, pin)
eb0c6f1b 464 if os.path.exists(cache_file):
fb27ccc7
SW
465 with open(cache_file, encoding='utf-8') as f:
466 cached_tarball = f.read(9999)
467 if os.path.exists(cached_tarball):
468 return cached_tarball
eb0c6f1b 469
736c25eb
SW
470 with tempfile.TemporaryDirectory() as output_dir:
471 output_filename = os.path.join(
9343cf48 472 output_dir, pin.release_name + '.tar.xz')
fb27ccc7 473 with open(output_filename, 'w', encoding='utf-8') as output_file:
cb28d8e5 474 v.status(f'Generating tarball for git revision {pin.git_revision}')
fb27ccc7
SW
475 with subprocess.Popen(
476 ['git', '-C', git_cache.git_cachedir(channel.git_repo),
cb28d8e5 477 'archive', f'--prefix={pin.release_name}/', pin.git_revision],
fb27ccc7
SW
478 stdout=subprocess.PIPE) as git:
479 with subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file) as xz:
480 xz.wait()
481 git.wait()
482 v.result(git.returncode == 0 and xz.returncode == 0)
736c25eb 483
0afcdb2a 484 store_tarball = copy_to_nix_store(v, output_filename)
eb0c6f1b
SW
485
486 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
fb27ccc7
SW
487 with open(cache_file, 'w', encoding='utf-8') as f:
488 f.write(store_tarball)
7c4de64c 489 return store_tarball # type: ignore # (for old mypy)
736c25eb
SW
490
491
f9cd7bdc
SW
492def check_channel_metadata(
493 v: Verification,
a72fdca9 494 pin: GitPin,
f9cd7bdc
SW
495 channel_contents: str) -> None:
496 v.status('Verifying git commit in channel tarball')
fb27ccc7
SW
497 with open(os.path.join(channel_contents, pin.release_name, '.git-revision'),
498 encoding='utf-8') as f:
499 v.result(f.read(999) == pin.git_revision)
f9cd7bdc
SW
500
501 v.status(
cb28d8e5 502 f'Verifying version-suffix is a suffix of release name {pin.release_name}:')
fb27ccc7
SW
503 with open(os.path.join(channel_contents, pin.release_name, '.version-suffix'),
504 encoding='utf-8') as f:
505 version_suffix = f.read(999)
f9cd7bdc 506 v.status(version_suffix)
a72fdca9 507 v.result(pin.release_name.endswith(version_suffix))
f9cd7bdc
SW
508
509
7d2c278f
SW
510def check_channel_contents(
511 v: Verification,
a72fdca9 512 channel: TarrableSearchPath,
567a6783 513 table: Dict[str, ChannelTableEntry],
a72fdca9 514 pin: GitPin) -> None:
dc038df0
SW
515 with tempfile.TemporaryDirectory() as channel_contents, \
516 tempfile.TemporaryDirectory() as git_contents:
925c801b 517
567a6783 518 extract_tarball(v, table, channel_contents)
a72fdca9 519 check_channel_metadata(v, pin, channel_contents)
f9cd7bdc 520
3258ff2c 521 git_checkout(v, channel, pin, git_contents)
925c801b 522
a72fdca9 523 compare_tarball_and_git(v, pin, channel_contents, git_contents)
925c801b 524
dc038df0 525 v.status('Removing temporary directories')
2f96f32a
SW
526 v.ok()
527
528
d7cfdb22
SW
529def git_revision_name(
530 v: Verification,
531 channel: TarrableSearchPath,
532 git_revision: str) -> str:
e3cae769
SW
533 v.status('Getting commit date')
534 process = subprocess.run(['git',
535 '-C',
d06918bc 536 git_cache.git_cachedir(channel.git_repo),
bed32182 537 'log',
e3cae769
SW
538 '-n1',
539 '--format=%ct-%h',
540 '--abbrev=11',
88af5903 541 '--no-show-signature',
d7cfdb22 542 git_revision],
ba596fc0 543 stdout=subprocess.PIPE)
de68382a 544 v.result(process.returncode == 0 and process.stdout != b'')
cb28d8e5 545 return f'{os.path.basename(channel.git_repo)}-{process.stdout.decode().strip()}'
e3cae769
SW
546
547
567a6783
SW
548K = TypeVar('K')
549V = TypeVar('V')
550
551
9d2c406b
SW
552def partition_dict(pred: Callable[[K, V], bool],
553 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
567a6783
SW
554 selected: Dict[K, V] = {}
555 remaining: Dict[K, V] = {}
556 for k, v in d.items():
9d2c406b 557 if pred(k, v):
567a6783
SW
558 selected[k] = v
559 else:
560 remaining[k] = v
561 return selected, remaining
562
563
9d2c406b
SW
564def filter_dict(d: Dict[K, V], fields: Set[K]
565 ) -> Tuple[Dict[K, V], Dict[K, V]]:
566 return partition_dict(lambda k, v: k in fields, d)
567
568
d815b199 569def read_config_section(
567a6783
SW
570 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
571 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
572 'alias': (AliasSearchPath, AliasPin),
573 'channel': (ChannelSearchPath, ChannelPin),
574 'git': (GitSearchPath, GitPin),
0afcdb2a 575 'symlink': (SymlinkSearchPath, SymlinkPin),
7f4c3ace 576 }
567a6783
SW
577 SP, P = mapping[conf['type']]
578 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
579 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
580 # Error suppression works around https://github.com/python/mypy/issues/9007
9e4ad890 581 pin_present = pin_fields or P._fields == ()
530104d7 582 pin = P(**pin_fields) if pin_present else None # type: ignore
567a6783 583 return SP(**remaining_fields), pin
f8f5b125
SW
584
585
e8bd4979
SW
586def read_pinned_config_section(
587 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
588 sp, pin = read_config_section(conf)
589 if pin is None:
352cba96 590 raise RuntimeError(
cb28d8e5 591 f'Cannot update unpinned channel "{section}" (Run "pin" before "update")')
e8bd4979
SW
592 return sp, pin
593
594
01ba0eb2
SW
595def read_config(filename: str) -> configparser.ConfigParser:
596 config = configparser.ConfigParser()
fb27ccc7
SW
597 with open(filename, encoding='utf-8') as f:
598 config.read_file(f, filename)
01ba0eb2
SW
599 return config
600
601
4603b1a7
SW
602def read_config_files(
603 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
604 merged_config: Dict[str, configparser.SectionProxy] = {}
605 for file in filenames:
606 config = read_config(file)
607 for section in config.sections():
608 if section in merged_config:
352cba96 609 raise RuntimeError('Duplicate channel "{section}"')
4603b1a7
SW
610 merged_config[section] = config[section]
611 return merged_config
612
613
41b87c9c 614def pinCommand(args: argparse.Namespace) -> None:
2f96f32a 615 v = Verification()
01ba0eb2 616 config = read_config(args.channels_file)
5cfa8e11 617 for section in config.sections():
98853153
SW
618 if args.channels and section not in args.channels:
619 continue
736c25eb 620
d815b199 621 sp, old_pin = read_config_section(config[section])
17906b27 622
567a6783 623 config[section].update(sp.pin(v, old_pin)._asdict())
8fca6c28 624
fb27ccc7 625 with open(args.channels_file, 'w', encoding='utf-8') as configfile:
e434d96d 626 config.write(configfile)
2f96f32a
SW
627
628
41b87c9c 629def updateCommand(args: argparse.Namespace) -> None:
736c25eb 630 v = Verification()
da135b07 631 exprs: Dict[str, str] = {}
3b2117a3
SW
632 profile_manifest = os.path.join(args.profile, "manifest.nix")
633 search_paths: List[str] = [
634 "-I", "pinch_profile=" + args.profile,
635 "-I", "pinch_profile_manifest=" + os.readlink(profile_manifest)
636 ] if os.path.exists(profile_manifest) else []
9d2c406b
SW
637 config = {
638 section: read_pinned_config_section(section, conf) for section,
639 conf in read_config_files(
640 args.channels_file).items()}
641 alias, nonalias = partition_dict(
642 lambda k, v: isinstance(v[0], AliasSearchPath), config)
643
436195f0 644 for section, (sp, pin) in sorted(nonalias.items()):
9d2c406b
SW
645 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
646 assert not isinstance(pin, AliasPin) # partition_dict()
567a6783 647 tarball = sp.fetch(v, pin)
cb28d8e5
SW
648 search_paths.extend(
649 ["-I", f"pinch_tarball_for_{pin.release_name}={tarball}"])
4603b1a7 650 exprs[section] = (
cb28d8e5
SW
651 f'f: f {{ name = "{pin.release_name}"; channelName = "%s"; '
652 f'src = builtins.storePath "{tarball}"; }}')
4603b1a7 653
9d2c406b
SW
654 for section, (sp, pin) in alias.items():
655 assert isinstance(sp, AliasSearchPath) # For mypy
656 exprs[section] = exprs[sp.alias_of]
17906b27 657
9a78329e
SW
658 command = [
659 'nix-env',
660 '--profile',
9e8ed0ed 661 args.profile,
9a78329e
SW
662 '--show-trace',
663 '--file',
664 '<nix/unpack-channel.nix>',
665 '--install',
fc168c34 666 '--remove-all',
436195f0
SW
667 ] + search_paths + ['--from-expression'] + [
668 exprs[name] % name for name in sorted(exprs.keys())]
9a78329e
SW
669 if args.dry_run:
670 print(' '.join(map(shlex.quote, command)))
671 else:
672 v.status('Installing channels with nix-env')
673 process = subprocess.run(command)
674 v.result(process.returncode == 0)
736c25eb
SW
675
676
0e5e611d
SW
677def main() -> None:
678 parser = argparse.ArgumentParser(prog='pinch')
679 subparsers = parser.add_subparsers(dest='mode', required=True)
680 parser_pin = subparsers.add_parser('pin')
681 parser_pin.add_argument('channels_file', type=str)
98853153 682 parser_pin.add_argument('channels', type=str, nargs='*')
41b87c9c 683 parser_pin.set_defaults(func=pinCommand)
736c25eb 684 parser_update = subparsers.add_parser('update')
9a78329e 685 parser_update.add_argument('--dry-run', action='store_true')
9e8ed0ed 686 parser_update.add_argument('--profile', default=(
cb28d8e5 687 f'/nix/var/nix/profiles/per-user/{getpass.getuser()}/channels'))
01ba0eb2 688 parser_update.add_argument('channels_file', type=str, nargs='+')
41b87c9c 689 parser_update.set_defaults(func=updateCommand)
0e5e611d
SW
690 args = parser.parse_args()
691 args.func(args)
692
693
b5964ec3
SW
694if __name__ == '__main__':
695 main()