]> git.scottworley.com Git - pinch/blame - pinch.py
Specify license
[pinch] / pinch.py
CommitLineData
05a3d87b
SW
1# pinch: PIN CHannels - a replacement for `nix-channel --update`
2#
3# This program is free software: you can redistribute it and/or modify it
4# under the terms of the GNU General Public License as published by the
5# Free Software Foundation, version 3.
6
7
0e5e611d 8import argparse
f15e458d 9import configparser
2f96f32a
SW
10import filecmp
11import functools
736c25eb 12import getpass
2f96f32a
SW
13import hashlib
14import operator
15import os
16import os.path
9a78329e 17import shlex
2f96f32a 18import shutil
73bec7e8 19import subprocess
9d7844bb 20import sys
0afcdb2a 21import tarfile
2f96f32a 22import tempfile
89e79125 23import types
2f96f32a
SW
24import urllib.parse
25import urllib.request
26import xml.dom.minidom
27
28from typing import (
9d2c406b 29 Callable,
2f96f32a
SW
30 Dict,
31 Iterable,
32 List,
7f4c3ace 33 Mapping,
0a4ff8dd 34 NamedTuple,
73bec7e8 35 NewType,
d7cfdb22 36 Optional,
567a6783 37 Set,
2f96f32a 38 Tuple,
7f4c3ace 39 Type,
567a6783 40 TypeVar,
0a4ff8dd 41 Union,
2f96f32a
SW
42)
43
d06918bc
SW
44import git_cache
45
3603dde2
SW
46# Use xdg module when it's less painful to have as a dependency
47
48
9f936f16 49class XDG(NamedTuple):
3603dde2
SW
50 XDG_CACHE_HOME: str
51
52
53xdg = XDG(
54 XDG_CACHE_HOME=os.getenv(
55 'XDG_CACHE_HOME',
56 os.path.expanduser('~/.cache')))
26125a28
SW
57
58
2f96f32a
SW
59class VerificationError(Exception):
60 pass
61
62
63class Verification:
64
65 def __init__(self) -> None:
66 self.line_length = 0
67
68 def status(self, s: str) -> None:
9d7844bb 69 print(s, end=' ', file=sys.stderr, flush=True)
2f96f32a
SW
70 self.line_length += 1 + len(s) # Unicode??
71
72 @staticmethod
73 def _color(s: str, c: int) -> str:
cb28d8e5 74 return f'\033[{c:2d}m{s}\033[00m'
2f96f32a
SW
75
76 def result(self, r: bool) -> None:
77 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
78 length = len(message)
dd1026fe 79 cols = shutil.get_terminal_size().columns or 80
2f96f32a 80 pad = (cols - (self.line_length + length)) % cols
9d7844bb 81 print(' ' * pad + self._color(message, color), file=sys.stderr)
2f96f32a
SW
82 self.line_length = 0
83 if not r:
84 raise VerificationError()
85
86 def check(self, s: str, r: bool) -> None:
87 self.status(s)
88 self.result(r)
89
90 def ok(self) -> None:
91 self.result(True)
92
93
faff8642
SW
94Digest16 = NewType('Digest16', str)
95Digest32 = NewType('Digest32', str)
96
97
98class ChannelTableEntry(types.SimpleNamespace):
99 absolute_url: str
100 digest: Digest16
101 file: str
102 size: int
103 url: str
104
105
0a4ff8dd
SW
106class AliasPin(NamedTuple):
107 pass
108
109
0afcdb2a
SW
110class SymlinkPin(NamedTuple):
111 @property
112 def release_name(self) -> str:
113 return 'link'
114
115
0a4ff8dd
SW
116class GitPin(NamedTuple):
117 git_revision: str
faff8642
SW
118 release_name: str
119
0a4ff8dd
SW
120
121class ChannelPin(NamedTuple):
122 git_revision: str
123 release_name: str
124 tarball_url: str
125 tarball_sha256: str
126
127
0afcdb2a
SW
128Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
129
130
131def copy_to_nix_store(v: Verification, filename: str) -> str:
132 v.status('Putting tarball in Nix store')
133 process = subprocess.run(
134 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
135 v.result(process.returncode == 0)
530104d7 136 return process.stdout.decode().strip() # type: ignore # (for old mypy)
0a4ff8dd
SW
137
138
96063a51
SW
139def symlink_archive(v: Verification, path: str) -> str:
140 with tempfile.TemporaryDirectory() as td:
141 archive_filename = os.path.join(td, 'link.tar.gz')
142 os.symlink(path, os.path.join(td, 'link'))
143 with tarfile.open(archive_filename, mode='x:gz') as t:
144 t.add(os.path.join(td, 'link'), arcname='link')
145 return copy_to_nix_store(v, archive_filename)
146
147
567a6783 148class AliasSearchPath(NamedTuple):
faff8642 149 alias_of: str
2fa9cbea 150
567a6783
SW
151 # pylint: disable=no-self-use
152 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
0a4ff8dd 153 return AliasPin()
2fa9cbea
SW
154
155
0afcdb2a
SW
156class SymlinkSearchPath(NamedTuple):
157 path: str
158
159 # pylint: disable=no-self-use
160 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
161 return SymlinkPin()
162
163 def fetch(self, v: Verification, _: Pin) -> str:
96063a51 164 return symlink_archive(v, self.path)
0afcdb2a
SW
165
166
567a6783 167class GitSearchPath(NamedTuple):
faff8642
SW
168 git_ref: str
169 git_repo: str
faff8642 170
567a6783 171 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
d06918bc 172 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
567a6783
SW
173 if old_pin is not None:
174 assert isinstance(old_pin, GitPin)
d06918bc 175 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
d7cfdb22
SW
176 return GitPin(release_name=git_revision_name(v, self, new_revision),
177 git_revision=new_revision)
4a82be40 178
567a6783
SW
179 def fetch(self, v: Verification, pin: Pin) -> str:
180 assert isinstance(pin, GitPin)
d06918bc
SW
181 git_cache.ensure_rev_available(
182 self.git_repo, self.git_ref, pin.git_revision)
567a6783 183 return git_get_tarball(v, self, pin)
b3ea39b8 184
b3ea39b8 185
567a6783
SW
186class ChannelSearchPath(NamedTuple):
187 channel_url: str
188 git_ref: str
189 git_repo: str
4a82be40 190
567a6783
SW
191 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
192 if old_pin is not None:
193 assert isinstance(old_pin, ChannelPin)
a67cfec9 194
567a6783
SW
195 channel_html, forwarded_url = fetch_channel(v, self)
196 table, new_gitpin = parse_channel(v, channel_html)
53a27350
SW
197 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
198 return old_pin
567a6783 199 fetch_resources(v, new_gitpin, forwarded_url, table)
d06918bc
SW
200 git_cache.ensure_rev_available(
201 self.git_repo, self.git_ref, new_gitpin.git_revision)
202 if old_pin is not None:
203 verify_git_ancestry(
204 v, self, old_pin.git_revision, new_gitpin.git_revision)
567a6783 205 check_channel_contents(v, self, table, new_gitpin)
0a4ff8dd 206 return ChannelPin(
b17278e3 207 release_name=new_gitpin.release_name,
567a6783
SW
208 tarball_url=table['nixexprs.tar.xz'].absolute_url,
209 tarball_sha256=table['nixexprs.tar.xz'].digest,
3258ff2c 210 git_revision=new_gitpin.git_revision)
4a82be40 211
b3ea39b8 212 # pylint: disable=no-self-use
567a6783
SW
213 def fetch(self, v: Verification, pin: Pin) -> str:
214 assert isinstance(pin, ChannelPin)
b3ea39b8
SW
215
216 return fetch_with_nix_prefetch_url(
567a6783
SW
217 v, pin.tarball_url, Digest16(pin.tarball_sha256))
218
219
0afcdb2a
SW
220SearchPath = Union[AliasSearchPath,
221 SymlinkSearchPath,
222 GitSearchPath,
223 ChannelSearchPath]
567a6783 224TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
b3ea39b8 225
4a82be40 226
dc038df0 227def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
2f96f32a
SW
228
229 def throw(error: OSError) -> None:
230 raise error
231
232 def join(x: str, y: str) -> str:
233 return y if x == '.' else os.path.join(x, y)
234
235 def recursive_files(d: str) -> Iterable[str]:
236 all_files: List[str] = []
237 for path, dirs, files in os.walk(d, onerror=throw):
238 rel = os.path.relpath(path, start=d)
239 all_files.extend(join(rel, f) for f in files)
240 for dir_or_link in dirs:
241 if os.path.islink(join(path, dir_or_link)):
242 all_files.append(join(rel, dir_or_link))
243 return all_files
244
245 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
246 return (f for f in files if not f.startswith('.git/'))
247
248 files = functools.reduce(
249 operator.or_, (set(
250 exclude_dot_git(
251 recursive_files(x))) for x in [a, b]))
252 return filecmp.cmpfiles(a, b, files, shallow=False)
253
254
567a6783
SW
255def fetch_channel(
256 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
cb28d8e5 257 v.status(f'Fetching channel from {channel.channel_url}')
fb27ccc7
SW
258 with urllib.request.urlopen(channel.channel_url, timeout=10) as request:
259 channel_html = request.read().decode()
260 forwarded_url = request.geturl()
261 v.result(request.status == 200)
567a6783
SW
262 v.check('Got forwarded', channel.channel_url != forwarded_url)
263 return channel_html, forwarded_url
2f96f32a
SW
264
265
567a6783
SW
266def parse_channel(v: Verification, channel_html: str) \
267 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
2f96f32a 268 v.status('Parsing channel description as XML')
567a6783 269 d = xml.dom.minidom.parseString(channel_html)
2f96f32a
SW
270 v.ok()
271
3e6421c4
SW
272 v.status('Extracting release name:')
273 title_name = d.getElementsByTagName(
274 'title')[0].firstChild.nodeValue.split()[2]
275 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
276 v.status(title_name)
277 v.result(title_name == h1_name)
3e6421c4
SW
278
279 v.status('Extracting git commit:')
2f96f32a 280 git_commit_node = d.getElementsByTagName('tt')[0]
3258ff2c
SW
281 git_revision = git_commit_node.firstChild.nodeValue
282 v.status(git_revision)
2f96f32a
SW
283 v.ok()
284 v.status('Verifying git commit label')
285 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
286
287 v.status('Parsing table')
567a6783 288 table: Dict[str, ChannelTableEntry] = {}
2f96f32a
SW
289 for row in d.getElementsByTagName('tr')[1:]:
290 name = row.childNodes[0].firstChild.firstChild.nodeValue
291 url = row.childNodes[0].firstChild.getAttribute('href')
292 size = int(row.childNodes[1].firstChild.nodeValue)
73bec7e8 293 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
567a6783 294 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
2f96f32a 295 v.ok()
567a6783 296 return table, GitPin(release_name=title_name, git_revision=git_revision)
2f96f32a
SW
297
298
dc038df0
SW
299def digest_string(s: bytes) -> Digest16:
300 return Digest16(hashlib.sha256(s).hexdigest())
301
302
73bec7e8
SW
303def digest_file(filename: str) -> Digest16:
304 hasher = hashlib.sha256()
305 with open(filename, 'rb') as f:
306 # pylint: disable=cell-var-from-loop
307 for block in iter(lambda: f.read(4096), b''):
308 hasher.update(block)
309 return Digest16(hasher.hexdigest())
310
311
b1e8c1b0
SW
312@functools.lru_cache
313def _experimental_flag_needed(v: Verification) -> bool:
314 v.status('Checking Nix version')
315 process = subprocess.run(['nix', '--help'], stdout=subprocess.PIPE)
316 v.result(process.returncode == 0)
317 return b'--experimental-features' in process.stdout
1d48f551
SW
318
319
320def _nix_command(v: Verification) -> List[str]:
1d48f551 321 return ['nix', '--experimental-features',
b1e8c1b0 322 'nix-command'] if _experimental_flag_needed(v) else ['nix']
1d48f551
SW
323
324
73bec7e8
SW
325def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
326 v.status('Converting digest to base16')
1d48f551
SW
327 process = subprocess.run(_nix_command(v) + [
328 'to-base16',
329 '--type',
330 'sha256',
331 digest32],
332 stdout=subprocess.PIPE)
73bec7e8
SW
333 v.result(process.returncode == 0)
334 return Digest16(process.stdout.decode().strip())
335
336
337def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
338 v.status('Converting digest to base32')
1d48f551
SW
339 process = subprocess.run(_nix_command(v) + [
340 'to-base32',
341 '--type',
342 'sha256',
343 digest16],
344 stdout=subprocess.PIPE)
73bec7e8
SW
345 v.result(process.returncode == 0)
346 return Digest32(process.stdout.decode().strip())
347
348
349def fetch_with_nix_prefetch_url(
350 v: Verification,
351 url: str,
352 digest: Digest16) -> str:
cb28d8e5 353 v.status(f'Fetching {url}')
73bec7e8 354 process = subprocess.run(
ba596fc0 355 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
73bec7e8
SW
356 v.result(process.returncode == 0)
357 prefetch_digest, path, empty = process.stdout.decode().split('\n')
358 assert empty == ''
359 v.check("Verifying nix-prefetch-url's digest",
360 to_Digest16(v, Digest32(prefetch_digest)) == digest)
cb28d8e5 361 v.status(f"Verifying digest of {path}")
73bec7e8
SW
362 file_digest = digest_file(path)
363 v.result(file_digest == digest)
7c4de64c 364 return path # type: ignore # (for old mypy)
2f96f32a 365
73bec7e8 366
9343cf48
SW
367def fetch_resources(
368 v: Verification,
567a6783
SW
369 pin: GitPin,
370 forwarded_url: str,
371 table: Dict[str, ChannelTableEntry]) -> None:
2f96f32a 372 for resource in ['git-revision', 'nixexprs.tar.xz']:
567a6783
SW
373 fields = table[resource]
374 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
e434d96d
SW
375 fields.file = fetch_with_nix_prefetch_url(
376 v, fields.absolute_url, fields.digest)
73bec7e8 377 v.status('Verifying git commit on main page matches git commit in table')
fb27ccc7
SW
378 with open(table['git-revision'].file, encoding='utf-8') as rev_file:
379 v.result(rev_file.read(999) == pin.git_revision)
2f96f32a 380
971d3659 381
b17278e3 382def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
eb0c6f1b
SW
383 return os.path.join(
384 xdg.XDG_CACHE_HOME,
385 'pinch/git-tarball',
cb28d8e5 386 f'{digest_string(channel.git_repo.encode())}-{pin.git_revision}-{pin.release_name}')
971d3659
SW
387
388
d7cfdb22
SW
389def verify_git_ancestry(
390 v: Verification,
391 channel: TarrableSearchPath,
d06918bc
SW
392 old_revision: str,
393 new_revision: str) -> None:
394 cachedir = git_cache.git_cachedir(channel.git_repo)
cb28d8e5 395 v.status(f'Verifying rev is an ancestor of previous rev {old_revision}')
971d3659
SW
396 process = subprocess.run(['git',
397 '-C',
398 cachedir,
399 'merge-base',
400 '--is-ancestor',
d06918bc
SW
401 old_revision,
402 new_revision])
971d3659
SW
403 v.result(process.returncode == 0)
404
dc038df0 405
925c801b
SW
406def compare_tarball_and_git(
407 v: Verification,
a72fdca9 408 pin: GitPin,
925c801b
SW
409 channel_contents: str,
410 git_contents: str) -> None:
411 v.status('Comparing channel tarball with git checkout')
412 match, mismatch, errors = compare(os.path.join(
a72fdca9 413 channel_contents, pin.release_name), git_contents)
925c801b 414 v.ok()
cb28d8e5
SW
415 v.check(f'{len(match)} files match', len(match) > 0)
416 v.check(f'{len(mismatch)} files differ', len(mismatch) == 0)
925c801b
SW
417 expected_errors = [
418 '.git-revision',
419 '.version-suffix',
420 'nixpkgs',
421 'programs.sqlite',
422 'svn-revision']
423 benign_errors = []
424 for ee in expected_errors:
425 if ee in errors:
426 errors.remove(ee)
427 benign_errors.append(ee)
cb28d8e5 428 v.check(f'{len(errors)} unexpected incomparable files', len(errors) == 0)
925c801b 429 v.check(
cb28d8e5 430 f'({len(benign_errors)} of {len(expected_errors)} expected incomparable files)',
925c801b
SW
431 len(benign_errors) == len(expected_errors))
432
433
7d2c278f
SW
434def extract_tarball(
435 v: Verification,
567a6783 436 table: Dict[str, ChannelTableEntry],
7d2c278f 437 dest: str) -> None:
cb28d8e5 438 v.status(f"Extracting tarball {table['nixexprs.tar.xz'].file}")
567a6783 439 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
925c801b
SW
440 v.ok()
441
442
7d2c278f
SW
443def git_checkout(
444 v: Verification,
445 channel: TarrableSearchPath,
3258ff2c 446 pin: GitPin,
7d2c278f 447 dest: str) -> None:
925c801b 448 v.status('Checking out corresponding git revision')
fb27ccc7
SW
449 with subprocess.Popen(
450 ['git', '-C', git_cache.git_cachedir(channel.git_repo), 'archive', pin.git_revision],
451 stdout=subprocess.PIPE) as git:
452 with subprocess.Popen(['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout) as tar:
453 if git.stdout:
454 git.stdout.close()
455 tar.wait()
456 git.wait()
457 v.result(git.returncode == 0 and tar.returncode == 0)
925c801b
SW
458
459
9343cf48
SW
460def git_get_tarball(
461 v: Verification,
462 channel: TarrableSearchPath,
463 pin: GitPin) -> str:
b17278e3 464 cache_file = tarball_cache_file(channel, pin)
eb0c6f1b 465 if os.path.exists(cache_file):
fb27ccc7
SW
466 with open(cache_file, encoding='utf-8') as f:
467 cached_tarball = f.read(9999)
468 if os.path.exists(cached_tarball):
469 return cached_tarball
eb0c6f1b 470
736c25eb
SW
471 with tempfile.TemporaryDirectory() as output_dir:
472 output_filename = os.path.join(
9343cf48 473 output_dir, pin.release_name + '.tar.xz')
fb27ccc7 474 with open(output_filename, 'w', encoding='utf-8') as output_file:
cb28d8e5 475 v.status(f'Generating tarball for git revision {pin.git_revision}')
fb27ccc7
SW
476 with subprocess.Popen(
477 ['git', '-C', git_cache.git_cachedir(channel.git_repo),
cb28d8e5 478 'archive', f'--prefix={pin.release_name}/', pin.git_revision],
fb27ccc7
SW
479 stdout=subprocess.PIPE) as git:
480 with subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file) as xz:
481 xz.wait()
482 git.wait()
483 v.result(git.returncode == 0 and xz.returncode == 0)
736c25eb 484
0afcdb2a 485 store_tarball = copy_to_nix_store(v, output_filename)
eb0c6f1b
SW
486
487 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
fb27ccc7
SW
488 with open(cache_file, 'w', encoding='utf-8') as f:
489 f.write(store_tarball)
7c4de64c 490 return store_tarball # type: ignore # (for old mypy)
736c25eb
SW
491
492
f9cd7bdc
SW
493def check_channel_metadata(
494 v: Verification,
a72fdca9 495 pin: GitPin,
f9cd7bdc
SW
496 channel_contents: str) -> None:
497 v.status('Verifying git commit in channel tarball')
fb27ccc7
SW
498 with open(os.path.join(channel_contents, pin.release_name, '.git-revision'),
499 encoding='utf-8') as f:
500 v.result(f.read(999) == pin.git_revision)
f9cd7bdc
SW
501
502 v.status(
cb28d8e5 503 f'Verifying version-suffix is a suffix of release name {pin.release_name}:')
fb27ccc7
SW
504 with open(os.path.join(channel_contents, pin.release_name, '.version-suffix'),
505 encoding='utf-8') as f:
506 version_suffix = f.read(999)
f9cd7bdc 507 v.status(version_suffix)
a72fdca9 508 v.result(pin.release_name.endswith(version_suffix))
f9cd7bdc
SW
509
510
7d2c278f
SW
511def check_channel_contents(
512 v: Verification,
a72fdca9 513 channel: TarrableSearchPath,
567a6783 514 table: Dict[str, ChannelTableEntry],
a72fdca9 515 pin: GitPin) -> None:
dc038df0
SW
516 with tempfile.TemporaryDirectory() as channel_contents, \
517 tempfile.TemporaryDirectory() as git_contents:
925c801b 518
567a6783 519 extract_tarball(v, table, channel_contents)
a72fdca9 520 check_channel_metadata(v, pin, channel_contents)
f9cd7bdc 521
3258ff2c 522 git_checkout(v, channel, pin, git_contents)
925c801b 523
a72fdca9 524 compare_tarball_and_git(v, pin, channel_contents, git_contents)
925c801b 525
dc038df0 526 v.status('Removing temporary directories')
2f96f32a
SW
527 v.ok()
528
529
d7cfdb22
SW
530def git_revision_name(
531 v: Verification,
532 channel: TarrableSearchPath,
533 git_revision: str) -> str:
e3cae769
SW
534 v.status('Getting commit date')
535 process = subprocess.run(['git',
536 '-C',
d06918bc 537 git_cache.git_cachedir(channel.git_repo),
bed32182 538 'log',
e3cae769
SW
539 '-n1',
540 '--format=%ct-%h',
541 '--abbrev=11',
88af5903 542 '--no-show-signature',
d7cfdb22 543 git_revision],
ba596fc0 544 stdout=subprocess.PIPE)
de68382a 545 v.result(process.returncode == 0 and process.stdout != b'')
cb28d8e5 546 return f'{os.path.basename(channel.git_repo)}-{process.stdout.decode().strip()}'
e3cae769
SW
547
548
567a6783
SW
549K = TypeVar('K')
550V = TypeVar('V')
551
552
9d2c406b
SW
553def partition_dict(pred: Callable[[K, V], bool],
554 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
567a6783
SW
555 selected: Dict[K, V] = {}
556 remaining: Dict[K, V] = {}
557 for k, v in d.items():
9d2c406b 558 if pred(k, v):
567a6783
SW
559 selected[k] = v
560 else:
561 remaining[k] = v
562 return selected, remaining
563
564
9d2c406b
SW
565def filter_dict(d: Dict[K, V], fields: Set[K]
566 ) -> Tuple[Dict[K, V], Dict[K, V]]:
567 return partition_dict(lambda k, v: k in fields, d)
568
569
d815b199 570def read_config_section(
567a6783
SW
571 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
572 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
573 'alias': (AliasSearchPath, AliasPin),
574 'channel': (ChannelSearchPath, ChannelPin),
575 'git': (GitSearchPath, GitPin),
0afcdb2a 576 'symlink': (SymlinkSearchPath, SymlinkPin),
7f4c3ace 577 }
567a6783
SW
578 SP, P = mapping[conf['type']]
579 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
580 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
581 # Error suppression works around https://github.com/python/mypy/issues/9007
9e4ad890 582 pin_present = pin_fields or P._fields == ()
530104d7 583 pin = P(**pin_fields) if pin_present else None # type: ignore
567a6783 584 return SP(**remaining_fields), pin
f8f5b125
SW
585
586
e8bd4979
SW
587def read_pinned_config_section(
588 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
589 sp, pin = read_config_section(conf)
590 if pin is None:
591 raise Exception(
cb28d8e5 592 f'Cannot update unpinned channel "{section}" (Run "pin" before "update")')
e8bd4979
SW
593 return sp, pin
594
595
01ba0eb2
SW
596def read_config(filename: str) -> configparser.ConfigParser:
597 config = configparser.ConfigParser()
fb27ccc7
SW
598 with open(filename, encoding='utf-8') as f:
599 config.read_file(f, filename)
01ba0eb2
SW
600 return config
601
602
4603b1a7
SW
603def read_config_files(
604 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
605 merged_config: Dict[str, configparser.SectionProxy] = {}
606 for file in filenames:
607 config = read_config(file)
608 for section in config.sections():
609 if section in merged_config:
cb28d8e5 610 raise Exception('Duplicate channel "{section}"')
4603b1a7
SW
611 merged_config[section] = config[section]
612 return merged_config
613
614
41b87c9c 615def pinCommand(args: argparse.Namespace) -> None:
2f96f32a 616 v = Verification()
01ba0eb2 617 config = read_config(args.channels_file)
5cfa8e11 618 for section in config.sections():
98853153
SW
619 if args.channels and section not in args.channels:
620 continue
736c25eb 621
d815b199 622 sp, old_pin = read_config_section(config[section])
17906b27 623
567a6783 624 config[section].update(sp.pin(v, old_pin)._asdict())
8fca6c28 625
fb27ccc7 626 with open(args.channels_file, 'w', encoding='utf-8') as configfile:
e434d96d 627 config.write(configfile)
2f96f32a
SW
628
629
41b87c9c 630def updateCommand(args: argparse.Namespace) -> None:
736c25eb 631 v = Verification()
da135b07 632 exprs: Dict[str, str] = {}
3b2117a3
SW
633 profile_manifest = os.path.join(args.profile, "manifest.nix")
634 search_paths: List[str] = [
635 "-I", "pinch_profile=" + args.profile,
636 "-I", "pinch_profile_manifest=" + os.readlink(profile_manifest)
637 ] if os.path.exists(profile_manifest) else []
9d2c406b
SW
638 config = {
639 section: read_pinned_config_section(section, conf) for section,
640 conf in read_config_files(
641 args.channels_file).items()}
642 alias, nonalias = partition_dict(
643 lambda k, v: isinstance(v[0], AliasSearchPath), config)
644
436195f0 645 for section, (sp, pin) in sorted(nonalias.items()):
9d2c406b
SW
646 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
647 assert not isinstance(pin, AliasPin) # partition_dict()
567a6783 648 tarball = sp.fetch(v, pin)
cb28d8e5
SW
649 search_paths.extend(
650 ["-I", f"pinch_tarball_for_{pin.release_name}={tarball}"])
4603b1a7 651 exprs[section] = (
cb28d8e5
SW
652 f'f: f {{ name = "{pin.release_name}"; channelName = "%s"; '
653 f'src = builtins.storePath "{tarball}"; }}')
4603b1a7 654
9d2c406b
SW
655 for section, (sp, pin) in alias.items():
656 assert isinstance(sp, AliasSearchPath) # For mypy
657 exprs[section] = exprs[sp.alias_of]
17906b27 658
9a78329e
SW
659 command = [
660 'nix-env',
661 '--profile',
9e8ed0ed 662 args.profile,
9a78329e
SW
663 '--show-trace',
664 '--file',
665 '<nix/unpack-channel.nix>',
666 '--install',
fc168c34 667 '--remove-all',
436195f0
SW
668 ] + search_paths + ['--from-expression'] + [
669 exprs[name] % name for name in sorted(exprs.keys())]
9a78329e
SW
670 if args.dry_run:
671 print(' '.join(map(shlex.quote, command)))
672 else:
673 v.status('Installing channels with nix-env')
674 process = subprocess.run(command)
675 v.result(process.returncode == 0)
736c25eb
SW
676
677
0e5e611d
SW
678def main() -> None:
679 parser = argparse.ArgumentParser(prog='pinch')
680 subparsers = parser.add_subparsers(dest='mode', required=True)
681 parser_pin = subparsers.add_parser('pin')
682 parser_pin.add_argument('channels_file', type=str)
98853153 683 parser_pin.add_argument('channels', type=str, nargs='*')
41b87c9c 684 parser_pin.set_defaults(func=pinCommand)
736c25eb 685 parser_update = subparsers.add_parser('update')
9a78329e 686 parser_update.add_argument('--dry-run', action='store_true')
9e8ed0ed 687 parser_update.add_argument('--profile', default=(
cb28d8e5 688 f'/nix/var/nix/profiles/per-user/{getpass.getuser()}/channels'))
01ba0eb2 689 parser_update.add_argument('channels_file', type=str, nargs='+')
41b87c9c 690 parser_update.set_defaults(func=updateCommand)
0e5e611d
SW
691 args = parser.parse_args()
692 args.func(args)
693
694
b5964ec3
SW
695if __name__ == '__main__':
696 main()