]> git.scottworley.com Git - pinch/blame - pinch.py
Remove assert that's now handled by types
[pinch] / pinch.py
CommitLineData
0e5e611d 1import argparse
f15e458d 2import configparser
2f96f32a
SW
3import filecmp
4import functools
736c25eb 5import getpass
2f96f32a
SW
6import hashlib
7import operator
8import os
9import os.path
9a78329e 10import shlex
2f96f32a 11import shutil
73bec7e8 12import subprocess
9d7844bb 13import sys
2f96f32a 14import tempfile
89e79125 15import types
2f96f32a
SW
16import urllib.parse
17import urllib.request
18import xml.dom.minidom
19
20from typing import (
2f96f32a
SW
21 Dict,
22 Iterable,
23 List,
7f4c3ace 24 Mapping,
0a4ff8dd 25 NamedTuple,
73bec7e8 26 NewType,
d7cfdb22 27 Optional,
567a6783 28 Set,
2f96f32a 29 Tuple,
7f4c3ace 30 Type,
567a6783 31 TypeVar,
0a4ff8dd 32 Union,
2f96f32a
SW
33)
34
3603dde2
SW
35# Use xdg module when it's less painful to have as a dependency
36
37
9f936f16 38class XDG(NamedTuple):
3603dde2
SW
39 XDG_CACHE_HOME: str
40
41
42xdg = XDG(
43 XDG_CACHE_HOME=os.getenv(
44 'XDG_CACHE_HOME',
45 os.path.expanduser('~/.cache')))
26125a28
SW
46
47
2f96f32a
SW
48class VerificationError(Exception):
49 pass
50
51
52class Verification:
53
54 def __init__(self) -> None:
55 self.line_length = 0
56
57 def status(self, s: str) -> None:
9d7844bb 58 print(s, end=' ', file=sys.stderr, flush=True)
2f96f32a
SW
59 self.line_length += 1 + len(s) # Unicode??
60
61 @staticmethod
62 def _color(s: str, c: int) -> str:
63 return '\033[%2dm%s\033[00m' % (c, s)
64
65 def result(self, r: bool) -> None:
66 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
67 length = len(message)
dd1026fe 68 cols = shutil.get_terminal_size().columns or 80
2f96f32a 69 pad = (cols - (self.line_length + length)) % cols
9d7844bb 70 print(' ' * pad + self._color(message, color), file=sys.stderr)
2f96f32a
SW
71 self.line_length = 0
72 if not r:
73 raise VerificationError()
74
75 def check(self, s: str, r: bool) -> None:
76 self.status(s)
77 self.result(r)
78
79 def ok(self) -> None:
80 self.result(True)
81
82
faff8642
SW
83Digest16 = NewType('Digest16', str)
84Digest32 = NewType('Digest32', str)
85
86
87class ChannelTableEntry(types.SimpleNamespace):
88 absolute_url: str
89 digest: Digest16
90 file: str
91 size: int
92 url: str
93
94
0a4ff8dd
SW
95class AliasPin(NamedTuple):
96 pass
97
98
99class GitPin(NamedTuple):
100 git_revision: str
faff8642
SW
101 release_name: str
102
0a4ff8dd
SW
103
104class ChannelPin(NamedTuple):
105 git_revision: str
106 release_name: str
107 tarball_url: str
108 tarball_sha256: str
109
110
111Pin = Union[AliasPin, GitPin, ChannelPin]
112
113
567a6783 114class AliasSearchPath(NamedTuple):
faff8642 115 alias_of: str
2fa9cbea 116
567a6783
SW
117 # pylint: disable=no-self-use
118 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
0a4ff8dd 119 return AliasPin()
2fa9cbea
SW
120
121
567a6783 122class GitSearchPath(NamedTuple):
faff8642
SW
123 git_ref: str
124 git_repo: str
faff8642 125
567a6783
SW
126 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
127 if old_pin is not None:
128 assert isinstance(old_pin, GitPin)
129 old_revision = old_pin.git_revision if old_pin is not None else None
a67cfec9 130
d7cfdb22
SW
131 new_revision = git_fetch(v, self, None, old_revision)
132 return GitPin(release_name=git_revision_name(v, self, new_revision),
133 git_revision=new_revision)
4a82be40 134
567a6783
SW
135 def fetch(self, v: Verification, pin: Pin) -> str:
136 assert isinstance(pin, GitPin)
137 ensure_git_rev_available(v, self, pin, None)
138 return git_get_tarball(v, self, pin)
b3ea39b8 139
b3ea39b8 140
567a6783
SW
141class ChannelSearchPath(NamedTuple):
142 channel_url: str
143 git_ref: str
144 git_repo: str
4a82be40 145
567a6783
SW
146 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
147 if old_pin is not None:
148 assert isinstance(old_pin, ChannelPin)
149 old_revision = old_pin.git_revision if old_pin is not None else None
a67cfec9 150
567a6783
SW
151 channel_html, forwarded_url = fetch_channel(v, self)
152 table, new_gitpin = parse_channel(v, channel_html)
153 fetch_resources(v, new_gitpin, forwarded_url, table)
d7cfdb22 154 ensure_git_rev_available(v, self, new_gitpin, old_revision)
567a6783 155 check_channel_contents(v, self, table, new_gitpin)
0a4ff8dd 156 return ChannelPin(
b17278e3 157 release_name=new_gitpin.release_name,
567a6783
SW
158 tarball_url=table['nixexprs.tar.xz'].absolute_url,
159 tarball_sha256=table['nixexprs.tar.xz'].digest,
3258ff2c 160 git_revision=new_gitpin.git_revision)
4a82be40 161
b3ea39b8 162 # pylint: disable=no-self-use
567a6783
SW
163 def fetch(self, v: Verification, pin: Pin) -> str:
164 assert isinstance(pin, ChannelPin)
b3ea39b8
SW
165
166 return fetch_with_nix_prefetch_url(
567a6783
SW
167 v, pin.tarball_url, Digest16(pin.tarball_sha256))
168
169
170SearchPath = Union[AliasSearchPath, GitSearchPath, ChannelSearchPath]
171TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
b3ea39b8 172
4a82be40 173
dc038df0 174def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
2f96f32a
SW
175
176 def throw(error: OSError) -> None:
177 raise error
178
179 def join(x: str, y: str) -> str:
180 return y if x == '.' else os.path.join(x, y)
181
182 def recursive_files(d: str) -> Iterable[str]:
183 all_files: List[str] = []
184 for path, dirs, files in os.walk(d, onerror=throw):
185 rel = os.path.relpath(path, start=d)
186 all_files.extend(join(rel, f) for f in files)
187 for dir_or_link in dirs:
188 if os.path.islink(join(path, dir_or_link)):
189 all_files.append(join(rel, dir_or_link))
190 return all_files
191
192 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
193 return (f for f in files if not f.startswith('.git/'))
194
195 files = functools.reduce(
196 operator.or_, (set(
197 exclude_dot_git(
198 recursive_files(x))) for x in [a, b]))
199 return filecmp.cmpfiles(a, b, files, shallow=False)
200
201
567a6783
SW
202def fetch_channel(
203 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
2f96f32a 204 v.status('Fetching channel')
b17def3f 205 request = urllib.request.urlopen(channel.channel_url, timeout=10)
567a6783
SW
206 channel_html = request.read()
207 forwarded_url = request.geturl()
7c4de64c 208 v.result(request.status == 200) # type: ignore # (for old mypy)
567a6783
SW
209 v.check('Got forwarded', channel.channel_url != forwarded_url)
210 return channel_html, forwarded_url
2f96f32a
SW
211
212
567a6783
SW
213def parse_channel(v: Verification, channel_html: str) \
214 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
2f96f32a 215 v.status('Parsing channel description as XML')
567a6783 216 d = xml.dom.minidom.parseString(channel_html)
2f96f32a
SW
217 v.ok()
218
3e6421c4
SW
219 v.status('Extracting release name:')
220 title_name = d.getElementsByTagName(
221 'title')[0].firstChild.nodeValue.split()[2]
222 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
223 v.status(title_name)
224 v.result(title_name == h1_name)
3e6421c4
SW
225
226 v.status('Extracting git commit:')
2f96f32a 227 git_commit_node = d.getElementsByTagName('tt')[0]
3258ff2c
SW
228 git_revision = git_commit_node.firstChild.nodeValue
229 v.status(git_revision)
2f96f32a
SW
230 v.ok()
231 v.status('Verifying git commit label')
232 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
233
234 v.status('Parsing table')
567a6783 235 table: Dict[str, ChannelTableEntry] = {}
2f96f32a
SW
236 for row in d.getElementsByTagName('tr')[1:]:
237 name = row.childNodes[0].firstChild.firstChild.nodeValue
238 url = row.childNodes[0].firstChild.getAttribute('href')
239 size = int(row.childNodes[1].firstChild.nodeValue)
73bec7e8 240 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
567a6783 241 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
2f96f32a 242 v.ok()
567a6783 243 return table, GitPin(release_name=title_name, git_revision=git_revision)
2f96f32a
SW
244
245
dc038df0
SW
246def digest_string(s: bytes) -> Digest16:
247 return Digest16(hashlib.sha256(s).hexdigest())
248
249
73bec7e8
SW
250def digest_file(filename: str) -> Digest16:
251 hasher = hashlib.sha256()
252 with open(filename, 'rb') as f:
253 # pylint: disable=cell-var-from-loop
254 for block in iter(lambda: f.read(4096), b''):
255 hasher.update(block)
256 return Digest16(hasher.hexdigest())
257
258
259def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
260 v.status('Converting digest to base16')
261 process = subprocess.run(
ba596fc0 262 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
73bec7e8
SW
263 v.result(process.returncode == 0)
264 return Digest16(process.stdout.decode().strip())
265
266
267def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
268 v.status('Converting digest to base32')
269 process = subprocess.run(
ba596fc0 270 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
73bec7e8
SW
271 v.result(process.returncode == 0)
272 return Digest32(process.stdout.decode().strip())
273
274
275def fetch_with_nix_prefetch_url(
276 v: Verification,
277 url: str,
278 digest: Digest16) -> str:
279 v.status('Fetching %s' % url)
280 process = subprocess.run(
ba596fc0 281 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
73bec7e8
SW
282 v.result(process.returncode == 0)
283 prefetch_digest, path, empty = process.stdout.decode().split('\n')
284 assert empty == ''
285 v.check("Verifying nix-prefetch-url's digest",
286 to_Digest16(v, Digest32(prefetch_digest)) == digest)
287 v.status("Verifying file digest")
288 file_digest = digest_file(path)
289 v.result(file_digest == digest)
7c4de64c 290 return path # type: ignore # (for old mypy)
2f96f32a 291
73bec7e8 292
9343cf48
SW
293def fetch_resources(
294 v: Verification,
567a6783
SW
295 pin: GitPin,
296 forwarded_url: str,
297 table: Dict[str, ChannelTableEntry]) -> None:
2f96f32a 298 for resource in ['git-revision', 'nixexprs.tar.xz']:
567a6783
SW
299 fields = table[resource]
300 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
e434d96d
SW
301 fields.file = fetch_with_nix_prefetch_url(
302 v, fields.absolute_url, fields.digest)
73bec7e8 303 v.status('Verifying git commit on main page matches git commit in table')
567a6783 304 v.result(open(table['git-revision'].file).read(999) == pin.git_revision)
2f96f32a 305
971d3659 306
9836141c 307def git_cachedir(git_repo: str) -> str:
26125a28
SW
308 return os.path.join(
309 xdg.XDG_CACHE_HOME,
310 'pinch/git',
eb0c6f1b
SW
311 digest_string(git_repo.encode()))
312
313
b17278e3 314def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
eb0c6f1b
SW
315 return os.path.join(
316 xdg.XDG_CACHE_HOME,
317 'pinch/git-tarball',
318 '%s-%s-%s' %
319 (digest_string(channel.git_repo.encode()),
b17278e3
SW
320 pin.git_revision,
321 pin.release_name))
971d3659
SW
322
323
d7cfdb22
SW
324def verify_git_ancestry(
325 v: Verification,
326 channel: TarrableSearchPath,
327 new_revision: str,
328 old_revision: Optional[str]) -> None:
971d3659
SW
329 cachedir = git_cachedir(channel.git_repo)
330 v.status('Verifying rev is an ancestor of ref')
331 process = subprocess.run(['git',
332 '-C',
333 cachedir,
334 'merge-base',
335 '--is-ancestor',
d7cfdb22 336 new_revision,
971d3659
SW
337 channel.git_ref])
338 v.result(process.returncode == 0)
339
d7cfdb22 340 if old_revision is not None:
971d3659
SW
341 v.status(
342 'Verifying rev is an ancestor of previous rev %s' %
d7cfdb22 343 old_revision)
971d3659
SW
344 process = subprocess.run(['git',
345 '-C',
346 cachedir,
347 'merge-base',
348 '--is-ancestor',
d7cfdb22
SW
349 old_revision,
350 new_revision])
971d3659 351 v.result(process.returncode == 0)
9836141c 352
2f96f32a 353
d7cfdb22
SW
354def git_fetch(
355 v: Verification,
356 channel: TarrableSearchPath,
357 desired_revision: Optional[str],
358 old_revision: Optional[str]) -> str:
dc038df0
SW
359 # It would be nice if we could share the nix git cache, but as of the time
360 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
361 # yet), and trying to straddle them both is too far into nix implementation
362 # details for my comfort. So we re-implement here half of nix.fetchGit.
363 # :(
364
9836141c
SW
365 cachedir = git_cachedir(channel.git_repo)
366 if not os.path.exists(cachedir):
dc038df0
SW
367 v.status("Initializing git repo")
368 process = subprocess.run(
9836141c 369 ['git', 'init', '--bare', cachedir])
dc038df0
SW
370 v.result(process.returncode == 0)
371
971d3659
SW
372 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
373 # We don't use --force here because we want to abort and freak out if forced
374 # updates are happening.
375 process = subprocess.run(['git',
376 '-C',
377 cachedir,
378 'fetch',
379 channel.git_repo,
380 '%s:%s' % (channel.git_ref,
381 channel.git_ref)])
382 v.result(process.returncode == 0)
383
d7cfdb22 384 if desired_revision is not None:
971d3659 385 v.status('Verifying that fetch retrieved this rev')
8fca6c28 386 process = subprocess.run(
d7cfdb22 387 ['git', '-C', cachedir, 'cat-file', '-e', desired_revision])
dc038df0 388 v.result(process.returncode == 0)
dc038df0 389
d7cfdb22
SW
390 new_revision = open(
391 os.path.join(
392 cachedir,
393 'refs',
394 'heads',
395 channel.git_ref)).read(999).strip()
396
397 verify_git_ancestry(v, channel, new_revision, old_revision)
398
399 return new_revision
dc038df0 400
971d3659 401
7d2c278f
SW
402def ensure_git_rev_available(
403 v: Verification,
9343cf48 404 channel: TarrableSearchPath,
d7cfdb22
SW
405 pin: GitPin,
406 old_revision: Optional[str]) -> None:
971d3659
SW
407 cachedir = git_cachedir(channel.git_repo)
408 if os.path.exists(cachedir):
409 v.status('Checking if we already have this rev:')
410 process = subprocess.run(
9343cf48 411 ['git', '-C', cachedir, 'cat-file', '-e', pin.git_revision])
971d3659
SW
412 if process.returncode == 0:
413 v.status('yes')
414 if process.returncode == 1:
415 v.status('no')
416 v.result(process.returncode == 0 or process.returncode == 1)
417 if process.returncode == 0:
d7cfdb22 418 verify_git_ancestry(v, channel, pin.git_revision, old_revision)
971d3659 419 return
d7cfdb22 420 git_fetch(v, channel, pin.git_revision, old_revision)
7d889b12 421
dc038df0 422
925c801b
SW
423def compare_tarball_and_git(
424 v: Verification,
a72fdca9 425 pin: GitPin,
925c801b
SW
426 channel_contents: str,
427 git_contents: str) -> None:
428 v.status('Comparing channel tarball with git checkout')
429 match, mismatch, errors = compare(os.path.join(
a72fdca9 430 channel_contents, pin.release_name), git_contents)
925c801b
SW
431 v.ok()
432 v.check('%d files match' % len(match), len(match) > 0)
433 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
434 expected_errors = [
435 '.git-revision',
436 '.version-suffix',
437 'nixpkgs',
438 'programs.sqlite',
439 'svn-revision']
440 benign_errors = []
441 for ee in expected_errors:
442 if ee in errors:
443 errors.remove(ee)
444 benign_errors.append(ee)
445 v.check(
446 '%d unexpected incomparable files' %
447 len(errors),
448 len(errors) == 0)
449 v.check(
450 '(%d of %d expected incomparable files)' %
451 (len(benign_errors),
452 len(expected_errors)),
453 len(benign_errors) == len(expected_errors))
454
455
7d2c278f
SW
456def extract_tarball(
457 v: Verification,
567a6783 458 table: Dict[str, ChannelTableEntry],
7d2c278f 459 dest: str) -> None:
567a6783
SW
460 v.status('Extracting tarball %s' % table['nixexprs.tar.xz'].file)
461 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
925c801b
SW
462 v.ok()
463
464
7d2c278f
SW
465def git_checkout(
466 v: Verification,
467 channel: TarrableSearchPath,
3258ff2c 468 pin: GitPin,
7d2c278f 469 dest: str) -> None:
925c801b
SW
470 v.status('Checking out corresponding git revision')
471 git = subprocess.Popen(['git',
472 '-C',
9836141c 473 git_cachedir(channel.git_repo),
925c801b 474 'archive',
3258ff2c 475 pin.git_revision],
925c801b
SW
476 stdout=subprocess.PIPE)
477 tar = subprocess.Popen(
478 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
de68382a
SW
479 if git.stdout:
480 git.stdout.close()
925c801b
SW
481 tar.wait()
482 git.wait()
483 v.result(git.returncode == 0 and tar.returncode == 0)
484
485
9343cf48
SW
486def git_get_tarball(
487 v: Verification,
488 channel: TarrableSearchPath,
489 pin: GitPin) -> str:
b17278e3 490 cache_file = tarball_cache_file(channel, pin)
eb0c6f1b
SW
491 if os.path.exists(cache_file):
492 cached_tarball = open(cache_file).read(9999)
493 if os.path.exists(cached_tarball):
494 return cached_tarball
495
736c25eb
SW
496 with tempfile.TemporaryDirectory() as output_dir:
497 output_filename = os.path.join(
9343cf48 498 output_dir, pin.release_name + '.tar.xz')
736c25eb
SW
499 with open(output_filename, 'w') as output_file:
500 v.status(
501 'Generating tarball for git revision %s' %
9343cf48 502 pin.git_revision)
736c25eb
SW
503 git = subprocess.Popen(['git',
504 '-C',
505 git_cachedir(channel.git_repo),
506 'archive',
9343cf48
SW
507 '--prefix=%s/' % pin.release_name,
508 pin.git_revision],
736c25eb
SW
509 stdout=subprocess.PIPE)
510 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
511 xz.wait()
512 git.wait()
513 v.result(git.returncode == 0 and xz.returncode == 0)
514
515 v.status('Putting tarball in Nix store')
516 process = subprocess.run(
ba596fc0 517 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
736c25eb 518 v.result(process.returncode == 0)
eb0c6f1b
SW
519 store_tarball = process.stdout.decode().strip()
520
521 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
522 open(cache_file, 'w').write(store_tarball)
7c4de64c 523 return store_tarball # type: ignore # (for old mypy)
736c25eb
SW
524
525
f9cd7bdc
SW
526def check_channel_metadata(
527 v: Verification,
a72fdca9 528 pin: GitPin,
f9cd7bdc
SW
529 channel_contents: str) -> None:
530 v.status('Verifying git commit in channel tarball')
531 v.result(
532 open(
533 os.path.join(
534 channel_contents,
a72fdca9
SW
535 pin.release_name,
536 '.git-revision')).read(999) == pin.git_revision)
f9cd7bdc
SW
537
538 v.status(
539 'Verifying version-suffix is a suffix of release name %s:' %
a72fdca9 540 pin.release_name)
f9cd7bdc
SW
541 version_suffix = open(
542 os.path.join(
543 channel_contents,
a72fdca9 544 pin.release_name,
f9cd7bdc
SW
545 '.version-suffix')).read(999)
546 v.status(version_suffix)
a72fdca9 547 v.result(pin.release_name.endswith(version_suffix))
f9cd7bdc
SW
548
549
7d2c278f
SW
550def check_channel_contents(
551 v: Verification,
a72fdca9 552 channel: TarrableSearchPath,
567a6783 553 table: Dict[str, ChannelTableEntry],
a72fdca9 554 pin: GitPin) -> None:
dc038df0
SW
555 with tempfile.TemporaryDirectory() as channel_contents, \
556 tempfile.TemporaryDirectory() as git_contents:
925c801b 557
567a6783 558 extract_tarball(v, table, channel_contents)
a72fdca9 559 check_channel_metadata(v, pin, channel_contents)
f9cd7bdc 560
3258ff2c 561 git_checkout(v, channel, pin, git_contents)
925c801b 562
a72fdca9 563 compare_tarball_and_git(v, pin, channel_contents, git_contents)
925c801b 564
dc038df0 565 v.status('Removing temporary directories')
2f96f32a
SW
566 v.ok()
567
568
d7cfdb22
SW
569def git_revision_name(
570 v: Verification,
571 channel: TarrableSearchPath,
572 git_revision: str) -> str:
e3cae769
SW
573 v.status('Getting commit date')
574 process = subprocess.run(['git',
575 '-C',
9836141c 576 git_cachedir(channel.git_repo),
bed32182 577 'log',
e3cae769
SW
578 '-n1',
579 '--format=%ct-%h',
580 '--abbrev=11',
88af5903 581 '--no-show-signature',
d7cfdb22 582 git_revision],
ba596fc0 583 stdout=subprocess.PIPE)
de68382a 584 v.result(process.returncode == 0 and process.stdout != b'')
e3cae769
SW
585 return '%s-%s' % (os.path.basename(channel.git_repo),
586 process.stdout.decode().strip())
587
588
567a6783
SW
589K = TypeVar('K')
590V = TypeVar('V')
591
592
593def filter_dict(d: Dict[K, V], fields: Set[K]
594 ) -> Tuple[Dict[K, V], Dict[K, V]]:
595 selected: Dict[K, V] = {}
596 remaining: Dict[K, V] = {}
597 for k, v in d.items():
598 if k in fields:
599 selected[k] = v
600 else:
601 remaining[k] = v
602 return selected, remaining
603
604
605def read_search_path(
606 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
607 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
608 'alias': (AliasSearchPath, AliasPin),
609 'channel': (ChannelSearchPath, ChannelPin),
610 'git': (GitSearchPath, GitPin),
7f4c3ace 611 }
567a6783
SW
612 SP, P = mapping[conf['type']]
613 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
614 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
615 # Error suppression works around https://github.com/python/mypy/issues/9007
616 pin_present = pin_fields != {} or P._fields == ()
617 pin = P(**pin_fields) if pin_present else None # type:ignore[call-arg]
618 return SP(**remaining_fields), pin
f8f5b125
SW
619
620
01ba0eb2
SW
621def read_config(filename: str) -> configparser.ConfigParser:
622 config = configparser.ConfigParser()
623 config.read_file(open(filename), filename)
624 return config
625
626
4603b1a7
SW
627def read_config_files(
628 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
629 merged_config: Dict[str, configparser.SectionProxy] = {}
630 for file in filenames:
631 config = read_config(file)
632 for section in config.sections():
633 if section in merged_config:
634 raise Exception('Duplicate channel "%s"' % section)
635 merged_config[section] = config[section]
636 return merged_config
637
638
41b87c9c 639def pinCommand(args: argparse.Namespace) -> None:
2f96f32a 640 v = Verification()
01ba0eb2 641 config = read_config(args.channels_file)
5cfa8e11 642 for section in config.sections():
98853153
SW
643 if args.channels and section not in args.channels:
644 continue
736c25eb 645
567a6783 646 sp, old_pin = read_search_path(config[section])
17906b27 647
567a6783 648 config[section].update(sp.pin(v, old_pin)._asdict())
8fca6c28 649
0e5e611d 650 with open(args.channels_file, 'w') as configfile:
e434d96d 651 config.write(configfile)
2f96f32a
SW
652
653
41b87c9c 654def updateCommand(args: argparse.Namespace) -> None:
736c25eb 655 v = Verification()
da135b07 656 exprs: Dict[str, str] = {}
4603b1a7
SW
657 config = read_config_files(args.channels_file)
658 for section in config:
567a6783
SW
659 sp, pin = read_search_path(config[section])
660 if pin is None:
661 raise Exception(
662 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
663 section)
b896966b 664 if isinstance(sp, AliasSearchPath):
4603b1a7 665 continue
567a6783 666 tarball = sp.fetch(v, pin)
4603b1a7
SW
667 exprs[section] = (
668 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
669 (config[section]['release_name'], tarball))
670
671 for section in config:
672 if 'alias_of' in config[section]:
673 exprs[section] = exprs[str(config[section]['alias_of'])]
17906b27 674
9a78329e
SW
675 command = [
676 'nix-env',
677 '--profile',
678 '/nix/var/nix/profiles/per-user/%s/channels' %
679 getpass.getuser(),
680 '--show-trace',
681 '--file',
682 '<nix/unpack-channel.nix>',
683 '--install',
17906b27 684 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
9a78329e
SW
685 if args.dry_run:
686 print(' '.join(map(shlex.quote, command)))
687 else:
688 v.status('Installing channels with nix-env')
689 process = subprocess.run(command)
690 v.result(process.returncode == 0)
736c25eb
SW
691
692
0e5e611d
SW
693def main() -> None:
694 parser = argparse.ArgumentParser(prog='pinch')
695 subparsers = parser.add_subparsers(dest='mode', required=True)
696 parser_pin = subparsers.add_parser('pin')
697 parser_pin.add_argument('channels_file', type=str)
98853153 698 parser_pin.add_argument('channels', type=str, nargs='*')
41b87c9c 699 parser_pin.set_defaults(func=pinCommand)
736c25eb 700 parser_update = subparsers.add_parser('update')
9a78329e 701 parser_update.add_argument('--dry-run', action='store_true')
01ba0eb2 702 parser_update.add_argument('channels_file', type=str, nargs='+')
41b87c9c 703 parser_update.set_defaults(func=updateCommand)
0e5e611d
SW
704 args = parser.parse_args()
705 args.func(args)
706
707
b5964ec3
SW
708if __name__ == '__main__':
709 main()