]> git.scottworley.com Git - pinch/blob - pinch.py
9555d74567f7b27ea76cb700beb87a31ed1db0b8
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tarfile
15 import tempfile
16 import types
17 import urllib.parse
18 import urllib.request
19 import xml.dom.minidom
20
21 from typing import (
22 Callable,
23 Dict,
24 Iterable,
25 List,
26 Mapping,
27 NamedTuple,
28 NewType,
29 Optional,
30 Set,
31 Tuple,
32 Type,
33 TypeVar,
34 Union,
35 )
36
37 import git_cache
38
39 # Use xdg module when it's less painful to have as a dependency
40
41
42 class XDG(NamedTuple):
43 XDG_CACHE_HOME: str
44
45
46 xdg = XDG(
47 XDG_CACHE_HOME=os.getenv(
48 'XDG_CACHE_HOME',
49 os.path.expanduser('~/.cache')))
50
51
52 class VerificationError(Exception):
53 pass
54
55
56 class Verification:
57
58 def __init__(self) -> None:
59 self.line_length = 0
60
61 def status(self, s: str) -> None:
62 print(s, end=' ', file=sys.stderr, flush=True)
63 self.line_length += 1 + len(s) # Unicode??
64
65 @staticmethod
66 def _color(s: str, c: int) -> str:
67 return '\033[%2dm%s\033[00m' % (c, s)
68
69 def result(self, r: bool) -> None:
70 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
71 length = len(message)
72 cols = shutil.get_terminal_size().columns or 80
73 pad = (cols - (self.line_length + length)) % cols
74 print(' ' * pad + self._color(message, color), file=sys.stderr)
75 self.line_length = 0
76 if not r:
77 raise VerificationError()
78
79 def check(self, s: str, r: bool) -> None:
80 self.status(s)
81 self.result(r)
82
83 def ok(self) -> None:
84 self.result(True)
85
86
87 Digest16 = NewType('Digest16', str)
88 Digest32 = NewType('Digest32', str)
89
90
91 class ChannelTableEntry(types.SimpleNamespace):
92 absolute_url: str
93 digest: Digest16
94 file: str
95 size: int
96 url: str
97
98
99 class AliasPin(NamedTuple):
100 pass
101
102
103 class SymlinkPin(NamedTuple):
104 @property
105 def release_name(self) -> str:
106 return 'link'
107
108
109 class GitPin(NamedTuple):
110 git_revision: str
111 release_name: str
112
113
114 class ChannelPin(NamedTuple):
115 git_revision: str
116 release_name: str
117 tarball_url: str
118 tarball_sha256: str
119
120
121 Pin = Union[AliasPin, SymlinkPin, GitPin, ChannelPin]
122
123
124 def copy_to_nix_store(v: Verification, filename: str) -> str:
125 v.status('Putting tarball in Nix store')
126 process = subprocess.run(
127 ['nix-store', '--add', filename], stdout=subprocess.PIPE)
128 v.result(process.returncode == 0)
129 return process.stdout.decode().strip() # type: ignore # (for old mypy)
130
131
132 def symlink_archive(v: Verification, path: str) -> str:
133 with tempfile.TemporaryDirectory() as td:
134 archive_filename = os.path.join(td, 'link.tar.gz')
135 os.symlink(path, os.path.join(td, 'link'))
136 with tarfile.open(archive_filename, mode='x:gz') as t:
137 t.add(os.path.join(td, 'link'), arcname='link')
138 return copy_to_nix_store(v, archive_filename)
139
140
141 class AliasSearchPath(NamedTuple):
142 alias_of: str
143
144 # pylint: disable=no-self-use
145 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
146 return AliasPin()
147
148
149 class SymlinkSearchPath(NamedTuple):
150 path: str
151
152 # pylint: disable=no-self-use
153 def pin(self, _: Verification, __: Optional[Pin]) -> SymlinkPin:
154 return SymlinkPin()
155
156 def fetch(self, v: Verification, _: Pin) -> str:
157 return symlink_archive(v, self.path)
158
159
160 class GitSearchPath(NamedTuple):
161 git_ref: str
162 git_repo: str
163
164 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
165 _, new_revision = git_cache.fetch(self.git_repo, self.git_ref)
166 if old_pin is not None:
167 assert isinstance(old_pin, GitPin)
168 verify_git_ancestry(v, self, old_pin.git_revision, new_revision)
169 return GitPin(release_name=git_revision_name(v, self, new_revision),
170 git_revision=new_revision)
171
172 def fetch(self, v: Verification, pin: Pin) -> str:
173 assert isinstance(pin, GitPin)
174 git_cache.ensure_rev_available(
175 self.git_repo, self.git_ref, pin.git_revision)
176 return git_get_tarball(v, self, pin)
177
178
179 class ChannelSearchPath(NamedTuple):
180 channel_url: str
181 git_ref: str
182 git_repo: str
183
184 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
185 if old_pin is not None:
186 assert isinstance(old_pin, ChannelPin)
187
188 channel_html, forwarded_url = fetch_channel(v, self)
189 table, new_gitpin = parse_channel(v, channel_html)
190 if old_pin is not None and old_pin.git_revision == new_gitpin.git_revision:
191 return old_pin
192 fetch_resources(v, new_gitpin, forwarded_url, table)
193 git_cache.ensure_rev_available(
194 self.git_repo, self.git_ref, new_gitpin.git_revision)
195 if old_pin is not None:
196 verify_git_ancestry(
197 v, self, old_pin.git_revision, new_gitpin.git_revision)
198 check_channel_contents(v, self, table, new_gitpin)
199 return ChannelPin(
200 release_name=new_gitpin.release_name,
201 tarball_url=table['nixexprs.tar.xz'].absolute_url,
202 tarball_sha256=table['nixexprs.tar.xz'].digest,
203 git_revision=new_gitpin.git_revision)
204
205 # pylint: disable=no-self-use
206 def fetch(self, v: Verification, pin: Pin) -> str:
207 assert isinstance(pin, ChannelPin)
208
209 return fetch_with_nix_prefetch_url(
210 v, pin.tarball_url, Digest16(pin.tarball_sha256))
211
212
213 SearchPath = Union[AliasSearchPath,
214 SymlinkSearchPath,
215 GitSearchPath,
216 ChannelSearchPath]
217 TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
218
219
220 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
221
222 def throw(error: OSError) -> None:
223 raise error
224
225 def join(x: str, y: str) -> str:
226 return y if x == '.' else os.path.join(x, y)
227
228 def recursive_files(d: str) -> Iterable[str]:
229 all_files: List[str] = []
230 for path, dirs, files in os.walk(d, onerror=throw):
231 rel = os.path.relpath(path, start=d)
232 all_files.extend(join(rel, f) for f in files)
233 for dir_or_link in dirs:
234 if os.path.islink(join(path, dir_or_link)):
235 all_files.append(join(rel, dir_or_link))
236 return all_files
237
238 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
239 return (f for f in files if not f.startswith('.git/'))
240
241 files = functools.reduce(
242 operator.or_, (set(
243 exclude_dot_git(
244 recursive_files(x))) for x in [a, b]))
245 return filecmp.cmpfiles(a, b, files, shallow=False)
246
247
248 def fetch_channel(
249 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
250 v.status('Fetching channel')
251 request = urllib.request.urlopen(channel.channel_url, timeout=10)
252 channel_html = request.read().decode()
253 forwarded_url = request.geturl()
254 v.result(request.status == 200) # type: ignore # (for old mypy)
255 v.check('Got forwarded', channel.channel_url != forwarded_url)
256 return channel_html, forwarded_url
257
258
259 def parse_channel(v: Verification, channel_html: str) \
260 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
261 v.status('Parsing channel description as XML')
262 d = xml.dom.minidom.parseString(channel_html)
263 v.ok()
264
265 v.status('Extracting release name:')
266 title_name = d.getElementsByTagName(
267 'title')[0].firstChild.nodeValue.split()[2]
268 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
269 v.status(title_name)
270 v.result(title_name == h1_name)
271
272 v.status('Extracting git commit:')
273 git_commit_node = d.getElementsByTagName('tt')[0]
274 git_revision = git_commit_node.firstChild.nodeValue
275 v.status(git_revision)
276 v.ok()
277 v.status('Verifying git commit label')
278 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
279
280 v.status('Parsing table')
281 table: Dict[str, ChannelTableEntry] = {}
282 for row in d.getElementsByTagName('tr')[1:]:
283 name = row.childNodes[0].firstChild.firstChild.nodeValue
284 url = row.childNodes[0].firstChild.getAttribute('href')
285 size = int(row.childNodes[1].firstChild.nodeValue)
286 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
287 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
288 v.ok()
289 return table, GitPin(release_name=title_name, git_revision=git_revision)
290
291
292 def digest_string(s: bytes) -> Digest16:
293 return Digest16(hashlib.sha256(s).hexdigest())
294
295
296 def digest_file(filename: str) -> Digest16:
297 hasher = hashlib.sha256()
298 with open(filename, 'rb') as f:
299 # pylint: disable=cell-var-from-loop
300 for block in iter(lambda: f.read(4096), b''):
301 hasher.update(block)
302 return Digest16(hasher.hexdigest())
303
304
305 @functools.lru_cache
306 def _experimental_flag_needed(v: Verification) -> bool:
307 v.status('Checking Nix version')
308 process = subprocess.run(['nix', '--help'], stdout=subprocess.PIPE)
309 v.result(process.returncode == 0)
310 return b'--experimental-features' in process.stdout
311
312
313 def _nix_command(v: Verification) -> List[str]:
314 return ['nix', '--experimental-features',
315 'nix-command'] if _experimental_flag_needed(v) else ['nix']
316
317
318 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
319 v.status('Converting digest to base16')
320 process = subprocess.run(_nix_command(v) + [
321 'to-base16',
322 '--type',
323 'sha256',
324 digest32],
325 stdout=subprocess.PIPE)
326 v.result(process.returncode == 0)
327 return Digest16(process.stdout.decode().strip())
328
329
330 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
331 v.status('Converting digest to base32')
332 process = subprocess.run(_nix_command(v) + [
333 'to-base32',
334 '--type',
335 'sha256',
336 digest16],
337 stdout=subprocess.PIPE)
338 v.result(process.returncode == 0)
339 return Digest32(process.stdout.decode().strip())
340
341
342 def fetch_with_nix_prefetch_url(
343 v: Verification,
344 url: str,
345 digest: Digest16) -> str:
346 v.status('Fetching %s' % url)
347 process = subprocess.run(
348 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
349 v.result(process.returncode == 0)
350 prefetch_digest, path, empty = process.stdout.decode().split('\n')
351 assert empty == ''
352 v.check("Verifying nix-prefetch-url's digest",
353 to_Digest16(v, Digest32(prefetch_digest)) == digest)
354 v.status("Verifying file digest")
355 file_digest = digest_file(path)
356 v.result(file_digest == digest)
357 return path # type: ignore # (for old mypy)
358
359
360 def fetch_resources(
361 v: Verification,
362 pin: GitPin,
363 forwarded_url: str,
364 table: Dict[str, ChannelTableEntry]) -> None:
365 for resource in ['git-revision', 'nixexprs.tar.xz']:
366 fields = table[resource]
367 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
368 fields.file = fetch_with_nix_prefetch_url(
369 v, fields.absolute_url, fields.digest)
370 v.status('Verifying git commit on main page matches git commit in table')
371 v.result(open(table['git-revision'].file).read(999) == pin.git_revision)
372
373
374 def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
375 return os.path.join(
376 xdg.XDG_CACHE_HOME,
377 'pinch/git-tarball',
378 '%s-%s-%s' %
379 (digest_string(channel.git_repo.encode()),
380 pin.git_revision,
381 pin.release_name))
382
383
384 def verify_git_ancestry(
385 v: Verification,
386 channel: TarrableSearchPath,
387 old_revision: str,
388 new_revision: str) -> None:
389 cachedir = git_cache.git_cachedir(channel.git_repo)
390 v.status('Verifying rev is an ancestor of previous rev %s' % old_revision)
391 process = subprocess.run(['git',
392 '-C',
393 cachedir,
394 'merge-base',
395 '--is-ancestor',
396 old_revision,
397 new_revision])
398 v.result(process.returncode == 0)
399
400
401 def compare_tarball_and_git(
402 v: Verification,
403 pin: GitPin,
404 channel_contents: str,
405 git_contents: str) -> None:
406 v.status('Comparing channel tarball with git checkout')
407 match, mismatch, errors = compare(os.path.join(
408 channel_contents, pin.release_name), git_contents)
409 v.ok()
410 v.check('%d files match' % len(match), len(match) > 0)
411 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
412 expected_errors = [
413 '.git-revision',
414 '.version-suffix',
415 'nixpkgs',
416 'programs.sqlite',
417 'svn-revision']
418 benign_errors = []
419 for ee in expected_errors:
420 if ee in errors:
421 errors.remove(ee)
422 benign_errors.append(ee)
423 v.check(
424 '%d unexpected incomparable files' %
425 len(errors),
426 len(errors) == 0)
427 v.check(
428 '(%d of %d expected incomparable files)' %
429 (len(benign_errors),
430 len(expected_errors)),
431 len(benign_errors) == len(expected_errors))
432
433
434 def extract_tarball(
435 v: Verification,
436 table: Dict[str, ChannelTableEntry],
437 dest: str) -> None:
438 v.status('Extracting tarball %s' % table['nixexprs.tar.xz'].file)
439 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
440 v.ok()
441
442
443 def git_checkout(
444 v: Verification,
445 channel: TarrableSearchPath,
446 pin: GitPin,
447 dest: str) -> None:
448 v.status('Checking out corresponding git revision')
449 git = subprocess.Popen(['git',
450 '-C',
451 git_cache.git_cachedir(channel.git_repo),
452 'archive',
453 pin.git_revision],
454 stdout=subprocess.PIPE)
455 tar = subprocess.Popen(
456 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
457 if git.stdout:
458 git.stdout.close()
459 tar.wait()
460 git.wait()
461 v.result(git.returncode == 0 and tar.returncode == 0)
462
463
464 def git_get_tarball(
465 v: Verification,
466 channel: TarrableSearchPath,
467 pin: GitPin) -> str:
468 cache_file = tarball_cache_file(channel, pin)
469 if os.path.exists(cache_file):
470 cached_tarball = open(cache_file).read(9999)
471 if os.path.exists(cached_tarball):
472 return cached_tarball
473
474 with tempfile.TemporaryDirectory() as output_dir:
475 output_filename = os.path.join(
476 output_dir, pin.release_name + '.tar.xz')
477 with open(output_filename, 'w') as output_file:
478 v.status(
479 'Generating tarball for git revision %s' %
480 pin.git_revision)
481 git = subprocess.Popen(['git',
482 '-C',
483 git_cache.git_cachedir(channel.git_repo),
484 'archive',
485 '--prefix=%s/' % pin.release_name,
486 pin.git_revision],
487 stdout=subprocess.PIPE)
488 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
489 xz.wait()
490 git.wait()
491 v.result(git.returncode == 0 and xz.returncode == 0)
492
493 store_tarball = copy_to_nix_store(v, output_filename)
494
495 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
496 open(cache_file, 'w').write(store_tarball)
497 return store_tarball # type: ignore # (for old mypy)
498
499
500 def check_channel_metadata(
501 v: Verification,
502 pin: GitPin,
503 channel_contents: str) -> None:
504 v.status('Verifying git commit in channel tarball')
505 v.result(
506 open(
507 os.path.join(
508 channel_contents,
509 pin.release_name,
510 '.git-revision')).read(999) == pin.git_revision)
511
512 v.status(
513 'Verifying version-suffix is a suffix of release name %s:' %
514 pin.release_name)
515 version_suffix = open(
516 os.path.join(
517 channel_contents,
518 pin.release_name,
519 '.version-suffix')).read(999)
520 v.status(version_suffix)
521 v.result(pin.release_name.endswith(version_suffix))
522
523
524 def check_channel_contents(
525 v: Verification,
526 channel: TarrableSearchPath,
527 table: Dict[str, ChannelTableEntry],
528 pin: GitPin) -> None:
529 with tempfile.TemporaryDirectory() as channel_contents, \
530 tempfile.TemporaryDirectory() as git_contents:
531
532 extract_tarball(v, table, channel_contents)
533 check_channel_metadata(v, pin, channel_contents)
534
535 git_checkout(v, channel, pin, git_contents)
536
537 compare_tarball_and_git(v, pin, channel_contents, git_contents)
538
539 v.status('Removing temporary directories')
540 v.ok()
541
542
543 def git_revision_name(
544 v: Verification,
545 channel: TarrableSearchPath,
546 git_revision: str) -> str:
547 v.status('Getting commit date')
548 process = subprocess.run(['git',
549 '-C',
550 git_cache.git_cachedir(channel.git_repo),
551 'log',
552 '-n1',
553 '--format=%ct-%h',
554 '--abbrev=11',
555 '--no-show-signature',
556 git_revision],
557 stdout=subprocess.PIPE)
558 v.result(process.returncode == 0 and process.stdout != b'')
559 return '%s-%s' % (os.path.basename(channel.git_repo),
560 process.stdout.decode().strip())
561
562
563 K = TypeVar('K')
564 V = TypeVar('V')
565
566
567 def partition_dict(pred: Callable[[K, V], bool],
568 d: Dict[K, V]) -> Tuple[Dict[K, V], Dict[K, V]]:
569 selected: Dict[K, V] = {}
570 remaining: Dict[K, V] = {}
571 for k, v in d.items():
572 if pred(k, v):
573 selected[k] = v
574 else:
575 remaining[k] = v
576 return selected, remaining
577
578
579 def filter_dict(d: Dict[K, V], fields: Set[K]
580 ) -> Tuple[Dict[K, V], Dict[K, V]]:
581 return partition_dict(lambda k, v: k in fields, d)
582
583
584 def read_config_section(
585 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
586 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
587 'alias': (AliasSearchPath, AliasPin),
588 'channel': (ChannelSearchPath, ChannelPin),
589 'git': (GitSearchPath, GitPin),
590 'symlink': (SymlinkSearchPath, SymlinkPin),
591 }
592 SP, P = mapping[conf['type']]
593 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
594 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
595 # Error suppression works around https://github.com/python/mypy/issues/9007
596 pin_present = pin_fields != {} or P._fields == ()
597 pin = P(**pin_fields) if pin_present else None # type: ignore
598 return SP(**remaining_fields), pin
599
600
601 def read_pinned_config_section(
602 section: str, conf: configparser.SectionProxy) -> Tuple[SearchPath, Pin]:
603 sp, pin = read_config_section(conf)
604 if pin is None:
605 raise Exception(
606 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
607 section)
608 return sp, pin
609
610
611 def read_config(filename: str) -> configparser.ConfigParser:
612 config = configparser.ConfigParser()
613 config.read_file(open(filename), filename)
614 return config
615
616
617 def read_config_files(
618 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
619 merged_config: Dict[str, configparser.SectionProxy] = {}
620 for file in filenames:
621 config = read_config(file)
622 for section in config.sections():
623 if section in merged_config:
624 raise Exception('Duplicate channel "%s"' % section)
625 merged_config[section] = config[section]
626 return merged_config
627
628
629 def pinCommand(args: argparse.Namespace) -> None:
630 v = Verification()
631 config = read_config(args.channels_file)
632 for section in config.sections():
633 if args.channels and section not in args.channels:
634 continue
635
636 sp, old_pin = read_config_section(config[section])
637
638 config[section].update(sp.pin(v, old_pin)._asdict())
639
640 with open(args.channels_file, 'w') as configfile:
641 config.write(configfile)
642
643
644 def updateCommand(args: argparse.Namespace) -> None:
645 v = Verification()
646 exprs: Dict[str, str] = {}
647 config = {
648 section: read_pinned_config_section(section, conf) for section,
649 conf in read_config_files(
650 args.channels_file).items()}
651 alias, nonalias = partition_dict(
652 lambda k, v: isinstance(v[0], AliasSearchPath), config)
653
654 for section, (sp, pin) in nonalias.items():
655 assert not isinstance(sp, AliasSearchPath) # mypy can't see through
656 assert not isinstance(pin, AliasPin) # partition_dict()
657 tarball = sp.fetch(v, pin)
658 exprs[section] = (
659 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
660 (pin.release_name, tarball))
661
662 for section, (sp, pin) in alias.items():
663 assert isinstance(sp, AliasSearchPath) # For mypy
664 exprs[section] = exprs[sp.alias_of]
665
666 command = [
667 'nix-env',
668 '--profile',
669 args.profile,
670 '--show-trace',
671 '--file',
672 '<nix/unpack-channel.nix>',
673 '--install',
674 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
675 if args.dry_run:
676 print(' '.join(map(shlex.quote, command)))
677 else:
678 v.status('Installing channels with nix-env')
679 process = subprocess.run(command)
680 v.result(process.returncode == 0)
681
682
683 def main() -> None:
684 parser = argparse.ArgumentParser(prog='pinch')
685 subparsers = parser.add_subparsers(dest='mode', required=True)
686 parser_pin = subparsers.add_parser('pin')
687 parser_pin.add_argument('channels_file', type=str)
688 parser_pin.add_argument('channels', type=str, nargs='*')
689 parser_pin.set_defaults(func=pinCommand)
690 parser_update = subparsers.add_parser('update')
691 parser_update.add_argument('--dry-run', action='store_true')
692 parser_update.add_argument('--profile', default=(
693 '/nix/var/nix/profiles/per-user/%s/channels' % getpass.getuser()))
694 parser_update.add_argument('channels_file', type=str, nargs='+')
695 parser_update.set_defaults(func=updateCommand)
696 args = parser.parse_args()
697 args.func(args)
698
699
700 if __name__ == '__main__':
701 main()