]> git.scottworley.com Git - pinch/blame - pinch.py
Verify unknown config fields raise errors
[pinch] / pinch.py
CommitLineData
0e5e611d 1import argparse
f15e458d 2import configparser
2f96f32a
SW
3import filecmp
4import functools
736c25eb 5import getpass
2f96f32a
SW
6import hashlib
7import operator
8import os
9import os.path
9a78329e 10import shlex
2f96f32a 11import shutil
73bec7e8 12import subprocess
9d7844bb 13import sys
2f96f32a 14import tempfile
89e79125 15import types
2f96f32a
SW
16import urllib.parse
17import urllib.request
18import xml.dom.minidom
19
20from typing import (
9d2c406b 21 Callable,
2f96f32a
SW
22 Dict,
23 Iterable,
24 List,
7f4c3ace 25 Mapping,
0a4ff8dd 26 NamedTuple,
73bec7e8 27 NewType,
d7cfdb22 28 Optional,
567a6783 29 Set,
2f96f32a 30 Tuple,
7f4c3ace 31 Type,
567a6783 32 TypeVar,
0a4ff8dd 33 Union,
2f96f32a
SW
34)
35
3603dde2
SW
36# Use xdg module when it's less painful to have as a dependency
37
38
9f936f16 39class XDG(NamedTuple):
3603dde2
SW
40 XDG_CACHE_HOME: str
41
42
43xdg = XDG(
44 XDG_CACHE_HOME=os.getenv(
45 'XDG_CACHE_HOME',
46 os.path.expanduser('~/.cache')))
26125a28
SW
47
48
2f96f32a
SW
49class VerificationError(Exception):
50 pass
51
52
53class Verification:
54
55 def __init__(self) -> None:
56 self.line_length = 0
57
58 def status(self, s: str) -> None:
9d7844bb 59 print(s, end=' ', file=sys.stderr, flush=True)
2f96f32a
SW
60 self.line_length += 1 + len(s) # Unicode??
61
62 @staticmethod
63 def _color(s: str, c: int) -> str:
64 return '\033[%2dm%s\033[00m' % (c, s)
65
66 def result(self, r: bool) -> None:
67 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
68 length = len(message)
dd1026fe 69 cols = shutil.get_terminal_size().columns or 80
2f96f32a 70 pad = (cols - (self.line_length + length)) % cols
9d7844bb 71 print(' ' * pad + self._color(message, color), file=sys.stderr)
2f96f32a
SW
72 self.line_length = 0
73 if not r:
74 raise VerificationError()
75
76 def check(self, s: str, r: bool) -> None:
77 self.status(s)
78 self.result(r)
79
80 def ok(self) -> None:
81 self.result(True)
82
83
faff8642
SW
84Digest16 = NewType('Digest16', str)
85Digest32 = NewType('Digest32', str)
86
87
88class ChannelTableEntry(types.SimpleNamespace):
89 absolute_url: str
90 digest: Digest16
91 file: str
92 size: int
93 url: str
94
95
0a4ff8dd
SW
96class AliasPin(NamedTuple):
97 pass
98
99
100class GitPin(NamedTuple):
101 git_revision: str
faff8642
SW
102 release_name: str
103
0a4ff8dd
SW
104
105class ChannelPin(NamedTuple):
106 git_revision: str
107 release_name: str
108 tarball_url: str
109 tarball_sha256: str
110
111
112Pin = Union[AliasPin, GitPin, ChannelPin]
113
114
567a6783 115class AliasSearchPath(NamedTuple):
faff8642 116 alias_of: str
2fa9cbea 117
567a6783
SW
118 # pylint: disable=no-self-use
119 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
0a4ff8dd 120 return AliasPin()
2fa9cbea
SW
121
122
567a6783 123class GitSearchPath(NamedTuple):
faff8642
SW
124 git_ref: str
125 git_repo: str
faff8642 126
567a6783
SW
127 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
128 if old_pin is not None:
129 assert isinstance(old_pin, GitPin)
130 old_revision = old_pin.git_revision if old_pin is not None else None
a67cfec9 131
d7cfdb22
SW
132 new_revision = git_fetch(v, self, None, old_revision)
133 return GitPin(release_name=git_revision_name(v, self, new_revision),
134 git_revision=new_revision)
4a82be40 135
567a6783
SW
136 def fetch(self, v: Verification, pin: Pin) -> str:
137 assert isinstance(pin, GitPin)
138 ensure_git_rev_available(v, self, pin, None)
139 return git_get_tarball(v, self, pin)
b3ea39b8 140
b3ea39b8 141
567a6783
SW
142class ChannelSearchPath(NamedTuple):
143 channel_url: str
144 git_ref: str
145 git_repo: str
4a82be40 146
567a6783
SW
147 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
148 if old_pin is not None:
149 assert isinstance(old_pin, ChannelPin)
150 old_revision = old_pin.git_revision if old_pin is not None else None
a67cfec9 151
567a6783
SW
152 channel_html, forwarded_url = fetch_channel(v, self)
153 table, new_gitpin = parse_channel(v, channel_html)
154 fetch_resources(v, new_gitpin, forwarded_url, table)
d7cfdb22 155 ensure_git_rev_available(v, self, new_gitpin, old_revision)
567a6783 156 check_channel_contents(v, self, table, new_gitpin)
0a4ff8dd 157 return ChannelPin(
b17278e3 158 release_name=new_gitpin.release_name,
567a6783
SW
159 tarball_url=table['nixexprs.tar.xz'].absolute_url,
160 tarball_sha256=table['nixexprs.tar.xz'].digest,
3258ff2c 161 git_revision=new_gitpin.git_revision)
4a82be40 162
b3ea39b8 163 # pylint: disable=no-self-use
567a6783
SW
164 def fetch(self, v: Verification, pin: Pin) -> str:
165 assert isinstance(pin, ChannelPin)
b3ea39b8
SW
166
167 return fetch_with_nix_prefetch_url(
567a6783
SW
168 v, pin.tarball_url, Digest16(pin.tarball_sha256))
169
170
171SearchPath = Union[AliasSearchPath, GitSearchPath, ChannelSearchPath]
172TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
b3ea39b8 173
4a82be40 174
dc038df0 175def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
2f96f32a
SW
176
177 def throw(error: OSError) -> None:
178 raise error
179
180 def join(x: str, y: str) -> str:
181 return y if x == '.' else os.path.join(x, y)
182
183 def recursive_files(d: str) -> Iterable[str]:
184 all_files: List[str] = []
185 for path, dirs, files in os.walk(d, onerror=throw):
186 rel = os.path.relpath(path, start=d)
187 all_files.extend(join(rel, f) for f in files)
188 for dir_or_link in dirs:
189 if os.path.islink(join(path, dir_or_link)):
190 all_files.append(join(rel, dir_or_link))
191 return all_files
192
193 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
194 return (f for f in files if not f.startswith('.git/'))
195
196 files = functools.reduce(
197 operator.or_, (set(
198 exclude_dot_git(
199 recursive_files(x))) for x in [a, b]))
200 return filecmp.cmpfiles(a, b, files, shallow=False)
201
202
567a6783
SW
203def fetch_channel(
204 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
2f96f32a 205 v.status('Fetching channel')
b17def3f 206 request = urllib.request.urlopen(channel.channel_url, timeout=10)
567a6783
SW
207 channel_html = request.read()
208 forwarded_url = request.geturl()
7c4de64c 209 v.result(request.status == 200) # type: ignore # (for old mypy)
567a6783
SW
210 v.check('Got forwarded', channel.channel_url != forwarded_url)
211 return channel_html, forwarded_url
2f96f32a
SW
212
213
567a6783
SW
214def parse_channel(v: Verification, channel_html: str) \
215 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
2f96f32a 216 v.status('Parsing channel description as XML')
567a6783 217 d = xml.dom.minidom.parseString(channel_html)
2f96f32a
SW
218 v.ok()
219
3e6421c4
SW
220 v.status('Extracting release name:')
221 title_name = d.getElementsByTagName(
222 'title')[0].firstChild.nodeValue.split()[2]
223 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
224 v.status(title_name)
225 v.result(title_name == h1_name)
3e6421c4
SW
226
227 v.status('Extracting git commit:')
2f96f32a 228 git_commit_node = d.getElementsByTagName('tt')[0]
3258ff2c
SW
229 git_revision = git_commit_node.firstChild.nodeValue
230 v.status(git_revision)
2f96f32a
SW
231 v.ok()
232 v.status('Verifying git commit label')
233 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
234
235 v.status('Parsing table')
567a6783 236 table: Dict[str, ChannelTableEntry] = {}
2f96f32a
SW
237 for row in d.getElementsByTagName('tr')[1:]:
238 name = row.childNodes[0].firstChild.firstChild.nodeValue
239 url = row.childNodes[0].firstChild.getAttribute('href')
240 size = int(row.childNodes[1].firstChild.nodeValue)
73bec7e8 241 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
567a6783 242 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
2f96f32a 243 v.ok()
567a6783 244 return table, GitPin(release_name=title_name, git_revision=git_revision)
2f96f32a
SW
245
246
dc038df0
SW
247def digest_string(s: bytes) -> Digest16:
248 return Digest16(hashlib.sha256(s).hexdigest())
249
250
73bec7e8
SW
251def digest_file(filename: str) -> Digest16:
252 hasher = hashlib.sha256()
253 with open(filename, 'rb') as f:
254 # pylint: disable=cell-var-from-loop
255 for block in iter(lambda: f.read(4096), b''):
256 hasher.update(block)
257 return Digest16(hasher.hexdigest())
258
259
260def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
261 v.status('Converting digest to base16')
262 process = subprocess.run(
ba596fc0 263 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
73bec7e8
SW
264 v.result(process.returncode == 0)
265 return Digest16(process.stdout.decode().strip())
266
267
268def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
269 v.status('Converting digest to base32')
270 process = subprocess.run(
ba596fc0 271 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
73bec7e8
SW
272 v.result(process.returncode == 0)
273 return Digest32(process.stdout.decode().strip())
274
275
276def fetch_with_nix_prefetch_url(
277 v: Verification,
278 url: str,
279 digest: Digest16) -> str:
280 v.status('Fetching %s' % url)
281 process = subprocess.run(
ba596fc0 282 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
73bec7e8
SW
283 v.result(process.returncode == 0)
284 prefetch_digest, path, empty = process.stdout.decode().split('\n')
285 assert empty == ''
286 v.check("Verifying nix-prefetch-url's digest",
287 to_Digest16(v, Digest32(prefetch_digest)) == digest)
288 v.status("Verifying file digest")
289 file_digest = digest_file(path)
290 v.result(file_digest == digest)
7c4de64c 291 return path # type: ignore # (for old mypy)
2f96f32a 292
73bec7e8 293
9343cf48
SW
294def fetch_resources(
295 v: Verification,
567a6783
SW
296 pin: GitPin,
297 forwarded_url: str,
298 table: Dict[str, ChannelTableEntry]) -> None:
2f96f32a 299 for resource in ['git-revision', 'nixexprs.tar.xz']:
567a6783
SW
300 fields = table[resource]
301 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
e434d96d
SW
302 fields.file = fetch_with_nix_prefetch_url(
303 v, fields.absolute_url, fields.digest)
73bec7e8 304 v.status('Verifying git commit on main page matches git commit in table')
567a6783 305 v.result(open(table['git-revision'].file).read(999) == pin.git_revision)
2f96f32a 306
971d3659 307
9836141c 308def git_cachedir(git_repo: str) -> str:
26125a28
SW
309 return os.path.join(
310 xdg.XDG_CACHE_HOME,
311 'pinch/git',
eb0c6f1b
SW
312 digest_string(git_repo.encode()))
313
314
b17278e3 315def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
eb0c6f1b
SW
316 return os.path.join(
317 xdg.XDG_CACHE_HOME,
318 'pinch/git-tarball',
319 '%s-%s-%s' %
320 (digest_string(channel.git_repo.encode()),
b17278e3
SW
321 pin.git_revision,
322 pin.release_name))
971d3659
SW
323
324
d7cfdb22
SW
325def verify_git_ancestry(
326 v: Verification,
327 channel: TarrableSearchPath,
328 new_revision: str,
329 old_revision: Optional[str]) -> None:
971d3659
SW
330 cachedir = git_cachedir(channel.git_repo)
331 v.status('Verifying rev is an ancestor of ref')
332 process = subprocess.run(['git',
333 '-C',
334 cachedir,
335 'merge-base',
336 '--is-ancestor',
d7cfdb22 337 new_revision,
971d3659
SW
338 channel.git_ref])
339 v.result(process.returncode == 0)
340
d7cfdb22 341 if old_revision is not None:
971d3659
SW
342 v.status(
343 'Verifying rev is an ancestor of previous rev %s' %
d7cfdb22 344 old_revision)
971d3659
SW
345 process = subprocess.run(['git',
346 '-C',
347 cachedir,
348 'merge-base',
349 '--is-ancestor',
d7cfdb22
SW
350 old_revision,
351 new_revision])
971d3659 352 v.result(process.returncode == 0)
9836141c 353
2f96f32a 354
d7cfdb22
SW
355def git_fetch(
356 v: Verification,
357 channel: TarrableSearchPath,
358 desired_revision: Optional[str],
359 old_revision: Optional[str]) -> str:
dc038df0
SW
360 # It would be nice if we could share the nix git cache, but as of the time
361 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
362 # yet), and trying to straddle them both is too far into nix implementation
363 # details for my comfort. So we re-implement here half of nix.fetchGit.
364 # :(
365
9836141c
SW
366 cachedir = git_cachedir(channel.git_repo)
367 if not os.path.exists(cachedir):
dc038df0
SW
368 v.status("Initializing git repo")
369 process = subprocess.run(
9836141c 370 ['git', 'init', '--bare', cachedir])
dc038df0
SW
371 v.result(process.returncode == 0)
372
971d3659
SW
373 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
374 # We don't use --force here because we want to abort and freak out if forced
375 # updates are happening.
376 process = subprocess.run(['git',
377 '-C',
378 cachedir,
379 'fetch',
380 channel.git_repo,
381 '%s:%s' % (channel.git_ref,
382 channel.git_ref)])
383 v.result(process.returncode == 0)
384
d7cfdb22 385 if desired_revision is not None:
971d3659 386 v.status('Verifying that fetch retrieved this rev')
8fca6c28 387 process = subprocess.run(
d7cfdb22 388 ['git', '-C', cachedir, 'cat-file', '-e', desired_revision])
dc038df0 389 v.result(process.returncode == 0)
dc038df0 390
d7cfdb22
SW
391 new_revision = open(
392 os.path.join(
393 cachedir,
394 'refs',
395 'heads',
396 channel.git_ref)).read(999).strip()
397
398 verify_git_ancestry(v, channel, new_revision, old_revision)
399
400 return new_revision
dc038df0 401
971d3659 402
7d2c278f
SW
403def ensure_git_rev_available(
404 v: Verification,
9343cf48 405 channel: TarrableSearchPath,
d7cfdb22
SW
406 pin: GitPin,
407 old_revision: Optional[str]) -> None:
971d3659
SW
408 cachedir = git_cachedir(channel.git_repo)
409 if os.path.exists(cachedir):
410 v.status('Checking if we already have this rev:')
411 process = subprocess.run(
9343cf48 412 ['git', '-C', cachedir, 'cat-file', '-e', pin.git_revision])
971d3659
SW
413 if process.returncode == 0:
414 v.status('yes')
415 if process.returncode == 1:
416 v.status('no')
417 v.result(process.returncode == 0 or process.returncode == 1)
418 if process.returncode == 0:
d7cfdb22 419 verify_git_ancestry(v, channel, pin.git_revision, old_revision)
971d3659 420 return
d7cfdb22 421 git_fetch(v, channel, pin.git_revision, old_revision)
7d889b12 422
dc038df0 423
925c801b
SW
424def compare_tarball_and_git(
425 v: Verification,
a72fdca9 426 pin: GitPin,
925c801b
SW
427 channel_contents: str,
428 git_contents: str) -> None:
429 v.status('Comparing channel tarball with git checkout')
430 match, mismatch, errors = compare(os.path.join(
a72fdca9 431 channel_contents, pin.release_name), git_contents)
925c801b
SW
432 v.ok()
433 v.check('%d files match' % len(match), len(match) > 0)
434 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
435 expected_errors = [
436 '.git-revision',
437 '.version-suffix',
438 'nixpkgs',
439 'programs.sqlite',
440 'svn-revision']
441 benign_errors = []
442 for ee in expected_errors:
443 if ee in errors:
444 errors.remove(ee)
445 benign_errors.append(ee)
446 v.check(
447 '%d unexpected incomparable files' %
448 len(errors),
449 len(errors) == 0)
450 v.check(
451 '(%d of %d expected incomparable files)' %
452 (len(benign_errors),
453 len(expected_errors)),
454 len(benign_errors) == len(expected_errors))
455
456
7d2c278f
SW
457def extract_tarball(
458 v: Verification,
567a6783 459 table: Dict[str, ChannelTableEntry],
7d2c278f 460 dest: str) -> None:
567a6783
SW
461 v.status('Extracting tarball %s' % table['nixexprs.tar.xz'].file)
462 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
925c801b
SW
463 v.ok()
464
465
7d2c278f
SW
466def git_checkout(
467 v: Verification,
468 channel: TarrableSearchPath,
3258ff2c 469 pin: GitPin,
7d2c278f 470 dest: str) -> None:
925c801b
SW
471 v.status('Checking out corresponding git revision')
472 git = subprocess.Popen(['git',
473 '-C',
9836141c 474 git_cachedir(channel.git_repo),
925c801b 475 'archive',
3258ff2c 476 pin.git_revision],
925c801b
SW
477 stdout=subprocess.PIPE)
478 tar = subprocess.Popen(
479 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
de68382a
SW
480 if git.stdout:
481 git.stdout.close()
925c801b
SW
482 tar.wait()
483 git.wait()
484 v.result(git.returncode == 0 and tar.returncode == 0)
485
486
9343cf48
SW
487def git_get_tarball(
488 v: Verification,
489 channel: TarrableSearchPath,
490 pin: GitPin) -> str:
b17278e3 491 cache_file = tarball_cache_file(channel, pin)
eb0c6f1b
SW
492 if os.path.exists(cache_file):
493 cached_tarball = open(cache_file).read(9999)
494 if os.path.exists(cached_tarball):
495 return cached_tarball
496
736c25eb
SW
497 with tempfile.TemporaryDirectory() as output_dir:
498 output_filename = os.path.join(
9343cf48 499 output_dir, pin.release_name + '.tar.xz')
736c25eb
SW
500 with open(output_filename, 'w') as output_file:
501 v.status(
502 'Generating tarball for git revision %s' %
9343cf48 503 pin.git_revision)
736c25eb
SW
504 git = subprocess.Popen(['git',
505 '-C',
506 git_cachedir(channel.git_repo),
507 'archive',
9343cf48
SW
508 '--prefix=%s/' % pin.release_name,
509 pin.git_revision],
736c25eb
SW
510 stdout=subprocess.PIPE)
511 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
512 xz.wait()
513 git.wait()
514 v.result(git.returncode == 0 and xz.returncode == 0)
515
516 v.status('Putting tarball in Nix store')
517 process = subprocess.run(
ba596fc0 518 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
736c25eb 519 v.result(process.returncode == 0)
eb0c6f1b
SW
520 store_tarball = process.stdout.decode().strip()
521
522 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
523 open(cache_file, 'w').write(store_tarball)
7c4de64c 524 return store_tarball # type: ignore # (for old mypy)
736c25eb
SW
525
526
f9cd7bdc
SW
527def check_channel_metadata(
528 v: Verification,
a72fdca9 529 pin: GitPin,
f9cd7bdc
SW
530 channel_contents: str) -> None:
531 v.status('Verifying git commit in channel tarball')
532 v.result(
533 open(
534 os.path.join(
535 channel_contents,
a72fdca9
SW
536 pin.release_name,
537 '.git-revision')).read(999) == pin.git_revision)
f9cd7bdc
SW
538
539 v.status(
540 'Verifying version-suffix is a suffix of release name %s:' %
a72fdca9 541 pin.release_name)
f9cd7bdc
SW
542 version_suffix = open(
543 os.path.join(
544 channel_contents,
a72fdca9 545 pin.release_name,
f9cd7bdc
SW
546 '.version-suffix')).read(999)
547 v.status(version_suffix)
a72fdca9 548 v.result(pin.release_name.endswith(version_suffix))
f9cd7bdc
SW
549
550
7d2c278f
SW
551def check_channel_contents(
552 v: Verification,
a72fdca9 553 channel: TarrableSearchPath,
567a6783 554 table: Dict[str, ChannelTableEntry],
a72fdca9 555 pin: GitPin) -> None:
dc038df0
SW
556 with tempfile.TemporaryDirectory() as channel_contents, \
557 tempfile.TemporaryDirectory() as git_contents:
925c801b 558
567a6783 559 extract_tarball(v, table, channel_contents)
a72fdca9 560 check_channel_metadata(v, pin, channel_contents)
f9cd7bdc 561
3258ff2c 562 git_checkout(v, channel, pin, git_contents)
925c801b 563
a72fdca9 564 compare_tarball_and_git(v, pin, channel_contents, git_contents)
925c801b 565
dc038df0 566 v.status('Removing temporary directories')
2f96f32a
SW
567 v.ok()
568
569
d7cfdb22
SW
570def git_revision_name(
571 v: Verification,
572 channel: TarrableSearchPath,
573 git_revision: str) -> str:
e3cae769
SW
574 v.status('Getting commit date')
575 process = subprocess.run(['git',
576 '-C',
9836141c 577 git_cachedir(channel.git_repo),
bed32182 578 'log',
e3cae769
SW
579 '-n1',
580 '--format=%ct-%h',
581 '--abbrev=11',
88af5903 582 '--no-show-signature',
d7cfdb22 583 git_revision],
ba596fc0 584 stdout=subprocess.PIPE)
de68382a 585 v.result(process.returncode == 0 and process.stdout != b'')
e3cae769
SW
586 return '%s-%s' % (os.path.basename(channel.git_repo),
587 process.stdout.decode().strip())
588
589
567a6783
SW
590K = TypeVar('K')
591V = TypeVar('V')
592
593
9d2c406b
SW
594def partition_dict(pred: Callable[[K, V], bool],
595 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
567a6783
SW
596 selected: Dict[K, V] = {}
597 remaining: Dict[K, V] = {}
598 for k, v in d.items():
9d2c406b 599 if pred(k, v):
567a6783
SW
600 selected[k] = v
601 else:
602 remaining[k] = v
603 return selected, remaining
604
605
9d2c406b
SW
606def filter_dict(d: Dict[K, V], fields: Set[K]
607 ) -> Tuple[Dict[K, V], Dict[K, V]]:
608 return partition_dict(lambda k, v: k in fields, d)
609
610
d815b199 611def read_config_section(
567a6783
SW
612 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
613 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
614 'alias': (AliasSearchPath, AliasPin),
615 'channel': (ChannelSearchPath, ChannelPin),
616 'git': (GitSearchPath, GitPin),
7f4c3ace 617 }
567a6783
SW
618 SP, P = mapping[conf['type']]
619 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
620 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
621 # Error suppression works around https://github.com/python/mypy/issues/9007
622 pin_present = pin_fields != {} or P._fields == ()
623 pin = P(**pin_fields) if pin_present else None # type:ignore[call-arg]
624 return SP(**remaining_fields), pin
f8f5b125
SW
625
626
e8bd4979
SW
627def read_pinned_config_section(
628 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
629 sp, pin = read_config_section(conf)
630 if pin is None:
631 raise Exception(
632 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
633 section)
634 return sp, pin
635
636
01ba0eb2
SW
637def read_config(filename: str) -> configparser.ConfigParser:
638 config = configparser.ConfigParser()
639 config.read_file(open(filename), filename)
640 return config
641
642
4603b1a7
SW
643def read_config_files(
644 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
645 merged_config: Dict[str, configparser.SectionProxy] = {}
646 for file in filenames:
647 config = read_config(file)
648 for section in config.sections():
649 if section in merged_config:
650 raise Exception('Duplicate channel "%s"' % section)
651 merged_config[section] = config[section]
652 return merged_config
653
654
41b87c9c 655def pinCommand(args: argparse.Namespace) -> None:
2f96f32a 656 v = Verification()
01ba0eb2 657 config = read_config(args.channels_file)
5cfa8e11 658 for section in config.sections():
98853153
SW
659 if args.channels and section not in args.channels:
660 continue
736c25eb 661
d815b199 662 sp, old_pin = read_config_section(config[section])
17906b27 663
567a6783 664 config[section].update(sp.pin(v, old_pin)._asdict())
8fca6c28 665
0e5e611d 666 with open(args.channels_file, 'w') as configfile:
e434d96d 667 config.write(configfile)
2f96f32a
SW
668
669
41b87c9c 670def updateCommand(args: argparse.Namespace) -> None:
736c25eb 671 v = Verification()
da135b07 672 exprs: Dict[str, str] = {}
9d2c406b
SW
673 config = {
674 section: read_pinned_config_section(section, conf) for section,
675 conf in read_config_files(
676 args.channels_file).items()}
677 alias, nonalias = partition_dict(
678 lambda k, v: isinstance(v[0], AliasSearchPath), config)
679
680 for section, (sp, pin) in nonalias.items():
681 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
682 assert not isinstance(pin, AliasPin) # partition_dict()
567a6783 683 tarball = sp.fetch(v, pin)
4603b1a7
SW
684 exprs[section] = (
685 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
9d2c406b 686 (pin.release_name, tarball))
4603b1a7 687
9d2c406b
SW
688 for section, (sp, pin) in alias.items():
689 assert isinstance(sp, AliasSearchPath) # For mypy
690 exprs[section] = exprs[sp.alias_of]
17906b27 691
9a78329e
SW
692 command = [
693 'nix-env',
694 '--profile',
695 '/nix/var/nix/profiles/per-user/%s/channels' %
696 getpass.getuser(),
697 '--show-trace',
698 '--file',
699 '<nix/unpack-channel.nix>',
700 '--install',
17906b27 701 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
9a78329e
SW
702 if args.dry_run:
703 print(' '.join(map(shlex.quote, command)))
704 else:
705 v.status('Installing channels with nix-env')
706 process = subprocess.run(command)
707 v.result(process.returncode == 0)
736c25eb
SW
708
709
0e5e611d
SW
710def main() -> None:
711 parser = argparse.ArgumentParser(prog='pinch')
712 subparsers = parser.add_subparsers(dest='mode', required=True)
713 parser_pin = subparsers.add_parser('pin')
714 parser_pin.add_argument('channels_file', type=str)
98853153 715 parser_pin.add_argument('channels', type=str, nargs='*')
41b87c9c 716 parser_pin.set_defaults(func=pinCommand)
736c25eb 717 parser_update = subparsers.add_parser('update')
9a78329e 718 parser_update.add_argument('--dry-run', action='store_true')
01ba0eb2 719 parser_update.add_argument('channels_file', type=str, nargs='+')
41b87c9c 720 parser_update.set_defaults(func=updateCommand)
0e5e611d
SW
721 args = parser.parse_args()
722 args.func(args)
723
724
b5964ec3
SW
725if __name__ == '__main__':
726 main()