]> git.scottworley.com Git - pinch/blob - pinch.py
Remove assert that's now handled by types
[pinch] / pinch.py
1 import argparse
2 import configparser
3 import filecmp
4 import functools
5 import getpass
6 import hashlib
7 import operator
8 import os
9 import os.path
10 import shlex
11 import shutil
12 import subprocess
13 import sys
14 import tempfile
15 import types
16 import urllib.parse
17 import urllib.request
18 import xml.dom.minidom
19
20 from typing import (
21 Dict,
22 Iterable,
23 List,
24 Mapping,
25 NamedTuple,
26 NewType,
27 Optional,
28 Set,
29 Tuple,
30 Type,
31 TypeVar,
32 Union,
33 )
34
35 # Use xdg module when it's less painful to have as a dependency
36
37
38 class XDG(NamedTuple):
39 XDG_CACHE_HOME: str
40
41
42 xdg = XDG(
43 XDG_CACHE_HOME=os.getenv(
44 'XDG_CACHE_HOME',
45 os.path.expanduser('~/.cache')))
46
47
48 class VerificationError(Exception):
49 pass
50
51
52 class Verification:
53
54 def __init__(self) -> None:
55 self.line_length = 0
56
57 def status(self, s: str) -> None:
58 print(s, end=' ', file=sys.stderr, flush=True)
59 self.line_length += 1 + len(s) # Unicode??
60
61 @staticmethod
62 def _color(s: str, c: int) -> str:
63 return '\033[%2dm%s\033[00m' % (c, s)
64
65 def result(self, r: bool) -> None:
66 message, color = {True: ('OK ', 92), False: ('FAIL', 91)}[r]
67 length = len(message)
68 cols = shutil.get_terminal_size().columns or 80
69 pad = (cols - (self.line_length + length)) % cols
70 print(' ' * pad + self._color(message, color), file=sys.stderr)
71 self.line_length = 0
72 if not r:
73 raise VerificationError()
74
75 def check(self, s: str, r: bool) -> None:
76 self.status(s)
77 self.result(r)
78
79 def ok(self) -> None:
80 self.result(True)
81
82
83 Digest16 = NewType('Digest16', str)
84 Digest32 = NewType('Digest32', str)
85
86
87 class ChannelTableEntry(types.SimpleNamespace):
88 absolute_url: str
89 digest: Digest16
90 file: str
91 size: int
92 url: str
93
94
95 class AliasPin(NamedTuple):
96 pass
97
98
99 class GitPin(NamedTuple):
100 git_revision: str
101 release_name: str
102
103
104 class ChannelPin(NamedTuple):
105 git_revision: str
106 release_name: str
107 tarball_url: str
108 tarball_sha256: str
109
110
111 Pin = Union[AliasPin, GitPin, ChannelPin]
112
113
114 class AliasSearchPath(NamedTuple):
115 alias_of: str
116
117 # pylint: disable=no-self-use
118 def pin(self, _: Verification, __: Optional[Pin]) -> AliasPin:
119 return AliasPin()
120
121
122 class GitSearchPath(NamedTuple):
123 git_ref: str
124 git_repo: str
125
126 def pin(self, v: Verification, old_pin: Optional[Pin]) -> GitPin:
127 if old_pin is not None:
128 assert isinstance(old_pin, GitPin)
129 old_revision = old_pin.git_revision if old_pin is not None else None
130
131 new_revision = git_fetch(v, self, None, old_revision)
132 return GitPin(release_name=git_revision_name(v, self, new_revision),
133 git_revision=new_revision)
134
135 def fetch(self, v: Verification, pin: Pin) -> str:
136 assert isinstance(pin, GitPin)
137 ensure_git_rev_available(v, self, pin, None)
138 return git_get_tarball(v, self, pin)
139
140
141 class ChannelSearchPath(NamedTuple):
142 channel_url: str
143 git_ref: str
144 git_repo: str
145
146 def pin(self, v: Verification, old_pin: Optional[Pin]) -> ChannelPin:
147 if old_pin is not None:
148 assert isinstance(old_pin, ChannelPin)
149 old_revision = old_pin.git_revision if old_pin is not None else None
150
151 channel_html, forwarded_url = fetch_channel(v, self)
152 table, new_gitpin = parse_channel(v, channel_html)
153 fetch_resources(v, new_gitpin, forwarded_url, table)
154 ensure_git_rev_available(v, self, new_gitpin, old_revision)
155 check_channel_contents(v, self, table, new_gitpin)
156 return ChannelPin(
157 release_name=new_gitpin.release_name,
158 tarball_url=table['nixexprs.tar.xz'].absolute_url,
159 tarball_sha256=table['nixexprs.tar.xz'].digest,
160 git_revision=new_gitpin.git_revision)
161
162 # pylint: disable=no-self-use
163 def fetch(self, v: Verification, pin: Pin) -> str:
164 assert isinstance(pin, ChannelPin)
165
166 return fetch_with_nix_prefetch_url(
167 v, pin.tarball_url, Digest16(pin.tarball_sha256))
168
169
170 SearchPath = Union[AliasSearchPath, GitSearchPath, ChannelSearchPath]
171 TarrableSearchPath = Union[GitSearchPath, ChannelSearchPath]
172
173
174 def compare(a: str, b: str) -> Tuple[List[str], List[str], List[str]]:
175
176 def throw(error: OSError) -> None:
177 raise error
178
179 def join(x: str, y: str) -> str:
180 return y if x == '.' else os.path.join(x, y)
181
182 def recursive_files(d: str) -> Iterable[str]:
183 all_files: List[str] = []
184 for path, dirs, files in os.walk(d, onerror=throw):
185 rel = os.path.relpath(path, start=d)
186 all_files.extend(join(rel, f) for f in files)
187 for dir_or_link in dirs:
188 if os.path.islink(join(path, dir_or_link)):
189 all_files.append(join(rel, dir_or_link))
190 return all_files
191
192 def exclude_dot_git(files: Iterable[str]) -> Iterable[str]:
193 return (f for f in files if not f.startswith('.git/'))
194
195 files = functools.reduce(
196 operator.or_, (set(
197 exclude_dot_git(
198 recursive_files(x))) for x in [a, b]))
199 return filecmp.cmpfiles(a, b, files, shallow=False)
200
201
202 def fetch_channel(
203 v: Verification, channel: ChannelSearchPath) -> Tuple[str, str]:
204 v.status('Fetching channel')
205 request = urllib.request.urlopen(channel.channel_url, timeout=10)
206 channel_html = request.read()
207 forwarded_url = request.geturl()
208 v.result(request.status == 200) # type: ignore # (for old mypy)
209 v.check('Got forwarded', channel.channel_url != forwarded_url)
210 return channel_html, forwarded_url
211
212
213 def parse_channel(v: Verification, channel_html: str) \
214 -> Tuple[Dict[str, ChannelTableEntry], GitPin]:
215 v.status('Parsing channel description as XML')
216 d = xml.dom.minidom.parseString(channel_html)
217 v.ok()
218
219 v.status('Extracting release name:')
220 title_name = d.getElementsByTagName(
221 'title')[0].firstChild.nodeValue.split()[2]
222 h1_name = d.getElementsByTagName('h1')[0].firstChild.nodeValue.split()[2]
223 v.status(title_name)
224 v.result(title_name == h1_name)
225
226 v.status('Extracting git commit:')
227 git_commit_node = d.getElementsByTagName('tt')[0]
228 git_revision = git_commit_node.firstChild.nodeValue
229 v.status(git_revision)
230 v.ok()
231 v.status('Verifying git commit label')
232 v.result(git_commit_node.previousSibling.nodeValue == 'Git commit ')
233
234 v.status('Parsing table')
235 table: Dict[str, ChannelTableEntry] = {}
236 for row in d.getElementsByTagName('tr')[1:]:
237 name = row.childNodes[0].firstChild.firstChild.nodeValue
238 url = row.childNodes[0].firstChild.getAttribute('href')
239 size = int(row.childNodes[1].firstChild.nodeValue)
240 digest = Digest16(row.childNodes[2].firstChild.firstChild.nodeValue)
241 table[name] = ChannelTableEntry(url=url, digest=digest, size=size)
242 v.ok()
243 return table, GitPin(release_name=title_name, git_revision=git_revision)
244
245
246 def digest_string(s: bytes) -> Digest16:
247 return Digest16(hashlib.sha256(s).hexdigest())
248
249
250 def digest_file(filename: str) -> Digest16:
251 hasher = hashlib.sha256()
252 with open(filename, 'rb') as f:
253 # pylint: disable=cell-var-from-loop
254 for block in iter(lambda: f.read(4096), b''):
255 hasher.update(block)
256 return Digest16(hasher.hexdigest())
257
258
259 def to_Digest16(v: Verification, digest32: Digest32) -> Digest16:
260 v.status('Converting digest to base16')
261 process = subprocess.run(
262 ['nix', 'to-base16', '--type', 'sha256', digest32], stdout=subprocess.PIPE)
263 v.result(process.returncode == 0)
264 return Digest16(process.stdout.decode().strip())
265
266
267 def to_Digest32(v: Verification, digest16: Digest16) -> Digest32:
268 v.status('Converting digest to base32')
269 process = subprocess.run(
270 ['nix', 'to-base32', '--type', 'sha256', digest16], stdout=subprocess.PIPE)
271 v.result(process.returncode == 0)
272 return Digest32(process.stdout.decode().strip())
273
274
275 def fetch_with_nix_prefetch_url(
276 v: Verification,
277 url: str,
278 digest: Digest16) -> str:
279 v.status('Fetching %s' % url)
280 process = subprocess.run(
281 ['nix-prefetch-url', '--print-path', url, digest], stdout=subprocess.PIPE)
282 v.result(process.returncode == 0)
283 prefetch_digest, path, empty = process.stdout.decode().split('\n')
284 assert empty == ''
285 v.check("Verifying nix-prefetch-url's digest",
286 to_Digest16(v, Digest32(prefetch_digest)) == digest)
287 v.status("Verifying file digest")
288 file_digest = digest_file(path)
289 v.result(file_digest == digest)
290 return path # type: ignore # (for old mypy)
291
292
293 def fetch_resources(
294 v: Verification,
295 pin: GitPin,
296 forwarded_url: str,
297 table: Dict[str, ChannelTableEntry]) -> None:
298 for resource in ['git-revision', 'nixexprs.tar.xz']:
299 fields = table[resource]
300 fields.absolute_url = urllib.parse.urljoin(forwarded_url, fields.url)
301 fields.file = fetch_with_nix_prefetch_url(
302 v, fields.absolute_url, fields.digest)
303 v.status('Verifying git commit on main page matches git commit in table')
304 v.result(open(table['git-revision'].file).read(999) == pin.git_revision)
305
306
307 def git_cachedir(git_repo: str) -> str:
308 return os.path.join(
309 xdg.XDG_CACHE_HOME,
310 'pinch/git',
311 digest_string(git_repo.encode()))
312
313
314 def tarball_cache_file(channel: TarrableSearchPath, pin: GitPin) -> str:
315 return os.path.join(
316 xdg.XDG_CACHE_HOME,
317 'pinch/git-tarball',
318 '%s-%s-%s' %
319 (digest_string(channel.git_repo.encode()),
320 pin.git_revision,
321 pin.release_name))
322
323
324 def verify_git_ancestry(
325 v: Verification,
326 channel: TarrableSearchPath,
327 new_revision: str,
328 old_revision: Optional[str]) -> None:
329 cachedir = git_cachedir(channel.git_repo)
330 v.status('Verifying rev is an ancestor of ref')
331 process = subprocess.run(['git',
332 '-C',
333 cachedir,
334 'merge-base',
335 '--is-ancestor',
336 new_revision,
337 channel.git_ref])
338 v.result(process.returncode == 0)
339
340 if old_revision is not None:
341 v.status(
342 'Verifying rev is an ancestor of previous rev %s' %
343 old_revision)
344 process = subprocess.run(['git',
345 '-C',
346 cachedir,
347 'merge-base',
348 '--is-ancestor',
349 old_revision,
350 new_revision])
351 v.result(process.returncode == 0)
352
353
354 def git_fetch(
355 v: Verification,
356 channel: TarrableSearchPath,
357 desired_revision: Optional[str],
358 old_revision: Optional[str]) -> str:
359 # It would be nice if we could share the nix git cache, but as of the time
360 # of writing it is transitioning from gitv2 (deprecated) to gitv3 (not ready
361 # yet), and trying to straddle them both is too far into nix implementation
362 # details for my comfort. So we re-implement here half of nix.fetchGit.
363 # :(
364
365 cachedir = git_cachedir(channel.git_repo)
366 if not os.path.exists(cachedir):
367 v.status("Initializing git repo")
368 process = subprocess.run(
369 ['git', 'init', '--bare', cachedir])
370 v.result(process.returncode == 0)
371
372 v.status('Fetching ref "%s" from %s' % (channel.git_ref, channel.git_repo))
373 # We don't use --force here because we want to abort and freak out if forced
374 # updates are happening.
375 process = subprocess.run(['git',
376 '-C',
377 cachedir,
378 'fetch',
379 channel.git_repo,
380 '%s:%s' % (channel.git_ref,
381 channel.git_ref)])
382 v.result(process.returncode == 0)
383
384 if desired_revision is not None:
385 v.status('Verifying that fetch retrieved this rev')
386 process = subprocess.run(
387 ['git', '-C', cachedir, 'cat-file', '-e', desired_revision])
388 v.result(process.returncode == 0)
389
390 new_revision = open(
391 os.path.join(
392 cachedir,
393 'refs',
394 'heads',
395 channel.git_ref)).read(999).strip()
396
397 verify_git_ancestry(v, channel, new_revision, old_revision)
398
399 return new_revision
400
401
402 def ensure_git_rev_available(
403 v: Verification,
404 channel: TarrableSearchPath,
405 pin: GitPin,
406 old_revision: Optional[str]) -> None:
407 cachedir = git_cachedir(channel.git_repo)
408 if os.path.exists(cachedir):
409 v.status('Checking if we already have this rev:')
410 process = subprocess.run(
411 ['git', '-C', cachedir, 'cat-file', '-e', pin.git_revision])
412 if process.returncode == 0:
413 v.status('yes')
414 if process.returncode == 1:
415 v.status('no')
416 v.result(process.returncode == 0 or process.returncode == 1)
417 if process.returncode == 0:
418 verify_git_ancestry(v, channel, pin.git_revision, old_revision)
419 return
420 git_fetch(v, channel, pin.git_revision, old_revision)
421
422
423 def compare_tarball_and_git(
424 v: Verification,
425 pin: GitPin,
426 channel_contents: str,
427 git_contents: str) -> None:
428 v.status('Comparing channel tarball with git checkout')
429 match, mismatch, errors = compare(os.path.join(
430 channel_contents, pin.release_name), git_contents)
431 v.ok()
432 v.check('%d files match' % len(match), len(match) > 0)
433 v.check('%d files differ' % len(mismatch), len(mismatch) == 0)
434 expected_errors = [
435 '.git-revision',
436 '.version-suffix',
437 'nixpkgs',
438 'programs.sqlite',
439 'svn-revision']
440 benign_errors = []
441 for ee in expected_errors:
442 if ee in errors:
443 errors.remove(ee)
444 benign_errors.append(ee)
445 v.check(
446 '%d unexpected incomparable files' %
447 len(errors),
448 len(errors) == 0)
449 v.check(
450 '(%d of %d expected incomparable files)' %
451 (len(benign_errors),
452 len(expected_errors)),
453 len(benign_errors) == len(expected_errors))
454
455
456 def extract_tarball(
457 v: Verification,
458 table: Dict[str, ChannelTableEntry],
459 dest: str) -> None:
460 v.status('Extracting tarball %s' % table['nixexprs.tar.xz'].file)
461 shutil.unpack_archive(table['nixexprs.tar.xz'].file, dest)
462 v.ok()
463
464
465 def git_checkout(
466 v: Verification,
467 channel: TarrableSearchPath,
468 pin: GitPin,
469 dest: str) -> None:
470 v.status('Checking out corresponding git revision')
471 git = subprocess.Popen(['git',
472 '-C',
473 git_cachedir(channel.git_repo),
474 'archive',
475 pin.git_revision],
476 stdout=subprocess.PIPE)
477 tar = subprocess.Popen(
478 ['tar', 'x', '-C', dest, '-f', '-'], stdin=git.stdout)
479 if git.stdout:
480 git.stdout.close()
481 tar.wait()
482 git.wait()
483 v.result(git.returncode == 0 and tar.returncode == 0)
484
485
486 def git_get_tarball(
487 v: Verification,
488 channel: TarrableSearchPath,
489 pin: GitPin) -> str:
490 cache_file = tarball_cache_file(channel, pin)
491 if os.path.exists(cache_file):
492 cached_tarball = open(cache_file).read(9999)
493 if os.path.exists(cached_tarball):
494 return cached_tarball
495
496 with tempfile.TemporaryDirectory() as output_dir:
497 output_filename = os.path.join(
498 output_dir, pin.release_name + '.tar.xz')
499 with open(output_filename, 'w') as output_file:
500 v.status(
501 'Generating tarball for git revision %s' %
502 pin.git_revision)
503 git = subprocess.Popen(['git',
504 '-C',
505 git_cachedir(channel.git_repo),
506 'archive',
507 '--prefix=%s/' % pin.release_name,
508 pin.git_revision],
509 stdout=subprocess.PIPE)
510 xz = subprocess.Popen(['xz'], stdin=git.stdout, stdout=output_file)
511 xz.wait()
512 git.wait()
513 v.result(git.returncode == 0 and xz.returncode == 0)
514
515 v.status('Putting tarball in Nix store')
516 process = subprocess.run(
517 ['nix-store', '--add', output_filename], stdout=subprocess.PIPE)
518 v.result(process.returncode == 0)
519 store_tarball = process.stdout.decode().strip()
520
521 os.makedirs(os.path.dirname(cache_file), exist_ok=True)
522 open(cache_file, 'w').write(store_tarball)
523 return store_tarball # type: ignore # (for old mypy)
524
525
526 def check_channel_metadata(
527 v: Verification,
528 pin: GitPin,
529 channel_contents: str) -> None:
530 v.status('Verifying git commit in channel tarball')
531 v.result(
532 open(
533 os.path.join(
534 channel_contents,
535 pin.release_name,
536 '.git-revision')).read(999) == pin.git_revision)
537
538 v.status(
539 'Verifying version-suffix is a suffix of release name %s:' %
540 pin.release_name)
541 version_suffix = open(
542 os.path.join(
543 channel_contents,
544 pin.release_name,
545 '.version-suffix')).read(999)
546 v.status(version_suffix)
547 v.result(pin.release_name.endswith(version_suffix))
548
549
550 def check_channel_contents(
551 v: Verification,
552 channel: TarrableSearchPath,
553 table: Dict[str, ChannelTableEntry],
554 pin: GitPin) -> None:
555 with tempfile.TemporaryDirectory() as channel_contents, \
556 tempfile.TemporaryDirectory() as git_contents:
557
558 extract_tarball(v, table, channel_contents)
559 check_channel_metadata(v, pin, channel_contents)
560
561 git_checkout(v, channel, pin, git_contents)
562
563 compare_tarball_and_git(v, pin, channel_contents, git_contents)
564
565 v.status('Removing temporary directories')
566 v.ok()
567
568
569 def git_revision_name(
570 v: Verification,
571 channel: TarrableSearchPath,
572 git_revision: str) -> str:
573 v.status('Getting commit date')
574 process = subprocess.run(['git',
575 '-C',
576 git_cachedir(channel.git_repo),
577 'log',
578 '-n1',
579 '--format=%ct-%h',
580 '--abbrev=11',
581 '--no-show-signature',
582 git_revision],
583 stdout=subprocess.PIPE)
584 v.result(process.returncode == 0 and process.stdout != b'')
585 return '%s-%s' % (os.path.basename(channel.git_repo),
586 process.stdout.decode().strip())
587
588
589 K = TypeVar('K')
590 V = TypeVar('V')
591
592
593 def filter_dict(d: Dict[K, V], fields: Set[K]
594 ) -> Tuple[Dict[K, V], Dict[K, V]]:
595 selected: Dict[K, V] = {}
596 remaining: Dict[K, V] = {}
597 for k, v in d.items():
598 if k in fields:
599 selected[k] = v
600 else:
601 remaining[k] = v
602 return selected, remaining
603
604
605 def read_search_path(
606 conf: configparser.SectionProxy) -> Tuple[SearchPath, Optional[Pin]]:
607 mapping: Mapping[str, Tuple[Type[SearchPath], Type[Pin]]] = {
608 'alias': (AliasSearchPath, AliasPin),
609 'channel': (ChannelSearchPath, ChannelPin),
610 'git': (GitSearchPath, GitPin),
611 }
612 SP, P = mapping[conf['type']]
613 _, all_fields = filter_dict(dict(conf.items()), set(['type']))
614 pin_fields, remaining_fields = filter_dict(all_fields, set(P._fields))
615 # Error suppression works around https://github.com/python/mypy/issues/9007
616 pin_present = pin_fields != {} or P._fields == ()
617 pin = P(**pin_fields) if pin_present else None # type:ignore[call-arg]
618 return SP(**remaining_fields), pin
619
620
621 def read_config(filename: str) -> configparser.ConfigParser:
622 config = configparser.ConfigParser()
623 config.read_file(open(filename), filename)
624 return config
625
626
627 def read_config_files(
628 filenames: Iterable[str]) -> Dict[str, configparser.SectionProxy]:
629 merged_config: Dict[str, configparser.SectionProxy] = {}
630 for file in filenames:
631 config = read_config(file)
632 for section in config.sections():
633 if section in merged_config:
634 raise Exception('Duplicate channel "%s"' % section)
635 merged_config[section] = config[section]
636 return merged_config
637
638
639 def pinCommand(args: argparse.Namespace) -> None:
640 v = Verification()
641 config = read_config(args.channels_file)
642 for section in config.sections():
643 if args.channels and section not in args.channels:
644 continue
645
646 sp, old_pin = read_search_path(config[section])
647
648 config[section].update(sp.pin(v, old_pin)._asdict())
649
650 with open(args.channels_file, 'w') as configfile:
651 config.write(configfile)
652
653
654 def updateCommand(args: argparse.Namespace) -> None:
655 v = Verification()
656 exprs: Dict[str, str] = {}
657 config = read_config_files(args.channels_file)
658 for section in config:
659 sp, pin = read_search_path(config[section])
660 if pin is None:
661 raise Exception(
662 'Cannot update unpinned channel "%s" (Run "pin" before "update")' %
663 section)
664 if isinstance(sp, AliasSearchPath):
665 continue
666 tarball = sp.fetch(v, pin)
667 exprs[section] = (
668 'f: f { name = "%s"; channelName = "%%s"; src = builtins.storePath "%s"; }' %
669 (config[section]['release_name'], tarball))
670
671 for section in config:
672 if 'alias_of' in config[section]:
673 exprs[section] = exprs[str(config[section]['alias_of'])]
674
675 command = [
676 'nix-env',
677 '--profile',
678 '/nix/var/nix/profiles/per-user/%s/channels' %
679 getpass.getuser(),
680 '--show-trace',
681 '--file',
682 '<nix/unpack-channel.nix>',
683 '--install',
684 '--from-expression'] + [exprs[name] % name for name in sorted(exprs.keys())]
685 if args.dry_run:
686 print(' '.join(map(shlex.quote, command)))
687 else:
688 v.status('Installing channels with nix-env')
689 process = subprocess.run(command)
690 v.result(process.returncode == 0)
691
692
693 def main() -> None:
694 parser = argparse.ArgumentParser(prog='pinch')
695 subparsers = parser.add_subparsers(dest='mode', required=True)
696 parser_pin = subparsers.add_parser('pin')
697 parser_pin.add_argument('channels_file', type=str)
698 parser_pin.add_argument('channels', type=str, nargs='*')
699 parser_pin.set_defaults(func=pinCommand)
700 parser_update = subparsers.add_parser('update')
701 parser_update.add_argument('--dry-run', action='store_true')
702 parser_update.add_argument('channels_file', type=str, nargs='+')
703 parser_update.set_defaults(func=updateCommand)
704 args = parser.parse_args()
705 args.func(args)
706
707
708 if __name__ == '__main__':
709 main()