]> git.scottworley.com Git - pinch/blob - pinch.py
Don't require experimental-features=nix-command in system nix.conf
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tarfile
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Callable,
23 Dict,
24 Iterable,
25 List,
26 Mapping,
27 NamedTuple,
28 NewType,
29 Optional,
30 Set,
31 Tuple,
32 Type,
33 TypeVar,
34 Union,
35 )
36
37 import git_cache
38
39 # Use xdg module when it's less painful to have as a dependency
40
41
42 class XDG(NamedTuple):
43 XDG_CACHE_HOME: str
44
45
46 xdg = XDG(
47 XDG_CACHE_HOME=os.getenv(
48 'XDG_CACHE_HOME',
49 os.path.expanduser('~/.cache')))
50
51
52 class VerificationError(Exception):
53 pass
54
55
56 class Verification:
57
58 def __init__(self) -> None:
59 self.line_length = 0
60
61 def status(self, s: str) -> None:
62 print(s, end=' ', file=sys.stderr, flush=True)
63 self.line_length += 1 + len(s) # Unicode??
64
65 @staticmethod
66 def _color(s: str, c: int) -> str:
67 return '\033[%2dm%s\033[00m' % (c, s)
68
69 def result(self, r: bool) -> None:
70 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
71 length = len(message)
72 cols = shutil.get_terminal_size().columns or 80
73 pad = (cols - (self.line_length + length)) % cols
74 print(' ' * pad + self._color(message, color), file=sys.stderr)
75 self.line_length = 0
76 if not r:
77 raise VerificationError()
78
79 def check(self, s: str, r: bool) -> None:
80 self.status(s)
81 self.result(r)
82
83 def ok(self) -> None:
84 self.result(True)
85
86
87 Digest16 = NewType('Digest16', str)
88 Digest32 = NewType('Digest32', str)
89
90
91 class ChannelTableEntry(types.SimpleNamespace):
92 absolute_url: str
93 digest: Digest16
94 file: str
95 size: int
96 url: str
97
98
99 class AliasPin(NamedTuple):
100 pass
101
102
103 class SymlinkPin(NamedTuple):
104 @property
105 def release_name(self) -> str:
106 return 'link'
107
108
109 class GitPin(NamedTuple):
110 git_revision: str
111 release_name: str
112
113
114 class ChannelPin(NamedTuple):
115 git_revision: str
116 release_name: str
117 tarball_url: str
118 tarball_sha256: str
119
120
121 Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
122
123
124 def copy_to_nix_store(v: Verification, filename: str) -> str:
125 v.status('Putting tarball in Nix store')
126 process = subprocess.run(
127 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
128 v.result(process.returncode == 0)
129 return process.stdout.decode().strip() # type: ignore # (for old mypy)
130
131
132 def symlink_archive(v: Verification, path: str) -> str:
133 with tempfile.TemporaryDirectory() as td:
134 archive_filename = os.path.join(td, 'link.tar.gz')
135 os.symlink(path, os.path.join(td, 'link'))
136 with tarfile.open(archive_filename, mode='x:gz') as t:
137 t.add(os.path.join(td, 'link'), arcname='link')
138 return copy_to_nix_store(v, archive_filename)
139
140
141 class AliasSearchPath(NamedTuple):
142 alias_of: str
143
144 # pylint: disable=no-self-use
145 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
146 return AliasPin()
147
148
149 class SymlinkSearchPath(NamedTuple):
150 path: str
151
152 # pylint: disable=no-self-use
153 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
154 return SymlinkPin()
155
156 def fetch(self, v: Verification, _: Pin) -> str:
157 return symlink_archive(v, self.path)
158
159
160 class GitSearchPath(NamedTuple):
161 git_ref: str
162 git_repo: str
163
164 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
165 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
166 if old_pin is not None:
167 assert isinstance(old_pin, GitPin)
168 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
169 return GitPin(release_name=git_revision_name(v, self, new_revision),
170 git_revision=new_revision)
171
172 def fetch(self, v: Verification, pin: Pin) -> str:
173 assert isinstance(pin, GitPin)
174 git_cache.ensure_rev_available(
175 self.git_repo, self.git_ref, pin.git_revision)
176 return git_get_tarball(v, self, pin)
177
178
179 class ChannelSearchPath(NamedTuple):
180 channel_url: str
181 git_ref: str
182 git_repo: str
183
184 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
185 if old_pin is not None:
186 assert isinstance(old_pin, ChannelPin)
187
188 channel_html, forwarded_url = fetch_channel(v, self)
189 table, new_gitpin = parse_channel(v, channel_html)
190 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
191 return old_pin
192 fetch_resources(v, new_gitpin, forwarded_url, table)
193 git_cache.ensure_rev_available(
194 self.git_repo, self.git_ref, new_gitpin.git_revision)
195 if old_pin is not None:
196 verify_git_ancestry(
197 v, self, old_pin.git_revision, new_gitpin.git_revision)
198 check_channel_contents(v, self, table, new_gitpin)
199 return ChannelPin(
200 release_name=new_gitpin.release_name,
201 tarball_url=table['nixexprs.tar.xz'].absolute_url,
202 tarball_sha256=table['nixexprs.tar.xz'].digest,
203 git_revision=new_gitpin.git_revision)
204
205 # pylint: disable=no-self-use
206 def fetch(self, v: Verification, pin: Pin) -> str:
207 assert isinstance(pin, ChannelPin)
208
209 return fetch_with_nix_prefetch_url(
210 v, pin.tarball_url, Digest16(pin.tarball_sha256))
211
212
213 SearchPath = Union[AliasSearchPath,
214 SymlinkSearchPath,
215 GitSearchPath,
216 ChannelSearchPath]
217 TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
218
219
220 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
221
222 def throw(error: OSError) -> None:
223 raise error
224
225 def join(x: str, y: str) -> str:
226 return y if x == '.' else os.path.join(x, y)
227
228 def recursive_files(d: str) -> Iterable[str]:
229 all_files: List[str] = []
230 for path, dirs, files in os.walk(d, onerror=throw):
231 rel = os.path.relpath(path, start=d)
232 all_files.extend(join(rel, f) for f in files)
233 for dir_or_link in dirs:
234 if os.path.islink(join(path, dir_or_link)):
235 all_files.append(join(rel, dir_or_link))
236 return all_files
237
238 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
239 return (f for f in files if not f.startswith('.git/'))
240
241 files = functools.reduce(
242 operator.or_, (set(
243 exclude_dot_git(
244 recursive_files(x))) for x in [a, b]))
245 return filecmp.cmpfiles(a, b, files, shallow=False)
246
247
248 def fetch_channel(
249 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
250 v.status('Fetching channel')
251 request = urllib.request.urlopen(channel.channel_url, timeout=10)
252 channel_html = request.read().decode()
253 forwarded_url = request.geturl()
254 v.result(request.status == 200) # type: ignore # (for old mypy)
255 v.check('Got forwarded', channel.channel_url != forwarded_url)
256 return channel_html, forwarded_url
257
258
259 def parse_channel(v: Verification, channel_html: str) \
260 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
261 v.status('Parsing channel description as XML')
262 d = xml.dom.minidom.parseString(channel_html)
263 v.ok()
264
265 v.status('Extracting release name:')
266 title_name = d.getElementsByTagName(
267 'title')[0].firstChild.nodeValue.split()[2]
268 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
269 v.status(title_name)
270 v.result(title_name == h1_name)
271
272 v.status('Extracting git commit:')
273 git_commit_node = d.getElementsByTagName('tt')[0]
274 git_revision = git_commit_node.firstChild.nodeValue
275 v.status(git_revision)
276 v.ok()
277 v.status('Verifying git commit label')
278 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
279
280 v.status('Parsing table')
281 table: Dict[str, ChannelTableEntry] = {}
282 for row in d.getElementsByTagName('tr')[1:]:
283 name = row.childNodes[0].firstChild.firstChild.nodeValue
284 url = row.childNodes[0].firstChild.getAttribute('href')
285 size = int(row.childNodes[1].firstChild.nodeValue)
286 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
287 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
288 v.ok()
289 return table, GitPin(release_name=title_name, git_revision=git_revision)
290
291
292 def digest_string(s: bytes) -> Digest16:
293 return Digest16(hashlib.sha256(s).hexdigest())
294
295
296 def digest_file(filename: str) -> Digest16:
297 hasher = hashlib.sha256()
298 with open(filename, 'rb') as f:
299 # pylint: disable=cell-var-from-loop
300 for block in iter(lambda: f.read(4096), b''):
301 hasher.update(block)
302 return Digest16(hasher.hexdigest())
303
304
305 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
306 v.status('Converting digest to base16')
307 process = subprocess.run(['nix',
308 '--experimental-features',
309 'nix-command',
310 'to-base16',
311 '--type',
312 'sha256',
313 digest32],
314 stdout=subprocess.PIPE)
315 v.result(process.returncode == 0)
316 return Digest16(process.stdout.decode().strip())
317
318
319 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
320 v.status('Converting digest to base32')
321 process = subprocess.run(['nix',
322 '--experimental-features',
323 'nix-command',
324 'to-base32',
325 '--type',
326 'sha256',
327 digest16],
328 stdout=subprocess.PIPE)
329 v.result(process.returncode == 0)
330 return Digest32(process.stdout.decode().strip())
331
332
333 def fetch_with_nix_prefetch_url(
334 v: Verification,
335 url: str,
336 digest: Digest16) -> str:
337 v.status('Fetching %s' % url)
338 process = subprocess.run(
339 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
340 v.result(process.returncode == 0)
341 prefetch_digest, path, empty = process.stdout.decode().split('\n')
342 assert empty == ''
343 v.check("Verifying nix-prefetch-url's digest",
344 to_Digest16(v, Digest32(prefetch_digest)) == digest)
345 v.status("Verifying file digest")
346 file_digest = digest_file(path)
347 v.result(file_digest == digest)
348 return path # type: ignore # (for old mypy)
349
350
351 def fetch_resources(
352 v: Verification,
353 pin: GitPin,
354 forwarded_url: str,
355 table: Dict[str, ChannelTableEntry]) -> None:
356 for resource in ['git-revision', 'nixexprs.tar.xz']:
357 fields = table[resource]
358 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
359 fields.file = fetch_with_nix_prefetch_url(
360 v, fields.absolute_url, fields.digest)
361 v.status('Verifying git commit on main page matches git commit in table')
362 v.result(open(table['git-revision'].file).read(999) == pin.git_revision)
363
364
365 def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
366 return os.path.join(
367 xdg.XDG_CACHE_HOME,
368 'pinch/git-tarball',
369 '%s-%s-%s' %
370 (digest_string(channel.git_repo.encode()),
371 pin.git_revision,
372 pin.release_name))
373
374
375 def verify_git_ancestry(
376 v: Verification,
377 channel: TarrableSearchPath,
378 old_revision: str,
379 new_revision: str) -> None:
380 cachedir = git_cache.git_cachedir(channel.git_repo)
381 v.status('Verifying rev is an ancestor of previous rev %s' % old_revision)
382 process = subprocess.run(['git',
383 '-C',
384 cachedir,
385 'merge-base',
386 '--is-ancestor',
387 old_revision,
388 new_revision])
389 v.result(process.returncode == 0)
390
391
392 def compare_tarball_and_git(
393 v: Verification,
394 pin: GitPin,
395 channel_contents: str,
396 git_contents: str) -> None:
397 v.status('Comparing channel tarball with git checkout')
398 match, mismatch, errors = compare(os.path.join(
399 channel_contents, pin.release_name), git_contents)
400 v.ok()
401 v.check('%d files match' % len(match), len(match) > 0)
402 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
403 expected_errors = [
404 '.git-revision',
405 '.version-suffix',
406 'nixpkgs',
407 'programs.sqlite',
408 'svn-revision']
409 benign_errors = []
410 for ee in expected_errors:
411 if ee in errors:
412 errors.remove(ee)
413 benign_errors.append(ee)
414 v.check(
415 '%d unexpected incomparable files' %
416 len(errors),
417 len(errors) == 0)
418 v.check(
419 '(%d of %d expected incomparable files)' %
420 (len(benign_errors),
421 len(expected_errors)),
422 len(benign_errors) == len(expected_errors))
423
424
425 def extract_tarball(
426 v: Verification,
427 table: Dict[str, ChannelTableEntry],
428 dest: str) -> None:
429 v.status('Extracting tarball %s' % table['nixexprs.tar.xz'].file)
430 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
431 v.ok()
432
433
434 def git_checkout(
435 v: Verification,
436 channel: TarrableSearchPath,
437 pin: GitPin,
438 dest: str) -> None:
439 v.status('Checking out corresponding git revision')
440 git = subprocess.Popen(['git',
441 '-C',
442 git_cache.git_cachedir(channel.git_repo),
443 'archive',
444 pin.git_revision],
445 stdout=subprocess.PIPE)
446 tar = subprocess.Popen(
447 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
448 if git.stdout:
449 git.stdout.close()
450 tar.wait()
451 git.wait()
452 v.result(git.returncode == 0 and tar.returncode == 0)
453
454
455 def git_get_tarball(
456 v: Verification,
457 channel: TarrableSearchPath,
458 pin: GitPin) -> str:
459 cache_file = tarball_cache_file(channel, pin)
460 if os.path.exists(cache_file):
461 cached_tarball = open(cache_file).read(9999)
462 if os.path.exists(cached_tarball):
463 return cached_tarball
464
465 with tempfile.TemporaryDirectory() as output_dir:
466 output_filename = os.path.join(
467 output_dir, pin.release_name + '.tar.xz')
468 with open(output_filename, 'w') as output_file:
469 v.status(
470 'Generating tarball for git revision %s' %
471 pin.git_revision)
472 git = subprocess.Popen(['git',
473 '-C',
474 git_cache.git_cachedir(channel.git_repo),
475 'archive',
476 '--prefix=%s/' % pin.release_name,
477 pin.git_revision],
478 stdout=subprocess.PIPE)
479 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
480 xz.wait()
481 git.wait()
482 v.result(git.returncode == 0 and xz.returncode == 0)
483
484 store_tarball = copy_to_nix_store(v, output_filename)
485
486 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
487 open(cache_file, 'w').write(store_tarball)
488 return store_tarball # type: ignore # (for old mypy)
489
490
491 def check_channel_metadata(
492 v: Verification,
493 pin: GitPin,
494 channel_contents: str) -> None:
495 v.status('Verifying git commit in channel tarball')
496 v.result(
497 open(
498 os.path.join(
499 channel_contents,
500 pin.release_name,
501 '.git-revision')).read(999) == pin.git_revision)
502
503 v.status(
504 'Verifying version-suffix is a suffix of release name %s:' %
505 pin.release_name)
506 version_suffix = open(
507 os.path.join(
508 channel_contents,
509 pin.release_name,
510 '.version-suffix')).read(999)
511 v.status(version_suffix)
512 v.result(pin.release_name.endswith(version_suffix))
513
514
515 def check_channel_contents(
516 v: Verification,
517 channel: TarrableSearchPath,
518 table: Dict[str, ChannelTableEntry],
519 pin: GitPin) -> None:
520 with tempfile.TemporaryDirectory() as channel_contents, \
521 tempfile.TemporaryDirectory() as git_contents:
522
523 extract_tarball(v, table, channel_contents)
524 check_channel_metadata(v, pin, channel_contents)
525
526 git_checkout(v, channel, pin, git_contents)
527
528 compare_tarball_and_git(v, pin, channel_contents, git_contents)
529
530 v.status('Removing temporary directories')
531 v.ok()
532
533
534 def git_revision_name(
535 v: Verification,
536 channel: TarrableSearchPath,
537 git_revision: str) -> str:
538 v.status('Getting commit date')
539 process = subprocess.run(['git',
540 '-C',
541 git_cache.git_cachedir(channel.git_repo),
542 'log',
543 '-n1',
544 '--format=%ct-%h',
545 '--abbrev=11',
546 '--no-show-signature',
547 git_revision],
548 stdout=subprocess.PIPE)
549 v.result(process.returncode == 0 and process.stdout != b'')
550 return '%s-%s' % (os.path.basename(channel.git_repo),
551 process.stdout.decode().strip())
552
553
554 K = TypeVar('K')
555 V = TypeVar('V')
556
557
558 def partition_dict(pred: Callable[[K, V], bool],
559 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
560 selected: Dict[K, V] = {}
561 remaining: Dict[K, V] = {}
562 for k, v in d.items():
563 if pred(k, v):
564 selected[k] = v
565 else:
566 remaining[k] = v
567 return selected, remaining
568
569
570 def filter_dict(d: Dict[K, V], fields: Set[K]
571 ) -> Tuple[Dict[K, V], Dict[K, V]]:
572 return partition_dict(lambda k, v: k in fields, d)
573
574
575 def read_config_section(
576 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
577 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
578 'alias': (AliasSearchPath, AliasPin),
579 'channel': (ChannelSearchPath, ChannelPin),
580 'git': (GitSearchPath, GitPin),
581 'symlink': (SymlinkSearchPath, SymlinkPin),
582 }
583 SP, P = mapping[conf['type']]
584 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
585 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
586 # Error suppression works around https://github.com/python/mypy/issues/9007
587 pin_present = pin_fields != {} or P._fields == ()
588 pin = P(**pin_fields) if pin_present else None # type: ignore
589 return SP(**remaining_fields), pin
590
591
592 def read_pinned_config_section(
593 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
594 sp, pin = read_config_section(conf)
595 if pin is None:
596 raise Exception(
597 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
598 section)
599 return sp, pin
600
601
602 def read_config(filename: str) -> configparser.ConfigParser:
603 config = configparser.ConfigParser()
604 config.read_file(open(filename), filename)
605 return config
606
607
608 def read_config_files(
609 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
610 merged_config: Dict[str, configparser.SectionProxy] = {}
611 for file in filenames:
612 config = read_config(file)
613 for section in config.sections():
614 if section in merged_config:
615 raise Exception('Duplicate channel "%s"' % section)
616 merged_config[section] = config[section]
617 return merged_config
618
619
620 def pinCommand(args: argparse.Namespace) -> None:
621 v = Verification()
622 config = read_config(args.channels_file)
623 for section in config.sections():
624 if args.channels and section not in args.channels:
625 continue
626
627 sp, old_pin = read_config_section(config[section])
628
629 config[section].update(sp.pin(v, old_pin)._asdict())
630
631 with open(args.channels_file, 'w') as configfile:
632 config.write(configfile)
633
634
635 def updateCommand(args: argparse.Namespace) -> None:
636 v = Verification()
637 exprs: Dict[str, str] = {}
638 config = {
639 section: read_pinned_config_section(section, conf) for section,
640 conf in read_config_files(
641 args.channels_file).items()}
642 alias, nonalias = partition_dict(
643 lambda k, v: isinstance(v[0], AliasSearchPath), config)
644
645 for section, (sp, pin) in nonalias.items():
646 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
647 assert not isinstance(pin, AliasPin) # partition_dict()
648 tarball = sp.fetch(v, pin)
649 exprs[section] = (
650 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
651 (pin.release_name, tarball))
652
653 for section, (sp, pin) in alias.items():
654 assert isinstance(sp, AliasSearchPath) # For mypy
655 exprs[section] = exprs[sp.alias_of]
656
657 command = [
658 'nix-env',
659 '--profile',
660 args.profile,
661 '--show-trace',
662 '--file',
663 '<nix/unpack-channel.nix>',
664 '--install',
665 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
666 if args.dry_run:
667 print(' '.join(map(shlex.quote, command)))
668 else:
669 v.status('Installing channels with nix-env')
670 process = subprocess.run(command)
671 v.result(process.returncode == 0)
672
673
674 def main() -> None:
675 parser = argparse.ArgumentParser(prog='pinch')
676 subparsers = parser.add_subparsers(dest='mode', required=True)
677 parser_pin = subparsers.add_parser('pin')
678 parser_pin.add_argument('channels_file', type=str)
679 parser_pin.add_argument('channels', type=str, nargs='*')
680 parser_pin.set_defaults(func=pinCommand)
681 parser_update = subparsers.add_parser('update')
682 parser_update.add_argument('--dry-run', action='store_true')
683 parser_update.add_argument('--profile', default=(
684 '/nix/var/nix/profiles/per-user/%s/channels' % getpass.getuser()))
685 parser_update.add_argument('channels_file', type=str, nargs='+')
686 parser_update.set_defaults(func=updateCommand)
687 args = parser.parse_args()
688 args.func(args)
689
690
691 if __name__ == '__main__':
692 main()